diff --git a/commands/handler.go b/commands/handler.go index 3b92a908..56f27f06 100644 --- a/commands/handler.go +++ b/commands/handler.go @@ -33,6 +33,7 @@ type Handler[MetaType any] struct { // Parameters is a description of structured command parameters. // If set, the StructuredArgs field of Event will be populated. Parameters []*cmdschema.Parameter + TailParam string parents []*Handler[MetaType] nestedNameCache []string @@ -68,9 +69,18 @@ func (h *Handler[MetaType]) Spec() *cmdschema.EventContent { Aliases: names[1:], Parameters: h.Parameters, Description: h.Description, + TailParam: h.TailParam, } } +func (h *Handler[MetaType]) CopyFrom(other *Handler[MetaType]) { + if h.Parameters == nil { + h.Parameters = other.Parameters + h.TailParam = other.TailParam + } + h.Func = other.Func +} + func (h *Handler[MetaType]) initSubcommandContainer() { if len(h.Subcommands) > 0 { h.subcommandContainer = NewCommandContainer[MetaType]() diff --git a/commands/processor.go b/commands/processor.go index 0089226f..80f6745d 100644 --- a/commands/processor.go +++ b/commands/processor.go @@ -108,6 +108,7 @@ func (proc *Processor[MetaType]) Process(ctx context.Context, evt *event.Event) } } if parsed.StructuredArgs != nil && len(parsed.Args) > 0 { + // TODO allow unknown command handlers to be called? // The client sent MSC4391 data, but the target command wasn't found log.Debug().Msg("Didn't find handler for MSC4391 command") return diff --git a/event/cmdschema/content.go b/event/cmdschema/content.go index b69f0c1f..e7f362ed 100644 --- a/event/cmdschema/content.go +++ b/event/cmdschema/content.go @@ -13,6 +13,7 @@ import ( "reflect" "slices" + "go.mau.fi/util/exsync" "go.mau.fi/util/ptr" "maunium.net/go/mautrix/event" @@ -24,6 +25,7 @@ type EventContent struct { Aliases []string `json:"aliases,omitempty"` Parameters []*Parameter `json:"parameters,omitempty"` Description *event.ExtensibleTextContainer `json:"description,omitempty"` + TailParam string `json:"fi.mau.tail_parameter,omitempty"` } func (ec *EventContent) Validate() error { @@ -32,11 +34,22 @@ func (ec *EventContent) Validate() error { } else if ec.Command == "" { return fmt.Errorf("command is empty") } + var tailFound bool + dupMap := exsync.NewSet[string]() for i, p := range ec.Parameters { if err := p.Validate(); err != nil { return fmt.Errorf("parameter %q (#%d) is invalid: %w", ptr.Val(p).Key, i+1, err) + } else if !dupMap.Add(p.Key) { + return fmt.Errorf("duplicate parameter key %q at #%d", p.Key, i+1) + } else if p.Key == ec.TailParam { + tailFound = true + } else if tailFound && !p.Optional { + return fmt.Errorf("required parameter %q (#%d) is after tail parameter %q", p.Key, i+1, ec.TailParam) } } + if ec.TailParam != "" && !tailFound { + return fmt.Errorf("tail parameter %q not found in parameters", ec.TailParam) + } return nil } diff --git a/event/cmdschema/parse.go b/event/cmdschema/parse.go index 5269ab28..91a02827 100644 --- a/event/cmdschema/parse.go +++ b/event/cmdschema/parse.go @@ -135,8 +135,8 @@ func (ec *EventContent) ParseArguments(input string) (json.RawMessage, error) { args[param.Key] = collector } else { nextVal, input, wasQuoted = parseQuoted(input) - if isLast && !wasQuoted && len(input) > 0 { - // If the last argument is not quoted and not variadic, just treat the rest of the string + if isLast && !wasQuoted && len(input) > 0 && !strings.Contains(input, "--") { + // If the last argument is not quoted and doesn't have flags, just treat the rest of the string // as the argument without escapes (arguments with escapes should be quoted). nextVal += " " + input input = "" @@ -146,7 +146,7 @@ func (ec *EventContent) ParseArguments(input string) (json.RawMessage, error) { args[param.Key] = true return } - if nextVal == "" && !param.Optional { + if nextVal == "" && !wasQuoted && !isNamed && !param.Optional { setError(fmt.Errorf("missing value for required parameter %s", param.Key)) } parsedVal, err := param.Schema.ParseString(nextVal) @@ -180,10 +180,11 @@ func (ec *EventContent) ParseArguments(input string) (json.RawMessage, error) { break } } - if skipParams[i] { + isTail := param.Key == ec.TailParam + if skipParams[i] || (param.Optional && !isTail) { continue } - processParameter(param, i == len(ec.Parameters)-1, false) + processParameter(param, i == len(ec.Parameters)-1 || isTail, false) } jsonArgs, marshalErr := json.Marshal(args) if marshalErr != nil { diff --git a/event/cmdschema/parse_test.go b/event/cmdschema/parse_test.go index 725b0150..1e0d1817 100644 --- a/event/cmdschema/parse_test.go +++ b/event/cmdschema/parse_test.go @@ -109,7 +109,7 @@ func TestMSC4391BotCommandEventContent_ParseInput(t *testing.T) { assert.Nil(t, output) } else { assert.Equal(t, ctd.Spec.Command, output.MSC4391BotCommand.Command) - assert.Equal(t, outputStr, exbytes.UnsafeString(output.MSC4391BotCommand.Arguments)) + assert.Equalf(t, outputStr, exbytes.UnsafeString(output.MSC4391BotCommand.Arguments), "Input: %s", test.Input) } }) } diff --git a/event/cmdschema/testdata/commands/flags.json b/event/cmdschema/testdata/commands/flags.json index dedde348..469986f0 100644 --- a/event/cmdschema/testdata/commands/flags.json +++ b/event/cmdschema/testdata/commands/flags.json @@ -27,7 +27,8 @@ "optional": true, "fi.mau.default_value": false } - ] + ], + "fi.mau.tail_parameter": "user" }, "tests": [ { @@ -35,17 +36,15 @@ "input": "/flag mrrp", "output": { "meow": "mrrp", - "user": null, - "woof": false + "user": null } }, { - "name": "positional flag", - "input": "/flag mrrp @user:example.com yes", + "name": "no flags, has tail", + "input": "/flag mrrp @user:example.com", "output": { "meow": "mrrp", - "user": "@user:example.com", - "woof": true + "user": "@user:example.com" } }, { @@ -130,18 +129,9 @@ "woof": true } }, - { - "name": "only string variables named", - "input": "/flag --user=@user:example.com --meow=mrrp yes", - "output": { - "meow": "mrrp", - "user": "@user:example.com", - "woof": true - } - }, { "name": "invalid value for named parameter", - "input": "/flag --user=meowings mrrp yes", + "input": "/flag --user=meowings mrrp --woof", "error": true, "output": { "meow": "mrrp",