diff --git a/commands.go b/commands.go index 4c8e683..3718634 100644 --- a/commands.go +++ b/commands.go @@ -16,15 +16,31 @@ package maubot -type CommandHandler func(*Event) +type CommandHandler func(*Event) CommandHandlerResult type CommandSpec struct { - Commands []Command `json:"commands"` - PassiveCommands []PassiveCommand `json:"passive_commands"` + Commands []Command `json:"commands,omitempty"` + PassiveCommands []PassiveCommand `json:"passive_commands,omitempty"` +} + +func (spec *CommandSpec) Clone() *CommandSpec { + return &CommandSpec{ + Commands: append([]Command(nil), spec.Commands...), + PassiveCommands: append([]PassiveCommand(nil), spec.PassiveCommands...), + } +} + +func (spec *CommandSpec) Merge(otherSpecs ...*CommandSpec) { + for _, otherSpec := range otherSpecs { + spec.Commands = append(spec.Commands, otherSpec.Commands...) + spec.PassiveCommands = append(spec.PassiveCommands, otherSpec.PassiveCommands...) + } } func (spec *CommandSpec) Equals(otherSpec *CommandSpec) bool { - if len(spec.Commands) != len(otherSpec.Commands) || len(spec.PassiveCommands) != len(otherSpec.PassiveCommands) { + if otherSpec == nil || + len(spec.Commands) != len(otherSpec.Commands) || + len(spec.PassiveCommands) != len(otherSpec.PassiveCommands) { return false } diff --git a/database/clients.go b/database/clients.go index 660bd10..1717ba2 100644 --- a/database/clients.go +++ b/database/clients.go @@ -20,6 +20,7 @@ import ( "maubot.xyz" log "maunium.net/go/maulogger" "database/sql" + "sort" ) type MatrixClient struct { @@ -37,7 +38,7 @@ type MatrixClient struct { DisplayName string `json:"display_name"` AvatarURL string `json:"avatar_url"` - Commands map[string]*CommandSpec `json:"commandspecs"` + CommandSpecs map[string]*CommandSpec `json:"command_specs"` } type MatrixClientStatic struct { @@ -91,32 +92,63 @@ func (mcs *MatrixClientStatic) New() *MatrixClient { func (mxc *MatrixClient) Scan(row Scannable) *MatrixClient { err := row.Scan(&mxc.UserID, &mxc.Homeserver, &mxc.AccessToken, &mxc.NextBatch, &mxc.FilterID, &mxc.Sync, &mxc.AutoJoinRooms, &mxc.DisplayName, &mxc.AvatarURL) if err != nil { - log.Fatalln("Database scan failed:", err) + log.Errorln("MatrixClient scan failed:", err) + return mxc } mxc.LoadCommandSpecs() return mxc } func (mxc *MatrixClient) SetCommandSpec(owner string, newSpec *maubot.CommandSpec) bool { - spec := mxc.db.CommandSpec.GetOrCreate(owner, mxc.UserID) - if newSpec.Equals(spec.CommandSpec) { + spec, ok := mxc.CommandSpecs[owner] + if ok && newSpec.Equals(spec.CommandSpec) { return false } - spec.CommandSpec = newSpec - spec.Update() - mxc.Commands[owner] = spec + if spec == nil { + spec = mxc.db.CommandSpec.New() + spec.CommandSpec = newSpec + spec.Insert() + } else { + spec.CommandSpec = newSpec + spec.Update() + } + mxc.CommandSpecs[owner] = spec return true } func (mxc *MatrixClient) LoadCommandSpecs() *MatrixClient { specs := mxc.db.CommandSpec.GetAllByClient(mxc.UserID) - mxc.Commands = make(map[string]*CommandSpec) + mxc.CommandSpecs = make(map[string]*CommandSpec) for _, spec := range specs { - mxc.Commands[spec.Owner] = spec + mxc.CommandSpecs[spec.Owner] = spec } + log.Debugln("Loaded command specs:", mxc.CommandSpecs) return mxc } +func (mxc *MatrixClient) CommandSpecIDs() []string { + keys := make([]string, len(mxc.CommandSpecs)) + i := 0 + for key := range mxc.CommandSpecs { + keys[i] = key + i++ + } + sort.Strings(keys) + return keys +} + +func (mxc *MatrixClient) Commands() *maubot.CommandSpec { + if len(mxc.CommandSpecs) == 0 { + return &maubot.CommandSpec{} + } + specIDs := mxc.CommandSpecIDs() + spec := mxc.CommandSpecs[specIDs[0]].Clone() + for _, specID := range specIDs[1:] { + spec.Merge(mxc.CommandSpecs[specID].CommandSpec) + } + return spec +} + func (mxc *MatrixClient) Insert() error { _, err := mxc.sql.Exec("INSERT INTO matrix_client (user_id, homeserver, access_token, next_batch, filter_id, sync, autojoin, display_name, avatar_url) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)", mxc.UserID, mxc.Homeserver, mxc.AccessToken, mxc.NextBatch, mxc.FilterID, mxc.Sync, mxc.AutoJoinRooms, mxc.DisplayName, mxc.AvatarURL) diff --git a/database/commands.go b/database/commands.go index ccf5e37..0425a47 100644 --- a/database/commands.go +++ b/database/commands.go @@ -54,11 +54,11 @@ func (css *CommandSpecStatic) CreateTable() error { } func (css *CommandSpecStatic) Get(owner, client string) *CommandSpec { - row := css.sql.QueryRow("SELECT * FROM command_spec WHERE owner=? AND client=?", owner, client) - if row != nil { - return css.New().Scan(row) + rows, err := css.sql.Query("SELECT * FROM command_spec WHERE owner=? AND client=?", owner, client) + if err != nil { + log.Errorf("Failed to Get(%s, %s): %v\n", owner, client, err) } - return nil + return css.New().Scan(rows) } func (css *CommandSpecStatic) GetOrCreate(owner, client string) (spec *CommandSpec) { @@ -74,13 +74,15 @@ func (css *CommandSpecStatic) GetOrCreate(owner, client string) (spec *CommandSp func (css *CommandSpecStatic) getAllByQuery(query string, args ...interface{}) (specs []*CommandSpec) { rows, err := css.sql.Query(query, args...) - if err != nil || rows == nil { + if err != nil { + log.Errorf("Failed to getAllByQuery(%s): %v\n", query, err) return nil } defer rows.Close() for rows.Next() { specs = append(specs, css.New().Scan(rows)) } + log.Debugln("getAllByQuery() =", specs) return } @@ -103,9 +105,14 @@ func (cs *CommandSpec) Scan(row Scannable) *CommandSpec { var spec string err := row.Scan(&cs.Owner, &cs.Client, &spec) if err != nil { - log.Fatalln("Database scan failed:", err) + log.Errorln("CommandSpec scan failed:", err) + return cs + } + cs.CommandSpec = &maubot.CommandSpec{} + err = json.Unmarshal([]byte(spec), cs.CommandSpec) + if err != nil { + log.Errorln("CommandSpec parse failed:", err) } - json.Unmarshal([]byte(spec), &cs.CommandSpec) return cs } diff --git a/database/database.go b/database/database.go index b49b063..09fc87d 100644 --- a/database/database.go +++ b/database/database.go @@ -62,6 +62,11 @@ func (db *Database) CreateTables() { if err != nil { log.Errorln("Failed to create plugin table:", err) } + + err = db.CommandSpec.CreateTable() + if err != nil { + log.Errorln("Failed to create command_spec table:", err) + } } func (db *Database) SQL() *sql.DB { diff --git a/database/plugins.go b/database/plugins.go index 3b4ef91..0d43f43 100644 --- a/database/plugins.go +++ b/database/plugins.go @@ -87,7 +87,7 @@ func (ps *PluginStatic) New() *Plugin { func (p *Plugin) Scan(row Scannable) *Plugin { err := row.Scan(&p.ID, &p.Type, &p.Enabled, &p.UserID) if err != nil { - log.Fatalln("Database scan failed:", err) + log.Errorln("Plugin scan failed:", err) } return p } diff --git a/matrix.go b/matrix.go index 89b16a3..04a0ae7 100644 --- a/matrix.go +++ b/matrix.go @@ -54,11 +54,13 @@ const ( const FormatHTML = "org.matrix.custom.html" type EventHandler func(*Event) EventHandlerResult -type EventHandlerResult bool +type EventHandlerResult int +type CommandHandlerResult = EventHandlerResult const ( - Continue EventHandlerResult = false - StopPropagation EventHandlerResult = true + Continue EventHandlerResult = iota + StopEventPropagation + StopCommandPropagation CommandHandlerResult = iota ) type MatrixClient interface { @@ -130,7 +132,8 @@ type Content struct { Membership string `json:"membership,omitempty"` - RelatesTo RelatesTo `json:"m.relates_to,omitempty"` + Command MatchedCommand `json:"m.command,omitempty"` + RelatesTo RelatesTo `json:"m.relates_to,omitempty"` } func (content Content) Equals(otherContent *Content) bool { @@ -162,6 +165,12 @@ func (fi *FileInfo) Equals(otherFI *FileInfo) bool { ((fi.ThumbnailInfo != nil && fi.ThumbnailInfo.Equals(otherFI.ThumbnailInfo)) || otherFI.ThumbnailInfo == nil) } +type MatchedCommand struct { + Target string `json:"target"` + Matched string `json:"matched"` + Arguments map[string]string `json:"arguments"` +} + type RelatesTo struct { InReplyTo InReplyTo `json:"m.in_reply_to,omitempty"` } diff --git a/matrix/commands.go b/matrix/commands.go new file mode 100644 index 0000000..62bee76 --- /dev/null +++ b/matrix/commands.go @@ -0,0 +1,138 @@ +// maubot - A plugin-based Matrix bot system written in Go. +// Copyright (C) 2018 Tulir Asokan +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package matrix + +import ( + "fmt" + "regexp" + "strings" + + log "maunium.net/go/maulogger" + + "maubot.xyz" +) + +type ParsedCommand struct { + Name string + StartsWith string + Matches *regexp.Regexp + MatchAgainst string + MatchesEvent *maubot.Event +} + +func (pc *ParsedCommand) parseCommandSyntax(command maubot.Command) error { + regexBuilder := &strings.Builder{} + swBuilder := &strings.Builder{} + argumentEncountered := false + + regexBuilder.WriteRune('^') + words := strings.Split(command.Syntax, " ") + for i, word := range words { + argument, ok := command.Arguments[word] + if ok { + argumentEncountered = true + regex := argument.Matches + if argument.Required { + regex = fmt.Sprintf("(?:%s)?", regex) + } + regexBuilder.WriteString(regex) + } else { + if !argumentEncountered { + swBuilder.WriteString(word) + } + regexBuilder.WriteString(regexp.QuoteMeta(word)) + } + + if i < len(words) - 1 { + if !argumentEncountered { + swBuilder.WriteRune(' ') + } + regexBuilder.WriteRune(' ') + } + } + regexBuilder.WriteRune('$') + + var err error + pc.StartsWith = swBuilder.String() + // Trim the extra space at the end added in the parse loop + pc.StartsWith = pc.StartsWith[:len(pc.StartsWith)-1] + pc.Matches, err = regexp.Compile(regexBuilder.String()) + pc.MatchAgainst = "body" + return err +} + +func (pc *ParsedCommand) parsePassiveCommandSyntax(command maubot.PassiveCommand) error { + pc.MatchAgainst = command.MatchAgainst + var err error + pc.Matches, err = regexp.Compile(command.Matches) + pc.MatchesEvent = command.MatchEvent + return err +} + +func ParseSpec(spec *maubot.CommandSpec) (commands []*ParsedCommand) { + for _, command := range spec.Commands { + parsing := &ParsedCommand{ + Name: command.Syntax, + } + err := parsing.parseCommandSyntax(command) + if err != nil { + log.Warnf("Failed to parse regex of command %s: %v\n", command.Syntax, err) + continue + } + commands = append(commands, parsing) + } + for _, command := range spec.PassiveCommands { + parsing := &ParsedCommand{ + Name: command.Name, + } + err := parsing.parsePassiveCommandSyntax(command) + if err != nil { + log.Warnf("Failed to parse regex of passive command %s: %v\n", command.Name, err) + continue + } + commands = append(commands, parsing) + } + return commands +} + +func deepGet(from map[string]interface{}, path string) interface{} { + for { + dotIndex := strings.IndexRune(path, '.') + if dotIndex == -1 { + return from[path] + } + + var key string + key, path = path[:dotIndex], path[dotIndex+1:] + var ok bool + from, ok = from[key].(map[string]interface{}) + if !ok { + return nil + } + } +} + +func (pc *ParsedCommand) Match(evt *maubot.Event) bool { + matchAgainst, ok := deepGet(evt.Content.Raw, pc.MatchAgainst).(string) + if !ok { + matchAgainst = evt.Content.Body + } + + return strings.HasPrefix(matchAgainst, pc.StartsWith) && + pc.Matches.MatchString(matchAgainst) && + (pc.MatchesEvent == nil || pc.MatchesEvent.Equals(evt)) +} diff --git a/matrix/matrix.go b/matrix/matrix.go index 55762a4..bb4bdbe 100644 --- a/matrix/matrix.go +++ b/matrix/matrix.go @@ -25,9 +25,10 @@ import ( type Client struct { *gomatrix.Client - syncer *MaubotSyncer - - DB *database.MatrixClient + syncer *MaubotSyncer + handlers map[string][]maubot.CommandHandler + commands []*ParsedCommand + DB *database.MatrixClient } func NewClient(db *database.MatrixClient) (*Client, error) { @@ -37,8 +38,10 @@ func NewClient(db *database.MatrixClient) (*Client, error) { } client := &Client{ - Client: mxClient, - DB: db, + Client: mxClient, + handlers: make(map[string][]maubot.CommandHandler), + commands: ParseSpec(db.Commands()), + DB: db, } client.syncer = NewMaubotSyncer(client, client.Store) @@ -60,21 +63,29 @@ func (client *Client) Proxy(owner string) *ClientProxy { func (client *Client) AddEventHandler(evt maubot.EventType, handler maubot.EventHandler) { client.syncer.OnEventType(evt, func(evt *maubot.Event) maubot.EventHandlerResult { if evt.Sender == client.UserID { - return maubot.StopPropagation + return maubot.StopEventPropagation } return handler(evt) }) } -func (client *Client) AddCommandHandler(evt string, handler maubot.CommandHandler) { - // TODO add command handler +func (client *Client) AddCommandHandler(owner, evt string, handler maubot.CommandHandler) { + log.Debugln("Registering command handler for event", evt, "by", owner) + list, ok := client.handlers[evt] + if !ok { + list = []maubot.CommandHandler{handler} + } else { + list = append(list, handler) + } + client.handlers[evt] = list } func (client *Client) SetCommandSpec(owner string, spec *maubot.CommandSpec) { + log.Debugln("Registering command spec for", owner, "on", client.UserID) changed := client.DB.SetCommandSpec(owner, spec) if changed { + client.commands = ParseSpec(client.DB.Commands()) log.Debugln("Command spec of", owner, "on", client.UserID, "updated.") - // TODO } } @@ -87,15 +98,38 @@ func (client *Client) GetEvent(roomID, eventID string) *maubot.Event { return client.ParseEvent(evt).Event } +func (client *Client) TriggerCommand(command *ParsedCommand, evt *maubot.Event) maubot.CommandHandlerResult { + handlers, ok := client.handlers[command.Name] + if !ok { + log.Warnf("Command %s triggered by %s doesn't have any handlers.", command.Name, evt.Sender) + return maubot.Continue + } + log.Debugf("Command %s on client %s triggered by %s\n", command.Name, client.UserID, evt.Sender) + for _, handler := range handlers { + result := handler(evt) + if result == maubot.StopCommandPropagation { + break + } else if result != maubot.Continue { + return result + } + } + + return maubot.Continue +} + func (client *Client) onMessage(evt *maubot.Event) maubot.EventHandlerResult { - // TODO call command handlers + for _, command := range client.commands { + if command.Match(evt) { + return client.TriggerCommand(command, evt) + } + } return maubot.Continue } func (client *Client) onJoin(evt *maubot.Event) maubot.EventHandlerResult { if client.DB.AutoJoinRooms && evt.StateKey == client.DB.UserID && evt.Content.Membership == "invite" { client.JoinRoom(evt.RoomID) - return maubot.StopPropagation + return maubot.StopEventPropagation } return maubot.Continue } @@ -120,6 +154,10 @@ type ClientProxy struct { owner string } +func (cp *ClientProxy) AddCommandHandler(evt string, handler maubot.CommandHandler) { + cp.hiddenClient.AddCommandHandler(cp.owner, evt, handler) +} + func (cp *ClientProxy) SetCommandSpec(spec *maubot.CommandSpec) { cp.hiddenClient.SetCommandSpec(cp.owner, spec) } diff --git a/matrix/sync.go b/matrix/sync.go index dbdaa68..39bc588 100644 --- a/matrix/sync.go +++ b/matrix/sync.go @@ -135,7 +135,7 @@ func (s *MaubotSyncer) notifyListeners(mxEvent *gomatrix.Event) { return } for _, fn := range listeners { - if fn(event.Event) { + if fn(event.Event) == maubot.StopEventPropagation { break } }