mirror of
https://mau.dev/mautrix/go.git
synced 2026-03-14 14:25:53 +01:00
commands: add subcommand system
This commit is contained in:
parent
33f3ccd6ae
commit
3badb9b332
5 changed files with 161 additions and 74 deletions
89
commands/container.go
Normal file
89
commands/container.go
Normal file
|
|
@ -0,0 +1,89 @@
|
|||
// Copyright (c) 2025 Tulir Asokan
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the Mozilla Public
|
||||
// License, v. 2.0. If a copy of the MPL was not distributed with this
|
||||
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
|
||||
|
||||
package commands
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
"sync"
|
||||
)
|
||||
|
||||
type CommandContainer[MetaType any] struct {
|
||||
commands map[string]*Handler[MetaType]
|
||||
aliases map[string]string
|
||||
lock sync.RWMutex
|
||||
}
|
||||
|
||||
func NewCommandContainer[MetaType any]() *CommandContainer[MetaType] {
|
||||
return &CommandContainer[MetaType]{
|
||||
commands: make(map[string]*Handler[MetaType]),
|
||||
aliases: make(map[string]string),
|
||||
}
|
||||
}
|
||||
|
||||
// Register registers the given command handlers.
|
||||
func (cont *CommandContainer[MetaType]) Register(handlers ...*Handler[MetaType]) {
|
||||
if cont == nil {
|
||||
return
|
||||
}
|
||||
cont.lock.Lock()
|
||||
defer cont.lock.Unlock()
|
||||
for _, handler := range handlers {
|
||||
cont.registerOne(handler)
|
||||
}
|
||||
}
|
||||
|
||||
func (cont *CommandContainer[MetaType]) registerOne(handler *Handler[MetaType]) {
|
||||
if strings.ToLower(handler.Name) != handler.Name {
|
||||
panic(fmt.Errorf("command %q is not lowercase", handler.Name))
|
||||
}
|
||||
cont.commands[handler.Name] = handler
|
||||
for _, alias := range handler.Aliases {
|
||||
if strings.ToLower(alias) != alias {
|
||||
panic(fmt.Errorf("alias %q is not lowercase", alias))
|
||||
}
|
||||
cont.aliases[alias] = handler.Name
|
||||
}
|
||||
handler.initSubcommandContainer()
|
||||
}
|
||||
|
||||
func (cont *CommandContainer[MetaType]) Unregister(handlers ...*Handler[MetaType]) {
|
||||
if cont == nil {
|
||||
return
|
||||
}
|
||||
cont.lock.Lock()
|
||||
defer cont.lock.Unlock()
|
||||
for _, handler := range handlers {
|
||||
cont.unregisterOne(handler)
|
||||
}
|
||||
}
|
||||
|
||||
func (cont *CommandContainer[MetaType]) unregisterOne(handler *Handler[MetaType]) {
|
||||
delete(cont.commands, handler.Name)
|
||||
for _, alias := range handler.Aliases {
|
||||
if cont.aliases[alias] == handler.Name {
|
||||
delete(cont.aliases, alias)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (cont *CommandContainer[MetaType]) GetHandler(name string) *Handler[MetaType] {
|
||||
if cont == nil {
|
||||
return nil
|
||||
}
|
||||
cont.lock.RLock()
|
||||
defer cont.lock.RUnlock()
|
||||
alias, ok := cont.aliases[name]
|
||||
if ok {
|
||||
name = alias
|
||||
}
|
||||
handler, ok := cont.commands[name]
|
||||
if !ok {
|
||||
handler = cont.commands[UnknownCommandName]
|
||||
}
|
||||
return handler
|
||||
}
|
||||
|
|
@ -23,6 +23,9 @@ type Event[MetaType any] struct {
|
|||
*event.Event
|
||||
// RawInput is the entire message before splitting into command and arguments.
|
||||
RawInput string
|
||||
// ParentCommands is the chain of commands leading up to this command.
|
||||
// This is only set if the command is a subcommand.
|
||||
ParentCommands []string
|
||||
// Command is the lowercased first word of the message.
|
||||
Command string
|
||||
// Args are the rest of the message split by whitespace ([strings.Fields]).
|
||||
|
|
@ -122,3 +125,18 @@ func (evt *Event[MetaType]) MarkRead() {
|
|||
zerolog.Ctx(evt.Ctx).Err(err).Msg("Failed to send read receipt")
|
||||
}
|
||||
}
|
||||
|
||||
// PromoteFirstArgToCommand promotes the first argument to the command name.
|
||||
//
|
||||
// Command will be set to the lowercased first item in the Args list.
|
||||
// Both Args and RawArgs will be updated to remove the first argument, but RawInput will be left as-is.
|
||||
//
|
||||
// The caller MUST check that there are args before calling this function.
|
||||
func (evt *Event[MetaType]) PromoteFirstArgToCommand() {
|
||||
if len(evt.Args) == 0 {
|
||||
panic(fmt.Errorf("PromoteFirstArgToCommand called with no args"))
|
||||
}
|
||||
evt.Command = strings.ToLower(evt.Args[0])
|
||||
evt.RawArgs = strings.TrimLeft(strings.TrimPrefix(evt.RawArgs, evt.Args[0]), " ")
|
||||
evt.Args = evt.Args[1:]
|
||||
}
|
||||
|
|
|
|||
29
commands/handler.go
Normal file
29
commands/handler.go
Normal file
|
|
@ -0,0 +1,29 @@
|
|||
// Copyright (c) 2025 Tulir Asokan
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the Mozilla Public
|
||||
// License, v. 2.0. If a copy of the MPL was not distributed with this
|
||||
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
|
||||
|
||||
package commands
|
||||
|
||||
type Handler[MetaType any] struct {
|
||||
Func func(ce *Event[MetaType])
|
||||
|
||||
// Name is the primary name of the command. It must be lowercase.
|
||||
Name string
|
||||
// Aliases are alternative names for the command. They must be lowercase.
|
||||
Aliases []string
|
||||
// Subcommands are subcommands of this command.
|
||||
Subcommands []*Handler[MetaType]
|
||||
|
||||
subcommandContainer *CommandContainer[MetaType]
|
||||
}
|
||||
|
||||
func (h *Handler[MetaType]) initSubcommandContainer() {
|
||||
if len(h.Subcommands) > 0 {
|
||||
h.subcommandContainer = NewCommandContainer[MetaType]()
|
||||
h.subcommandContainer.Register(h.Subcommands...)
|
||||
} else {
|
||||
h.subcommandContainer = nil
|
||||
}
|
||||
}
|
||||
|
|
@ -61,9 +61,7 @@ func (f AnyPreValidator[MetaType]) Validate(ce *Event[MetaType]) bool {
|
|||
func ValidatePrefixCommand[MetaType any](prefix string) PreValidator[MetaType] {
|
||||
return FuncPreValidator[MetaType](func(ce *Event[MetaType]) bool {
|
||||
if ce.Command == prefix && len(ce.Args) > 0 {
|
||||
ce.Command = strings.ToLower(ce.Args[0])
|
||||
ce.RawArgs = strings.TrimLeft(strings.TrimPrefix(ce.RawArgs, ce.Args[0]), " ")
|
||||
ce.Args = ce.Args[1:]
|
||||
ce.PromoteFirstArgToCommand()
|
||||
return true
|
||||
}
|
||||
return false
|
||||
|
|
|
|||
|
|
@ -8,10 +8,8 @@ package commands
|
|||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"runtime/debug"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/rs/zerolog"
|
||||
|
||||
|
|
@ -22,34 +20,23 @@ import (
|
|||
// Processor implements boilerplate code for splitting messages into a command and arguments,
|
||||
// and finding the appropriate handler for the command.
|
||||
type Processor[MetaType any] struct {
|
||||
*CommandContainer[MetaType]
|
||||
|
||||
Client *mautrix.Client
|
||||
LogArgs bool
|
||||
PreValidator PreValidator[MetaType]
|
||||
Meta MetaType
|
||||
commands map[string]*Handler[MetaType]
|
||||
aliases map[string]string
|
||||
lock sync.RWMutex
|
||||
}
|
||||
|
||||
type Handler[MetaType any] struct {
|
||||
Func func(ce *Event[MetaType])
|
||||
|
||||
// Name is the primary name of the command. It must be lowercase.
|
||||
Name string
|
||||
// Aliases are alternative names for the command. They must be lowercase.
|
||||
Aliases []string
|
||||
}
|
||||
|
||||
// UnknownCommandName is the name of the fallback handler which is used if no other handler is found.
|
||||
// If even the unknown command handler is not found, the command is ignored.
|
||||
const UnknownCommandName = "unknown-command"
|
||||
const UnknownCommandName = "__unknown-command__"
|
||||
|
||||
func NewProcessor[MetaType any](cli *mautrix.Client) *Processor[MetaType] {
|
||||
proc := &Processor[MetaType]{
|
||||
Client: cli,
|
||||
PreValidator: ValidatePrefixSubstring[MetaType]("!"),
|
||||
commands: make(map[string]*Handler[MetaType]),
|
||||
aliases: make(map[string]string),
|
||||
CommandContainer: NewCommandContainer[MetaType](),
|
||||
Client: cli,
|
||||
PreValidator: ValidatePrefixSubstring[MetaType]("!"),
|
||||
}
|
||||
proc.Register(&Handler[MetaType]{
|
||||
Name: UnknownCommandName,
|
||||
|
|
@ -60,45 +47,6 @@ func NewProcessor[MetaType any](cli *mautrix.Client) *Processor[MetaType] {
|
|||
return proc
|
||||
}
|
||||
|
||||
// Register registers the given command handlers.
|
||||
func (proc *Processor[MetaType]) Register(handlers ...*Handler[MetaType]) {
|
||||
proc.lock.Lock()
|
||||
defer proc.lock.Unlock()
|
||||
for _, handler := range handlers {
|
||||
proc.registerOne(handler)
|
||||
}
|
||||
}
|
||||
|
||||
func (proc *Processor[MetaType]) registerOne(handler *Handler[MetaType]) {
|
||||
if strings.ToLower(handler.Name) != handler.Name {
|
||||
panic(fmt.Errorf("command %q is not lowercase", handler.Name))
|
||||
}
|
||||
proc.commands[handler.Name] = handler
|
||||
for _, alias := range handler.Aliases {
|
||||
if strings.ToLower(alias) != alias {
|
||||
panic(fmt.Errorf("alias %q is not lowercase", alias))
|
||||
}
|
||||
proc.aliases[alias] = handler.Name
|
||||
}
|
||||
}
|
||||
|
||||
func (proc *Processor[MetaType]) Unregister(handlers ...*Handler[MetaType]) {
|
||||
proc.lock.Lock()
|
||||
defer proc.lock.Unlock()
|
||||
for _, handler := range handlers {
|
||||
proc.unregisterOne(handler)
|
||||
}
|
||||
}
|
||||
|
||||
func (proc *Processor[MetaType]) unregisterOne(handler *Handler[MetaType]) {
|
||||
delete(proc.commands, handler.Name)
|
||||
for _, alias := range handler.Aliases {
|
||||
if proc.aliases[alias] == handler.Name {
|
||||
delete(proc.aliases, alias)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (proc *Processor[MetaType]) Process(ctx context.Context, evt *event.Event) {
|
||||
log := *zerolog.Ctx(ctx)
|
||||
defer func() {
|
||||
|
|
@ -123,25 +71,30 @@ func (proc *Processor[MetaType]) Process(ctx context.Context, evt *event.Event)
|
|||
return
|
||||
}
|
||||
|
||||
realCommand := parsed.Command
|
||||
proc.lock.RLock()
|
||||
alias, ok := proc.aliases[realCommand]
|
||||
if ok {
|
||||
realCommand = alias
|
||||
}
|
||||
handler, ok := proc.commands[realCommand]
|
||||
if !ok {
|
||||
handler, ok = proc.commands[UnknownCommandName]
|
||||
}
|
||||
proc.lock.RUnlock()
|
||||
if !ok {
|
||||
handler := proc.GetHandler(parsed.Command)
|
||||
if handler == nil {
|
||||
return
|
||||
}
|
||||
handlerChain := zerolog.Arr()
|
||||
handlerChain.Str(handler.Name)
|
||||
for handler.subcommandContainer != nil && len(parsed.Args) > 0 {
|
||||
subHandler := handler.subcommandContainer.GetHandler(strings.ToLower(parsed.Args[0]))
|
||||
if subHandler != nil {
|
||||
parsed.ParentCommands = append(parsed.ParentCommands, parsed.Command)
|
||||
handlerChain.Str(subHandler.Name)
|
||||
parsed.PromoteFirstArgToCommand()
|
||||
handler = subHandler
|
||||
}
|
||||
}
|
||||
|
||||
logWith := log.With().
|
||||
Str("command", realCommand).
|
||||
Str("command", parsed.Command).
|
||||
Array("handler", handlerChain).
|
||||
Stringer("sender", evt.Sender).
|
||||
Stringer("room_id", evt.RoomID)
|
||||
if len(parsed.ParentCommands) > 0 {
|
||||
logWith = logWith.Strs("parent_commands", parsed.ParentCommands)
|
||||
}
|
||||
if proc.LogArgs {
|
||||
logWith = logWith.Strs("args", parsed.Args)
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue