diff --git a/.github/workflows/go.yml b/.github/workflows/go.yml index dc4f17e2..c0add220 100644 --- a/.github/workflows/go.yml +++ b/.github/workflows/go.yml @@ -10,12 +10,12 @@ jobs: runs-on: ubuntu-latest name: Lint (latest) steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@v6 - name: Set up Go - uses: actions/setup-go@v5 + uses: actions/setup-go@v6 with: - go-version: "1.25" + go-version: "1.26" cache: true - name: Install libolm @@ -24,6 +24,7 @@ jobs: - name: Install goimports run: | go install golang.org/x/tools/cmd/goimports@latest + go install honnef.co/go/tools/cmd/staticcheck@latest export PATH="$HOME/go/bin:$PATH" - name: Run pre-commit @@ -34,14 +35,14 @@ jobs: strategy: fail-fast: false matrix: - go-version: ["1.24", "1.25"] - name: Build (${{ matrix.go-version == '1.25' && 'latest' || 'old' }}, libolm) + go-version: ["1.25", "1.26"] + name: Build (${{ matrix.go-version == '1.26' && 'latest' || 'old' }}, libolm) steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@v6 - name: Set up Go ${{ matrix.go-version }} - uses: actions/setup-go@v5 + uses: actions/setup-go@v6 with: go-version: ${{ matrix.go-version }} cache: true @@ -61,7 +62,6 @@ jobs: run: go test -json -v ./... 2>&1 | gotestfmt - name: Test (jsonv2) - if: matrix.go-version == '1.25' env: GOEXPERIMENT: jsonv2 run: go test -json -v ./... 2>&1 | gotestfmt @@ -71,14 +71,14 @@ jobs: strategy: fail-fast: false matrix: - go-version: ["1.24", "1.25"] - name: Build (${{ matrix.go-version == '1.25' && 'latest' || 'old' }}, goolm) + go-version: ["1.25", "1.26"] + name: Build (${{ matrix.go-version == '1.26' && 'latest' || 'old' }}, goolm) steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@v6 - name: Set up Go ${{ matrix.go-version }} - uses: actions/setup-go@v5 + uses: actions/setup-go@v6 with: go-version: ${{ matrix.go-version }} cache: true diff --git a/.github/workflows/stale.yml b/.github/workflows/stale.yml index 578349c9..9a9e7375 100644 --- a/.github/workflows/stale.yml +++ b/.github/workflows/stale.yml @@ -17,7 +17,7 @@ jobs: lock-stale: runs-on: ubuntu-latest steps: - - uses: dessant/lock-threads@v5 + - uses: dessant/lock-threads@v6 id: lock with: issue-inactive-days: 90 diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 0b9785ae..616fccb2 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,6 +1,6 @@ repos: - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v5.0.0 + rev: v6.0.0 hooks: - id: trailing-whitespace exclude_types: [markdown] @@ -9,7 +9,7 @@ repos: - id: check-added-large-files - repo: https://github.com/tekwizely/pre-commit-golang - rev: v1.0.0-rc.1 + rev: v1.0.0-rc.4 hooks: - id: go-imports-repo args: @@ -18,8 +18,7 @@ repos: - "-w" - id: go-vet-repo-mod - id: go-mod-tidy - # TODO enable this - #- id: go-staticcheck-repo-mod + - id: go-staticcheck-repo-mod - repo: https://github.com/beeper/pre-commit-go rev: v0.4.2 diff --git a/CHANGELOG.md b/CHANGELOG.md index f59e6853..f2829199 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,130 @@ +## v0.26.3 (2026-02-16) + +* Bumped minimum Go version to 1.25. +* *(client)* Added fields for sending [MSC4354] sticky events. +* *(bridgev2)* Added automatic message request accepting when sending message. +* *(mediaproxy)* Added support for federation thumbnail endpoint. +* *(crypto/ssss)* Improved support for recovery keys with slightly broken + metadata. +* *(crypto)* Changed key import to call session received callback even for + sessions that already exist in the database. +* *(appservice)* Fixed building websocket URL accidentally using file path + separators instead of always `/`. +* *(crypto)* Fixed key exports not including the `sender_claimed_keys` field. +* *(client)* Fixed incorrect context usage in async uploads. +* *(crypto)* Fixed panic when passing invalid input to megolm message index + parser used for debugging. +* *(bridgev2/provisioning)* Fixed completed or failed logins not being cleaned + up properly. + +[MSC4354]: https://github.com/matrix-org/matrix-spec-proposals/pull/4354 + +## v0.26.2 (2026-01-16) + +* *(bridgev2)* Added chunked portal deletion to avoid database locks when + deleting large portals. +* *(crypto,bridgev2)* Added option to encrypt reaction and reply metadata + as per [MSC4392]. +* *(bridgev2/login)* Added `default_value` for user input fields. +* *(bridgev2)* Added interfaces to let the Matrix connector provide suggested + HTTP client settings and to reset active connections of the network connector. +* *(bridgev2)* Added interface to let network connectors get the provisioning + API HTTP router and add new endpoints. +* *(event)* Added blurhash field to Beeper link preview objects. +* *(event)* Added [MSC4391] support for bot commands. +* *(event)* Dropped [MSC4332] support for bot commands. +* *(client)* Changed media download methods to return an error if the provided + MXC URI is empty. +* *(client)* Stabilized support for [MSC4323]. +* *(bridgev2/matrix)* Fixed `GetEvent` panicking when trying to decrypt events. +* *(bridgev2)* Fixed some deadlocks when room creation happens in parallel with + a portal re-ID call. + +[MSC4391]: https://github.com/matrix-org/matrix-spec-proposals/pull/4391 +[MSC4392]: https://github.com/matrix-org/matrix-spec-proposals/pull/4392 + +## v0.26.1 (2025-12-16) + +* **Breaking change *(mediaproxy)*** Changed `GetMediaResponseFile` to return + the mime type from the callback rather than in the return get media return + value. The callback can now also redirect the caller to a different file. +* *(federation)* Added join/knock/leave functions + (thanks to [@nexy7574] in [#422]). +* *(federation/eventauth)* Fixed various incorrect checks. +* *(client)* Added backoff for retrying media uploads to external URLs + (with MSC3870). +* *(bridgev2/config)* Added support for overriding config fields using + environment variables. +* *(bridgev2/commands)* Added command to mute chat on remote network. +* *(bridgev2)* Added interface for network connectors to redirect to a different + user ID when handling an invite from Matrix. +* *(bridgev2)* Added interface for signaling message request status of portals. +* *(bridgev2)* Changed portal creation to not backfill unless `CanBackfill` flag + is set in chat info. +* *(bridgev2)* Changed Matrix reaction handling to only delete old reaction if + bridging the new one is successful. +* *(bridgev2/mxmain)* Improved error message when trying to run bridge with + pre-megabridge database when no database migration exists. +* *(bridgev2)* Improved reliability of database migration when enabling split + portals. +* *(bridgev2)* Improved detection of orphaned DM rooms when starting new chats. +* *(bridgev2)* Stopped sending redundant invites when joining ghosts to public + portal rooms. +* *(bridgev2)* Stopped hardcoding room versions in favor of checking + server capabilities to determine appropriate `/createRoom` parameters. + +[#422]: https://github.com/mautrix/go/pull/422 + +## v0.26.0 (2025-11-16) + +* *(client,appservice)* Deprecated `SendMassagedStateEvent` as `SendStateEvent` + has been able to do the same for a while now. +* *(client,federation)* Added size limits for responses to make it safer to send + requests to untrusted servers. +* *(client)* Added wrapper for `/admin/whois` client API + (thanks to [@nexy7574] in [#411]). +* *(synapseadmin)* Added `force_purge` option to DeleteRoom + (thanks to [@nexy7574] in [#420]). +* *(statestore)* Added saving join rules for rooms. +* *(bridgev2)* Added optional automatic rollback of room state if bridging the + change to the remote network fails. +* *(bridgev2)* Added management room notices if transient disconnect state + doesn't resolve within 3 minutes. +* *(bridgev2)* Added interface to signal that certain participants couldn't be + invited when creating a group. +* *(bridgev2)* Added `select` type for user input fields in login. +* *(bridgev2)* Added interface to let network connector customize personal + filtering space. +* *(bridgev2/matrix)* Added checks to avoid sending error messages in reply to + other bots. +* *(bridgev2/matrix)* Switched to using [MSC4169] to send redactions whenever + possible. +* *(bridgev2/publicmedia)* Added support for custom path prefixes, file names, + and encrypted files. +* *(bridgev2/commands)* Added command to resync a single portal. +* *(bridgev2/commands)* Added create group command. +* *(bridgev2/config)* Added option to limit maximum number of logins. +* *(bridgev2)* Changed ghost joining to skip unnecessary invite if portal room + is public. +* *(bridgev2/disappear)* Changed read receipt handling to only start + disappearing timers for messages up to the read message (note: may not work in + all cases if the read receipt points at an unknown event). +* *(event/reply)* Changed plaintext reply fallback removal to only happen when + an HTML reply fallback is removed successfully. +* *(bridgev2/matrix)* Fixed unnecessary sleep after registering bot on first run. +* *(crypto/goolm)* Fixed panic when processing certain malformed Olm messages. +* *(federation)* Fixed HTTP method for sending transactions + (thanks to [@nexy7574] in [#426]). +* *(federation)* Fixed response body being closed even when using `DontReadBody` + parameter. +* *(federation)* Fixed validating auth for requests with query params. +* *(federation/eventauth)* Fixed typo causing restricted joins to not work. + +[MSC4169]: https://github.com/matrix-org/matrix-spec-proposals/pull/4169 +[#411]: github.com/mautrix/go/pull/411 +[#420]: github.com/mautrix/go/pull/420 +[#426]: github.com/mautrix/go/pull/426 + ## v0.25.2 (2025-10-16) * **Breaking change *(id)*** Split `UserID.ParseAndValidate` into @@ -310,6 +437,7 @@ [MSC4156]: https://github.com/matrix-org/matrix-spec-proposals/pull/4156 [MSC4190]: https://github.com/matrix-org/matrix-spec-proposals/pull/4190 [#288]: https://github.com/mautrix/go/pull/288 +[@onestacked]: https://github.com/onestacked ## v0.22.0 (2024-11-16) diff --git a/README.md b/README.md index ac41ca78..b1a2edf8 100644 --- a/README.md +++ b/README.md @@ -1,8 +1,9 @@ # mautrix-go [![GoDoc](https://pkg.go.dev/badge/maunium.net/go/mautrix)](https://pkg.go.dev/maunium.net/go/mautrix) -A Golang Matrix framework. Used by [gomuks](https://matrix.org/docs/projects/client/gomuks), -[go-neb](https://github.com/matrix-org/go-neb), [mautrix-whatsapp](https://github.com/mautrix/whatsapp) +A Golang Matrix framework. Used by [gomuks](https://gomuks.app), +[go-neb](https://github.com/matrix-org/go-neb), +[mautrix-whatsapp](https://github.com/mautrix/whatsapp) and others. Matrix room: [`#go:maunium.net`](https://matrix.to/#/#go:maunium.net) @@ -13,9 +14,10 @@ The original project is licensed under [Apache 2.0](https://github.com/matrix-or In addition to the basic client API features the original project has, this framework also has: * Appservice support (Intent API like mautrix-python, room state storage, etc) -* End-to-end encryption support (incl. interactive SAS verification) +* End-to-end encryption support (incl. key backup, cross-signing, interactive verification, etc) * High-level module for building puppeting bridges -* High-level module for building chat clients +* Partial federation module (making requests, PDU processing and event authorization) +* A media proxy server which can be used to expose anything as a Matrix media repo * Wrapper functions for the Synapse admin API * Structs for parsing event content * Helpers for parsing and generating Matrix HTML diff --git a/appservice/intent.go b/appservice/intent.go index 4635f59a..5d43f190 100644 --- a/appservice/intent.go +++ b/appservice/intent.go @@ -51,7 +51,7 @@ func (as *AppService) NewIntentAPI(localpart string) *IntentAPI { } func (intent *IntentAPI) Register(ctx context.Context) error { - _, err := intent.Client.MakeRequest(ctx, http.MethodPost, intent.BuildClientURL("v3", "register"), &mautrix.ReqRegister{ + _, err := intent.Client.MakeRequest(ctx, http.MethodPost, intent.BuildClientURL("v3", "register"), &mautrix.ReqRegister[any]{ Username: intent.Localpart, Type: mautrix.AuthTypeAppservice, InhibitLogin: true, @@ -214,23 +214,31 @@ func (intent *IntentAPI) AddDoublePuppetValueWithTS(into any, ts int64) any { } } -func (intent *IntentAPI) SendMessageEvent(ctx context.Context, roomID id.RoomID, eventType event.Type, contentJSON interface{}) (*mautrix.RespSendEvent, error) { +func (intent *IntentAPI) SendMessageEvent(ctx context.Context, roomID id.RoomID, eventType event.Type, contentJSON any, extra ...mautrix.ReqSendEvent) (*mautrix.RespSendEvent, error) { if err := intent.EnsureJoined(ctx, roomID); err != nil { return nil, err } contentJSON = intent.AddDoublePuppetValue(contentJSON) - return intent.Client.SendMessageEvent(ctx, roomID, eventType, contentJSON) + return intent.Client.SendMessageEvent(ctx, roomID, eventType, contentJSON, extra...) } -func (intent *IntentAPI) SendMassagedMessageEvent(ctx context.Context, roomID id.RoomID, eventType event.Type, contentJSON interface{}, ts int64) (*mautrix.RespSendEvent, error) { +func (intent *IntentAPI) BeeperSendEphemeralEvent(ctx context.Context, roomID id.RoomID, eventType event.Type, contentJSON any, extra ...mautrix.ReqSendEvent) (*mautrix.RespSendEvent, error) { if err := intent.EnsureJoined(ctx, roomID); err != nil { return nil, err } - contentJSON = intent.AddDoublePuppetValueWithTS(contentJSON, ts) - return intent.Client.SendMessageEvent(ctx, roomID, eventType, contentJSON, mautrix.ReqSendEvent{Timestamp: ts}) + if !intent.SpecVersions.Supports(mautrix.BeeperFeatureEphemeralEvents) { + return nil, mautrix.MUnrecognized.WithMessage("Homeserver does not advertise com.beeper.ephemeral support") + } + contentJSON = intent.AddDoublePuppetValue(contentJSON) + return intent.Client.BeeperSendEphemeralEvent(ctx, roomID, eventType, contentJSON, extra...) } -func (intent *IntentAPI) SendStateEvent(ctx context.Context, roomID id.RoomID, eventType event.Type, stateKey string, contentJSON interface{}) (*mautrix.RespSendEvent, error) { +// Deprecated: use SendMessageEvent with mautrix.ReqSendEvent.Timestamp instead +func (intent *IntentAPI) SendMassagedMessageEvent(ctx context.Context, roomID id.RoomID, eventType event.Type, contentJSON interface{}, ts int64) (*mautrix.RespSendEvent, error) { + return intent.SendMessageEvent(ctx, roomID, eventType, contentJSON, mautrix.ReqSendEvent{Timestamp: ts}) +} + +func (intent *IntentAPI) SendStateEvent(ctx context.Context, roomID id.RoomID, eventType event.Type, stateKey string, contentJSON any, extra ...mautrix.ReqSendEvent) (*mautrix.RespSendEvent, error) { if eventType != event.StateMember || stateKey != string(intent.UserID) { if err := intent.EnsureJoined(ctx, roomID); err != nil { return nil, err @@ -239,15 +247,12 @@ func (intent *IntentAPI) SendStateEvent(ctx context.Context, roomID id.RoomID, e return nil, err } contentJSON = intent.AddDoublePuppetValue(contentJSON) - return intent.Client.SendStateEvent(ctx, roomID, eventType, stateKey, contentJSON) + return intent.Client.SendStateEvent(ctx, roomID, eventType, stateKey, contentJSON, extra...) } +// Deprecated: use SendStateEvent with mautrix.ReqSendEvent.Timestamp instead func (intent *IntentAPI) SendMassagedStateEvent(ctx context.Context, roomID id.RoomID, eventType event.Type, stateKey string, contentJSON interface{}, ts int64) (*mautrix.RespSendEvent, error) { - if err := intent.EnsureJoined(ctx, roomID); err != nil { - return nil, err - } - contentJSON = intent.AddDoublePuppetValueWithTS(contentJSON, ts) - return intent.Client.SendMassagedStateEvent(ctx, roomID, eventType, stateKey, contentJSON, ts) + return intent.SendStateEvent(ctx, roomID, eventType, stateKey, contentJSON, mautrix.ReqSendEvent{Timestamp: ts}) } func (intent *IntentAPI) StateEvent(ctx context.Context, roomID id.RoomID, eventType event.Type, stateKey string, outContent interface{}) error { diff --git a/appservice/websocket.go b/appservice/websocket.go index 1e401c53..ef65e65a 100644 --- a/appservice/websocket.go +++ b/appservice/websocket.go @@ -14,7 +14,7 @@ import ( "io" "net/http" "net/url" - "path/filepath" + "path" "strings" "sync" "sync/atomic" @@ -56,7 +56,7 @@ func (wsc *WebsocketCommand) MakeResponse(ok bool, data any) *WebsocketRequest { var prefixMessage string for unwrappedErr != nil { errorData, jsonErr = json.Marshal(unwrappedErr) - if errorData != nil && len(errorData) > 2 && jsonErr == nil { + if len(errorData) > 2 && jsonErr == nil { prefixMessage = strings.Replace(err.Error(), unwrappedErr.Error(), "", 1) prefixMessage = strings.TrimRight(prefixMessage, ": ") break @@ -374,7 +374,7 @@ func (as *AppService) StartWebsocket(ctx context.Context, baseURL string, onConn copiedURL := *as.hsURLForClient parsed = &copiedURL } - parsed.Path = filepath.Join(parsed.Path, "_matrix/client/unstable/fi.mau.as_sync") + parsed.Path = path.Join(parsed.Path, "_matrix/client/unstable/fi.mau.as_sync") if parsed.Scheme == "http" { parsed.Scheme = "ws" } else if parsed.Scheme == "https" { diff --git a/bridgev2/bridge.go b/bridgev2/bridge.go index 2ad6a614..226adc90 100644 --- a/bridgev2/bridge.go +++ b/bridgev2/bridge.go @@ -11,10 +11,12 @@ import ( "fmt" "os" "sync" + "sync/atomic" "time" "github.com/rs/zerolog" "go.mau.fi/util/dbutil" + "go.mau.fi/util/exhttp" "go.mau.fi/util/exsync" "maunium.net/go/mautrix/bridgev2/bridgeconfig" @@ -52,6 +54,7 @@ type Bridge struct { Background bool ExternallyManagedDB bool + stopping atomic.Bool wakeupBackfillQueue chan struct{} stopBackfillQueue *exsync.Event @@ -127,6 +130,7 @@ func (br *Bridge) Start(ctx context.Context) error { func (br *Bridge) RunOnce(ctx context.Context, loginID networkid.UserLoginID, params *ConnectBackgroundParams) error { br.Background = true + br.stopping.Store(false) err := br.StartConnectors(ctx) if err != nil { return err @@ -162,6 +166,7 @@ func (br *Bridge) RunOnce(ctx context.Context, loginID networkid.UserLoginID, pa case <-time.After(20 * time.Second): case <-ctx.Done(): } + br.stopping.Store(true) return nil } else { br.Log.Info().Str("user_login_id", string(login.ID)).Msg("Starting individual user login in background mode") @@ -171,6 +176,7 @@ func (br *Bridge) RunOnce(ctx context.Context, loginID networkid.UserLoginID, pa func (br *Bridge) StartConnectors(ctx context.Context) error { br.Log.Info().Msg("Starting bridge") + br.stopping.Store(false) if br.BackgroundCtx == nil || br.BackgroundCtx.Err() != nil { br.BackgroundCtx, br.cancelBackgroundCtx = context.WithCancel(context.Background()) br.BackgroundCtx = br.Log.WithContext(br.BackgroundCtx) @@ -368,6 +374,46 @@ func (br *Bridge) StartLogins(ctx context.Context) error { return nil } +func (br *Bridge) ResetNetworkConnections() { + nrn, ok := br.Network.(NetworkResettingNetwork) + if ok { + br.Log.Info().Msg("Resetting network connections with NetworkConnector.ResetNetworkConnections") + nrn.ResetNetworkConnections() + return + } + + br.Log.Info().Msg("Network connector doesn't support ResetNetworkConnections, recreating clients manually") + for _, login := range br.GetAllCachedUserLogins() { + login.Log.Debug().Msg("Disconnecting and recreating client for network reset") + ctx := login.Log.WithContext(br.BackgroundCtx) + login.Client.Disconnect() + err := login.recreateClient(ctx) + if err != nil { + login.Log.Err(err).Msg("Failed to recreate client during network reset") + login.BridgeState.Send(status.BridgeState{ + StateEvent: status.StateUnknownError, + Error: "bridgev2-network-reset-fail", + Info: map[string]any{"go_error": err.Error()}, + }) + } else { + login.Client.Connect(ctx) + } + } + br.Log.Info().Msg("Finished resetting all user logins") +} + +func (br *Bridge) GetHTTPClientSettings() exhttp.ClientSettings { + mchs, ok := br.Matrix.(MatrixConnectorWithHTTPSettings) + if ok { + return mchs.GetHTTPClientSettings() + } + return exhttp.SensibleClientSettings +} + +func (br *Bridge) IsStopping() bool { + return br.stopping.Load() +} + func (br *Bridge) Stop() { br.stop(false, 0) } @@ -378,6 +424,7 @@ func (br *Bridge) StopWithTimeout(timeout time.Duration) { func (br *Bridge) stop(isRunOnce bool, timeout time.Duration) { br.Log.Info().Msg("Shutting down bridge") + br.stopping.Store(true) br.DisappearLoop.Stop() br.stopBackfillQueue.Set() br.Matrix.PreStop() diff --git a/bridgev2/bridgeconfig/backfill.go b/bridgev2/bridgeconfig/backfill.go index 53282e41..eedae1e8 100644 --- a/bridgev2/bridgeconfig/backfill.go +++ b/bridgev2/bridgeconfig/backfill.go @@ -34,10 +34,12 @@ type BackfillQueueConfig struct { MaxBatchesOverride map[string]int `yaml:"max_batches_override"` } -func (bqc *BackfillQueueConfig) GetOverride(name string) int { - override, ok := bqc.MaxBatchesOverride[name] - if !ok { - return bqc.MaxBatches +func (bqc *BackfillQueueConfig) GetOverride(names ...string) int { + for _, name := range names { + override, ok := bqc.MaxBatchesOverride[name] + if ok { + return override + } } - return override + return bqc.MaxBatches } diff --git a/bridgev2/bridgeconfig/config.go b/bridgev2/bridgeconfig/config.go index 13ec738c..bd6b9c06 100644 --- a/bridgev2/bridgeconfig/config.go +++ b/bridgev2/bridgeconfig/config.go @@ -33,6 +33,8 @@ type Config struct { Encryption EncryptionConfig `yaml:"encryption"` Logging zeroconfig.Config `yaml:"logging"` + EnvConfigPrefix string `yaml:"env_config_prefix"` + ManagementRoomTexts ManagementRoomTexts `yaml:"management_room_texts"` } @@ -60,36 +62,40 @@ type CleanupOnLogouts struct { } type BridgeConfig struct { - CommandPrefix string `yaml:"command_prefix"` - PersonalFilteringSpaces bool `yaml:"personal_filtering_spaces"` - PrivateChatPortalMeta bool `yaml:"private_chat_portal_meta"` - AsyncEvents bool `yaml:"async_events"` - SplitPortals bool `yaml:"split_portals"` - ResendBridgeInfo bool `yaml:"resend_bridge_info"` - NoBridgeInfoStateKey bool `yaml:"no_bridge_info_state_key"` - BridgeStatusNotices string `yaml:"bridge_status_notices"` - UnknownErrorAutoReconnect time.Duration `yaml:"unknown_error_auto_reconnect"` - BridgeMatrixLeave bool `yaml:"bridge_matrix_leave"` - BridgeNotices bool `yaml:"bridge_notices"` - TagOnlyOnCreate bool `yaml:"tag_only_on_create"` - OnlyBridgeTags []event.RoomTag `yaml:"only_bridge_tags"` - MuteOnlyOnCreate bool `yaml:"mute_only_on_create"` - DeduplicateMatrixMessages bool `yaml:"deduplicate_matrix_messages"` - CrossRoomReplies bool `yaml:"cross_room_replies"` - OutgoingMessageReID bool `yaml:"outgoing_message_re_id"` - CleanupOnLogout CleanupOnLogouts `yaml:"cleanup_on_logout"` - Relay RelayConfig `yaml:"relay"` - Permissions PermissionConfig `yaml:"permissions"` - Backfill BackfillConfig `yaml:"backfill"` + CommandPrefix string `yaml:"command_prefix"` + PersonalFilteringSpaces bool `yaml:"personal_filtering_spaces"` + PrivateChatPortalMeta bool `yaml:"private_chat_portal_meta"` + AsyncEvents bool `yaml:"async_events"` + SplitPortals bool `yaml:"split_portals"` + ResendBridgeInfo bool `yaml:"resend_bridge_info"` + NoBridgeInfoStateKey bool `yaml:"no_bridge_info_state_key"` + BridgeStatusNotices string `yaml:"bridge_status_notices"` + UnknownErrorAutoReconnect time.Duration `yaml:"unknown_error_auto_reconnect"` + UnknownErrorMaxAutoReconnects int `yaml:"unknown_error_max_auto_reconnects"` + BridgeMatrixLeave bool `yaml:"bridge_matrix_leave"` + BridgeNotices bool `yaml:"bridge_notices"` + TagOnlyOnCreate bool `yaml:"tag_only_on_create"` + OnlyBridgeTags []event.RoomTag `yaml:"only_bridge_tags"` + MuteOnlyOnCreate bool `yaml:"mute_only_on_create"` + DeduplicateMatrixMessages bool `yaml:"deduplicate_matrix_messages"` + CrossRoomReplies bool `yaml:"cross_room_replies"` + OutgoingMessageReID bool `yaml:"outgoing_message_re_id"` + RevertFailedStateChanges bool `yaml:"revert_failed_state_changes"` + KickMatrixUsers bool `yaml:"kick_matrix_users"` + CleanupOnLogout CleanupOnLogouts `yaml:"cleanup_on_logout"` + Relay RelayConfig `yaml:"relay"` + Permissions PermissionConfig `yaml:"permissions"` + Backfill BackfillConfig `yaml:"backfill"` } type MatrixConfig struct { - MessageStatusEvents bool `yaml:"message_status_events"` - DeliveryReceipts bool `yaml:"delivery_receipts"` - MessageErrorNotices bool `yaml:"message_error_notices"` - SyncDirectChatList bool `yaml:"sync_direct_chat_list"` - FederateRooms bool `yaml:"federate_rooms"` - UploadFileThreshold int64 `yaml:"upload_file_threshold"` + MessageStatusEvents bool `yaml:"message_status_events"` + DeliveryReceipts bool `yaml:"delivery_receipts"` + MessageErrorNotices bool `yaml:"message_error_notices"` + SyncDirectChatList bool `yaml:"sync_direct_chat_list"` + FederateRooms bool `yaml:"federate_rooms"` + UploadFileThreshold int64 `yaml:"upload_file_threshold"` + GhostExtraProfileInfo bool `yaml:"ghost_extra_profile_info"` } type AnalyticsConfig struct { @@ -111,10 +117,12 @@ type DirectMediaConfig struct { } type PublicMediaConfig struct { - Enabled bool `yaml:"enabled"` - SigningKey string `yaml:"signing_key"` - HashLength int `yaml:"hash_length"` - Expiry int `yaml:"expiry"` + Enabled bool `yaml:"enabled"` + SigningKey string `yaml:"signing_key"` + Expiry int `yaml:"expiry"` + HashLength int `yaml:"hash_length"` + PathPrefix string `yaml:"path_prefix"` + UseDatabase bool `yaml:"use_database"` } type DoublePuppetConfig struct { diff --git a/bridgev2/bridgeconfig/encryption.go b/bridgev2/bridgeconfig/encryption.go index 5a19b3ad..934613ca 100644 --- a/bridgev2/bridgeconfig/encryption.go +++ b/bridgev2/bridgeconfig/encryption.go @@ -16,6 +16,7 @@ type EncryptionConfig struct { Require bool `yaml:"require"` Appservice bool `yaml:"appservice"` MSC4190 bool `yaml:"msc4190"` + MSC4392 bool `yaml:"msc4392"` SelfSign bool `yaml:"self_sign"` PlaintextMentions bool `yaml:"plaintext_mentions"` diff --git a/bridgev2/bridgeconfig/permissions.go b/bridgev2/bridgeconfig/permissions.go index 610051e0..9efe068e 100644 --- a/bridgev2/bridgeconfig/permissions.go +++ b/bridgev2/bridgeconfig/permissions.go @@ -24,6 +24,7 @@ type Permissions struct { DoublePuppet bool `yaml:"double_puppet"` Admin bool `yaml:"admin"` ManageRelay bool `yaml:"manage_relay"` + MaxLogins int `yaml:"max_logins"` } type PermissionConfig map[string]*Permissions @@ -40,10 +41,7 @@ func (pc PermissionConfig) IsConfigured() bool { _, hasExampleDomain := pc["example.com"] _, hasExampleUser := pc["@admin:example.com"] exampleLen := boolToInt(hasWildcard) + boolToInt(hasExampleUser) + boolToInt(hasExampleDomain) - if len(pc) <= exampleLen { - return false - } - return true + return len(pc) > exampleLen } func (pc PermissionConfig) Get(userID id.UserID) Permissions { diff --git a/bridgev2/bridgeconfig/upgrade.go b/bridgev2/bridgeconfig/upgrade.go index 6533338f..92515ea0 100644 --- a/bridgev2/bridgeconfig/upgrade.go +++ b/bridgev2/bridgeconfig/upgrade.go @@ -33,6 +33,7 @@ func doUpgrade(helper up.Helper) { helper.Copy(up.Bool, "bridge", "no_bridge_info_state_key") helper.Copy(up.Str|up.Null, "bridge", "bridge_status_notices") helper.Copy(up.Str|up.Int|up.Null, "bridge", "unknown_error_auto_reconnect") + helper.Copy(up.Int, "bridge", "unknown_error_max_auto_reconnects") helper.Copy(up.Bool, "bridge", "bridge_matrix_leave") helper.Copy(up.Bool, "bridge", "bridge_notices") helper.Copy(up.Bool, "bridge", "tag_only_on_create") @@ -40,6 +41,8 @@ func doUpgrade(helper up.Helper) { helper.Copy(up.Bool, "bridge", "mute_only_on_create") helper.Copy(up.Bool, "bridge", "deduplicate_matrix_messages") helper.Copy(up.Bool, "bridge", "cross_room_replies") + helper.Copy(up.Bool, "bridge", "revert_failed_state_changes") + helper.Copy(up.Bool, "bridge", "kick_matrix_users") helper.Copy(up.Bool, "bridge", "cleanup_on_logout", "enabled") helper.Copy(up.Str, "bridge", "cleanup_on_logout", "manual", "private") helper.Copy(up.Str, "bridge", "cleanup_on_logout", "manual", "relayed") @@ -98,6 +101,7 @@ func doUpgrade(helper up.Helper) { helper.Copy(up.Bool, "matrix", "sync_direct_chat_list") helper.Copy(up.Bool, "matrix", "federate_rooms") helper.Copy(up.Int, "matrix", "upload_file_threshold") + helper.Copy(up.Bool, "matrix", "ghost_extra_profile_info") helper.Copy(up.Str|up.Null, "analytics", "token") helper.Copy(up.Str|up.Null, "analytics", "url") @@ -132,6 +136,8 @@ func doUpgrade(helper up.Helper) { } helper.Copy(up.Int, "public_media", "expiry") helper.Copy(up.Int, "public_media", "hash_length") + helper.Copy(up.Str|up.Null, "public_media", "path_prefix") + helper.Copy(up.Bool, "public_media", "use_database") helper.Copy(up.Bool, "backfill", "enabled") helper.Copy(up.Int, "backfill", "max_initial_messages") @@ -157,6 +163,7 @@ func doUpgrade(helper up.Helper) { } else { helper.Copy(up.Bool, "encryption", "msc4190") } + helper.Copy(up.Bool, "encryption", "msc4392") helper.Copy(up.Bool, "encryption", "self_sign") helper.Copy(up.Bool, "encryption", "allow_key_sharing") if secret, ok := helper.Get(up.Str, "encryption", "pickle_key"); !ok || secret == "generate" { @@ -180,6 +187,8 @@ func doUpgrade(helper up.Helper) { helper.Copy(up.Int, "encryption", "rotation", "messages") helper.Copy(up.Bool, "encryption", "rotation", "disable_device_change_key_rotation") + helper.Copy(up.Str|up.Null, "env_config_prefix") + helper.Copy(up.Map, "logging") } @@ -207,6 +216,7 @@ var SpacedBlocks = [][]string{ {"backfill"}, {"double_puppet"}, {"encryption"}, + {"env_config_prefix"}, {"logging"}, } diff --git a/bridgev2/bridgestate.go b/bridgev2/bridgestate.go index f31d4e92..96d9fd5c 100644 --- a/bridgev2/bridgestate.go +++ b/bridgev2/bridgestate.go @@ -15,12 +15,15 @@ import ( "time" "github.com/rs/zerolog" + "go.mau.fi/util/exfmt" "maunium.net/go/mautrix/bridgev2/status" "maunium.net/go/mautrix/event" "maunium.net/go/mautrix/format" ) +var CatchBridgeStateQueuePanics = true + type BridgeStateQueue struct { prevUnsent *status.BridgeState prevSent *status.BridgeState @@ -29,8 +32,13 @@ type BridgeStateQueue struct { bridge *Bridge login *UserLogin + firstTransientDisconnect time.Time + cancelScheduledNotice atomic.Pointer[context.CancelFunc] + stopChan chan struct{} stopReconnect atomic.Pointer[context.CancelFunc] + + unknownErrorReconnects int } func (br *Bridge) SendGlobalBridgeState(state status.BridgeState) { @@ -74,31 +82,63 @@ func (bsq *BridgeStateQueue) StopUnknownErrorReconnect() { if cancelFn := bsq.stopReconnect.Swap(nil); cancelFn != nil { (*cancelFn)() } + if cancelFn := bsq.cancelScheduledNotice.Swap(nil); cancelFn != nil { + (*cancelFn)() + } } func (bsq *BridgeStateQueue) loop() { - defer func() { - err := recover() - if err != nil { - bsq.login.Log.Error(). - Bytes(zerolog.ErrorStackFieldName, debug.Stack()). - Any(zerolog.ErrorFieldName, err). - Msg("Panic in bridge state loop") - } - }() + if CatchBridgeStateQueuePanics { + defer func() { + err := recover() + if err != nil { + bsq.login.Log.Error(). + Bytes(zerolog.ErrorStackFieldName, debug.Stack()). + Any(zerolog.ErrorFieldName, err). + Msg("Panic in bridge state loop") + } + }() + } for state := range bsq.ch { bsq.immediateSendBridgeState(state) } } -func (bsq *BridgeStateQueue) sendNotice(ctx context.Context, state status.BridgeState) { +func (bsq *BridgeStateQueue) scheduleNotice(triggeredBy status.BridgeState) { + log := bsq.login.Log.With().Str("action", "transient disconnect notice").Logger() + ctx := log.WithContext(bsq.bridge.BackgroundCtx) + if !bsq.waitForTransientDisconnectReconnect(ctx) { + return + } + prevUnsent := bsq.GetPrevUnsent() + prev := bsq.GetPrev() + if triggeredBy.Timestamp != prev.Timestamp || len(bsq.ch) > 0 || bsq.errorSent || + prevUnsent.StateEvent != status.StateTransientDisconnect || prev.StateEvent != status.StateTransientDisconnect { + log.Trace().Any("triggered_by", triggeredBy).Msg("Not sending delayed transient disconnect notice") + return + } + log.Debug().Any("triggered_by", triggeredBy).Msg("Sending delayed transient disconnect notice") + bsq.sendNotice(ctx, triggeredBy, true) +} + +func (bsq *BridgeStateQueue) sendNotice(ctx context.Context, state status.BridgeState, isDelayed bool) { noticeConfig := bsq.bridge.Config.BridgeStatusNotices isError := state.StateEvent == status.StateBadCredentials || state.StateEvent == status.StateUnknownError || - state.UserAction == status.UserActionOpenNative + state.UserAction == status.UserActionOpenNative || + (isDelayed && state.StateEvent == status.StateTransientDisconnect) sendNotice := noticeConfig == "all" || (noticeConfig == "errors" && (isError || (bsq.errorSent && state.StateEvent == status.StateConnected))) + if state.StateEvent != status.StateTransientDisconnect && state.StateEvent != status.StateUnknownError { + bsq.firstTransientDisconnect = time.Time{} + } if !sendNotice { + if !bsq.errorSent && !isDelayed && noticeConfig == "errors" && state.StateEvent == status.StateTransientDisconnect { + if bsq.firstTransientDisconnect.IsZero() { + bsq.firstTransientDisconnect = time.Now() + } + go bsq.scheduleNotice(state) + } return } managementRoom, err := bsq.login.User.GetManagementRoom(ctx) @@ -114,6 +154,9 @@ func (bsq *BridgeStateQueue) sendNotice(ctx context.Context, state status.Bridge if state.Error != "" { message += fmt.Sprintf(" (`%s`)", state.Error) } + if isDelayed { + message += fmt.Sprintf(" not resolved after waiting %s", exfmt.Duration(TransientDisconnectNoticeDelay)) + } if state.Message != "" { message += fmt.Sprintf(": %s", state.Message) } @@ -151,8 +194,14 @@ func (bsq *BridgeStateQueue) unknownErrorReconnect(triggeredBy status.BridgeStat } else if prevUnsent.StateEvent != status.StateUnknownError || prev.StateEvent != status.StateUnknownError { log.Debug().Msg("Not reconnecting as the previous state was not an unknown error") return + } else if bsq.unknownErrorReconnects > bsq.bridge.Config.UnknownErrorMaxAutoReconnects { + log.Warn().Msg("Not reconnecting as the maximum number of unknown error reconnects has been reached") + return } - log.Info().Msg("Disconnecting and reconnecting login due to unknown error") + bsq.unknownErrorReconnects++ + log.Info(). + Int("reconnect_num", bsq.unknownErrorReconnects). + Msg("Disconnecting and reconnecting login due to unknown error") bsq.login.Disconnect() log.Debug().Msg("Disconnection finished, recreating client and reconnecting") err := bsq.login.recreateClient(ctx) @@ -171,14 +220,30 @@ func (bsq *BridgeStateQueue) waitForUnknownErrorReconnect(ctx context.Context) b return false } reconnectIn += time.Duration(rand.Int64N(int64(float64(reconnectIn)*0.4)) - int64(float64(reconnectIn)*0.2)) + return bsq.waitForReconnect(ctx, reconnectIn, &bsq.stopReconnect) +} + +const TransientDisconnectNoticeDelay = 3 * time.Minute + +func (bsq *BridgeStateQueue) waitForTransientDisconnectReconnect(ctx context.Context) bool { + timeUntilSchedule := time.Until(bsq.firstTransientDisconnect.Add(TransientDisconnectNoticeDelay)) + zerolog.Ctx(ctx).Trace(). + Stringer("duration", timeUntilSchedule). + Msg("Waiting before sending notice about transient disconnect") + return bsq.waitForReconnect(ctx, timeUntilSchedule, &bsq.cancelScheduledNotice) +} + +func (bsq *BridgeStateQueue) waitForReconnect( + ctx context.Context, reconnectIn time.Duration, ptr *atomic.Pointer[context.CancelFunc], +) bool { cancelCtx, cancel := context.WithCancel(ctx) defer cancel() - if oldCancel := bsq.stopReconnect.Swap(&cancel); oldCancel != nil { + if oldCancel := ptr.Swap(&cancel); oldCancel != nil { (*oldCancel)() } select { case <-time.After(reconnectIn): - return bsq.stopReconnect.CompareAndSwap(&cancel, nil) + return ptr.CompareAndSwap(&cancel, nil) case <-cancelCtx.Done(): return false case <-bsq.stopChan: @@ -198,7 +263,7 @@ func (bsq *BridgeStateQueue) immediateSendBridgeState(state status.BridgeState) } ctx := bsq.login.Log.WithContext(context.Background()) - bsq.sendNotice(ctx, state) + bsq.sendNotice(ctx, state, false) retryIn := 2 for { diff --git a/bridgev2/commands/debug.go b/bridgev2/commands/debug.go index 4c93dbd4..1cae98fe 100644 --- a/bridgev2/commands/debug.go +++ b/bridgev2/commands/debug.go @@ -7,10 +7,13 @@ package commands import ( + "encoding/json" "strings" + "time" "maunium.net/go/mautrix/bridgev2" "maunium.net/go/mautrix/bridgev2/networkid" + "maunium.net/go/mautrix/event" ) var CommandRegisterPush = &FullHandler{ @@ -59,3 +62,64 @@ var CommandRegisterPush = &FullHandler{ RequiresLogin: true, NetworkAPI: NetworkAPIImplements[bridgev2.PushableNetworkAPI], } + +var CommandSendAccountData = &FullHandler{ + Func: func(ce *Event) { + if len(ce.Args) < 2 { + ce.Reply("Usage: `$cmdprefix debug-account-data ") + return + } + var content event.Content + evtType := event.Type{Type: ce.Args[0], Class: event.AccountDataEventType} + ce.RawArgs = strings.TrimSpace(strings.Trim(ce.RawArgs, ce.Args[0])) + err := json.Unmarshal([]byte(ce.RawArgs), &content) + if err != nil { + ce.Reply("Failed to parse JSON: %v", err) + return + } + err = content.ParseRaw(evtType) + if err != nil { + ce.Reply("Failed to deserialize content: %v", err) + return + } + res := ce.Bridge.QueueMatrixEvent(ce.Ctx, &event.Event{ + Sender: ce.User.MXID, + Type: evtType, + Timestamp: time.Now().UnixMilli(), + RoomID: ce.RoomID, + Content: content, + }) + ce.Reply("Result: %+v", res) + }, + Name: "debug-account-data", + Help: HelpMeta{ + Section: HelpSectionAdmin, + Description: "Send a room account data event to the bridge", + Args: "<_type_> <_content_>", + }, + RequiresAdmin: true, + RequiresPortal: true, + RequiresLogin: true, +} + +var CommandResetNetwork = &FullHandler{ + Func: func(ce *Event) { + if strings.Contains(strings.ToLower(ce.RawArgs), "--reset-transport") { + nrn, ok := ce.Bridge.Network.(bridgev2.NetworkResettingNetwork) + if ok { + nrn.ResetHTTPTransport() + } else { + ce.Reply("Network connector does not support resetting HTTP transport") + } + } + ce.Bridge.ResetNetworkConnections() + ce.React("✅️") + }, + Name: "debug-reset-network", + Help: HelpMeta{ + Section: HelpSectionAdmin, + Description: "Reset network connections to the remote network", + Args: "[--reset-transport]", + }, + RequiresAdmin: true, +} diff --git a/bridgev2/commands/login.go b/bridgev2/commands/login.go index a18564c2..96d62d3e 100644 --- a/bridgev2/commands/login.go +++ b/bridgev2/commands/login.go @@ -70,6 +70,15 @@ func fnLogin(ce *Event) { } ce.Args = ce.Args[1:] } + if reauth == nil && ce.User.HasTooManyLogins() { + ce.Reply( + "You have reached the maximum number of logins (%d). "+ + "Please logout from an existing login before creating a new one. "+ + "If you want to re-authenticate an existing login, use the `$cmdprefix relogin` command.", + ce.User.Permissions.MaxLogins, + ) + return + } flows := ce.Bridge.Network.GetLoginFlows() var chosenFlowID string if len(ce.Args) > 0 { @@ -112,6 +121,7 @@ func fnLogin(ce *Event) { ce.Reply("Failed to start login: %v", err) return } + ce.Log.Debug().Any("first_step", nextStep).Msg("Created login process") nextStep = checkLoginCommandDirectParams(ce, login, nextStep) if nextStep != nil { @@ -190,11 +200,14 @@ type userInputLoginCommandState struct { func (uilcs *userInputLoginCommandState) promptNext(ce *Event) { field := uilcs.RemainingFields[0] + parts := []string{fmt.Sprintf("Please enter your %s", field.Name)} if field.Description != "" { - ce.Reply("Please enter your %s\n%s", field.Name, field.Description) - } else { - ce.Reply("Please enter your %s", field.Name) + parts = append(parts, field.Description) } + if len(field.Options) > 0 { + parts = append(parts, fmt.Sprintf("Options: `%s`", strings.Join(field.Options, "`, `"))) + } + ce.Reply(strings.Join(parts, "\n")) StoreCommandState(ce.User, &CommandState{ Next: MinimalCommandHandlerFunc(uilcs.submitNext), Action: "Login", @@ -239,14 +252,19 @@ func sendQR(ce *Event, qr string, prevEventID *id.EventID) error { return fmt.Errorf("failed to upload image: %w", err) } content := &event.MessageEventContent{ - MsgType: event.MsgImage, - FileName: "qr.png", - URL: qrMXC, - File: qrFile, - + MsgType: event.MsgImage, + FileName: "qr.png", + URL: qrMXC, + File: qrFile, Body: qr, Format: event.FormatHTML, FormattedBody: fmt.Sprintf("
%s
", html.EscapeString(qr)), + Info: &event.FileInfo{ + MimeType: "image/png", + Width: qrSizePx, + Height: qrSizePx, + Size: len(qrData), + }, } if *prevEventID != "" { content.SetEdit(*prevEventID) @@ -261,6 +279,36 @@ func sendQR(ce *Event, qr string, prevEventID *id.EventID) error { return nil } +func sendUserInputAttachments(ce *Event, atts []*bridgev2.LoginUserInputAttachment) error { + for _, att := range atts { + if att.FileName == "" { + return fmt.Errorf("missing attachment filename") + } + mxc, file, err := ce.Bot.UploadMedia(ce.Ctx, ce.RoomID, att.Content, att.FileName, att.Info.MimeType) + if err != nil { + return fmt.Errorf("failed to upload attachment %q: %w", att.FileName, err) + } + content := &event.MessageEventContent{ + MsgType: att.Type, + FileName: att.FileName, + URL: mxc, + File: file, + Info: &event.FileInfo{ + MimeType: att.Info.MimeType, + Width: att.Info.Width, + Height: att.Info.Height, + Size: att.Info.Size, + }, + Body: att.FileName, + } + _, err = ce.Bot.SendMessage(ce.Ctx, ce.RoomID, event.EventMessage, &event.Content{Parsed: content}, nil) + if err != nil { + return nil + } + } + return nil +} + type contextKey int const ( @@ -452,6 +500,7 @@ func maybeURLDecodeCookie(val string, field *bridgev2.LoginCookieField) string { } func doLoginStep(ce *Event, login bridgev2.LoginProcess, step *bridgev2.LoginStep, override *bridgev2.UserLogin) { + ce.Log.Debug().Any("next_step", step).Msg("Got next login step") if step.Instructions != "" { ce.Reply(step.Instructions) } @@ -466,6 +515,10 @@ func doLoginStep(ce *Event, login bridgev2.LoginProcess, step *bridgev2.LoginSte Override: override, }).prompt(ce) case bridgev2.LoginStepTypeUserInput: + err := sendUserInputAttachments(ce, step.UserInputParams.Attachments) + if err != nil { + ce.Reply("Failed to send attachments: %v", err) + } (&userInputLoginCommandState{ Login: login.(bridgev2.LoginProcessUserInput), RemainingFields: step.UserInputParams.Fields, diff --git a/bridgev2/commands/processor.go b/bridgev2/commands/processor.go index c28e3a32..391c3685 100644 --- a/bridgev2/commands/processor.go +++ b/bridgev2/commands/processor.go @@ -41,10 +41,11 @@ func NewProcessor(bridge *bridgev2.Bridge) bridgev2.CommandProcessor { } proc.AddHandlers( CommandHelp, CommandCancel, - CommandRegisterPush, CommandDeletePortal, CommandDeleteAllPortals, CommandSetManagementRoom, + CommandRegisterPush, CommandSendAccountData, CommandResetNetwork, + CommandDeletePortal, CommandDeleteAllPortals, CommandSetManagementRoom, CommandLogin, CommandRelogin, CommandListLogins, CommandLogout, CommandSetPreferredLogin, CommandSetRelay, CommandUnsetRelay, - CommandResolveIdentifier, CommandStartChat, CommandSearch, + CommandResolveIdentifier, CommandStartChat, CommandCreateGroup, CommandSearch, CommandSyncChat, CommandMute, CommandSudo, CommandDoIn, ) return proc diff --git a/bridgev2/commands/relay.go b/bridgev2/commands/relay.go index af756c87..94c19739 100644 --- a/bridgev2/commands/relay.go +++ b/bridgev2/commands/relay.go @@ -37,7 +37,7 @@ func fnSetRelay(ce *Event) { } onlySetDefaultRelays := !ce.User.Permissions.Admin && ce.Bridge.Config.Relay.AdminOnly var relay *bridgev2.UserLogin - if len(ce.Args) == 0 { + if len(ce.Args) == 0 && ce.Portal.Receiver == "" { relay = ce.User.GetDefaultLogin() isLoggedIn := relay != nil if onlySetDefaultRelays { @@ -73,9 +73,19 @@ func fnSetRelay(ce *Event) { } } } else { - relay = ce.Bridge.GetCachedUserLoginByID(networkid.UserLoginID(ce.Args[0])) + var targetID networkid.UserLoginID + if ce.Portal.Receiver != "" { + targetID = ce.Portal.Receiver + if len(ce.Args) > 0 && ce.Args[0] != string(targetID) { + ce.Reply("In split portals, only the receiver (%s) can be set as relay", targetID) + return + } + } else { + targetID = networkid.UserLoginID(ce.Args[0]) + } + relay = ce.Bridge.GetCachedUserLoginByID(targetID) if relay == nil { - ce.Reply("User login with ID `%s` not found", ce.Args[0]) + ce.Reply("User login with ID `%s` not found", targetID) return } else if slices.Contains(ce.Bridge.Config.Relay.DefaultRelays, relay.ID) { // All good diff --git a/bridgev2/commands/startchat.go b/bridgev2/commands/startchat.go index 7b755064..c7b05a6e 100644 --- a/bridgev2/commands/startchat.go +++ b/bridgev2/commands/startchat.go @@ -8,11 +8,13 @@ package commands import ( "context" + "errors" "fmt" "html" "maps" "slices" "strings" + "time" "github.com/rs/zerolog" @@ -20,6 +22,7 @@ import ( "maunium.net/go/mautrix/bridgev2/networkid" "maunium.net/go/mautrix/bridgev2/provisionutil" "maunium.net/go/mautrix/event" + "maunium.net/go/mautrix/format" "maunium.net/go/mautrix/id" ) @@ -35,6 +38,35 @@ var CommandResolveIdentifier = &FullHandler{ NetworkAPI: NetworkAPIImplements[bridgev2.IdentifierResolvingNetworkAPI], } +var CommandSyncChat = &FullHandler{ + Func: func(ce *Event) { + login, _, err := ce.Portal.FindPreferredLogin(ce.Ctx, ce.User, false) + if err != nil { + ce.Log.Err(err).Msg("Failed to find login for sync") + ce.Reply("Failed to find login: %v", err) + return + } else if login == nil { + ce.Reply("No login found for sync") + return + } + info, err := login.Client.GetChatInfo(ce.Ctx, ce.Portal) + if err != nil { + ce.Log.Err(err).Msg("Failed to get chat info for sync") + ce.Reply("Failed to get chat info: %v", err) + return + } + ce.Portal.UpdateInfo(ce.Ctx, info, login, nil, time.Time{}) + ce.React("✅️") + }, + Name: "sync-portal", + Help: HelpMeta{ + Section: HelpSectionChats, + Description: "Sync the current portal room", + }, + RequiresPortal: true, + RequiresLogin: true, +} + var CommandStartChat = &FullHandler{ Func: fnResolveIdentifier, Name: "start-chat", @@ -48,9 +80,15 @@ var CommandStartChat = &FullHandler{ NetworkAPI: NetworkAPIImplements[bridgev2.IdentifierResolvingNetworkAPI], } -func getClientForStartingChat[T bridgev2.IdentifierResolvingNetworkAPI](ce *Event, thing string) (*bridgev2.UserLogin, T, []string) { - remainingArgs := ce.Args[1:] - login := ce.Bridge.GetCachedUserLoginByID(networkid.UserLoginID(ce.Args[0])) +func getClientForStartingChat[T bridgev2.NetworkAPI](ce *Event, thing string) (*bridgev2.UserLogin, T, []string) { + var remainingArgs []string + if len(ce.Args) > 1 { + remainingArgs = ce.Args[1:] + } + var login *bridgev2.UserLogin + if len(ce.Args) > 0 { + login = ce.Bridge.GetCachedUserLoginByID(networkid.UserLoginID(ce.Args[0])) + } if login == nil || login.UserMXID != ce.User.MXID { remainingArgs = ce.Args login = ce.User.GetDefaultLogin() @@ -81,9 +119,13 @@ func fnResolveIdentifier(ce *Event) { if api == nil { return } + allLogins := ce.User.GetUserLogins() createChat := ce.Command == "start-chat" || ce.Command == "pm" identifier := strings.Join(identifierParts, " ") resp, err := provisionutil.ResolveIdentifier(ce.Ctx, login, identifier, createChat) + for i := 0; i < len(allLogins) && errors.Is(err, bridgev2.ErrResolveIdentifierTryNext); i++ { + resp, err = provisionutil.ResolveIdentifier(ce.Ctx, allLogins[i], identifier, createChat) + } if err != nil { ce.Reply("Failed to resolve identifier: %v", err) return @@ -195,7 +237,17 @@ func fnCreateGroup(ce *Event) { ce.Reply("Failed to create group: %v", err) return } - ce.Reply("Successfully created group `%s`", resp.ID) + var postfix string + if len(resp.FailedParticipants) > 0 { + failedParticipantsStrings := make([]string, len(resp.FailedParticipants)) + i := 0 + for participantID, meta := range resp.FailedParticipants { + failedParticipantsStrings[i] = fmt.Sprintf("* %s: %s", format.SafeMarkdownCode(participantID), meta.Reason) + i++ + } + postfix += "\n\nFailed to add some participants:\n" + strings.Join(failedParticipantsStrings, "\n") + } + ce.Reply("Successfully created group `%s`%s", resp.ID, postfix) } var CommandSearch = &FullHandler{ @@ -238,3 +290,44 @@ func fnSearch(ce *Event) { } ce.Reply("Search results:\n\n%s", strings.Join(resultsString, "\n")) } + +var CommandMute = &FullHandler{ + Func: fnMute, + Name: "mute", + Aliases: []string{"unmute"}, + Help: HelpMeta{ + Section: HelpSectionChats, + Description: "Mute or unmute a chat on the remote network", + Args: "[duration]", + }, + RequiresPortal: true, + RequiresLogin: true, + NetworkAPI: NetworkAPIImplements[bridgev2.MuteHandlingNetworkAPI], +} + +func fnMute(ce *Event) { + _, api, _ := getClientForStartingChat[bridgev2.MuteHandlingNetworkAPI](ce, "muting chats") + var mutedUntil int64 + if ce.Command == "mute" { + mutedUntil = -1 + if len(ce.Args) > 0 { + duration, err := time.ParseDuration(ce.Args[0]) + if err != nil { + ce.Reply("Invalid duration: %v", err) + return + } + mutedUntil = time.Now().Add(duration).UnixMilli() + } + } + err := api.HandleMute(ce.Ctx, &bridgev2.MatrixMute{ + MatrixEventBase: bridgev2.MatrixEventBase[*event.BeeperMuteEventContent]{ + Content: &event.BeeperMuteEventContent{MutedUntil: mutedUntil}, + Portal: ce.Portal, + }, + }) + if err != nil { + ce.Reply("Failed to %s chat: %v", ce.Command, err) + } else { + ce.React("✅️") + } +} diff --git a/bridgev2/database/database.go b/bridgev2/database/database.go index f1789441..05abddf0 100644 --- a/bridgev2/database/database.go +++ b/bridgev2/database/database.go @@ -7,13 +7,7 @@ package database import ( - "encoding/json" - "reflect" - "strings" - "go.mau.fi/util/dbutil" - "golang.org/x/exp/constraints" - "golang.org/x/exp/maps" "maunium.net/go/mautrix/bridgev2/networkid" @@ -34,6 +28,7 @@ type Database struct { UserPortal *UserPortalQuery BackfillTask *BackfillTaskQuery KV *KVQuery + PublicMedia *PublicMediaQuery } type MetaMerger interface { @@ -141,6 +136,12 @@ func New(bridgeID networkid.BridgeID, mt MetaTypes, db *dbutil.Database) *Databa BridgeID: bridgeID, Database: db, }, + PublicMedia: &PublicMediaQuery{ + BridgeID: bridgeID, + QueryHelper: dbutil.MakeQueryHelper(db, func(_ *dbutil.QueryHelper[*PublicMedia]) *PublicMedia { + return &PublicMedia{} + }), + }, } } @@ -151,55 +152,3 @@ func ensureBridgeIDMatches(ptr *networkid.BridgeID, expected networkid.BridgeID) panic("bridge ID mismatch") } } - -func GetNumberFromMap[T constraints.Integer | constraints.Float](m map[string]any, key string) (T, bool) { - if val, found := m[key]; found { - floatVal, ok := val.(float64) - if ok { - return T(floatVal), true - } - tVal, ok := val.(T) - if ok { - return tVal, true - } - } - return 0, false -} - -func unmarshalMerge(input []byte, data any, extra *map[string]any) error { - err := json.Unmarshal(input, data) - if err != nil { - return err - } - err = json.Unmarshal(input, extra) - if err != nil { - return err - } - if *extra == nil { - *extra = make(map[string]any) - } - return nil -} - -func marshalMerge(data any, extra map[string]any) ([]byte, error) { - if extra == nil { - return json.Marshal(data) - } - merged := make(map[string]any) - maps.Copy(merged, extra) - dataRef := reflect.ValueOf(data).Elem() - dataType := dataRef.Type() - for _, field := range reflect.VisibleFields(dataType) { - parts := strings.Split(field.Tag.Get("json"), ",") - if len(parts) == 0 || len(parts[0]) == 0 || parts[0] == "-" { - continue - } - fieldVal := dataRef.FieldByIndex(field.Index) - if fieldVal.IsZero() { - delete(merged, parts[0]) - } else { - merged[parts[0]] = fieldVal.Interface() - } - } - return json.Marshal(merged) -} diff --git a/bridgev2/database/disappear.go b/bridgev2/database/disappear.go index 9874e472..df36b205 100644 --- a/bridgev2/database/disappear.go +++ b/bridgev2/database/disappear.go @@ -37,6 +37,16 @@ type DisappearingSetting struct { DisappearAt time.Time } +func DisappearingSettingFromEvent(evt *event.BeeperDisappearingTimer) DisappearingSetting { + if evt == nil || evt.Type == event.DisappearingTypeNone { + return DisappearingSetting{} + } + return DisappearingSetting{ + Type: evt.Type, + Timer: evt.Timer.Duration, + } +} + func (ds DisappearingSetting) Normalize() DisappearingSetting { if ds.Type == event.DisappearingTypeNone { ds.Timer = 0 @@ -67,26 +77,27 @@ type DisappearingMessageQuery struct { } type DisappearingMessage struct { - BridgeID networkid.BridgeID - RoomID id.RoomID - EventID id.EventID + BridgeID networkid.BridgeID + RoomID id.RoomID + EventID id.EventID + Timestamp time.Time DisappearingSetting } const ( upsertDisappearingMessageQuery = ` - INSERT INTO disappearing_message (bridge_id, mx_room, mxid, type, timer, disappear_at) - VALUES ($1, $2, $3, $4, $5, $6) + INSERT INTO disappearing_message (bridge_id, mx_room, mxid, timestamp, type, timer, disappear_at) + VALUES ($1, $2, $3, $4, $5, $6, $7) ON CONFLICT (bridge_id, mxid) DO UPDATE SET timer=excluded.timer, disappear_at=excluded.disappear_at ` startDisappearingMessagesQuery = ` UPDATE disappearing_message SET disappear_at=$1 + timer - WHERE bridge_id=$2 AND mx_room=$3 AND disappear_at IS NULL AND type='after_read' - RETURNING bridge_id, mx_room, mxid, type, timer, disappear_at + WHERE bridge_id=$2 AND mx_room=$3 AND disappear_at IS NULL AND type='after_read' AND timestamp<=$4 + RETURNING bridge_id, mx_room, mxid, timestamp, type, timer, disappear_at ` getUpcomingDisappearingMessagesQuery = ` - SELECT bridge_id, mx_room, mxid, type, timer, disappear_at + SELECT bridge_id, mx_room, mxid, timestamp, type, timer, disappear_at FROM disappearing_message WHERE bridge_id = $1 AND disappear_at IS NOT NULL AND disappear_at < $2 ORDER BY disappear_at LIMIT $3 ` @@ -100,8 +111,8 @@ func (dmq *DisappearingMessageQuery) Put(ctx context.Context, dm *DisappearingMe return dmq.Exec(ctx, upsertDisappearingMessageQuery, dm.sqlVariables()...) } -func (dmq *DisappearingMessageQuery) StartAll(ctx context.Context, roomID id.RoomID) ([]*DisappearingMessage, error) { - return dmq.QueryMany(ctx, startDisappearingMessagesQuery, time.Now().UnixNano(), dmq.BridgeID, roomID) +func (dmq *DisappearingMessageQuery) StartAllBefore(ctx context.Context, roomID id.RoomID, beforeTS time.Time) ([]*DisappearingMessage, error) { + return dmq.QueryMany(ctx, startDisappearingMessagesQuery, time.Now().UnixNano(), dmq.BridgeID, roomID, beforeTS.UnixNano()) } func (dmq *DisappearingMessageQuery) GetUpcoming(ctx context.Context, duration time.Duration, limit int) ([]*DisappearingMessage, error) { @@ -113,17 +124,19 @@ func (dmq *DisappearingMessageQuery) Delete(ctx context.Context, eventID id.Even } func (d *DisappearingMessage) Scan(row dbutil.Scannable) (*DisappearingMessage, error) { + var timestamp int64 var disappearAt sql.NullInt64 - err := row.Scan(&d.BridgeID, &d.RoomID, &d.EventID, &d.Type, &d.Timer, &disappearAt) + err := row.Scan(&d.BridgeID, &d.RoomID, &d.EventID, ×tamp, &d.Type, &d.Timer, &disappearAt) if err != nil { return nil, err } if disappearAt.Valid { d.DisappearAt = time.Unix(0, disappearAt.Int64) } + d.Timestamp = time.Unix(0, timestamp) return d, nil } func (d *DisappearingMessage) sqlVariables() []any { - return []any{d.BridgeID, d.RoomID, d.EventID, d.Type, d.Timer, dbutil.ConvertedPtr(d.DisappearAt, time.Time.UnixNano)} + return []any{d.BridgeID, d.RoomID, d.EventID, d.Timestamp.UnixNano(), d.Type, d.Timer, dbutil.ConvertedPtr(d.DisappearAt, time.Time.UnixNano)} } diff --git a/bridgev2/database/ghost.go b/bridgev2/database/ghost.go index c32929ad..16af35ca 100644 --- a/bridgev2/database/ghost.go +++ b/bridgev2/database/ghost.go @@ -7,12 +7,17 @@ package database import ( + "bytes" "context" "encoding/hex" + "encoding/json" + "fmt" "go.mau.fi/util/dbutil" + "go.mau.fi/util/exerrors" "maunium.net/go/mautrix/bridgev2/networkid" + "maunium.net/go/mautrix/crypto/canonicaljson" "maunium.net/go/mautrix/id" ) @@ -22,6 +27,55 @@ type GhostQuery struct { *dbutil.QueryHelper[*Ghost] } +type ExtraProfile map[string]json.RawMessage + +func (ep *ExtraProfile) Set(key string, value any) error { + if key == "displayname" || key == "avatar_url" { + return fmt.Errorf("cannot set reserved profile key %q", key) + } + marshaled, err := json.Marshal(value) + if err != nil { + return err + } + if *ep == nil { + *ep = make(ExtraProfile) + } + (*ep)[key] = canonicaljson.CanonicalJSONAssumeValid(marshaled) + return nil +} + +func (ep *ExtraProfile) With(key string, value any) *ExtraProfile { + exerrors.PanicIfNotNil(ep.Set(key, value)) + return ep +} + +func canonicalizeIfObject(data json.RawMessage) json.RawMessage { + if len(data) > 0 && (data[0] == '{' || data[0] == '[') { + return canonicaljson.CanonicalJSONAssumeValid(data) + } + return data +} + +func (ep *ExtraProfile) CopyTo(dest *ExtraProfile) (changed bool) { + if len(*ep) == 0 { + return + } + if *dest == nil { + *dest = make(ExtraProfile) + } + for key, val := range *ep { + if key == "displayname" || key == "avatar_url" { + continue + } + existing, exists := (*dest)[key] + if !exists || !bytes.Equal(canonicalizeIfObject(existing), val) { + (*dest)[key] = val + changed = true + } + } + return +} + type Ghost struct { BridgeID networkid.BridgeID ID networkid.UserID @@ -35,13 +89,14 @@ type Ghost struct { ContactInfoSet bool IsBot bool Identifiers []string + ExtraProfile ExtraProfile Metadata any } const ( getGhostBaseQuery = ` SELECT bridge_id, id, name, avatar_id, avatar_hash, avatar_mxc, - name_set, avatar_set, contact_info_set, is_bot, identifiers, metadata + name_set, avatar_set, contact_info_set, is_bot, identifiers, extra_profile, metadata FROM ghost ` getGhostByIDQuery = getGhostBaseQuery + `WHERE bridge_id=$1 AND id=$2` @@ -49,13 +104,14 @@ const ( insertGhostQuery = ` INSERT INTO ghost ( bridge_id, id, name, avatar_id, avatar_hash, avatar_mxc, - name_set, avatar_set, contact_info_set, is_bot, identifiers, metadata + name_set, avatar_set, contact_info_set, is_bot, identifiers, extra_profile, metadata ) - VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12) + VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13) ` updateGhostQuery = ` UPDATE ghost SET name=$3, avatar_id=$4, avatar_hash=$5, avatar_mxc=$6, - name_set=$7, avatar_set=$8, contact_info_set=$9, is_bot=$10, identifiers=$11, metadata=$12 + name_set=$7, avatar_set=$8, contact_info_set=$9, is_bot=$10, + identifiers=$11, extra_profile=$12, metadata=$13 WHERE bridge_id=$1 AND id=$2 ` ) @@ -86,7 +142,7 @@ func (g *Ghost) Scan(row dbutil.Scannable) (*Ghost, error) { &g.BridgeID, &g.ID, &g.Name, &g.AvatarID, &avatarHash, &g.AvatarMXC, &g.NameSet, &g.AvatarSet, &g.ContactInfoSet, &g.IsBot, - dbutil.JSON{Data: &g.Identifiers}, dbutil.JSON{Data: g.Metadata}, + dbutil.JSON{Data: &g.Identifiers}, dbutil.JSON{Data: &g.ExtraProfile}, dbutil.JSON{Data: g.Metadata}, ) if err != nil { return nil, err @@ -116,6 +172,6 @@ func (g *Ghost) sqlVariables() []any { g.BridgeID, g.ID, g.Name, g.AvatarID, avatarHash, g.AvatarMXC, g.NameSet, g.AvatarSet, g.ContactInfoSet, g.IsBot, - dbutil.JSON{Data: &g.Identifiers}, dbutil.JSON{Data: g.Metadata}, + dbutil.JSON{Data: &g.Identifiers}, dbutil.JSON{Data: g.ExtraProfile}, dbutil.JSON{Data: g.Metadata}, } } diff --git a/bridgev2/database/message.go b/bridgev2/database/message.go index 9b3b1493..4fd599a8 100644 --- a/bridgev2/database/message.go +++ b/bridgev2/database/message.go @@ -11,9 +11,12 @@ import ( "crypto/sha256" "database/sql" "encoding/base64" + "fmt" "strings" + "sync" "time" + "github.com/rs/zerolog" "go.mau.fi/util/dbutil" "maunium.net/go/mautrix/bridgev2/networkid" @@ -24,6 +27,7 @@ type MessageQuery struct { BridgeID networkid.BridgeID MetaType MetaTypeCreator *dbutil.QueryHelper[*Message] + chunkDeleteLock sync.Mutex } type Message struct { @@ -64,8 +68,8 @@ const ( getFirstMessagePartByIDQuery = getMessageBaseQuery + `WHERE bridge_id=$1 AND (room_receiver=$2 OR room_receiver='') AND id=$3 ORDER BY part_id ASC LIMIT 1` getMessagesBetweenTimeQuery = getMessageBaseQuery + `WHERE bridge_id=$1 AND room_id=$2 AND room_receiver=$3 AND timestamp>$4 AND timestamp<=$5` getOldestMessageInPortal = getMessageBaseQuery + `WHERE bridge_id=$1 AND room_id=$2 AND room_receiver=$3 ORDER BY timestamp ASC, part_id ASC LIMIT 1` - getFirstMessageInThread = getMessageBaseQuery + `WHERE bridge_id=$1 AND room_id=$2 AND room_receiver=$3 AND (id=$4 OR thread_root_id=$4) ORDER BY timestamp ASC, part_id ASC LIMIT 1` - getLastMessageInThread = getMessageBaseQuery + `WHERE bridge_id=$1 AND room_id=$2 AND room_receiver=$3 AND (id=$4 OR thread_root_id=$4) ORDER BY timestamp DESC, part_id DESC LIMIT 1` + getFirstMessageInThread = getMessageBaseQuery + `WHERE bridge_id=$1 AND room_id=$2 AND room_receiver=$3 AND (id=$4 OR thread_root_id=$4) ORDER BY thread_root_id NULLS FIRST, timestamp ASC, part_id ASC LIMIT 1` + getLastMessageInThread = getMessageBaseQuery + `WHERE bridge_id=$1 AND room_id=$2 AND room_receiver=$3 AND (id=$4 OR thread_root_id=$4) ORDER BY thread_root_id NULLS LAST, timestamp DESC, part_id DESC LIMIT 1` getLastNInPortal = getMessageBaseQuery + `WHERE bridge_id=$1 AND room_id=$2 AND room_receiver=$3 ORDER BY timestamp DESC, part_id DESC LIMIT $4` getLastMessagePartAtOrBeforeTimeQuery = getMessageBaseQuery + `WHERE bridge_id = $1 AND room_id=$2 AND room_receiver=$3 AND timestamp<=$4 ORDER BY timestamp DESC, part_id DESC LIMIT 1` @@ -96,6 +100,10 @@ const ( deleteMessagePartByRowIDQuery = ` DELETE FROM message WHERE bridge_id=$1 AND rowid=$2 ` + deleteMessageChunkQuery = ` + DELETE FROM message WHERE bridge_id=$1 AND room_id=$2 AND room_receiver=$3 AND rowid > $4 AND rowid <= $5 + ` + getMaxMessageRowIDQuery = `SELECT MAX(rowid) FROM message WHERE bridge_id=$1` ) func (mq *MessageQuery) GetAllPartsByID(ctx context.Context, receiver networkid.UserLoginID, id networkid.MessageID) ([]*Message, error) { @@ -180,6 +188,85 @@ func (mq *MessageQuery) Delete(ctx context.Context, rowID int64) error { return mq.Exec(ctx, deleteMessagePartByRowIDQuery, mq.BridgeID, rowID) } +func (mq *MessageQuery) deleteChunk(ctx context.Context, portal networkid.PortalKey, minRowID, maxRowID int64) (int64, error) { + res, err := mq.GetDB().Exec(ctx, deleteMessageChunkQuery, mq.BridgeID, portal.ID, portal.Receiver, minRowID, maxRowID) + if err != nil { + return 0, err + } + return res.RowsAffected() +} + +func (mq *MessageQuery) getMaxRowID(ctx context.Context) (maxRowID int64, err error) { + err = mq.GetDB().QueryRow(ctx, getMaxMessageRowIDQuery, mq.BridgeID).Scan(&maxRowID) + return +} + +const deleteChunkSize = 100_000 + +func (mq *MessageQuery) DeleteInChunks(ctx context.Context, portal networkid.PortalKey) error { + if mq.GetDB().Dialect != dbutil.SQLite { + return nil + } + log := zerolog.Ctx(ctx).With(). + Str("action", "delete messages in chunks"). + Stringer("portal_key", portal). + Logger() + if !mq.chunkDeleteLock.TryLock() { + log.Warn().Msg("Portal deletion lock is being held, waiting...") + mq.chunkDeleteLock.Lock() + log.Debug().Msg("Acquired portal deletion lock after waiting") + } + defer mq.chunkDeleteLock.Unlock() + total, err := mq.CountMessagesInPortal(ctx, portal) + if err != nil { + return fmt.Errorf("failed to count messages in portal: %w", err) + } else if total < deleteChunkSize/3 { + return nil + } + globalMaxRowID, err := mq.getMaxRowID(ctx) + if err != nil { + return fmt.Errorf("failed to get max row ID: %w", err) + } + log.Debug(). + Int("total_count", total). + Int64("global_max_row_id", globalMaxRowID). + Msg("Portal has lots of messages, deleting in chunks to avoid database locks") + maxRowID := int64(deleteChunkSize) + globalMaxRowID += deleteChunkSize * 1.2 + var dbTimeUsed time.Duration + globalStart := time.Now() + for total > 500 && maxRowID < globalMaxRowID { + start := time.Now() + count, err := mq.deleteChunk(ctx, portal, maxRowID-deleteChunkSize, maxRowID) + duration := time.Since(start) + dbTimeUsed += duration + if err != nil { + return fmt.Errorf("failed to delete chunk of messages before %d: %w", maxRowID, err) + } + total -= int(count) + maxRowID += deleteChunkSize + sleepTime := max(10*time.Millisecond, min(250*time.Millisecond, time.Duration(count/100)*time.Millisecond)) + log.Debug(). + Int64("max_row_id", maxRowID). + Int64("deleted_count", count). + Int("remaining_count", total). + Dur("duration", duration). + Dur("sleep_time", sleepTime). + Msg("Deleted chunk of messages") + select { + case <-time.After(sleepTime): + case <-ctx.Done(): + return ctx.Err() + } + } + log.Debug(). + Int("remaining_count", total). + Dur("db_time_used", dbTimeUsed). + Dur("total_duration", time.Since(globalStart)). + Msg("Finished chunked delete of messages in portal") + return nil +} + func (mq *MessageQuery) CountMessagesInPortal(ctx context.Context, key networkid.PortalKey) (count int, err error) { err = mq.GetDB().QueryRow(ctx, countMessagesInPortalQuery, mq.BridgeID, key.ID, key.Receiver).Scan(&count) return diff --git a/bridgev2/database/portal.go b/bridgev2/database/portal.go index 97af4c4c..0e6be286 100644 --- a/bridgev2/database/portal.go +++ b/bridgev2/database/portal.go @@ -56,30 +56,31 @@ type Portal struct { networkid.PortalKey MXID id.RoomID - ParentKey networkid.PortalKey - RelayLoginID networkid.UserLoginID - OtherUserID networkid.UserID - Name string - Topic string - AvatarID networkid.AvatarID - AvatarHash [32]byte - AvatarMXC id.ContentURIString - NameSet bool - TopicSet bool - AvatarSet bool - NameIsCustom bool - InSpace bool - RoomType RoomType - Disappear DisappearingSetting - CapState CapabilityState - Metadata any + ParentKey networkid.PortalKey + RelayLoginID networkid.UserLoginID + OtherUserID networkid.UserID + Name string + Topic string + AvatarID networkid.AvatarID + AvatarHash [32]byte + AvatarMXC id.ContentURIString + NameSet bool + TopicSet bool + AvatarSet bool + NameIsCustom bool + InSpace bool + MessageRequest bool + RoomType RoomType + Disappear DisappearingSetting + CapState CapabilityState + Metadata any } const ( getPortalBaseQuery = ` SELECT bridge_id, id, receiver, mxid, parent_id, parent_receiver, relay_login_id, other_user_id, name, topic, avatar_id, avatar_hash, avatar_mxc, - name_set, topic_set, avatar_set, name_is_custom, in_space, + name_set, topic_set, avatar_set, name_is_custom, in_space, message_request, room_type, disappear_type, disappear_timer, cap_state, metadata FROM portal @@ -88,8 +89,9 @@ const ( getPortalByIDWithUncertainReceiverQuery = getPortalBaseQuery + `WHERE bridge_id=$1 AND id=$2 AND (receiver=$3 OR receiver='')` getPortalByMXIDQuery = getPortalBaseQuery + `WHERE bridge_id=$1 AND mxid=$2` getAllPortalsWithMXIDQuery = getPortalBaseQuery + `WHERE bridge_id=$1 AND mxid IS NOT NULL` - getAllPortalsWithoutReceiver = getPortalBaseQuery + `WHERE bridge_id=$1 AND receiver=''` + getAllPortalsWithoutReceiver = getPortalBaseQuery + `WHERE bridge_id=$1 AND (receiver='' OR (parent_id<>'' AND parent_receiver='')) ORDER BY parent_id DESC` getAllDMPortalsQuery = getPortalBaseQuery + `WHERE bridge_id=$1 AND room_type='dm' AND other_user_id=$2` + getDMPortalQuery = getPortalBaseQuery + `WHERE bridge_id=$1 AND room_type='dm' AND receiver=$2 AND other_user_id=$3` getAllPortalsQuery = getPortalBaseQuery + `WHERE bridge_id=$1` getChildPortalsQuery = getPortalBaseQuery + `WHERE bridge_id=$1 AND parent_id=$2 AND parent_receiver=$3` @@ -100,11 +102,11 @@ const ( bridge_id, id, receiver, mxid, parent_id, parent_receiver, relay_login_id, other_user_id, name, topic, avatar_id, avatar_hash, avatar_mxc, - name_set, avatar_set, topic_set, name_is_custom, in_space, + name_set, avatar_set, topic_set, name_is_custom, in_space, message_request, room_type, disappear_type, disappear_timer, cap_state, metadata, relay_bridge_id ) VALUES ( - $1, $2, $3, $4, $5, $6, cast($7 AS TEXT), $8, $9, $10, $11, $12, $13, $14, $15, $16, $17, $18, $19, $20, $21, $22, $23, + $1, $2, $3, $4, $5, $6, cast($7 AS TEXT), $8, $9, $10, $11, $12, $13, $14, $15, $16, $17, $18, $19, $20, $21, $22, $23, $24, CASE WHEN cast($7 AS TEXT) IS NULL THEN NULL ELSE $1 END ) ` @@ -113,8 +115,8 @@ const ( SET mxid=$4, parent_id=$5, parent_receiver=$6, relay_login_id=cast($7 AS TEXT), relay_bridge_id=CASE WHEN cast($7 AS TEXT) IS NULL THEN NULL ELSE bridge_id END, other_user_id=$8, name=$9, topic=$10, avatar_id=$11, avatar_hash=$12, avatar_mxc=$13, - name_set=$14, avatar_set=$15, topic_set=$16, name_is_custom=$17, in_space=$18, - room_type=$19, disappear_type=$20, disappear_timer=$21, cap_state=$22, metadata=$23 + name_set=$14, avatar_set=$15, topic_set=$16, name_is_custom=$17, in_space=$18, message_request=$19, + room_type=$20, disappear_type=$21, disappear_timer=$22, cap_state=$23, metadata=$24 WHERE bridge_id=$1 AND id=$2 AND receiver=$3 ` deletePortalQuery = ` @@ -147,7 +149,10 @@ const ( ) ` fixParentsAfterSplitPortalMigrationQuery = ` - UPDATE portal SET parent_receiver=receiver WHERE bridge_id=$1 AND parent_receiver='' AND receiver<>'' AND parent_id<>''; + UPDATE portal + SET parent_receiver=receiver + WHERE bridge_id=$1 AND parent_receiver='' AND receiver<>'' AND parent_id<>'' + AND EXISTS(SELECT 1 FROM portal pp WHERE pp.bridge_id=$1 AND pp.id=portal.parent_id AND pp.receiver=portal.receiver); ` ) @@ -187,6 +192,10 @@ func (pq *PortalQuery) GetAllDMsWith(ctx context.Context, otherUserID networkid. return pq.QueryMany(ctx, getAllDMPortalsQuery, pq.BridgeID, otherUserID) } +func (pq *PortalQuery) GetDM(ctx context.Context, receiver networkid.UserLoginID, otherUserID networkid.UserID) (*Portal, error) { + return pq.QueryOne(ctx, getDMPortalQuery, pq.BridgeID, receiver, otherUserID) +} + func (pq *PortalQuery) GetChildren(ctx context.Context, parentKey networkid.PortalKey) ([]*Portal, error) { return pq.QueryMany(ctx, getChildPortalsQuery, pq.BridgeID, parentKey.ID, parentKey.Receiver) } @@ -233,7 +242,7 @@ func (p *Portal) Scan(row dbutil.Scannable) (*Portal, error) { &p.BridgeID, &p.ID, &p.Receiver, &mxid, &parentID, &parentReceiver, &relayLoginID, &otherUserID, &p.Name, &p.Topic, &p.AvatarID, &avatarHash, &p.AvatarMXC, - &p.NameSet, &p.TopicSet, &p.AvatarSet, &p.NameIsCustom, &p.InSpace, + &p.NameSet, &p.TopicSet, &p.AvatarSet, &p.NameIsCustom, &p.InSpace, &p.MessageRequest, &p.RoomType, &disappearType, &disappearTimer, dbutil.JSON{Data: &p.CapState}, dbutil.JSON{Data: p.Metadata}, ) @@ -280,7 +289,7 @@ func (p *Portal) sqlVariables() []any { p.BridgeID, p.ID, p.Receiver, dbutil.StrPtr(p.MXID), dbutil.StrPtr(p.ParentKey.ID), p.ParentKey.Receiver, dbutil.StrPtr(p.RelayLoginID), dbutil.StrPtr(p.OtherUserID), p.Name, p.Topic, p.AvatarID, avatarHash, p.AvatarMXC, - p.NameSet, p.TopicSet, p.AvatarSet, p.NameIsCustom, p.InSpace, + p.NameSet, p.TopicSet, p.AvatarSet, p.NameIsCustom, p.InSpace, p.MessageRequest, p.RoomType, dbutil.StrPtr(p.Disappear.Type), dbutil.NumPtr(p.Disappear.Timer), dbutil.JSON{Data: p.CapState}, dbutil.JSON{Data: p.Metadata}, } diff --git a/bridgev2/database/publicmedia.go b/bridgev2/database/publicmedia.go new file mode 100644 index 00000000..b667399c --- /dev/null +++ b/bridgev2/database/publicmedia.go @@ -0,0 +1,72 @@ +// Copyright (c) 2025 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package database + +import ( + "context" + "database/sql" + "time" + + "go.mau.fi/util/dbutil" + + "maunium.net/go/mautrix/bridgev2/networkid" + "maunium.net/go/mautrix/crypto/attachment" + "maunium.net/go/mautrix/id" +) + +type PublicMediaQuery struct { + BridgeID networkid.BridgeID + *dbutil.QueryHelper[*PublicMedia] +} + +type PublicMedia struct { + BridgeID networkid.BridgeID + PublicID string + MXC id.ContentURI + Keys *attachment.EncryptedFile + MimeType string + Expiry time.Time +} + +const ( + upsertPublicMediaQuery = ` + INSERT INTO public_media (bridge_id, public_id, mxc, keys, mimetype, expiry) + VALUES ($1, $2, $3, $4, $5, $6) + ON CONFLICT (bridge_id, public_id) DO UPDATE SET expiry=EXCLUDED.expiry + ` + getPublicMediaQuery = ` + SELECT bridge_id, public_id, mxc, keys, mimetype, expiry + FROM public_media WHERE bridge_id=$1 AND public_id=$2 + ` +) + +func (pmq *PublicMediaQuery) Put(ctx context.Context, pm *PublicMedia) error { + ensureBridgeIDMatches(&pm.BridgeID, pmq.BridgeID) + return pmq.Exec(ctx, upsertPublicMediaQuery, pm.sqlVariables()...) +} + +func (pmq *PublicMediaQuery) Get(ctx context.Context, publicID string) (*PublicMedia, error) { + return pmq.QueryOne(ctx, getPublicMediaQuery, pmq.BridgeID, publicID) +} + +func (pm *PublicMedia) Scan(row dbutil.Scannable) (*PublicMedia, error) { + var expiry sql.NullInt64 + var mimetype sql.NullString + err := row.Scan(&pm.BridgeID, &pm.PublicID, &pm.MXC, dbutil.JSON{Data: &pm.Keys}, &mimetype, &expiry) + if err != nil { + return nil, err + } + if expiry.Valid { + pm.Expiry = time.Unix(0, expiry.Int64) + } + pm.MimeType = mimetype.String + return pm, nil +} + +func (pm *PublicMedia) sqlVariables() []any { + return []any{pm.BridgeID, pm.PublicID, &pm.MXC, dbutil.JSONPtr(pm.Keys), dbutil.StrPtr(pm.MimeType), dbutil.ConvertedPtr(pm.Expiry, time.Time.UnixNano)} +} diff --git a/bridgev2/database/upgrades/00-latest.sql b/bridgev2/database/upgrades/00-latest.sql index 4eea05bb..6092dc24 100644 --- a/bridgev2/database/upgrades/00-latest.sql +++ b/bridgev2/database/upgrades/00-latest.sql @@ -1,4 +1,4 @@ --- v0 -> v22 (compatible with v9+): Latest revision +-- v0 -> v27 (compatible with v9+): Latest revision CREATE TABLE "user" ( bridge_id TEXT NOT NULL, mxid TEXT NOT NULL, @@ -48,6 +48,7 @@ CREATE TABLE portal ( topic_set BOOLEAN NOT NULL, name_is_custom BOOLEAN NOT NULL DEFAULT false, in_space BOOLEAN NOT NULL, + message_request BOOLEAN NOT NULL DEFAULT false, room_type TEXT NOT NULL, disappear_type TEXT, disappear_timer BIGINT, @@ -64,6 +65,7 @@ CREATE TABLE portal ( ON DELETE SET NULL ON UPDATE CASCADE ); CREATE UNIQUE INDEX portal_bridge_mxid_idx ON portal (bridge_id, mxid); +CREATE INDEX portal_parent_idx ON portal (bridge_id, parent_id, parent_receiver); CREATE TABLE ghost ( bridge_id TEXT NOT NULL, @@ -78,6 +80,7 @@ CREATE TABLE ghost ( contact_info_set BOOLEAN NOT NULL, is_bot BOOLEAN NOT NULL, identifiers jsonb NOT NULL, + extra_profile jsonb, metadata jsonb NOT NULL, PRIMARY KEY (bridge_id, id) @@ -127,6 +130,7 @@ CREATE TABLE disappearing_message ( bridge_id TEXT NOT NULL, mx_room TEXT NOT NULL, mxid TEXT NOT NULL, + timestamp BIGINT NOT NULL DEFAULT 0, type TEXT NOT NULL, timer BIGINT NOT NULL, disappear_at BIGINT, @@ -137,6 +141,7 @@ CREATE TABLE disappearing_message ( REFERENCES portal (bridge_id, mxid) ON DELETE CASCADE ); +CREATE INDEX disappearing_message_portal_idx ON disappearing_message (bridge_id, mx_room); CREATE TABLE reaction ( bridge_id TEXT NOT NULL, @@ -215,3 +220,14 @@ CREATE TABLE kv_store ( PRIMARY KEY (bridge_id, key) ); + +CREATE TABLE public_media ( + bridge_id TEXT NOT NULL, + public_id TEXT NOT NULL, + mxc TEXT NOT NULL, + keys jsonb, + mimetype TEXT, + expiry BIGINT, + + PRIMARY KEY (bridge_id, public_id) +); diff --git a/bridgev2/database/upgrades/23-disappearing-timer-ts.sql b/bridgev2/database/upgrades/23-disappearing-timer-ts.sql new file mode 100644 index 00000000..ecd00b8d --- /dev/null +++ b/bridgev2/database/upgrades/23-disappearing-timer-ts.sql @@ -0,0 +1,2 @@ +-- v23 (compatible with v9+): Add event timestamp for disappearing messages +ALTER TABLE disappearing_message ADD COLUMN timestamp BIGINT NOT NULL DEFAULT 0; diff --git a/bridgev2/database/upgrades/24-public-media.sql b/bridgev2/database/upgrades/24-public-media.sql new file mode 100644 index 00000000..c4290090 --- /dev/null +++ b/bridgev2/database/upgrades/24-public-media.sql @@ -0,0 +1,11 @@ +-- v24 (compatible with v9+): Custom URLs for public media +CREATE TABLE public_media ( + bridge_id TEXT NOT NULL, + public_id TEXT NOT NULL, + mxc TEXT NOT NULL, + keys jsonb, + mimetype TEXT, + expiry BIGINT, + + PRIMARY KEY (bridge_id, public_id) +); diff --git a/bridgev2/database/upgrades/25-message-requests.sql b/bridgev2/database/upgrades/25-message-requests.sql new file mode 100644 index 00000000..b9d82a7a --- /dev/null +++ b/bridgev2/database/upgrades/25-message-requests.sql @@ -0,0 +1,2 @@ +-- v25 (compatible with v9+): Flag for message request portals +ALTER TABLE portal ADD COLUMN message_request BOOLEAN NOT NULL DEFAULT false; diff --git a/bridgev2/database/upgrades/26-disappearing-message-portal-index.sql b/bridgev2/database/upgrades/26-disappearing-message-portal-index.sql new file mode 100644 index 00000000..ae5d8cad --- /dev/null +++ b/bridgev2/database/upgrades/26-disappearing-message-portal-index.sql @@ -0,0 +1,3 @@ +-- v26 (compatible with v9+): Add room index for disappearing message table and portal parents +CREATE INDEX disappearing_message_portal_idx ON disappearing_message (bridge_id, mx_room); +CREATE INDEX portal_parent_idx ON portal (bridge_id, parent_id, parent_receiver); diff --git a/bridgev2/database/upgrades/27-ghost-extra-profile.sql b/bridgev2/database/upgrades/27-ghost-extra-profile.sql new file mode 100644 index 00000000..e8e0549a --- /dev/null +++ b/bridgev2/database/upgrades/27-ghost-extra-profile.sql @@ -0,0 +1,2 @@ +-- v27 (compatible with v9+): Add column for extra ghost profile metadata +ALTER TABLE ghost ADD COLUMN extra_profile jsonb; diff --git a/bridgev2/database/userlogin.go b/bridgev2/database/userlogin.go index 9fa6569a..00ff01c9 100644 --- a/bridgev2/database/userlogin.go +++ b/bridgev2/database/userlogin.go @@ -116,7 +116,7 @@ func (u *UserLogin) ensureHasMetadata(metaType MetaTypeCreator) *UserLogin { func (u *UserLogin) sqlVariables() []any { var remoteProfile dbutil.JSON - if !u.RemoteProfile.IsEmpty() { + if !u.RemoteProfile.IsZero() { remoteProfile.Data = &u.RemoteProfile } return []any{u.BridgeID, u.UserMXID, u.ID, u.RemoteName, remoteProfile, dbutil.StrPtr(u.SpaceRoom), dbutil.JSON{Data: u.Metadata}} diff --git a/bridgev2/disappear.go b/bridgev2/disappear.go index f072c01f..b5c37e8f 100644 --- a/bridgev2/disappear.go +++ b/bridgev2/disappear.go @@ -86,8 +86,8 @@ func (dl *DisappearLoop) Stop() { } } -func (dl *DisappearLoop) StartAll(ctx context.Context, roomID id.RoomID) { - startedMessages, err := dl.br.DB.DisappearingMessage.StartAll(ctx, roomID) +func (dl *DisappearLoop) StartAllBefore(ctx context.Context, roomID id.RoomID, beforeTS time.Time) { + startedMessages, err := dl.br.DB.DisappearingMessage.StartAllBefore(ctx, roomID, beforeTS) if err != nil { zerolog.Ctx(ctx).Err(err).Msg("Failed to start disappearing messages") return diff --git a/bridgev2/errors.go b/bridgev2/errors.go index cf27ac6f..f6677d2e 100644 --- a/bridgev2/errors.go +++ b/bridgev2/errors.go @@ -38,40 +38,51 @@ var ErrNotLoggedIn = errors.New("not logged in") // but direct media is not enabled. var ErrDirectMediaNotEnabled = errors.New("direct media is not enabled") +var ErrPortalIsDeleted = errors.New("portal is deleted") +var ErrPortalNotFoundInEventHandler = errors.New("portal not found to handle remote event") + // Common message status errors var ( - ErrPanicInEventHandler error = WrapErrorInStatus(errors.New("panic in event handler")).WithSendNotice(true).WithErrorAsMessage() - ErrNoPortal error = WrapErrorInStatus(errors.New("room is not a portal")).WithIsCertain(true).WithSendNotice(false) - ErrIgnoringReactionFromRelayedUser error = WrapErrorInStatus(errors.New("ignoring reaction event from relayed user")).WithIsCertain(true).WithSendNotice(false) - ErrIgnoringPollFromRelayedUser error = WrapErrorInStatus(errors.New("ignoring poll event from relayed user")).WithIsCertain(true).WithSendNotice(false) - ErrIgnoringDeleteChatRelayedUser error = WrapErrorInStatus(errors.New("ignoring delete chat event from relayed user")).WithIsCertain(true).WithSendNotice(false) - ErrEditsNotSupported error = WrapErrorInStatus(errors.New("this bridge does not support edits")).WithIsCertain(true).WithErrorAsMessage().WithErrorReason(event.MessageStatusUnsupported) - ErrEditsNotSupportedInPortal error = WrapErrorInStatus(errors.New("edits are not allowed in this chat")).WithIsCertain(true).WithErrorAsMessage().WithErrorReason(event.MessageStatusUnsupported) - ErrCaptionsNotAllowed error = WrapErrorInStatus(errors.New("captions are not supported here")).WithIsCertain(true).WithErrorAsMessage().WithErrorReason(event.MessageStatusUnsupported) - ErrLocationMessagesNotAllowed error = WrapErrorInStatus(errors.New("location messages are not supported here")).WithIsCertain(true).WithErrorAsMessage().WithErrorReason(event.MessageStatusUnsupported) - ErrEditTargetTooOld error = WrapErrorInStatus(errors.New("the message is too old to be edited")).WithIsCertain(true).WithErrorAsMessage().WithErrorReason(event.MessageStatusUnsupported) - ErrEditTargetTooManyEdits error = WrapErrorInStatus(errors.New("the message has been edited too many times")).WithIsCertain(true).WithErrorAsMessage().WithErrorReason(event.MessageStatusUnsupported) - ErrReactionsNotSupported error = WrapErrorInStatus(errors.New("this bridge does not support reactions")).WithIsCertain(true).WithErrorAsMessage().WithErrorReason(event.MessageStatusUnsupported) - ErrPollsNotSupported error = WrapErrorInStatus(errors.New("this bridge does not support polls")).WithIsCertain(true).WithErrorAsMessage().WithErrorReason(event.MessageStatusUnsupported) - ErrRoomMetadataNotSupported error = WrapErrorInStatus(errors.New("this bridge does not support changing room metadata")).WithIsCertain(true).WithErrorAsMessage().WithSendNotice(false).WithErrorReason(event.MessageStatusUnsupported) - ErrRedactionsNotSupported error = WrapErrorInStatus(errors.New("this bridge does not support deleting messages")).WithIsCertain(true).WithErrorAsMessage().WithErrorReason(event.MessageStatusUnsupported) - ErrUnexpectedParsedContentType error = WrapErrorInStatus(errors.New("unexpected parsed content type")).WithErrorAsMessage().WithIsCertain(true).WithSendNotice(true) - ErrInvalidStateKey error = WrapErrorInStatus(errors.New("room metadata state key is unset or non-empty")).WithErrorAsMessage().WithIsCertain(true).WithSendNotice(false) - ErrDatabaseError error = WrapErrorInStatus(errors.New("database error")).WithMessage("internal database error").WithIsCertain(true).WithSendNotice(true) - ErrTargetMessageNotFound error = WrapErrorInStatus(errors.New("target message not found")).WithErrorAsMessage().WithIsCertain(true).WithSendNotice(false) - ErrUnsupportedMessageType error = WrapErrorInStatus(errors.New("unsupported message type")).WithErrorAsMessage().WithIsCertain(true).WithSendNotice(true).WithErrorReason(event.MessageStatusUnsupported) - ErrUnsupportedMediaType error = WrapErrorInStatus(errors.New("unsupported media type")).WithErrorAsMessage().WithIsCertain(true).WithSendNotice(true).WithErrorReason(event.MessageStatusUnsupported) - ErrMediaDurationTooLong error = WrapErrorInStatus(errors.New("media duration too long")).WithErrorAsMessage().WithSendNotice(true).WithErrorReason(event.MessageStatusUnsupported) - ErrMediaTooLarge error = WrapErrorInStatus(errors.New("media too large")).WithErrorAsMessage().WithIsCertain(true).WithSendNotice(true).WithErrorReason(event.MessageStatusUnsupported) - ErrIgnoringMNotice error = WrapErrorInStatus(errors.New("ignoring m.notice message")).WithIsCertain(true).WithErrorAsMessage().WithSendNotice(false) - ErrMediaDownloadFailed error = WrapErrorInStatus(errors.New("failed to download media")).WithMessage("failed to download media").WithIsCertain(true).WithSendNotice(true) - ErrMediaReuploadFailed error = WrapErrorInStatus(errors.New("failed to reupload media")).WithMessage("failed to reupload media").WithIsCertain(true).WithSendNotice(true) - ErrMediaConvertFailed error = WrapErrorInStatus(errors.New("failed to convert media")).WithMessage("failed to convert media").WithIsCertain(true).WithSendNotice(true) - ErrMembershipNotSupported error = WrapErrorInStatus(errors.New("this bridge does not support changing group membership")).WithIsCertain(true).WithErrorAsMessage().WithSendNotice(false).WithErrorReason(event.MessageStatusUnsupported) - ErrDeleteChatNotSupported error = WrapErrorInStatus(errors.New("this bridge does not support deleting chats")).WithIsCertain(true).WithErrorAsMessage().WithSendNotice(false).WithErrorReason(event.MessageStatusUnsupported) - ErrPowerLevelsNotSupported error = WrapErrorInStatus(errors.New("this bridge does not support changing group power levels")).WithIsCertain(true).WithErrorAsMessage().WithSendNotice(false).WithErrorReason(event.MessageStatusUnsupported) - ErrRemoteEchoTimeout = WrapErrorInStatus(errors.New("remote echo timed out")).WithIsCertain(false).WithSendNotice(true).WithErrorReason(event.MessageStatusTooOld) - ErrRemoteAckTimeout = WrapErrorInStatus(errors.New("remote ack timed out")).WithIsCertain(false).WithSendNotice(true).WithErrorReason(event.MessageStatusTooOld) + ErrPanicInEventHandler error = WrapErrorInStatus(errors.New("panic in event handler")).WithSendNotice(true).WithErrorAsMessage() + ErrNoPortal error = WrapErrorInStatus(errors.New("room is not a portal")).WithIsCertain(true).WithSendNotice(false) + ErrIgnoringReactionFromRelayedUser error = WrapErrorInStatus(errors.New("ignoring reaction event from relayed user")).WithIsCertain(true).WithSendNotice(false) + ErrIgnoringPollFromRelayedUser error = WrapErrorInStatus(errors.New("ignoring poll event from relayed user")).WithIsCertain(true).WithSendNotice(false) + ErrIgnoringDeleteChatRelayedUser error = WrapErrorInStatus(errors.New("ignoring delete chat event from relayed user")).WithIsCertain(true).WithSendNotice(false) + ErrIgnoringAcceptRequestRelayedUser error = WrapErrorInStatus(errors.New("ignoring accept message request event from relayed user")).WithIsCertain(true).WithSendNotice(false) + ErrEditsNotSupported error = WrapErrorInStatus(errors.New("this bridge does not support edits")).WithIsCertain(true).WithErrorAsMessage().WithErrorReason(event.MessageStatusUnsupported) + ErrEditsNotSupportedInPortal error = WrapErrorInStatus(errors.New("edits are not allowed in this chat")).WithIsCertain(true).WithErrorAsMessage().WithErrorReason(event.MessageStatusUnsupported) + ErrCaptionsNotAllowed error = WrapErrorInStatus(errors.New("captions are not supported here")).WithIsCertain(true).WithErrorAsMessage().WithErrorReason(event.MessageStatusUnsupported) + ErrLocationMessagesNotAllowed error = WrapErrorInStatus(errors.New("location messages are not supported here")).WithIsCertain(true).WithErrorAsMessage().WithErrorReason(event.MessageStatusUnsupported) + ErrEditTargetTooOld error = WrapErrorInStatus(errors.New("the message is too old to be edited")).WithIsCertain(true).WithErrorAsMessage().WithErrorReason(event.MessageStatusUnsupported) + ErrEditTargetTooManyEdits error = WrapErrorInStatus(errors.New("the message has been edited too many times")).WithIsCertain(true).WithErrorAsMessage().WithErrorReason(event.MessageStatusUnsupported) + ErrReactionsNotSupported error = WrapErrorInStatus(errors.New("this bridge does not support reactions")).WithIsCertain(true).WithErrorAsMessage().WithErrorReason(event.MessageStatusUnsupported) + ErrPollsNotSupported error = WrapErrorInStatus(errors.New("this bridge does not support polls")).WithIsCertain(true).WithErrorAsMessage().WithErrorReason(event.MessageStatusUnsupported) + ErrRoomMetadataNotSupported error = WrapErrorInStatus(errors.New("this bridge does not support changing room metadata")).WithIsCertain(true).WithErrorAsMessage().WithSendNotice(false).WithErrorReason(event.MessageStatusUnsupported) + ErrRoomMetadataNotAllowed error = WrapErrorInStatus(errors.New("changes are not allowed here")).WithIsCertain(true).WithErrorAsMessage().WithSendNotice(false).WithErrorReason(event.MessageStatusUnsupported) + ErrRedactionsNotSupported error = WrapErrorInStatus(errors.New("this bridge does not support deleting messages")).WithIsCertain(true).WithErrorAsMessage().WithErrorReason(event.MessageStatusUnsupported) + ErrUnexpectedParsedContentType error = WrapErrorInStatus(errors.New("unexpected parsed content type")).WithErrorAsMessage().WithIsCertain(true).WithSendNotice(true) + ErrInvalidStateKey error = WrapErrorInStatus(errors.New("room metadata state key is unset or non-empty")).WithErrorAsMessage().WithIsCertain(true).WithSendNotice(false) + ErrDatabaseError error = WrapErrorInStatus(errors.New("database error")).WithMessage("internal database error").WithIsCertain(true).WithSendNotice(true) + ErrTargetMessageNotFound error = WrapErrorInStatus(errors.New("target message not found")).WithErrorAsMessage().WithIsCertain(true).WithSendNotice(false) + ErrUnsupportedMessageType error = WrapErrorInStatus(errors.New("unsupported message type")).WithErrorAsMessage().WithIsCertain(true).WithSendNotice(true).WithErrorReason(event.MessageStatusUnsupported) + ErrUnsupportedMediaType error = WrapErrorInStatus(errors.New("unsupported media type")).WithErrorAsMessage().WithIsCertain(true).WithSendNotice(true).WithErrorReason(event.MessageStatusUnsupported) + ErrMediaDurationTooLong error = WrapErrorInStatus(errors.New("media duration too long")).WithErrorAsMessage().WithSendNotice(true).WithErrorReason(event.MessageStatusUnsupported) + ErrVoiceMessageDurationTooLong error = WrapErrorInStatus(errors.New("voice message too long")).WithErrorAsMessage().WithSendNotice(true).WithErrorReason(event.MessageStatusUnsupported) + ErrMediaTooLarge error = WrapErrorInStatus(errors.New("media too large")).WithErrorAsMessage().WithIsCertain(true).WithSendNotice(true).WithErrorReason(event.MessageStatusUnsupported) + ErrIgnoringMNotice error = WrapErrorInStatus(errors.New("ignoring m.notice message")).WithIsCertain(true).WithErrorAsMessage().WithSendNotice(false) + ErrMediaDownloadFailed error = WrapErrorInStatus(errors.New("failed to download media")).WithMessage("failed to download media").WithIsCertain(true).WithSendNotice(true) + ErrMediaReuploadFailed error = WrapErrorInStatus(errors.New("failed to reupload media")).WithMessage("failed to reupload media").WithIsCertain(true).WithSendNotice(true) + ErrMediaConvertFailed error = WrapErrorInStatus(errors.New("failed to convert media")).WithMessage("failed to convert media").WithIsCertain(true).WithSendNotice(true) + ErrMembershipNotSupported error = WrapErrorInStatus(errors.New("this bridge does not support changing group membership")).WithIsCertain(true).WithErrorAsMessage().WithSendNotice(false).WithErrorReason(event.MessageStatusUnsupported) + ErrDeleteChatNotSupported error = WrapErrorInStatus(errors.New("this bridge does not support deleting chats")).WithIsCertain(true).WithErrorAsMessage().WithSendNotice(false).WithErrorReason(event.MessageStatusUnsupported) + ErrBeeperAIStreamNotSupported error = WrapErrorInStatus(errors.New("this bridge does not support Beeper AI stream events")).WithIsCertain(true).WithErrorAsMessage().WithSendNotice(false).WithErrorReason(event.MessageStatusUnsupported) + ErrPowerLevelsNotSupported error = WrapErrorInStatus(errors.New("this bridge does not support changing group power levels")).WithIsCertain(true).WithErrorAsMessage().WithSendNotice(false).WithErrorReason(event.MessageStatusUnsupported) + ErrRemoteEchoTimeout = WrapErrorInStatus(errors.New("remote echo timed out")).WithIsCertain(false).WithSendNotice(true).WithErrorReason(event.MessageStatusTooOld) + ErrRemoteAckTimeout = WrapErrorInStatus(errors.New("remote ack timed out")).WithIsCertain(false).WithSendNotice(true).WithErrorReason(event.MessageStatusTooOld) + + ErrPublicMediaDisabled = WrapErrorInStatus(errors.New("public media is not enabled in the bridge config")).WithIsCertain(true).WithErrorAsMessage().WithErrorReason(event.MessageStatusUnsupported).WithSendNotice(true) + ErrPublicMediaDatabaseDisabled = WrapErrorInStatus(errors.New("public media database storage is disabled")).WithIsCertain(true).WithErrorAsMessage().WithErrorReason(event.MessageStatusUnsupported).WithSendNotice(true) + ErrPublicMediaGenerateFailed = WrapErrorInStatus(errors.New("failed to generate public media URL")).WithIsCertain(true).WithMessage("failed to generate public media URL").WithErrorReason(event.MessageStatusUnsupported).WithSendNotice(true) ErrDisappearingTimerUnsupported error = WrapErrorInStatus(errors.New("invalid disappearing timer")).WithIsCertain(true) ) diff --git a/bridgev2/ghost.go b/bridgev2/ghost.go index 6cef6f06..590dd1dc 100644 --- a/bridgev2/ghost.go +++ b/bridgev2/ghost.go @@ -9,12 +9,15 @@ package bridgev2 import ( "context" "crypto/sha256" + "encoding/json" "fmt" + "maps" "net/http" + "slices" "github.com/rs/zerolog" + "go.mau.fi/util/exerrors" "go.mau.fi/util/exmime" - "golang.org/x/exp/slices" "maunium.net/go/mautrix/bridgev2/database" "maunium.net/go/mautrix/bridgev2/networkid" @@ -134,10 +137,11 @@ func (a *Avatar) Reupload(ctx context.Context, intent MatrixAPI, currentHash [32 } type UserInfo struct { - Identifiers []string - Name *string - Avatar *Avatar - IsBot *bool + Identifiers []string + Name *string + Avatar *Avatar + IsBot *bool + ExtraProfile database.ExtraProfile ExtraUpdates ExtraUpdater[*Ghost] } @@ -185,9 +189,9 @@ func (ghost *Ghost) UpdateAvatar(ctx context.Context, avatar *Avatar) bool { return true } -func (ghost *Ghost) getExtraProfileMeta() *event.BeeperProfileExtra { +func (ghost *Ghost) getExtraProfileMeta() any { bridgeName := ghost.Bridge.Network.GetName() - return &event.BeeperProfileExtra{ + baseExtra := &event.BeeperProfileExtra{ RemoteID: string(ghost.ID), Identifiers: ghost.Identifiers, Service: bridgeName.BeeperBridgeType, @@ -195,23 +199,35 @@ func (ghost *Ghost) getExtraProfileMeta() *event.BeeperProfileExtra { IsBridgeBot: false, IsNetworkBot: ghost.IsBot, } + if len(ghost.ExtraProfile) == 0 { + return baseExtra + } + mergedExtra := maps.Clone(ghost.ExtraProfile) + baseExtraMarshaled := exerrors.Must(json.Marshal(baseExtra)) + exerrors.PanicIfNotNil(json.Unmarshal(baseExtraMarshaled, &mergedExtra)) + return mergedExtra } -func (ghost *Ghost) UpdateContactInfo(ctx context.Context, identifiers []string, isBot *bool) bool { - if identifiers != nil { - slices.Sort(identifiers) - } - if ghost.ContactInfoSet && - (identifiers == nil || slices.Equal(identifiers, ghost.Identifiers)) && - (isBot == nil || *isBot == ghost.IsBot) { +func (ghost *Ghost) UpdateContactInfo(ctx context.Context, identifiers []string, isBot *bool, extraProfile database.ExtraProfile) bool { + if !ghost.Bridge.Matrix.GetCapabilities().ExtraProfileMeta { + ghost.ContactInfoSet = false return false } if identifiers != nil { + slices.Sort(identifiers) + } + changed := extraProfile.CopyTo(&ghost.ExtraProfile) + if identifiers != nil { + changed = changed || !slices.Equal(identifiers, ghost.Identifiers) ghost.Identifiers = identifiers } if isBot != nil { + changed = changed || *isBot != ghost.IsBot ghost.IsBot = *isBot } + if ghost.ContactInfoSet && !changed { + return false + } err := ghost.Intent.SetExtraProfileMeta(ctx, ghost.getExtraProfileMeta()) if err != nil { zerolog.Ctx(ctx).Err(err).Msg("Failed to set extra profile metadata") @@ -234,7 +250,7 @@ func (br *Bridge) allowAggressiveUpdateForType(evtType RemoteEventType) bool { } func (ghost *Ghost) UpdateInfoIfNecessary(ctx context.Context, source *UserLogin, evtType RemoteEventType) { - if ghost.Name != "" && ghost.NameSet && !ghost.Bridge.allowAggressiveUpdateForType(evtType) { + if ghost.Name != "" && ghost.NameSet && ghost.AvatarSet && !ghost.Bridge.allowAggressiveUpdateForType(evtType) { return } info, err := source.Client.GetUserInfo(ctx, ghost) @@ -244,12 +260,16 @@ func (ghost *Ghost) UpdateInfoIfNecessary(ctx context.Context, source *UserLogin zerolog.Ctx(ctx).Debug(). Bool("has_name", ghost.Name != ""). Bool("name_set", ghost.NameSet). + Bool("has_avatar", ghost.AvatarMXC != ""). + Bool("avatar_set", ghost.AvatarSet). Msg("Updating ghost info in IfNecessary call") ghost.UpdateInfo(ctx, info) } else { zerolog.Ctx(ctx).Trace(). Bool("has_name", ghost.Name != ""). Bool("name_set", ghost.NameSet). + Bool("has_avatar", ghost.AvatarMXC != ""). + Bool("avatar_set", ghost.AvatarSet). Msg("No ghost info received in IfNecessary call") } } @@ -277,9 +297,14 @@ func (ghost *Ghost) UpdateInfo(ctx context.Context, info *UserInfo) { } if info.Avatar != nil { update = ghost.UpdateAvatar(ctx, info.Avatar) || update + } else if oldAvatar == "" && !ghost.AvatarSet { + // Special case: nil avatar means we're not expecting one ever, if we don't currently have + // one we flag it as set to avoid constantly refetching in UpdateInfoIfNecessary. + ghost.AvatarSet = true + update = true } - if info.Identifiers != nil || info.IsBot != nil { - update = ghost.UpdateContactInfo(ctx, info.Identifiers, info.IsBot) || update + if info.Identifiers != nil || info.IsBot != nil || info.ExtraProfile != nil { + update = ghost.UpdateContactInfo(ctx, info.Identifiers, info.IsBot, info.ExtraProfile) || update } if info.ExtraUpdates != nil { update = info.ExtraUpdates(ctx, ghost) || update diff --git a/bridgev2/login.go b/bridgev2/login.go index 1fa3afbc..b8321719 100644 --- a/bridgev2/login.go +++ b/bridgev2/login.go @@ -13,6 +13,7 @@ import ( "strings" "maunium.net/go/mautrix/bridgev2/networkid" + "maunium.net/go/mautrix/event" ) // LoginProcess represents a single occurrence of a user logging into the remote network. @@ -178,6 +179,8 @@ const ( LoginInputFieldTypeToken LoginInputFieldType = "token" LoginInputFieldTypeURL LoginInputFieldType = "url" LoginInputFieldTypeDomain LoginInputFieldType = "domain" + LoginInputFieldTypeSelect LoginInputFieldType = "select" + LoginInputFieldTypeCaptchaCode LoginInputFieldType = "captcha_code" ) type LoginInputDataField struct { @@ -189,8 +192,13 @@ type LoginInputDataField struct { Name string `json:"name"` // The description of the field shown to the user. Description string `json:"description"` + // A default value that the client can pre-fill the field with. + DefaultValue string `json:"default_value,omitempty"` // A regex pattern that the client can use to validate input client-side. Pattern string `json:"pattern,omitempty"` + // For fields of type select, the valid options. + // Pattern may also be filled with a regex that matches the same options. + Options []string `json:"options,omitempty"` // A function that validates the input and optionally cleans it up before it's submitted to the connector. Validate func(string) (string, error) `json:"-"` } @@ -265,6 +273,23 @@ func (f *LoginInputDataField) FillDefaultValidate() { type LoginUserInputParams struct { // The fields that the user needs to fill in. Fields []LoginInputDataField `json:"fields"` + + // Attachments to display alongside the input fields. + Attachments []*LoginUserInputAttachment `json:"attachments"` +} + +type LoginUserInputAttachment struct { + Type event.MessageType `json:"type,omitempty"` + FileName string `json:"filename,omitempty"` + Content []byte `json:"content,omitempty"` + Info LoginUserInputAttachmentInfo `json:"info,omitempty"` +} + +type LoginUserInputAttachmentInfo struct { + MimeType string `json:"mimetype,omitempty"` + Width int `json:"w,omitempty"` + Height int `json:"h,omitempty"` + Size int `json:"size,omitempty"` } type LoginCompleteParams struct { diff --git a/bridgev2/matrix/connector.go b/bridgev2/matrix/connector.go index 3dd9ae1a..5a2df953 100644 --- a/bridgev2/matrix/connector.go +++ b/bridgev2/matrix/connector.go @@ -26,6 +26,7 @@ import ( _ "go.mau.fi/util/dbutil/litestream" "go.mau.fi/util/exbytes" "go.mau.fi/util/exsync" + "go.mau.fi/util/ptr" "go.mau.fi/util/random" "golang.org/x/sync/semaphore" @@ -80,6 +81,8 @@ type Connector struct { MediaConfig mautrix.RespMediaConfig SpecVersions *mautrix.RespVersions + SpecCaps *mautrix.RespCapabilities + specCapsLock sync.Mutex Capabilities *bridgev2.MatrixCapabilities IgnoreUnsupportedServer bool @@ -141,16 +144,20 @@ func (br *Connector) Init(bridge *bridgev2.Bridge) { br.EventProcessor.On(event.EventReaction, br.handleRoomEvent) br.EventProcessor.On(event.EventRedaction, br.handleRoomEvent) br.EventProcessor.On(event.EventEncrypted, br.handleEncryptedEvent) + br.EventProcessor.On(event.EphemeralEventEncrypted, br.handleEncryptedEvent) br.EventProcessor.On(event.StateMember, br.handleRoomEvent) br.EventProcessor.On(event.StatePowerLevels, br.handleRoomEvent) br.EventProcessor.On(event.StateRoomName, br.handleRoomEvent) + br.EventProcessor.On(event.BeeperSendState, br.handleRoomEvent) br.EventProcessor.On(event.StateRoomAvatar, br.handleRoomEvent) br.EventProcessor.On(event.StateTopic, br.handleRoomEvent) br.EventProcessor.On(event.StateTombstone, br.handleRoomEvent) br.EventProcessor.On(event.StateBeeperDisappearingTimer, br.handleRoomEvent) br.EventProcessor.On(event.BeeperDeleteChat, br.handleRoomEvent) + br.EventProcessor.On(event.BeeperAcceptMessageRequest, br.handleRoomEvent) br.EventProcessor.On(event.EphemeralEventReceipt, br.handleEphemeralEvent) br.EventProcessor.On(event.EphemeralEventTyping, br.handleEphemeralEvent) + br.EventProcessor.On(event.BeeperEphemeralEventAIStream, br.handleEphemeralEvent) br.Bot = br.AS.BotIntent() br.Crypto = NewCryptoHelper(br) br.Bridge.Commands.(*commands.Processor).AddHandlers( @@ -275,7 +282,7 @@ func (br *Connector) GetPublicAddress() string { if br.Config.AppService.PublicAddress == "https://bridge.example.com" { return "" } - return br.Config.AppService.PublicAddress + return strings.TrimRight(br.Config.AppService.PublicAddress, "/") } func (br *Connector) GetRouter() *http.ServeMux { @@ -337,16 +344,18 @@ func (br *Connector) logInitialRequestError(err error, defaultMessage string) { } func (br *Connector) ensureConnection(ctx context.Context) { + triedToRegister := false for { versions, err := br.Bot.Versions(ctx) if err != nil { - if errors.Is(err, mautrix.MForbidden) { + if errors.Is(err, mautrix.MForbidden) && !triedToRegister { br.Log.Debug().Msg("M_FORBIDDEN in /versions, trying to register before retrying") err = br.Bot.EnsureRegistered(ctx) if err != nil { br.logInitialRequestError(err, "Failed to register after /versions failed with M_FORBIDDEN") os.Exit(16) } + triedToRegister = true } else if errors.Is(err, mautrix.MUnknownToken) || errors.Is(err, mautrix.MExclusive) { br.logInitialRequestError(err, "/versions request failed with auth error") os.Exit(16) @@ -359,6 +368,9 @@ func (br *Connector) ensureConnection(ctx context.Context) { *br.AS.SpecVersions = *versions br.Capabilities.AutoJoinInvites = br.SpecVersions.Supports(mautrix.BeeperFeatureAutojoinInvites) br.Capabilities.BatchSending = br.SpecVersions.Supports(mautrix.BeeperFeatureBatchSending) + br.Capabilities.ArbitraryMemberChange = br.SpecVersions.Supports(mautrix.BeeperFeatureArbitraryMemberChange) + br.Capabilities.ExtraProfileMeta = br.SpecVersions.Supports(mautrix.BeeperFeatureArbitraryProfileMeta) || + (br.SpecVersions.Supports(mautrix.FeatureArbitraryProfileFields) && br.Config.Matrix.GhostExtraProfileInfo) break } } @@ -403,6 +415,21 @@ func (br *Connector) ensureConnection(ctx context.Context) { br.Bot.EnsureAppserviceConnection(ctx) } +func (br *Connector) fetchCapabilities(ctx context.Context) *mautrix.RespCapabilities { + br.specCapsLock.Lock() + defer br.specCapsLock.Unlock() + if br.SpecCaps != nil { + return br.SpecCaps + } + caps, err := br.Bot.Capabilities(ctx) + if err != nil { + br.Log.Err(err).Msg("Failed to fetch capabilities from homeserver") + return nil + } + br.SpecCaps = caps + return caps +} + func (br *Connector) fetchMediaConfig(ctx context.Context) { cfg, err := br.Bot.GetMediaConfig(ctx) if err != nil { @@ -511,7 +538,8 @@ func (br *Connector) internalSendMessageStatus(ctx context.Context, ms *bridgev2 Msg("Failed to send MSS event") } } - if ms.SendNotice && br.Config.Matrix.MessageErrorNotices && (ms.Status == event.MessageStatusFail || ms.Status == event.MessageStatusRetriable || ms.Step == status.MsgStepDecrypted) { + if ms.SendNotice && br.Config.Matrix.MessageErrorNotices && evt.MessageType != event.MsgNotice && + (ms.Status == event.MessageStatusFail || ms.Status == event.MessageStatusRetriable || ms.Step == status.MsgStepDecrypted) { content := ms.ToNoticeEvent(evt) if editEvent != "" { content.SetEdit(editEvent) @@ -595,13 +623,28 @@ func (br *Connector) GetPowerLevels(ctx context.Context, roomID id.RoomID) (*eve } func (br *Connector) GetStateEvent(ctx context.Context, roomID id.RoomID, eventType event.Type, stateKey string) (*event.Event, error) { - if eventType == event.StateCreate && stateKey == "" { - createEvt, err := br.Bot.StateStore.GetCreate(ctx, roomID) - if err != nil || createEvt != nil { - return createEvt, err + if stateKey == "" { + switch eventType { + case event.StateCreate: + createEvt, err := br.Bot.StateStore.GetCreate(ctx, roomID) + if err != nil || createEvt != nil { + return createEvt, err + } + case event.StateJoinRules: + joinRulesContent, err := br.Bot.StateStore.GetJoinRules(ctx, roomID) + if err != nil { + return nil, err + } else if joinRulesContent != nil { + return &event.Event{ + Type: event.StateJoinRules, + RoomID: roomID, + StateKey: ptr.Ptr(""), + Content: event.Content{Parsed: joinRulesContent}, + }, nil + } } } - return br.Bot.FullStateEvent(ctx, roomID, eventType, "") + return br.Bot.FullStateEvent(ctx, roomID, eventType, stateKey) } func (br *Connector) GetMembers(ctx context.Context, roomID id.RoomID) (map[id.UserID]*event.MemberEventContent, error) { diff --git a/bridgev2/matrix/crypto.go b/bridgev2/matrix/crypto.go index f4a2e9a0..7f18f1f5 100644 --- a/bridgev2/matrix/crypto.go +++ b/bridgev2/matrix/crypto.go @@ -38,9 +38,9 @@ func init() { var _ crypto.StateStore = (*sqlstatestore.SQLStateStore)(nil) -var NoSessionFound = crypto.NoSessionFound -var DuplicateMessageIndex = crypto.DuplicateMessageIndex -var UnknownMessageIndex = olm.UnknownMessageIndex +var NoSessionFound = crypto.ErrNoSessionFound +var DuplicateMessageIndex = crypto.ErrDuplicateMessageIndex +var UnknownMessageIndex = olm.ErrUnknownMessageIndex type CryptoHelper struct { bridge *Connector @@ -439,7 +439,7 @@ func (helper *CryptoHelper) Encrypt(ctx context.Context, roomID id.RoomID, evtTy var encrypted *event.EncryptedEventContent encrypted, err = helper.mach.EncryptMegolmEvent(ctx, roomID, evtType, content) if err != nil { - if !errors.Is(err, crypto.SessionExpired) && !errors.Is(err, crypto.SessionNotShared) && !errors.Is(err, crypto.NoGroupSession) { + if !errors.Is(err, crypto.ErrSessionExpired) && !errors.Is(err, crypto.ErrSessionNotShared) && !errors.Is(err, crypto.ErrNoGroupSession) { return } helper.log.Debug().Err(err). diff --git a/bridgev2/matrix/intent.go b/bridgev2/matrix/intent.go index ab59a582..f7254bd4 100644 --- a/bridgev2/matrix/intent.go +++ b/bridgev2/matrix/intent.go @@ -9,6 +9,7 @@ package matrix import ( "bytes" "context" + "encoding/json" "errors" "fmt" "io" @@ -27,6 +28,7 @@ import ( "maunium.net/go/mautrix/bridgev2" "maunium.net/go/mautrix/bridgev2/bridgeconfig" "maunium.net/go/mautrix/crypto/attachment" + "maunium.net/go/mautrix/crypto/canonicaljson" "maunium.net/go/mautrix/event" "maunium.net/go/mautrix/id" "maunium.net/go/mautrix/pushrules" @@ -43,13 +45,13 @@ type ASIntent struct { var _ bridgev2.MatrixAPI = (*ASIntent)(nil) var _ bridgev2.MarkAsDMMatrixAPI = (*ASIntent)(nil) +var _ bridgev2.EphemeralSendingMatrixAPI = (*ASIntent)(nil) func (as *ASIntent) SendMessage(ctx context.Context, roomID id.RoomID, eventType event.Type, content *event.Content, extra *bridgev2.MatrixSendExtra) (*mautrix.RespSendEvent, error) { if extra == nil { extra = &bridgev2.MatrixSendExtra{} } - // TODO remove this once hungryserv and synapse support sending m.room.redactions directly in all room versions - if eventType == event.EventRedaction { + if eventType == event.EventRedaction && !as.Connector.SpecVersions.Supports(mautrix.FeatureRedactSendAsEvent) { parsedContent := content.Parsed.(*event.RedactionEventContent) as.Matrix.AddDoublePuppetValue(content) return as.Matrix.RedactEvent(ctx, roomID, parsedContent.Redacts, mautrix.ReqRedact{ @@ -57,7 +59,7 @@ func (as *ASIntent) SendMessage(ctx context.Context, roomID id.RoomID, eventType Extra: content.Raw, }) } - if eventType != event.EventReaction && eventType != event.EventRedaction { + if (eventType != event.EventReaction || as.Connector.Config.Encryption.MSC4392) && eventType != event.EventRedaction { msgContent, ok := content.Parsed.(*event.MessageEventContent) if ok { msgContent.AddPerMessageProfileFallback() @@ -82,16 +84,27 @@ func (as *ASIntent) SendMessage(ctx context.Context, roomID id.RoomID, eventType eventType = event.EventEncrypted } } - if extra.Timestamp.IsZero() { - return as.Matrix.SendMessageEvent(ctx, roomID, eventType, content) - } else { - return as.Matrix.SendMassagedMessageEvent(ctx, roomID, eventType, content, extra.Timestamp.UnixMilli()) + return as.Matrix.SendMessageEvent(ctx, roomID, eventType, content, mautrix.ReqSendEvent{Timestamp: extra.Timestamp.UnixMilli()}) +} + +func (as *ASIntent) BeeperSendEphemeralEvent(ctx context.Context, roomID id.RoomID, eventType event.Type, content *event.Content, txnID string) (*mautrix.RespSendEvent, error) { + if !as.Connector.SpecVersions.Supports(mautrix.BeeperFeatureEphemeralEvents) { + return nil, mautrix.MUnrecognized.WithMessage("Homeserver does not advertise com.beeper.ephemeral support") } + if encrypted, err := as.Matrix.StateStore.IsEncrypted(ctx, roomID); err != nil { + return nil, fmt.Errorf("failed to check if room is encrypted: %w", err) + } else if encrypted && as.Connector.Crypto != nil { + if err = as.Connector.Crypto.Encrypt(ctx, roomID, eventType, content); err != nil { + return nil, err + } + eventType = event.EventEncrypted + } + return as.Matrix.BeeperSendEphemeralEvent(ctx, roomID, eventType, content, mautrix.ReqSendEvent{TransactionID: txnID}) } func (as *ASIntent) fillMemberEvent(ctx context.Context, roomID id.RoomID, userID id.UserID, content *event.Content) { - targetContent := content.Parsed.(*event.MemberEventContent) - if targetContent.Displayname != "" || targetContent.AvatarURL != "" { + targetContent, ok := content.Parsed.(*event.MemberEventContent) + if !ok || targetContent.Displayname != "" || targetContent.AvatarURL != "" { return } memberContent, err := as.Matrix.StateStore.TryGetMember(ctx, roomID, userID) @@ -126,11 +139,7 @@ func (as *ASIntent) SendState(ctx context.Context, roomID id.RoomID, eventType e if eventType == event.StateMember { as.fillMemberEvent(ctx, roomID, id.UserID(stateKey), content) } - if ts.IsZero() { - resp, err = as.Matrix.SendStateEvent(ctx, roomID, eventType, stateKey, content) - } else { - resp, err = as.Matrix.SendMassagedStateEvent(ctx, roomID, eventType, stateKey, content, ts.UnixMilli()) - } + resp, err = as.Matrix.SendStateEvent(ctx, roomID, eventType, stateKey, content, mautrix.ReqSendEvent{Timestamp: ts.UnixMilli()}) if err != nil && eventType == event.StateMember { var httpErr mautrix.HTTPError if errors.As(err, &httpErr) && httpErr.RespError != nil && @@ -412,6 +421,7 @@ func (as *ASIntent) UploadMediaStream( removeAndClose(replFile) removeAndClose(tempFile) } + req.AsyncContext = zerolog.Ctx(ctx).WithContext(as.Connector.Bridge.BackgroundCtx) startedAsyncUpload = true var resp *mautrix.RespCreateMXC resp, err = as.Matrix.UploadAsync(ctx, req) @@ -444,6 +454,7 @@ func (as *ASIntent) doUploadReq(ctx context.Context, file *event.EncryptedFileIn as.Connector.uploadSema.Release(int64(len(req.ContentBytes))) } } + req.AsyncContext = zerolog.Ctx(ctx).WithContext(as.Connector.Bridge.BackgroundCtx) var resp *mautrix.RespCreateMXC resp, err = as.Matrix.UploadAsync(ctx, req) if resp != nil { @@ -475,11 +486,62 @@ func (as *ASIntent) SetAvatarURL(ctx context.Context, avatarURL id.ContentURIStr return as.Matrix.SetAvatarURL(ctx, parsedAvatarURL) } -func (as *ASIntent) SetExtraProfileMeta(ctx context.Context, data any) error { - if !as.Connector.SpecVersions.Supports(mautrix.BeeperFeatureArbitraryProfileMeta) { - return nil +func dataToFields(data any) (map[string]json.RawMessage, error) { + fields, ok := data.(map[string]json.RawMessage) + if ok { + return fields, nil } - return as.Matrix.BeeperUpdateProfile(ctx, data) + d, err := json.Marshal(data) + if err != nil { + return nil, err + } + d = canonicaljson.CanonicalJSONAssumeValid(d) + err = json.Unmarshal(d, &fields) + return fields, err +} + +func marshalField(val any) json.RawMessage { + data, _ := json.Marshal(val) + if len(data) > 0 && (data[0] == '{' || data[0] == '[') { + return canonicaljson.CanonicalJSONAssumeValid(data) + } + return data +} + +var nullJSON = json.RawMessage("null") + +func (as *ASIntent) SetExtraProfileMeta(ctx context.Context, data any) error { + if as.Connector.SpecVersions.Supports(mautrix.BeeperFeatureArbitraryProfileMeta) { + return as.Matrix.BeeperUpdateProfile(ctx, data) + } else if as.Connector.SpecVersions.Supports(mautrix.FeatureArbitraryProfileFields) && as.Connector.Config.Matrix.GhostExtraProfileInfo { + fields, err := dataToFields(data) + if err != nil { + return fmt.Errorf("failed to marshal fields: %w", err) + } + currentProfile, err := as.Matrix.GetProfile(ctx, as.Matrix.UserID) + if err != nil { + return fmt.Errorf("failed to get current profile: %w", err) + } + for key, val := range fields { + existing, ok := currentProfile.Extra[key] + if !ok { + if bytes.Equal(val, nullJSON) { + continue + } + err = as.Matrix.SetProfileField(ctx, key, val) + } else if !bytes.Equal(marshalField(existing), val) { + if bytes.Equal(val, nullJSON) { + err = as.Matrix.DeleteProfileField(ctx, key) + } else { + err = as.Matrix.SetProfileField(ctx, key, val) + } + } + if err != nil { + return fmt.Errorf("failed to set profile field %q: %w", key, err) + } + } + } + return nil } func (as *ASIntent) GetMXID() id.UserID { @@ -521,6 +583,39 @@ func (br *Connector) getDefaultEncryptionEvent() *event.EncryptionEventContent { return content } +func (as *ASIntent) filterCreateRequestForV12(ctx context.Context, req *mautrix.ReqCreateRoom) { + if as.Connector.Config.Homeserver.Software == bridgeconfig.SoftwareHungry { + // Hungryserv doesn't override the capabilities endpoint nor do room versions + return + } + caps := as.Connector.fetchCapabilities(ctx) + roomVer := req.RoomVersion + if roomVer == "" && caps != nil && caps.RoomVersions != nil { + roomVer = id.RoomVersion(caps.RoomVersions.Default) + } + if roomVer != "" && !roomVer.PrivilegedRoomCreators() { + return + } + creators, _ := req.CreationContent["additional_creators"].([]id.UserID) + creators = append(slices.Clone(creators), as.GetMXID()) + if req.PowerLevelOverride != nil { + for _, creator := range creators { + delete(req.PowerLevelOverride.Users, creator) + } + } + for _, evt := range req.InitialState { + if evt.Type != event.StatePowerLevels { + continue + } + content, ok := evt.Content.Parsed.(*event.PowerLevelsEventContent) + if ok { + for _, creator := range creators { + delete(content.Users, creator) + } + } + } +} + func (as *ASIntent) CreateRoom(ctx context.Context, req *mautrix.ReqCreateRoom) (id.RoomID, error) { if as.Connector.Config.Encryption.Default { req.InitialState = append(req.InitialState, &event.Event{ @@ -536,6 +631,7 @@ func (as *ASIntent) CreateRoom(ctx context.Context, req *mautrix.ReqCreateRoom) } req.CreationContent["m.federate"] = false } + as.filterCreateRequestForV12(ctx, req) resp, err := as.Matrix.CreateRoom(ctx, req) if err != nil { return "", err @@ -689,10 +785,10 @@ func (as *ASIntent) GetEvent(ctx context.Context, roomID id.RoomID, eventID id.E } if evt.Type == event.EventEncrypted { - if as.Connector.Config.Encryption.DeleteKeys.RatchetOnDecrypt { + if as.Connector.Crypto == nil || as.Connector.Config.Encryption.DeleteKeys.RatchetOnDecrypt { return nil, errors.New("can't decrypt the event") } - return as.Matrix.Crypto.Decrypt(ctx, evt) + return as.Connector.Crypto.Decrypt(ctx, evt) } return evt, nil diff --git a/bridgev2/matrix/matrix.go b/bridgev2/matrix/matrix.go index 64165941..954d0ad9 100644 --- a/bridgev2/matrix/matrix.go +++ b/bridgev2/matrix/matrix.go @@ -27,6 +27,11 @@ func (br *Connector) handleRoomEvent(ctx context.Context, evt *event.Event) { if br.shouldIgnoreEvent(evt) { return } + if !br.Config.Bridge.Permissions.Get(evt.Sender).SendEvents && evt.Type != event.StateMember { + zerolog.Ctx(ctx).Debug().Msg("Dropping event from user with no permission to send events") + br.SendMessageStatus(ctx, &bridgev2.ErrNoPermissionToInteract, bridgev2.StatusEventInfoFromEvent(evt)) + return + } if (evt.Type == event.EventMessage || evt.Type == event.EventSticker) && !evt.Mautrix.WasEncrypted && br.Config.Encryption.Require { zerolog.Ctx(ctx).Warn().Msg("Dropping unencrypted event as encryption is configured to be required") br.sendCryptoStatusError(ctx, evt, errMessageNotEncrypted, nil, 0, true) @@ -63,6 +68,10 @@ func (br *Connector) handleEphemeralEvent(ctx context.Context, evt *event.Event) case event.EphemeralEventTyping: typingContent := evt.Content.AsTyping() typingContent.UserIDs = slices.DeleteFunc(typingContent.UserIDs, br.shouldIgnoreEventFromUser) + case event.BeeperEphemeralEventAIStream: + if br.shouldIgnoreEvent(evt) { + return + } } br.Bridge.QueueMatrixEvent(ctx, evt) } @@ -76,6 +85,11 @@ func (br *Connector) handleEncryptedEvent(ctx context.Context, evt *event.Event) Str("event_id", evt.ID.String()). Str("session_id", content.SessionID.String()). Logger() + if !br.Config.Bridge.Permissions.Get(evt.Sender).SendEvents { + log.Debug().Msg("Dropping event from user with no permission to send events") + br.SendMessageStatus(ctx, &bridgev2.ErrNoPermissionToInteract, bridgev2.StatusEventInfoFromEvent(evt)) + return + } ctx = log.WithContext(ctx) if br.Crypto == nil { br.sendCryptoStatusError(ctx, evt, errNoCrypto, nil, 0, true) @@ -117,6 +131,7 @@ func (br *Connector) waitLongerForSession(ctx context.Context, evt *event.Event, Int("wait_seconds", int(extendedSessionWaitTimeout.Seconds())). Msg("Couldn't find session, requesting keys and waiting longer...") + //lint:ignore SA1019 RequestSession will gracefully request from all devices if DeviceID is blank go br.Crypto.RequestSession(ctx, evt.RoomID, content.SenderKey, content.SessionID, evt.Sender, content.DeviceID) go br.sendCryptoStatusError(ctx, evt, fmt.Errorf("%w. The bridge will retry for %d seconds", errNoDecryptionKeys, int(extendedSessionWaitTimeout.Seconds())), errorEventID, 1, false) @@ -220,7 +235,6 @@ func (br *Connector) postDecrypt(ctx context.Context, original, decrypted *event go br.sendSuccessCheckpoint(ctx, decrypted, status.MsgStepDecrypted, retryCount) decrypted.Mautrix.CheckpointSent = true decrypted.Mautrix.DecryptionDuration = duration - decrypted.Mautrix.EventSource |= event.SourceDecrypted br.EventProcessor.Dispatch(ctx, decrypted) if errorEventID != nil && *errorEventID != "" { _, _ = br.Bot.RedactEvent(ctx, decrypted.RoomID, *errorEventID) diff --git a/bridgev2/matrix/mxmain/dberror.go b/bridgev2/matrix/mxmain/dberror.go index 0f6aa68c..f5e438de 100644 --- a/bridgev2/matrix/mxmain/dberror.go +++ b/bridgev2/matrix/mxmain/dberror.go @@ -66,7 +66,12 @@ func (br *BridgeMain) LogDBUpgradeErrorAndExit(name string, err error, message s } else if errors.Is(err, dbutil.ErrForeignTables) { br.Log.Info().Msg("See https://docs.mau.fi/faq/foreign-tables for more info") } else if errors.Is(err, dbutil.ErrNotOwned) { - br.Log.Info().Msg("Sharing the same database with different programs is not supported") + var noe dbutil.NotOwnedError + if errors.As(err, &noe) && noe.Owner == br.Name { + br.Log.Info().Msg("The database appears to be on a very old pre-megabridge schema. Perhaps you need to run an older version of the bridge with migration support first?") + } else { + br.Log.Info().Msg("Sharing the same database with different programs is not supported") + } } else if errors.Is(err, dbutil.ErrUnsupportedDatabaseVersion) { br.Log.Info().Msg("Downgrading the bridge is not supported") } diff --git a/bridgev2/matrix/mxmain/envconfig.go b/bridgev2/matrix/mxmain/envconfig.go new file mode 100644 index 00000000..1b4f1467 --- /dev/null +++ b/bridgev2/matrix/mxmain/envconfig.go @@ -0,0 +1,161 @@ +// Copyright (c) 2025 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package mxmain + +import ( + "fmt" + "iter" + "os" + "reflect" + "strconv" + "strings" + + "go.mau.fi/util/random" +) + +var randomParseFilePrefix = random.String(16) + "READFILE:" + +func parseEnv(prefix string) iter.Seq2[[]string, string] { + return func(yield func([]string, string) bool) { + for _, s := range os.Environ() { + if !strings.HasPrefix(s, prefix) { + continue + } + kv := strings.SplitN(s, "=", 2) + key := strings.TrimPrefix(kv[0], prefix) + value := kv[1] + if strings.HasSuffix(key, "_FILE") { + key = strings.TrimSuffix(key, "_FILE") + value = randomParseFilePrefix + value + } + key = strings.ToLower(key) + if !strings.ContainsRune(key, '.') { + key = strings.ReplaceAll(key, "__", ".") + } + if !yield(strings.Split(key, "."), value) { + return + } + } + } +} + +func reflectYAMLFieldName(f *reflect.StructField) string { + parts := strings.SplitN(f.Tag.Get("yaml"), ",", 2) + fieldName := parts[0] + if fieldName == "-" && len(parts) == 1 { + return "" + } + if fieldName == "" { + return strings.ToLower(f.Name) + } + return fieldName +} + +type reflectGetResult struct { + val reflect.Value + valKind reflect.Kind + remainingPath []string +} + +func reflectGetYAML(rv reflect.Value, path []string) (*reflectGetResult, bool) { + if len(path) == 0 { + return &reflectGetResult{val: rv, valKind: rv.Kind()}, true + } + if rv.Kind() == reflect.Ptr { + rv = rv.Elem() + } + switch rv.Kind() { + case reflect.Map: + return &reflectGetResult{val: rv, remainingPath: path, valKind: rv.Type().Elem().Kind()}, true + case reflect.Struct: + fields := reflect.VisibleFields(rv.Type()) + for _, field := range fields { + fieldName := reflectYAMLFieldName(&field) + if fieldName != "" && fieldName == path[0] { + return reflectGetYAML(rv.FieldByIndex(field.Index), path[1:]) + } + } + default: + } + return nil, false +} + +func reflectGetFromMainOrNetwork(main, network reflect.Value, path []string) (*reflectGetResult, bool) { + if len(path) > 0 && path[0] == "network" { + return reflectGetYAML(network, path[1:]) + } + return reflectGetYAML(main, path) +} + +func formatKeyString(key []string) string { + return strings.Join(key, "->") +} + +func UpdateConfigFromEnv(cfg, networkData any, prefix string) error { + cfgVal := reflect.ValueOf(cfg) + networkVal := reflect.ValueOf(networkData) + for key, value := range parseEnv(prefix) { + field, ok := reflectGetFromMainOrNetwork(cfgVal, networkVal, key) + if !ok { + return fmt.Errorf("%s not found", formatKeyString(key)) + } + if strings.HasPrefix(value, randomParseFilePrefix) { + filepath := strings.TrimPrefix(value, randomParseFilePrefix) + fileData, err := os.ReadFile(filepath) + if err != nil { + return fmt.Errorf("failed to read file %s for %s: %w", filepath, formatKeyString(key), err) + } + value = strings.TrimSpace(string(fileData)) + } + var parsedVal any + var err error + switch field.valKind { + case reflect.String: + parsedVal = value + case reflect.Bool: + parsedVal, err = strconv.ParseBool(value) + if err != nil { + return fmt.Errorf("invalid value for %s: %w", formatKeyString(key), err) + } + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + parsedVal, err = strconv.ParseInt(value, 10, 64) + if err != nil { + return fmt.Errorf("invalid value for %s: %w", formatKeyString(key), err) + } + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + parsedVal, err = strconv.ParseUint(value, 10, 64) + if err != nil { + return fmt.Errorf("invalid value for %s: %w", formatKeyString(key), err) + } + case reflect.Float32, reflect.Float64: + parsedVal, err = strconv.ParseFloat(value, 64) + if err != nil { + return fmt.Errorf("invalid value for %s: %w", formatKeyString(key), err) + } + default: + return fmt.Errorf("unsupported type %s in %s", field.valKind, formatKeyString(key)) + } + if field.val.Kind() == reflect.Ptr { + if field.val.IsNil() { + field.val.Set(reflect.New(field.val.Type().Elem())) + } + field.val = field.val.Elem() + } + if field.val.Kind() == reflect.Map { + key = key[:len(key)-len(field.remainingPath)] + mapKeyStr := strings.Join(field.remainingPath, ".") + key = append(key, mapKeyStr) + if field.val.Type().Key().Kind() != reflect.String { + return fmt.Errorf("unsupported map key type %s in %s", field.val.Type().Key().Kind(), formatKeyString(key)) + } + field.val.SetMapIndex(reflect.ValueOf(mapKeyStr), reflect.ValueOf(parsedVal)) + } else { + field.val.Set(reflect.ValueOf(parsedVal)) + } + } + return nil +} diff --git a/bridgev2/matrix/mxmain/example-config.yaml b/bridgev2/matrix/mxmain/example-config.yaml index d8634028..ccc81c4b 100644 --- a/bridgev2/matrix/mxmain/example-config.yaml +++ b/bridgev2/matrix/mxmain/example-config.yaml @@ -29,6 +29,9 @@ bridge: # How long after an unknown error should the bridge attempt a full reconnect? # Must be at least 1 minute. The bridge will add an extra ±20% jitter to this value. unknown_error_auto_reconnect: null + # Maximum number of times to do the auto-reconnect above. + # The counter is per login, but is never reset except on logout and restart. + unknown_error_max_auto_reconnects: 10 # Should leaving Matrix rooms be bridged as leaving groups on the remote network? bridge_matrix_leave: false @@ -47,6 +50,11 @@ bridge: # Should cross-room reply metadata be bridged? # Most Matrix clients don't support this and servers may reject such messages too. cross_room_replies: false + # If a state event fails to bridge, should the bridge revert any state changes made by that event? + revert_failed_state_changes: false + # In portals with no relay set, should Matrix users be kicked if they're + # not logged into an account that's in the remote chat? + kick_matrix_users: true # What should be done to portal rooms when a user logs out or is logged out? # Permitted values: @@ -236,6 +244,9 @@ matrix: # The threshold as bytes after which the bridge should roundtrip uploads via the disk # rather than keeping the whole file in memory. upload_file_threshold: 5242880 + # Should the bridge set additional custom profile info for ghosts? + # This can make a lot of requests, as there's no batch profile update endpoint. + ghost_extra_profile_info: false # Segment-compatible analytics endpoint for tracking some events, like provisioning API login and encryption errors. analytics: @@ -275,6 +286,14 @@ public_media: expiry: 0 # Length of hash to use for public media URLs. Must be between 0 and 32. hash_length: 32 + # The path prefix for generated URLs. Note that this will NOT change the path where media is actually served. + # If you change this, you must configure your reverse proxy to rewrite the path accordingly. + path_prefix: /_mautrix/publicmedia + # Should the bridge store media metadata in the database in order to support encrypted media and generate shorter URLs? + # If false, the generated URLs will just have the MXC URI and a HMAC signature. + # The hash_length field will be used to decide the length of the generated URL. + # This also allows invalidating URLs by deleting the database entry. + use_database: false # Settings for converting remote media to custom mxc:// URIs instead of reuploading. # More details can be found at https://docs.mau.fi/bridges/go/discord/direct-media.html @@ -365,6 +384,8 @@ encryption: # Only relevant when using end-to-bridge encryption, required when using encryption with next-gen auth (MSC3861). # Changing this option requires updating the appservice registration file. msc4190: false + # Whether to encrypt reactions and reply metadata as per MSC4392. + msc4392: false # Should the bridge bot generate a recovery key and cross-signing keys and verify itself? # Note that without the latest version of MSC4190, this will fail if you reset the bridge database. # The generated recovery key will be saved in the kv_store table under `recovery_key`. @@ -431,6 +452,16 @@ encryption: # You should not enable this option unless you understand all the implications. disable_device_change_key_rotation: false +# Prefix for environment variables. All variables with this prefix must map to valid config fields. +# Nesting in variable names is represented with a dot (.). +# If there are no dots in the name, two underscores (__) are replaced with a dot. +# +# e.g. if the prefix is set to `BRIDGE_`, then `BRIDGE_APPSERVICE__AS_TOKEN` will set appservice.as_token. +# `BRIDGE_appservice.as_token` would work as well, but can't be set in a shell as easily. +# +# If this is null, reading config fields from environment will be disabled. +env_config_prefix: null + # Logging config. See https://github.com/tulir/zeroconfig for details. logging: min_level: debug diff --git a/bridgev2/matrix/mxmain/legacymigrate.go b/bridgev2/matrix/mxmain/legacymigrate.go index c8eb820b..97cdeddf 100644 --- a/bridgev2/matrix/mxmain/legacymigrate.go +++ b/bridgev2/matrix/mxmain/legacymigrate.go @@ -135,7 +135,10 @@ func (br *BridgeMain) CheckLegacyDB( } var dbVersion int err = br.DB.QueryRow(ctx, "SELECT version FROM version").Scan(&dbVersion) - if dbVersion < expectedVersion { + if err != nil { + log.Fatal().Err(err).Msg("Failed to get database version") + return + } else if dbVersion < expectedVersion { log.Fatal(). Int("expected_version", expectedVersion). Int("version", dbVersion). diff --git a/bridgev2/matrix/mxmain/main.go b/bridgev2/matrix/mxmain/main.go index ca0ca5f7..1e8b51d1 100644 --- a/bridgev2/matrix/mxmain/main.go +++ b/bridgev2/matrix/mxmain/main.go @@ -354,6 +354,13 @@ func (br *BridgeMain) LoadConfig() { } } cfg.Bridge.Backfill = cfg.Backfill + if cfg.EnvConfigPrefix != "" { + err = UpdateConfigFromEnv(&cfg, networkData, cfg.EnvConfigPrefix) + if err != nil { + _, _ = fmt.Fprintln(os.Stderr, "Failed to parse environment variables:", err) + os.Exit(10) + } + } br.Config = &cfg } diff --git a/bridgev2/matrix/provisioning.go b/bridgev2/matrix/provisioning.go index 61aad869..243b91da 100644 --- a/bridgev2/matrix/provisioning.go +++ b/bridgev2/matrix/provisioning.go @@ -85,10 +85,9 @@ const ( provisioningUserKey provisioningContextKey = iota provisioningUserLoginKey provisioningLoginProcessKey + ProvisioningKeyRequest ) -const ProvisioningKeyRequest = "fi.mau.provision.request" - func (prov *ProvisioningAPI) GetUser(r *http.Request) *bridgev2.User { return r.Context().Value(provisioningUserKey).(*bridgev2.User) } @@ -97,12 +96,7 @@ func (prov *ProvisioningAPI) GetRouter() *http.ServeMux { return prov.Router } -type IProvisioningAPI interface { - GetRouter() *http.ServeMux - GetUser(r *http.Request) *bridgev2.User -} - -func (br *Connector) GetProvisioning() IProvisioningAPI { +func (br *Connector) GetProvisioning() bridgev2.IProvisioningAPI { return br.Provisioning } @@ -330,7 +324,7 @@ func (prov *ProvisioningAPI) GetWhoami(w http.ResponseWriter, r *http.Request) { prevState.UserID = "" prevState.RemoteID = "" prevState.RemoteName = "" - prevState.RemoteProfile = nil + prevState.RemoteProfile = status.RemoteProfile{} resp.Logins[i] = RespWhoamiLogin{ StateEvent: prevState.StateEvent, StateTS: prevState.Timestamp, @@ -367,17 +361,19 @@ func (prov *ProvisioningAPI) GetCapabilities(w http.ResponseWriter, r *http.Requ } var ErrNilStep = errors.New("bridge returned nil step with no error") +var ErrTooManyLogins = bridgev2.RespError{ErrCode: "FI.MAU.BRIDGE.TOO_MANY_LOGINS", Err: "Maximum number of logins exceeded"} func (prov *ProvisioningAPI) PostLoginStart(w http.ResponseWriter, r *http.Request) { overrideLogin, failed := prov.GetExplicitLoginForRequest(w, r) if failed { return } - login, err := prov.net.CreateLogin( - r.Context(), - prov.GetUser(r), - r.PathValue("flowID"), - ) + user := prov.GetUser(r) + if overrideLogin == nil && user.HasTooManyLogins() { + ErrTooManyLogins.AppendMessage(" (%d)", user.Permissions.MaxLogins).Write(w) + return + } + login, err := prov.net.CreateLogin(r.Context(), user, r.PathValue("flowID")) if err != nil { zerolog.Ctx(r.Context()).Err(err).Msg("Failed to create login process") RespondWithError(w, err, "Internal error creating login process") @@ -407,10 +403,18 @@ func (prov *ProvisioningAPI) PostLoginStart(w http.ResponseWriter, r *http.Reque Override: overrideLogin, } prov.loginsLock.Unlock() + zerolog.Ctx(r.Context()).Info(). + Any("first_step", firstStep). + Msg("Created login process") exhttp.WriteJSONResponse(w, http.StatusOK, &RespSubmitLogin{LoginID: loginID, LoginStep: firstStep}) } func (prov *ProvisioningAPI) handleCompleteStep(ctx context.Context, login *ProvLogin, step *bridgev2.LoginStep) { + zerolog.Ctx(ctx).Info(). + Str("step_id", step.StepID). + Str("user_login_id", string(step.CompleteParams.UserLoginID)). + Msg("Login completed successfully") + prov.deleteLogin(login, false) if login.Override == nil || login.Override.ID == step.CompleteParams.UserLoginID { return } @@ -424,6 +428,15 @@ func (prov *ProvisioningAPI) handleCompleteStep(ctx context.Context, login *Prov }, bridgev2.DeleteOpts{LogoutRemote: true}) } +func (prov *ProvisioningAPI) deleteLogin(login *ProvLogin, cancel bool) { + if cancel { + login.Process.Cancel() + } + prov.loginsLock.Lock() + delete(prov.logins, login.ID) + prov.loginsLock.Unlock() +} + func (prov *ProvisioningAPI) PostLoginStep(w http.ResponseWriter, r *http.Request) { loginID := r.PathValue("loginProcessID") prov.loginsLock.RLock() @@ -494,11 +507,14 @@ func (prov *ProvisioningAPI) PostLoginSubmitInput(w http.ResponseWriter, r *http if err != nil { zerolog.Ctx(r.Context()).Err(err).Msg("Failed to submit input") RespondWithError(w, err, "Internal error submitting input") + prov.deleteLogin(login, true) return } login.NextStep = nextStep if nextStep.Type == bridgev2.LoginStepTypeComplete { prov.handleCompleteStep(r.Context(), login, nextStep) + } else { + zerolog.Ctx(r.Context()).Debug().Any("next_step", nextStep).Msg("Returning next login step") } exhttp.WriteJSONResponse(w, http.StatusOK, &RespSubmitLogin{LoginID: login.ID, LoginStep: nextStep}) } @@ -512,11 +528,14 @@ func (prov *ProvisioningAPI) PostLoginWait(w http.ResponseWriter, r *http.Reques if err != nil { zerolog.Ctx(r.Context()).Err(err).Msg("Failed to wait") RespondWithError(w, err, "Internal error waiting for login") + prov.deleteLogin(login, true) return } login.NextStep = nextStep if nextStep.Type == bridgev2.LoginStepTypeComplete { prov.handleCompleteStep(r.Context(), login, nextStep) + } else { + zerolog.Ctx(r.Context()).Debug().Any("next_step", nextStep).Msg("Returning next login step") } exhttp.WriteJSONResponse(w, http.StatusOK, &RespSubmitLogin{LoginID: login.ID, LoginStep: nextStep}) } diff --git a/bridgev2/matrix/provisioning.yaml b/bridgev2/matrix/provisioning.yaml index 21c93ca4..26068db4 100644 --- a/bridgev2/matrix/provisioning.yaml +++ b/bridgev2/matrix/provisioning.yaml @@ -714,7 +714,7 @@ components: type: type: string description: The type of field. - enum: [ username, phone_number, email, password, 2fa_code, token, url, domain ] + enum: [ username, phone_number, email, password, 2fa_code, token, url, domain, select ] id: type: string description: The internal ID of the field. This must be used as the key in the object when submitting the data back to the bridge. @@ -728,10 +728,53 @@ components: description: A more detailed description of the field shown to the user. examples: - Include the country code with a + + default_value: + type: string + description: A default value that the client can pre-fill the field with. pattern: type: string format: regex description: A regular expression that the field value must match. + options: + type: array + description: For fields of type select, the valid options. + items: + type: string + attachments: + type: array + description: A list of media attachments to show the user alongside the form fields. + items: + type: object + description: A media attachment to show the user. + required: [ type, filename, content ] + properties: + type: + type: string + description: The type of media attachment, using the same media type identifiers as Matrix attachments. Only some are supported. + enum: [ m.image, m.audio ] + filename: + type: string + description: The filename for the media attachment. + content: + type: string + description: The raw file content for the attachment encoded in base64. + info: + type: object + description: Optional but recommended metadata for the attachment. Can generally be derived from the raw content if omitted. + properties: + mimetype: + type: string + description: The MIME type for the media content. + examples: [ image/png, audio/mpeg ] + w: + type: number + description: The width of the media in pixels. Only applicable for images and videos. + h: + type: number + description: The height of the media in pixels. Only applicable for images and videos. + size: + type: number + description: The size of the media content in number of bytes. Strongly recommended to include. - description: Cookie login step required: [ type, cookies ] properties: diff --git a/bridgev2/matrix/publicmedia.go b/bridgev2/matrix/publicmedia.go index 95e37262..82ea8c2b 100644 --- a/bridgev2/matrix/publicmedia.go +++ b/bridgev2/matrix/publicmedia.go @@ -7,16 +7,26 @@ package matrix import ( + "context" "crypto/hmac" "crypto/sha256" "encoding/base64" "encoding/binary" "fmt" "io" + "mime" "net/http" + "net/url" + "slices" + "strings" "time" + "github.com/rs/zerolog" + "maunium.net/go/mautrix/bridgev2" + "maunium.net/go/mautrix/bridgev2/database" + "maunium.net/go/mautrix/crypto/attachment" + "maunium.net/go/mautrix/event" "maunium.net/go/mautrix/id" ) @@ -33,7 +43,10 @@ func (br *Connector) initPublicMedia() error { return fmt.Errorf("public media hash length is negative") } br.pubMediaSigKey = []byte(br.Config.PublicMedia.SigningKey) + br.AS.Router.HandleFunc("GET /_mautrix/publicmedia/{customID}", br.serveDatabasePublicMedia) + br.AS.Router.HandleFunc("GET /_mautrix/publicmedia/{customID}/{filename}", br.serveDatabasePublicMedia) br.AS.Router.HandleFunc("GET /_mautrix/publicmedia/{server}/{mediaID}/{checksum}", br.servePublicMedia) + br.AS.Router.HandleFunc("GET /_mautrix/publicmedia/{server}/{mediaID}/{checksum}/{filename}", br.servePublicMedia) return nil } @@ -44,6 +57,20 @@ func (br *Connector) hashContentURI(uri id.ContentURI, expiry []byte) []byte { return hasher.Sum(expiry)[:br.Config.PublicMedia.HashLength+len(expiry)] } +func (br *Connector) hashDBPublicMedia(pm *database.PublicMedia) []byte { + hasher := hmac.New(sha256.New, br.pubMediaSigKey) + hasher.Write([]byte(pm.MXC.String())) + hasher.Write([]byte(pm.MimeType)) + if pm.Keys != nil { + hasher.Write([]byte(pm.Keys.Version)) + hasher.Write([]byte(pm.Keys.Key.Algorithm)) + hasher.Write([]byte(pm.Keys.Key.Key)) + hasher.Write([]byte(pm.Keys.InitVector)) + hasher.Write([]byte(pm.Keys.Hashes.SHA256)) + } + return hasher.Sum(nil)[:br.Config.PublicMedia.HashLength] +} + func (br *Connector) makePublicMediaChecksum(uri id.ContentURI) []byte { var expiresAt []byte if br.Config.PublicMedia.Expiry > 0 { @@ -93,9 +120,47 @@ func (br *Connector) servePublicMedia(w http.ResponseWriter, r *http.Request) { http.Error(w, "checksum expired", http.StatusGone) return } + br.doProxyMedia(w, r, contentURI, nil, "") +} + +func (br *Connector) serveDatabasePublicMedia(w http.ResponseWriter, r *http.Request) { + if !br.Config.PublicMedia.UseDatabase { + http.Error(w, "public media short links are disabled", http.StatusNotFound) + return + } + log := zerolog.Ctx(r.Context()) + media, err := br.Bridge.DB.PublicMedia.Get(r.Context(), r.PathValue("customID")) + if err != nil { + log.Err(err).Msg("Failed to get public media from database") + http.Error(w, "failed to get media metadata", http.StatusInternalServerError) + return + } else if media == nil { + http.Error(w, "media ID not found", http.StatusNotFound) + return + } else if !media.Expiry.IsZero() && media.Expiry.Before(time.Now()) { + // This is not gone as it can still be refreshed in the DB + http.Error(w, "media expired", http.StatusNotFound) + return + } else if media.Keys != nil && media.Keys.PrepareForDecryption() != nil { + http.Error(w, "media keys are malformed", http.StatusInternalServerError) + return + } + br.doProxyMedia(w, r, media.MXC, media.Keys, media.MimeType) +} + +var safeMimes = []string{ + "text/css", "text/plain", "text/csv", + "application/json", "application/ld+json", + "image/jpeg", "image/gif", "image/png", "image/apng", "image/webp", "image/avif", + "video/mp4", "video/webm", "video/ogg", "video/quicktime", + "audio/mp4", "audio/webm", "audio/aac", "audio/mpeg", "audio/ogg", "audio/wave", + "audio/wav", "audio/x-wav", "audio/x-pn-wav", "audio/flac", "audio/x-flac", +} + +func (br *Connector) doProxyMedia(w http.ResponseWriter, r *http.Request, contentURI id.ContentURI, encInfo *attachment.EncryptedFile, mimeType string) { resp, err := br.Bot.Download(r.Context(), contentURI) if err != nil { - br.Log.Warn().Stringer("uri", contentURI).Err(err).Msg("Failed to download media to proxy") + zerolog.Ctx(r.Context()).Warn().Stringer("uri", contentURI).Err(err).Msg("Failed to download media to proxy") http.Error(w, "failed to download media", http.StatusInternalServerError) return } @@ -103,11 +168,41 @@ func (br *Connector) servePublicMedia(w http.ResponseWriter, r *http.Request) { for _, hdr := range proxyHeadersToCopy { w.Header()[hdr] = resp.Header[hdr] } + stream := resp.Body + if encInfo != nil { + if mimeType == "" { + mimeType = "application/octet-stream" + } + contentDisposition := "attachment" + if slices.Contains(safeMimes, mimeType) { + contentDisposition = "inline" + } + dispositionArgs := map[string]string{} + if filename := r.PathValue("filename"); filename != "" { + dispositionArgs["filename"] = filename + } + w.Header().Set("Content-Type", mimeType) + w.Header().Set("Content-Disposition", mime.FormatMediaType(contentDisposition, dispositionArgs)) + // Note: this won't check the Close result like it should, but it's probably not a big deal here + stream = encInfo.DecryptStream(stream) + } else if filename := r.PathValue("filename"); filename != "" { + contentDisposition, _, _ := mime.ParseMediaType(resp.Header.Get("Content-Disposition")) + if contentDisposition == "" { + contentDisposition = "attachment" + } + w.Header().Set("Content-Disposition", mime.FormatMediaType(contentDisposition, map[string]string{ + "filename": filename, + })) + } w.WriteHeader(http.StatusOK) - _, _ = io.Copy(w, resp.Body) + _, _ = io.Copy(w, stream) } func (br *Connector) GetPublicMediaAddress(contentURI id.ContentURIString) string { + return br.getPublicMediaAddressWithFileName(contentURI, "") +} + +func (br *Connector) getPublicMediaAddressWithFileName(contentURI id.ContentURIString, fileName string) string { if br.pubMediaSigKey == nil { return "" } @@ -115,11 +210,69 @@ func (br *Connector) GetPublicMediaAddress(contentURI id.ContentURIString) strin if err != nil || !parsed.IsValid() { return "" } - return fmt.Sprintf( - "%s/_mautrix/publicmedia/%s/%s/%s", + fileName = url.PathEscape(strings.ReplaceAll(fileName, "/", "_")) + if fileName == ".." { + fileName = "" + } + parts := []string{ br.GetPublicAddress(), + strings.Trim(br.Config.PublicMedia.PathPrefix, "/"), parsed.Homeserver, parsed.FileID, base64.RawURLEncoding.EncodeToString(br.makePublicMediaChecksum(parsed)), - ) + fileName, + } + if fileName == "" { + parts = parts[:len(parts)-1] + } + return strings.Join(parts, "/") +} + +func (br *Connector) GetPublicMediaAddressForEvent(ctx context.Context, evt *event.MessageEventContent) (string, error) { + if br.pubMediaSigKey == nil { + return "", bridgev2.ErrPublicMediaDisabled + } + if !br.Config.PublicMedia.UseDatabase { + if evt.File != nil { + return "", fmt.Errorf("can't generate address for encrypted file: %w", bridgev2.ErrPublicMediaDatabaseDisabled) + } + return br.getPublicMediaAddressWithFileName(evt.URL, evt.GetFileName()), nil + } + mxc := evt.URL + var keys *attachment.EncryptedFile + if evt.File != nil { + mxc = evt.File.URL + keys = &evt.File.EncryptedFile + } + parsedMXC, err := mxc.Parse() + if err != nil { + return "", fmt.Errorf("%w: failed to parse MXC: %w", bridgev2.ErrPublicMediaGenerateFailed, err) + } + pm := &database.PublicMedia{ + MXC: parsedMXC, + Keys: keys, + MimeType: evt.GetInfo().MimeType, + } + if br.Config.PublicMedia.Expiry > 0 { + pm.Expiry = time.Now().Add(time.Duration(br.Config.PublicMedia.Expiry) * time.Second) + } + pm.PublicID = base64.RawURLEncoding.EncodeToString(br.hashDBPublicMedia(pm)) + err = br.Bridge.DB.PublicMedia.Put(ctx, pm) + if err != nil { + return "", fmt.Errorf("%w: failed to store public media in database: %w", bridgev2.ErrPublicMediaGenerateFailed, err) + } + fileName := url.PathEscape(strings.ReplaceAll(evt.GetFileName(), "/", "_")) + if fileName == ".." { + fileName = "" + } + parts := []string{ + br.GetPublicAddress(), + strings.Trim(br.Config.PublicMedia.PathPrefix, "/"), + pm.PublicID, + fileName, + } + if fileName == "" { + parts = parts[:len(parts)-1] + } + return strings.Join(parts, "/"), nil } diff --git a/bridgev2/matrixinterface.go b/bridgev2/matrixinterface.go index 6fa5360c..be26db49 100644 --- a/bridgev2/matrixinterface.go +++ b/bridgev2/matrixinterface.go @@ -14,6 +14,8 @@ import ( "os" "time" + "go.mau.fi/util/exhttp" + "maunium.net/go/mautrix" "maunium.net/go/mautrix/bridgev2/database" "maunium.net/go/mautrix/bridgev2/networkid" @@ -23,8 +25,10 @@ import ( ) type MatrixCapabilities struct { - AutoJoinInvites bool - BatchSending bool + AutoJoinInvites bool + BatchSending bool + ArbitraryMemberChange bool + ExtraProfileMeta bool } type MatrixConnector interface { @@ -58,35 +62,54 @@ type MatrixConnector interface { } type MatrixConnectorWithArbitraryRoomState interface { + MatrixConnector GetStateEvent(ctx context.Context, roomID id.RoomID, eventType event.Type, stateKey string) (*event.Event, error) } type MatrixConnectorWithServer interface { + MatrixConnector GetPublicAddress() string GetRouter() *http.ServeMux } +type IProvisioningAPI interface { + GetRouter() *http.ServeMux + GetUser(r *http.Request) *User +} + +type MatrixConnectorWithProvisioning interface { + MatrixConnector + GetProvisioning() IProvisioningAPI +} + type MatrixConnectorWithPublicMedia interface { + MatrixConnector GetPublicMediaAddress(contentURI id.ContentURIString) string + GetPublicMediaAddressForEvent(ctx context.Context, evt *event.MessageEventContent) (string, error) } type MatrixConnectorWithNameDisambiguation interface { + MatrixConnector IsConfusableName(ctx context.Context, roomID id.RoomID, userID id.UserID, name string) ([]id.UserID, error) } type MatrixConnectorWithBridgeIdentifier interface { + MatrixConnector GetUniqueBridgeID() string } type MatrixConnectorWithURLPreviews interface { + MatrixConnector GetURLPreview(ctx context.Context, url string) (*event.LinkPreview, error) } type MatrixConnectorWithPostRoomBridgeHandling interface { + MatrixConnector HandleNewlyBridgedRoom(ctx context.Context, roomID id.RoomID) error } type MatrixConnectorWithAnalytics interface { + MatrixConnector TrackAnalytics(userID id.UserID, event string, properties map[string]any) } @@ -101,9 +124,15 @@ type DirectNotificationData struct { } type MatrixConnectorWithNotifications interface { + MatrixConnector DisplayNotification(ctx context.Context, data *DirectNotificationData) } +type MatrixConnectorWithHTTPSettings interface { + MatrixConnector + GetHTTPClientSettings() exhttp.ClientSettings +} + type MatrixSendExtra struct { Timestamp time.Time MessageMeta *database.Message @@ -181,9 +210,16 @@ type MatrixAPI interface { } type StreamOrderReadingMatrixAPI interface { + MatrixAPI MarkStreamOrderRead(ctx context.Context, roomID id.RoomID, streamOrder int64, ts time.Time) error } type MarkAsDMMatrixAPI interface { + MatrixAPI MarkAsDM(ctx context.Context, roomID id.RoomID, otherUser id.UserID) error } + +type EphemeralSendingMatrixAPI interface { + MatrixAPI + BeeperSendEphemeralEvent(ctx context.Context, roomID id.RoomID, eventType event.Type, content *event.Content, txnID string) (*mautrix.RespSendEvent, error) +} diff --git a/bridgev2/matrixinvite.go b/bridgev2/matrixinvite.go index 2c14cc7f..75c00cb0 100644 --- a/bridgev2/matrixinvite.go +++ b/bridgev2/matrixinvite.go @@ -88,6 +88,36 @@ func sendErrorAndLeave(ctx context.Context, evt *event.Event, intent MatrixAPI, rejectInvite(ctx, evt, intent, "") } +func (portal *Portal) CleanupOrphanedDM(ctx context.Context, userMXID id.UserID) { + if portal.MXID == "" { + return + } + log := zerolog.Ctx(ctx) + existingPortalMembers, err := portal.Bridge.Matrix.GetMembers(ctx, portal.MXID) + if err != nil { + log.Err(err). + Stringer("old_portal_mxid", portal.MXID). + Msg("Failed to check existing portal members, deleting room") + } else if targetUserMember, ok := existingPortalMembers[userMXID]; !ok { + log.Debug(). + Stringer("old_portal_mxid", portal.MXID). + Msg("Inviter has no member event in old portal, deleting room") + } else if targetUserMember.Membership.IsInviteOrJoin() { + return + } else { + log.Debug(). + Stringer("old_portal_mxid", portal.MXID). + Str("membership", string(targetUserMember.Membership)). + Msg("Inviter is not in old portal, deleting room") + } + + if err = portal.RemoveMXID(ctx); err != nil { + log.Err(err).Msg("Failed to delete old portal mxid") + } else if err = portal.Bridge.Bot.DeleteRoom(ctx, portal.MXID, true); err != nil { + log.Err(err).Msg("Failed to clean up old portal room") + } +} + func (br *Bridge) handleGhostDMInvite(ctx context.Context, evt *event.Event, sender *User) EventHandlingResult { ghostID, _ := br.Matrix.ParseGhostMXID(id.UserID(evt.GetStateKey())) validator, ok := br.Network.(IdentifierValidatingNetwork) @@ -165,34 +195,7 @@ func (br *Bridge) handleGhostDMInvite(ctx context.Context, evt *event.Event, sen return EventHandlingResultFailed } } - if portal.MXID != "" { - doCleanup := true - existingPortalMembers, err := br.Matrix.GetMembers(ctx, portal.MXID) - if err != nil { - log.Err(err). - Stringer("old_portal_mxid", portal.MXID). - Msg("Failed to check existing portal members, deleting room") - } else if targetUserMember, ok := existingPortalMembers[sender.MXID]; !ok { - log.Debug(). - Stringer("old_portal_mxid", portal.MXID). - Msg("Inviter has no member event in old portal, deleting room") - } else if targetUserMember.Membership.IsInviteOrJoin() { - doCleanup = false - } else { - log.Debug(). - Stringer("old_portal_mxid", portal.MXID). - Str("membership", string(targetUserMember.Membership)). - Msg("Inviter is not in old portal, deleting room") - } - - if doCleanup { - if err = portal.RemoveMXID(ctx); err != nil { - log.Err(err).Msg("Failed to delete old portal mxid") - } else if err = br.Bot.DeleteRoom(ctx, portal.MXID, true); err != nil { - log.Err(err).Msg("Failed to clean up old portal room") - } - } - } + portal.CleanupOrphanedDM(ctx, sender.MXID) err = invitedGhost.Intent.EnsureInvited(ctx, evt.RoomID, br.Bot.GetMXID()) if err != nil { log.Err(err).Msg("Failed to ensure bot is invited to room") @@ -221,11 +224,12 @@ func (br *Bridge) handleGhostDMInvite(ctx context.Context, evt *event.Event, sen rejectInvite(ctx, evt, br.Bot, "") return EventHandlingResultSuccess } + overrideIntent := invitedGhost.Intent if resp.DMRedirectedTo != "" && resp.DMRedirectedTo != invitedGhost.ID { log.Debug(). Str("dm_redirected_to_id", string(resp.DMRedirectedTo)). Msg("Created DM was redirected to another user ID") - _, err = invitedGhost.Intent.SendState(ctx, portal.MXID, event.StateMember, invitedGhost.Intent.GetMXID().String(), &event.Content{ + _, err = invitedGhost.Intent.SendState(ctx, evt.RoomID, event.StateMember, invitedGhost.Intent.GetMXID().String(), &event.Content{ Parsed: &event.MemberEventContent{ Membership: event.MembershipLeave, Reason: "Direct chat redirected to another internal user ID", @@ -234,11 +238,13 @@ func (br *Bridge) handleGhostDMInvite(ctx context.Context, evt *event.Event, sen if err != nil { log.Err(err).Msg("Failed to make incorrect ghost leave new DM room") } - otherUserGhost, err := br.GetGhostByID(ctx, resp.DMRedirectedTo) - if err != nil { + if resp.DMRedirectedTo == SpecialValueDMRedirectedToBot { + overrideIntent = br.Bot + } else if otherUserGhost, err := br.GetGhostByID(ctx, resp.DMRedirectedTo); err != nil { log.Err(err).Msg("Failed to get ghost of real portal other user ID") } else { invitedGhost = otherUserGhost + overrideIntent = otherUserGhost.Intent } } err = portal.UpdateMatrixRoomID(ctx, evt.RoomID, UpdateMatrixRoomIDParams{ @@ -251,7 +257,7 @@ func (br *Bridge) handleGhostDMInvite(ctx context.Context, evt *event.Event, sen }) if err != nil { log.Err(err).Msg("Failed to update Matrix room ID for new DM portal") - sendNotice(ctx, evt, invitedGhost.Intent, "Failed to finish configuring portal. The chat may or may not work") + sendNotice(ctx, evt, overrideIntent, "Failed to finish configuring portal. The chat may or may not work") return EventHandlingResultSuccess } message := "Private chat portal created" @@ -263,7 +269,7 @@ func (br *Bridge) handleGhostDMInvite(ctx context.Context, evt *event.Event, sen message += fmt.Sprintf("\n\nWarning: %s", err.Error()) } } - sendNotice(ctx, evt, invitedGhost.Intent, message) + sendNotice(ctx, evt, overrideIntent, message) return EventHandlingResultSuccess } diff --git a/bridgev2/messagestatus.go b/bridgev2/messagestatus.go index 7118649d..df0c9e4d 100644 --- a/bridgev2/messagestatus.go +++ b/bridgev2/messagestatus.go @@ -20,6 +20,7 @@ import ( type MessageStatusEventInfo struct { RoomID id.RoomID + TransactionID string SourceEventID id.EventID NewEventID id.EventID EventType event.Type @@ -41,6 +42,7 @@ func StatusEventInfoFromEvent(evt *event.Event) *MessageStatusEventInfo { return &MessageStatusEventInfo{ RoomID: evt.RoomID, + TransactionID: evt.Unsigned.TransactionID, SourceEventID: evt.ID, EventType: evt.Type, MessageType: evt.Content.AsMessage().MsgType, @@ -182,9 +184,10 @@ func (ms *MessageStatus) ToMSSEvent(evt *MessageStatusEventInfo) *event.BeeperMe Type: event.RelReference, EventID: evt.SourceEventID, }, - Status: ms.Status, - Reason: ms.ErrorReason, - Message: ms.Message, + TargetTxnID: evt.TransactionID, + Status: ms.Status, + Reason: ms.ErrorReason, + Message: ms.Message, } if ms.InternalError != nil { content.InternalError = ms.InternalError.Error() diff --git a/bridgev2/networkinterface.go b/bridgev2/networkinterface.go index 31647f63..b706aedb 100644 --- a/bridgev2/networkinterface.go +++ b/bridgev2/networkinterface.go @@ -16,7 +16,9 @@ import ( "github.com/rs/zerolog" "go.mau.fi/util/configupgrade" "go.mau.fi/util/ptr" + "go.mau.fi/util/random" + "maunium.net/go/mautrix" "maunium.net/go/mautrix/bridgev2/database" "maunium.net/go/mautrix/bridgev2/networkid" "maunium.net/go/mautrix/event" @@ -259,6 +261,7 @@ type NetworkConnector interface { } type StoppableNetwork interface { + NetworkConnector // Stop is called when the bridge is stopping, after all network clients have been disconnected. Stop() } @@ -315,6 +318,16 @@ type MaxFileSizeingNetwork interface { SetMaxFileSize(maxSize int64) } +type NetworkResettingNetwork interface { + NetworkConnector + // ResetHTTPTransport should recreate the HTTP client used by the bridge. + // It should refetch settings from the Matrix connector using GetHTTPClientSettings if applicable. + ResetHTTPTransport() + // ResetNetworkConnections should forcefully disconnect and restart any persistent network connections. + // ResetHTTPTransport will usually be called before this, so resetting the transport is not necessary here. + ResetNetworkConnections() +} + type RemoteEchoHandler func(RemoteMessage, *database.Message) (bool, error) type MatrixMessageResponse struct { @@ -705,6 +718,19 @@ type DeleteChatHandlingNetworkAPI interface { HandleMatrixDeleteChat(ctx context.Context, msg *MatrixDeleteChat) error } +// MessageRequestAcceptingNetworkAPI is an optional interface that network connectors +// can implement to accept message requests from the remote network. +type MessageRequestAcceptingNetworkAPI interface { + NetworkAPI + // HandleMatrixAcceptMessageRequest is called when the user accepts a message request. + HandleMatrixAcceptMessageRequest(ctx context.Context, msg *MatrixAcceptMessageRequest) error +} + +type BeeperAIStreamHandlingNetworkAPI interface { + NetworkAPI + HandleMatrixBeeperAIStream(ctx context.Context, msg *MatrixBeeperAIStream) error +} + type ResolveIdentifierResponse struct { // Ghost is the ghost of the user that the identifier resolves to. // This field should be set whenever possible. However, it is not required, @@ -724,6 +750,8 @@ type ResolveIdentifierResponse struct { Chat *CreateChatResponse } +var SpecialValueDMRedirectedToBot = networkid.UserID("__fi.mau.bridgev2.dm_redirected_to_bot::" + random.String(10)) + type CreateChatResponse struct { PortalKey networkid.PortalKey // Portal and PortalInfo are not required, the caller will fetch them automatically based on PortalKey if necessary. @@ -732,6 +760,17 @@ type CreateChatResponse struct { // If a start DM request (CreateChatWithGhost or ResolveIdentifier) returns the DM to a different user, // this field should have the user ID of said different user. DMRedirectedTo networkid.UserID + + FailedParticipants map[networkid.UserID]*CreateChatFailedParticipant +} + +type CreateChatFailedParticipant struct { + Reason string `json:"reason"` + InviteEventType string `json:"invite_event_type,omitempty"` + InviteContent *event.Content `json:"invite_content,omitempty"` + + UserMXID id.UserID `json:"user_mxid,omitempty"` + DMRoomMXID id.RoomID `json:"dm_room_mxid,omitempty"` } // IdentifierResolvingNetworkAPI is an optional interface that network connectors can implement to support starting new direct chats. @@ -764,6 +803,16 @@ type UserSearchingNetworkAPI interface { SearchUsers(ctx context.Context, query string) ([]*ResolveIdentifierResponse, error) } +type GroupCreatingNetworkAPI interface { + IdentifierResolvingNetworkAPI + CreateGroup(ctx context.Context, params *GroupCreateParams) (*CreateChatResponse, error) +} + +type PersonalFilteringCustomizingNetworkAPI interface { + NetworkAPI + CustomizePersonalFilteringSpace(req *mautrix.ReqCreateRoom) +} + type ProvisioningCapabilities struct { ResolveIdentifier ResolveIdentifierCapabilities `json:"resolve_identifier"` GroupCreation map[string]GroupTypeCapabilities `json:"group_creation"` @@ -812,12 +861,17 @@ type GroupFieldCapability struct { // Only for the disappear field: allowed disappearing settings DisappearSettings *event.DisappearingTimerCapability `json:"settings,omitempty"` + + // This can be used to tell provisionutil not to call ValidateUserID on each participant. + // It only meant to allow hacks where ResolveIdentifier returns a fake ID that isn't actually valid for MXIDs. + SkipIdentifierValidation bool `json:"-"` } type GroupCreateParams struct { Type string `json:"type,omitempty"` - Username string `json:"username,omitempty"` + Username string `json:"username,omitempty"` + // Clients may also provide MXIDs here, but provisionutil will normalize them, so bridges only need to handle network IDs Participants []networkid.UserID `json:"participants,omitempty"` Parent *networkid.PortalKey `json:"parent,omitempty"` @@ -830,11 +884,6 @@ type GroupCreateParams struct { RoomID id.RoomID `json:"room_id,omitempty"` } -type GroupCreatingNetworkAPI interface { - IdentifierResolvingNetworkAPI - CreateGroup(ctx context.Context, params *GroupCreateParams) (*CreateChatResponse, error) -} - type MembershipChangeType struct { From event.Membership To event.Membership @@ -872,16 +921,15 @@ type MatrixMembershipChange struct { MatrixRoomMeta[*event.MemberEventContent] Target GhostOrUserLogin Type MembershipChangeType +} - // Deprecated: Use Target instead - TargetGhost *Ghost - // Deprecated: Use Target instead - TargetUserLogin *UserLogin +type MatrixMembershipResult struct { + RedirectTo networkid.UserID } type MembershipHandlingNetworkAPI interface { NetworkAPI - HandleMatrixMembership(ctx context.Context, msg *MatrixMembershipChange) (bool, error) + HandleMatrixMembership(ctx context.Context, msg *MatrixMembershipChange) (*MatrixMembershipResult, error) } type SinglePowerLevelChange struct { @@ -1067,6 +1115,11 @@ type RemoteEvent interface { GetSender() EventSender } +type RemoteEventWithContextMutation interface { + RemoteEvent + MutateContext(ctx context.Context) context.Context +} + type RemoteEventWithUncertainPortalReceiver interface { RemoteEvent PortalReceiverIsUncertain() bool @@ -1120,6 +1173,11 @@ type RemoteChatDelete interface { RemoteDeleteOnlyForMe } +type RemoteChatDeleteWithChildren interface { + RemoteChatDelete + DeleteChildren() bool +} + type RemoteEventThatMayCreatePortal interface { RemoteEvent ShouldCreatePortal() bool @@ -1352,7 +1410,8 @@ type MatrixMessageRemove struct { type MatrixRoomMeta[ContentType any] struct { MatrixEventBase[ContentType] - PrevContent ContentType + PrevContent ContentType + IsStateRequest bool } type MatrixRoomName = MatrixRoomMeta[*event.RoomNameEventContent] @@ -1389,6 +1448,8 @@ type MatrixViewingChat struct { } type MatrixDeleteChat = MatrixEventBase[*event.BeeperChatDeleteEventContent] +type MatrixAcceptMessageRequest = MatrixEventBase[*event.BeeperAcceptMessageRequestEventContent] +type MatrixBeeperAIStream = MatrixEventBase[*event.BeeperAIStreamEventContent] type MatrixMarkedUnread = MatrixRoomMeta[*event.MarkedUnreadEventContent] type MatrixMute = MatrixRoomMeta[*event.BeeperMuteEventContent] type MatrixRoomTag = MatrixRoomMeta[*event.TagEventContent] diff --git a/bridgev2/portal.go b/bridgev2/portal.go index 067d92c2..5ba29507 100644 --- a/bridgev2/portal.go +++ b/bridgev2/portal.go @@ -86,13 +86,15 @@ type Portal struct { lastCapUpdate time.Time - roomCreateLock sync.Mutex + roomCreateLock sync.Mutex + cancelRoomCreate atomic.Pointer[context.CancelFunc] + RoomCreated *exsync.Event functionalMembersLock sync.Mutex functionalMembersCache *event.ElementFunctionalMembersContent events chan portalEvent - deleted bool + deleted *exsync.Event eventsLock sync.Mutex eventIdx int @@ -124,6 +126,12 @@ func (br *Bridge) loadPortal(ctx context.Context, dbPortal *database.Portal, que currentlyTypingLogins: make(map[id.UserID]*UserLogin), currentlyTypingGhosts: exsync.NewSet[id.UserID](), outgoingMessages: make(map[networkid.TransactionID]*outgoingMessage), + + RoomCreated: exsync.NewEvent(), + deleted: exsync.NewEvent(), + } + if portal.MXID != "" { + portal.RoomCreated.Set() } // Putting the portal in the cache before it's fully initialized is mildly dangerous, // but loading the relay user login may depend on it. @@ -161,7 +169,9 @@ func (br *Bridge) loadPortal(ctx context.Context, dbPortal *database.Portal, que } func (portal *Portal) updateLogger() { - logWith := portal.Bridge.Log.With().Str("portal_id", string(portal.ID)) + logWith := portal.Bridge.Log.With(). + Str("portal_id", string(portal.ID)). + Str("portal_receiver", string(portal.Receiver)) if portal.MXID != "" { logWith = logWith.Stringer("portal_mxid", portal.MXID) } @@ -185,6 +195,16 @@ func (br *Bridge) loadManyPortals(ctx context.Context, portals []*database.Porta return output, nil } +func (br *Bridge) loadPortalWithCacheCheck(ctx context.Context, dbPortal *database.Portal) (*Portal, error) { + if dbPortal == nil { + return nil, nil + } else if cached, ok := br.portalsByKey[dbPortal.PortalKey]; ok { + return cached, nil + } else { + return br.loadPortal(ctx, dbPortal, nil, nil) + } +} + func (br *Bridge) UnlockedGetPortalByKey(ctx context.Context, key networkid.PortalKey, onlyIfExists bool) (*Portal, error) { if br.Config.SplitPortals && key.Receiver == "" { return nil, fmt.Errorf("receiver must always be set when split portals is enabled") @@ -274,6 +294,26 @@ func (br *Bridge) GetDMPortalsWith(ctx context.Context, otherUserID networkid.Us return br.loadManyPortals(ctx, rows) } +func (br *Bridge) GetChildPortals(ctx context.Context, parent networkid.PortalKey) ([]*Portal, error) { + br.cacheLock.Lock() + defer br.cacheLock.Unlock() + rows, err := br.DB.Portal.GetChildren(ctx, parent) + if err != nil { + return nil, err + } + return br.loadManyPortals(ctx, rows) +} + +func (br *Bridge) GetDMPortal(ctx context.Context, receiver networkid.UserLoginID, otherUserID networkid.UserID) (*Portal, error) { + br.cacheLock.Lock() + defer br.cacheLock.Unlock() + dbPortal, err := br.DB.Portal.GetDM(ctx, receiver, otherUserID) + if err != nil { + return nil, err + } + return br.loadPortalWithCacheCheck(ctx, dbPortal) +} + func (br *Bridge) GetPortalByKey(ctx context.Context, key networkid.PortalKey) (*Portal, error) { br.cacheLock.Lock() defer br.cacheLock.Unlock() @@ -299,6 +339,9 @@ func (br *Bridge) GetExistingPortalByKey(ctx context.Context, key networkid.Port } func (portal *Portal) queueEvent(ctx context.Context, evt portalEvent) EventHandlingResult { + if portal.deleted.IsSet() { + return EventHandlingResultIgnored + } if PortalEventBuffer == 0 { portal.eventsLock.Lock() defer portal.eventsLock.Unlock() @@ -311,6 +354,8 @@ func (portal *Portal) queueEvent(ctx context.Context, evt portalEvent) EventHand select { case portal.events <- evt: return EventHandlingResultQueued + case <-portal.deleted.GetChan(): + return EventHandlingResultIgnored default: zerolog.Ctx(ctx).Error(). Str("portal_id", string(portal.ID)). @@ -335,17 +380,21 @@ func (portal *Portal) eventLoop() { go portal.pendingMessageTimeoutLoop(ctx, cfg) defer cancel() } - i := 0 - for rawEvt := range portal.events { - if portal.deleted { + deleteCh := portal.deleted.GetChan() + for i := 0; ; i++ { + select { + case rawEvt := <-portal.events: + if rawEvt == nil { + return + } + if portal.Bridge.Config.AsyncEvents { + go portal.handleSingleEventWithDelayLogging(i, rawEvt) + } else { + portal.handleSingleEventWithDelayLogging(i, rawEvt) + } + case <-deleteCh: return } - i++ - if portal.Bridge.Config.AsyncEvents { - go portal.handleSingleEventWithDelayLogging(i, rawEvt) - } else { - portal.handleSingleEventWithDelayLogging(i, rawEvt) - } } } @@ -399,6 +448,23 @@ func (portal *Portal) handleSingleEventWithDelayLogging(idx int, rawEvt any) (ou return } +type contextKey int + +const ( + contextKeyRemoteEvent contextKey = iota + contextKeyMatrixEvent +) + +func GetMatrixEventFromContext(ctx context.Context) (evt *event.Event) { + evt, _ = ctx.Value(contextKeyMatrixEvent).(*event.Event) + return +} + +func GetRemoteEventFromContext(ctx context.Context) (evt RemoteEvent) { + evt, _ = ctx.Value(contextKeyRemoteEvent).(RemoteEvent) + return +} + func (portal *Portal) getEventCtxWithLog(rawEvt any, idx int) context.Context { var logWith zerolog.Context switch evt := rawEvt.(type) { @@ -412,6 +478,10 @@ func (portal *Portal) getEventCtxWithLog(rawEvt any, idx int) context.Context { Stringer("event_id", evt.evt.ID). Stringer("sender", evt.sender.MXID) } + ctx := portal.Bridge.BackgroundCtx + ctx = context.WithValue(ctx, contextKeyMatrixEvent, evt.evt) + ctx = logWith.Logger().WithContext(ctx) + return ctx case *portalRemoteEvent: evt.evtType = evt.evt.GetType() logWith = portal.Log.With().Int("event_loop_index", idx). @@ -437,10 +507,23 @@ func (portal *Portal) getEventCtxWithLog(rawEvt any, idx int) context.Context { logWith = logWith.Int64("remote_stream_order", remoteStreamOrder) } } + if remoteMsg, ok := evt.evt.(RemoteEventWithTimestamp); ok { + if remoteTimestamp := remoteMsg.GetTimestamp(); !remoteTimestamp.IsZero() { + logWith = logWith.Time("remote_timestamp", remoteTimestamp) + } + } + ctx := portal.Bridge.BackgroundCtx + ctx = context.WithValue(ctx, contextKeyRemoteEvent, evt.evt) + ctx = logWith.Logger().WithContext(ctx) + if ctxMut, ok := evt.evt.(RemoteEventWithContextMutation); ok { + ctx = ctxMut.MutateContext(ctx) + } + return ctx case *portalCreateEvent: return evt.ctx + default: + panic(fmt.Errorf("invalid type %T in getEventCtxWithLog", evt)) } - return logWith.Logger().WithContext(portal.Bridge.BackgroundCtx) } func (portal *Portal) handleSingleEvent(ctx context.Context, rawEvt any, doneCallback func(res EventHandlingResult)) { @@ -476,7 +559,14 @@ func (portal *Portal) handleSingleEvent(ctx context.Context, rawEvt any, doneCal }() switch evt := rawEvt.(type) { case *portalMatrixEvent: - res = portal.handleMatrixEvent(ctx, evt.sender, evt.evt) + isStateRequest := evt.evt.Type == event.BeeperSendState + if isStateRequest { + if err := portal.unwrapBeeperSendState(ctx, evt.evt); err != nil { + portal.sendErrorStatus(ctx, evt.evt, err) + return + } + } + res = portal.handleMatrixEvent(ctx, evt.sender, evt.evt, isStateRequest) if res.SendMSS { if res.Error != nil { portal.sendErrorStatus(ctx, evt.evt, res.Error) @@ -484,6 +574,21 @@ func (portal *Portal) handleSingleEvent(ctx context.Context, rawEvt any, doneCal portal.sendSuccessStatus(ctx, evt.evt, 0, "") } } + if !isStateRequest && res.Error != nil && evt.evt.StateKey != nil { + portal.revertRoomMeta(ctx, evt.evt) + } + if isStateRequest && res.Success && !res.SkipStateEcho { + portal.sendRoomMeta( + ctx, + evt.sender.DoublePuppet(ctx), + time.UnixMilli(evt.evt.Timestamp), + evt.evt.Type, + evt.evt.GetStateKey(), + evt.evt.Content.Parsed, + false, + evt.evt.Content.Raw, + ) + } case *portalRemoteEvent: res = portal.handleRemoteEvent(ctx, evt.source, evt.evtType, evt.evt) case *portalCreateEvent: @@ -495,18 +600,44 @@ func (portal *Portal) handleSingleEvent(ctx context.Context, rawEvt any, doneCal } } +func (portal *Portal) unwrapBeeperSendState(ctx context.Context, evt *event.Event) error { + content, ok := evt.Content.Parsed.(*event.BeeperSendStateEventContent) + if !ok { + return fmt.Errorf("%w: %T", ErrUnexpectedParsedContentType, evt.Content.Parsed) + } + evt.Content = content.Content + evt.StateKey = &content.StateKey + evt.Type = event.Type{Type: content.Type, Class: event.StateEventType} + _ = evt.Content.ParseRaw(evt.Type) + mx, ok := portal.Bridge.Matrix.(MatrixConnectorWithArbitraryRoomState) + if !ok { + return fmt.Errorf("matrix connector doesn't support fetching state") + } + prevEvt, err := mx.GetStateEvent(ctx, portal.MXID, evt.Type, evt.GetStateKey()) + if err != nil && !errors.Is(err, mautrix.MNotFound) { + return fmt.Errorf("failed to get prev event: %w", err) + } else if prevEvt != nil { + evt.Unsigned.PrevContent = &prevEvt.Content + evt.Unsigned.PrevSender = prevEvt.Sender + } + return nil +} + func (portal *Portal) FindPreferredLogin(ctx context.Context, user *User, allowRelay bool) (*UserLogin, *database.UserPortal, error) { if portal.Receiver != "" { login, err := portal.Bridge.GetExistingUserLoginByID(ctx, portal.Receiver) if err != nil { return nil, nil, err } - if login == nil || login.UserMXID != user.MXID || !login.Client.IsLoggedIn() { + if login == nil { + return nil, nil, fmt.Errorf("%w (receiver login is nil)", ErrNotLoggedIn) + } else if !login.Client.IsLoggedIn() { + return nil, nil, fmt.Errorf("%w (receiver login is not logged in)", ErrNotLoggedIn) + } else if login.UserMXID != user.MXID { if allowRelay && portal.Relay != nil { return nil, nil, nil } - // TODO different error for this case? - return nil, nil, ErrNotLoggedIn + return nil, nil, fmt.Errorf("%w (relay not set and receiver login is owned by %s, not %s)", ErrNotLoggedIn, login.UserMXID, user.MXID) } up, err := portal.Bridge.DB.UserPortal.Get(ctx, login.UserLogin, portal.PortalKey) return login, up, err @@ -589,7 +720,7 @@ func (portal *Portal) checkConfusableName(ctx context.Context, userID id.UserID, var fakePerMessageProfileEventType = event.Type{Class: event.StateEventType, Type: "m.per_message_profile"} -func (portal *Portal) handleMatrixEvent(ctx context.Context, sender *User, evt *event.Event) EventHandlingResult { +func (portal *Portal) handleMatrixEvent(ctx context.Context, sender *User, evt *event.Event, isStateRequest bool) EventHandlingResult { log := zerolog.Ctx(ctx) if evt.Mautrix.EventSource&event.SourceEphemeral != 0 { switch evt.Type { @@ -597,6 +728,8 @@ func (portal *Portal) handleMatrixEvent(ctx context.Context, sender *User, evt * return portal.handleMatrixReceipts(ctx, evt) case event.EphemeralEventTyping: return portal.handleMatrixTyping(ctx, evt) + case event.BeeperEphemeralEventAIStream: + return portal.handleMatrixAIStream(ctx, sender, evt) default: return EventHandlingResultIgnored } @@ -621,6 +754,9 @@ func (portal *Portal) handleMatrixEvent(ctx context.Context, sender *User, evt * } var origSender *OrigSender if login == nil { + if isStateRequest { + return EventHandlingResultFailed.WithMSSError(ErrCantRelayStateRequest) + } login = portal.Relay origSender = &OrigSender{ User: sender, @@ -691,13 +827,13 @@ func (portal *Portal) handleMatrixEvent(ctx context.Context, sender *User, evt * case event.EventRedaction: return portal.handleMatrixRedaction(ctx, login, origSender, evt) case event.StateRoomName: - return handleMatrixRoomMeta(portal, ctx, login, origSender, evt, RoomNameHandlingNetworkAPI.HandleMatrixRoomName) + return handleMatrixRoomMeta(portal, ctx, login, origSender, evt, isStateRequest, RoomNameHandlingNetworkAPI.HandleMatrixRoomName) case event.StateTopic: - return handleMatrixRoomMeta(portal, ctx, login, origSender, evt, RoomTopicHandlingNetworkAPI.HandleMatrixRoomTopic) + return handleMatrixRoomMeta(portal, ctx, login, origSender, evt, isStateRequest, RoomTopicHandlingNetworkAPI.HandleMatrixRoomTopic) case event.StateRoomAvatar: - return handleMatrixRoomMeta(portal, ctx, login, origSender, evt, RoomAvatarHandlingNetworkAPI.HandleMatrixRoomAvatar) + return handleMatrixRoomMeta(portal, ctx, login, origSender, evt, isStateRequest, RoomAvatarHandlingNetworkAPI.HandleMatrixRoomAvatar) case event.StateBeeperDisappearingTimer: - return handleMatrixRoomMeta(portal, ctx, login, origSender, evt, DisappearTimerChangingNetworkAPI.HandleMatrixDisappearingTimer) + return handleMatrixRoomMeta(portal, ctx, login, origSender, evt, isStateRequest, DisappearTimerChangingNetworkAPI.HandleMatrixDisappearingTimer) case event.StateEncryption: // TODO? return EventHandlingResultIgnored @@ -708,11 +844,13 @@ func (portal *Portal) handleMatrixEvent(ctx context.Context, sender *User, evt * case event.AccountDataBeeperMute: return handleMatrixAccountData(portal, ctx, login, evt, MuteHandlingNetworkAPI.HandleMute) case event.StateMember: - return portal.handleMatrixMembership(ctx, login, origSender, evt) + return portal.handleMatrixMembership(ctx, login, origSender, evt, isStateRequest) case event.StatePowerLevels: - return portal.handleMatrixPowerLevels(ctx, login, origSender, evt) + return portal.handleMatrixPowerLevels(ctx, login, origSender, evt, isStateRequest) case event.BeeperDeleteChat: return portal.handleMatrixDeleteChat(ctx, login, origSender, evt) + case event.BeeperAcceptMessageRequest: + return portal.handleMatrixAcceptMessageRequest(ctx, login, origSender, evt) default: return EventHandlingResultIgnored } @@ -815,7 +953,7 @@ func (portal *Portal) callReadReceiptHandler( if err != nil { zerolog.Ctx(ctx).Err(err).Msg("Failed to save user portal metadata") } - portal.Bridge.DisappearLoop.StartAll(ctx, portal.MXID) + portal.Bridge.DisappearLoop.StartAllBefore(ctx, portal.MXID, evt.ReadUpTo) } func (portal *Portal) handleMatrixTyping(ctx context.Context, evt *event.Event) EventHandlingResult { @@ -836,6 +974,50 @@ func (portal *Portal) handleMatrixTyping(ctx context.Context, evt *event.Event) return EventHandlingResultSuccess } +func (portal *Portal) handleMatrixAIStream(ctx context.Context, sender *User, evt *event.Event) EventHandlingResult { + log := zerolog.Ctx(ctx) + if sender == nil { + log.Error().Msg("Missing sender for Matrix AI stream event") + return EventHandlingResultIgnored + } + login, _, err := portal.FindPreferredLogin(ctx, sender, true) + if err != nil { + log.Err(err).Msg("Failed to get user login to handle Matrix AI stream event") + return EventHandlingResultFailed.WithMSSError(err) + } + var origSender *OrigSender + if login == nil { + if portal.Relay == nil { + return EventHandlingResultIgnored + } + login = portal.Relay + origSender = &OrigSender{ + User: sender, + UserID: sender.MXID, + } + } + content, ok := evt.Content.Parsed.(*event.BeeperAIStreamEventContent) + if !ok { + log.Error().Type("content_type", evt.Content.Parsed).Msg("Unexpected parsed content type") + return EventHandlingResultFailed.WithMSSError(fmt.Errorf("%w: %T", ErrUnexpectedParsedContentType, evt.Content.Parsed)) + } + api, ok := login.Client.(BeeperAIStreamHandlingNetworkAPI) + if !ok { + return EventHandlingResultIgnored.WithMSSError(ErrBeeperAIStreamNotSupported) + } + err = api.HandleMatrixBeeperAIStream(ctx, &MatrixBeeperAIStream{ + Event: evt, + Content: content, + Portal: portal, + OrigSender: origSender, + }) + if err != nil { + log.Err(err).Msg("Failed to handle Matrix AI stream event") + return EventHandlingResultFailed.WithMSSError(err) + } + return EventHandlingResultSuccess.WithMSS() +} + func (portal *Portal) sendTypings(ctx context.Context, userIDs []id.UserID, typing bool) { for _, userID := range userIDs { login, ok := portal.currentlyTypingLogins[userID] @@ -947,6 +1129,9 @@ func (portal *Portal) checkMessageContentCaps(caps *event.RoomFeatures, content if content.Info != nil { dur := time.Duration(content.Info.Duration) * time.Millisecond if feat.MaxDuration != nil && dur > feat.MaxDuration.Duration { + if capMsgType == event.CapMsgVoice { + return fmt.Errorf("%w: %s supports voice messages up to %s long", ErrVoiceMessageDurationTooLong, portal.Bridge.Network.GetName().DisplayName, exfmt.Duration(feat.MaxDuration.Duration)) + } return fmt.Errorf("%w: %s is longer than the maximum of %s", ErrMediaDurationTooLong, exfmt.Duration(dur), exfmt.Duration(feat.MaxDuration.Duration)) } if feat.MaxSize != 0 && int64(content.Info.Size) > feat.MaxSize { @@ -1012,10 +1197,12 @@ func (portal *Portal) handleMatrixMessage(ctx context.Context, sender *UserLogin log.Debug().Msg("Ignoring poll event from relayed user") return EventHandlingResultIgnored.WithMSSError(ErrIgnoringPollFromRelayedUser) } - msgContent, err = portal.Bridge.Config.Relay.FormatMessage(msgContent, origSender) - if err != nil { - log.Err(err).Msg("Failed to format message for relaying") - return EventHandlingResultFailed.WithMSSError(err) + if !caps.PerMessageProfileRelay { + msgContent, err = portal.Bridge.Config.Relay.FormatMessage(msgContent, origSender) + if err != nil { + log.Err(err).Msg("Failed to format message for relaying") + return EventHandlingResultFailed.WithMSSError(err) + } } } if msgContent != nil { @@ -1083,6 +1270,16 @@ func (portal *Portal) handleMatrixMessage(ctx context.Context, sender *UserLogin } } } + var messageTimer *event.BeeperDisappearingTimer + if msgContent != nil { + messageTimer = msgContent.BeeperDisappearingTimer + } + if messageTimer != nil && *portal.Disappear.ToEventContent() != *messageTimer { + log.Warn(). + Any("event_timer", messageTimer). + Any("portal_timer", portal.Disappear.ToEventContent()). + Msg("Mismatching disappearing timer in event") + } wrappedMsgEvt := &MatrixMessage{ MatrixEventBase: MatrixEventBase[*event.MessageEventContent]{ @@ -1108,6 +1305,12 @@ func (portal *Portal) handleMatrixMessage(ctx context.Context, sender *UserLogin } } + err = portal.autoAcceptMessageRequest(ctx, evt, sender, origSender, caps) + if err != nil { + log.Warn().Err(err).Msg("Failed to auto-accept message request on message") + // TODO stop processing? + } + var resp *MatrixMessageResponse if msgContent != nil { resp, err = sender.Client.HandleMatrixMessage(ctx, wrappedMsgEvt) @@ -1159,18 +1362,23 @@ func (portal *Portal) handleMatrixMessage(ctx context.Context, sender *UserLogin } portal.sendSuccessStatus(ctx, evt, resp.StreamOrder, message.MXID) } - if portal.Disappear.Type != event.DisappearingTypeNone { + ds := portal.Disappear + if messageTimer != nil { + ds = database.DisappearingSettingFromEvent(messageTimer) + } + if ds.Type != event.DisappearingTypeNone { go portal.Bridge.DisappearLoop.Add(ctx, &database.DisappearingMessage{ RoomID: portal.MXID, EventID: message.MXID, - DisappearingSetting: portal.Disappear.StartingAt(message.Timestamp), + Timestamp: message.Timestamp, + DisappearingSetting: ds.StartingAt(message.Timestamp), }) } if resp.Pending { // Not exactly queued, but not finished either return EventHandlingResultQueued } - return EventHandlingResultSuccess + return EventHandlingResultSuccess.WithEventID(message.MXID).WithStreamOrder(resp.StreamOrder) } // AddPendingToIgnore adds a transaction ID that should be ignored if encountered as a new message. @@ -1359,7 +1567,7 @@ func (portal *Portal) handleMatrixEdit( return EventHandlingResultSuccess } -func (portal *Portal) handleMatrixReaction(ctx context.Context, sender *UserLogin, evt *event.Event) EventHandlingResult { +func (portal *Portal) handleMatrixReaction(ctx context.Context, sender *UserLogin, evt *event.Event) (handleRes EventHandlingResult) { log := zerolog.Ctx(ctx) reactingAPI, ok := sender.Client.(ReactionHandlingNetworkAPI) if !ok { @@ -1382,6 +1590,12 @@ func (portal *Portal) handleMatrixReaction(ctx context.Context, sender *UserLogi log.Warn().Msg("Reaction target message not found in database") return EventHandlingResultFailed.WithMSSError(fmt.Errorf("reaction %w", ErrTargetMessageNotFound)) } + caps := sender.Client.GetCapabilities(ctx, portal) + err = portal.autoAcceptMessageRequest(ctx, evt, sender, nil, caps) + if err != nil { + log.Warn().Err(err).Msg("Failed to auto-accept message request on reaction") + // TODO stop processing? + } log.UpdateContext(func(c zerolog.Context) zerolog.Context { return c.Str("reaction_target_remote_id", string(reactionTarget.ID)) }) @@ -1404,6 +1618,31 @@ func (portal *Portal) handleMatrixReaction(ctx context.Context, sender *UserLogi if portal.Bridge.Config.OutgoingMessageReID { deterministicID = portal.Bridge.Matrix.GenerateReactionEventID(portal.MXID, reactionTarget, preResp.SenderID, preResp.EmojiID) } + defer func() { + // Do this in a defer so that it happens after any potential defer calls to removeOutdatedReaction + if handleRes.Success { + portal.sendSuccessStatus(ctx, evt, 0, deterministicID) + } + }() + removeOutdatedReaction := func(oldReact *database.Reaction, deleteDB bool) { + if !handleRes.Success { + return + } + _, err := portal.Bridge.Bot.SendMessage(ctx, portal.MXID, event.EventRedaction, &event.Content{ + Parsed: &event.RedactionEventContent{ + Redacts: oldReact.MXID, + }, + }, nil) + if err != nil { + log.Err(err).Msg("Failed to remove old reaction") + } + if deleteDB { + err = portal.Bridge.DB.Reaction.Delete(ctx, oldReact) + if err != nil { + log.Err(err).Msg("Failed to delete old reaction from database") + } + } + } existing, err := portal.Bridge.DB.Reaction.GetByID(ctx, portal.Receiver, reactionTarget.ID, reactionTarget.PartID, preResp.SenderID, preResp.EmojiID) if err != nil { log.Err(err).Msg("Failed to check if reaction is a duplicate") @@ -1412,17 +1651,10 @@ func (portal *Portal) handleMatrixReaction(ctx context.Context, sender *UserLogi if existing.EmojiID != "" || existing.Emoji == preResp.Emoji { log.Debug().Msg("Ignoring duplicate reaction") portal.sendSuccessStatus(ctx, evt, 0, deterministicID) - return EventHandlingResultIgnored + return EventHandlingResultIgnored.WithEventID(deterministicID) } react.ReactionToOverride = existing - _, err = portal.Bridge.Bot.SendMessage(ctx, portal.MXID, event.EventRedaction, &event.Content{ - Parsed: &event.RedactionEventContent{ - Redacts: existing.MXID, - }, - }, nil) - if err != nil { - log.Err(err).Msg("Failed to remove old reaction") - } + defer removeOutdatedReaction(existing, false) } react.PreHandleResp = &preResp if preResp.MaxReactions > 0 { @@ -1437,18 +1669,14 @@ func (portal *Portal) handleMatrixReaction(ctx context.Context, sender *UserLogi // Keep n-1 previous reactions and remove the rest react.ExistingReactionsToKeep = allReactions[:preResp.MaxReactions-1] for _, oldReaction := range allReactions[preResp.MaxReactions-1:] { - _, err = portal.Bridge.Bot.SendMessage(ctx, portal.MXID, event.EventRedaction, &event.Content{ - Parsed: &event.RedactionEventContent{ - Redacts: oldReaction.MXID, - }, - }, nil) - if err != nil { - log.Err(err).Msg("Failed to remove previous reaction after limit was exceeded") - } - err = portal.Bridge.DB.Reaction.Delete(ctx, oldReaction) - if err != nil { - log.Err(err).Msg("Failed to delete previous reaction from database after limit was exceeded") + if existing != nil && oldReaction.EmojiID == existing.EmojiID { + // Don't double-delete on networks that only allow one emoji + continue } + // Intentionally defer in a loop, there won't be that many items, + // and we want all of them to be done after this function completes successfully + //goland:noinspection GoDeferInLoop + defer removeOutdatedReaction(oldReaction, true) } } } @@ -1493,8 +1721,7 @@ func (portal *Portal) handleMatrixReaction(ctx context.Context, sender *UserLogi if err != nil { log.Err(err).Msg("Failed to save reaction to database") } - portal.sendSuccessStatus(ctx, evt, 0, deterministicID) - return EventHandlingResultSuccess + return EventHandlingResultSuccess.WithEventID(deterministicID) } func handleMatrixRoomMeta[APIType any, ContentType any]( @@ -1503,14 +1730,19 @@ func handleMatrixRoomMeta[APIType any, ContentType any]( sender *UserLogin, origSender *OrigSender, evt *event.Event, + isStateRequest bool, fn func(APIType, context.Context, *MatrixRoomMeta[ContentType]) (bool, error), ) EventHandlingResult { if evt.StateKey == nil || *evt.StateKey != "" { return EventHandlingResultFailed.WithMSSError(ErrInvalidStateKey) } + //caps := sender.Client.GetCapabilities(ctx, portal) + //if stateCap, ok := caps.State[evt.Type.Type]; !ok || stateCap.Level <= event.CapLevelUnsupported { + // return EventHandlingResultIgnored.WithMSSError(fmt.Errorf("%s %w", evt.Type.Type, ErrRoomMetadataNotAllowed)) + //} api, ok := sender.Client.(APIType) if !ok { - return EventHandlingResultIgnored.WithMSSError(ErrRoomMetadataNotSupported) + return EventHandlingResultIgnored.WithMSSError(fmt.Errorf("%w of type %s", ErrRoomMetadataNotSupported, evt.Type)) } log := zerolog.Ctx(ctx) content, ok := evt.Content.Parsed.(ContentType) @@ -1544,7 +1776,6 @@ func handleMatrixRoomMeta[APIType any, ContentType any]( return EventHandlingResultIgnored } if !sender.Client.GetCapabilities(ctx, portal).DisappearingTimer.Supports(typedContent) { - portal.sendRoomMeta(ctx, nil, time.Now(), event.StateBeeperDisappearingTimer, "", portal.Disappear.ToEventContent(), false) return EventHandlingResultFailed.WithMSSError(ErrDisappearingTimerUnsupported) } } @@ -1563,13 +1794,11 @@ func handleMatrixRoomMeta[APIType any, ContentType any]( InputTransactionID: portal.parseInputTransactionID(origSender, evt), }, - PrevContent: prevContent, + IsStateRequest: isStateRequest, + PrevContent: prevContent, }) if err != nil { log.Err(err).Msg("Failed to handle Matrix room metadata") - if evt.Type == event.StateBeeperDisappearingTimer { - portal.sendRoomMeta(ctx, nil, time.Now(), event.StateBeeperDisappearingTimer, "", portal.Disappear.ToEventContent(), false) - } return EventHandlingResultFailed.WithMSSError(err) } if changed { @@ -1636,6 +1865,77 @@ func (portal *Portal) getTargetUser(ctx context.Context, userID id.UserID) (Ghos } } +func (portal *Portal) handleMatrixAcceptMessageRequest( + ctx context.Context, + sender *UserLogin, + origSender *OrigSender, + evt *event.Event, +) EventHandlingResult { + if origSender != nil { + return EventHandlingResultFailed.WithMSSError(ErrIgnoringAcceptRequestRelayedUser) + } + log := zerolog.Ctx(ctx) + content, ok := evt.Content.Parsed.(*event.BeeperAcceptMessageRequestEventContent) + if !ok { + log.Error().Type("content_type", evt.Content.Parsed).Msg("Unexpected parsed content type") + return EventHandlingResultFailed.WithMSSError(fmt.Errorf("%w: %T", ErrUnexpectedParsedContentType, evt.Content.Parsed)) + } + api, ok := sender.Client.(MessageRequestAcceptingNetworkAPI) + if !ok { + return EventHandlingResultIgnored.WithMSSError(ErrDeleteChatNotSupported) + } + err := api.HandleMatrixAcceptMessageRequest(ctx, &MatrixAcceptMessageRequest{ + Event: evt, + Content: content, + Portal: portal, + }) + if err != nil { + log.Err(err).Msg("Failed to handle Matrix accept message request") + return EventHandlingResultFailed.WithMSSError(err) + } + if portal.MessageRequest { + portal.MessageRequest = false + portal.UpdateBridgeInfo(ctx) + err = portal.Save(ctx) + if err != nil { + log.Err(err).Msg("Failed to save portal after accepting message request") + } + } + return EventHandlingResultSuccess.WithMSS() +} + +func (portal *Portal) autoAcceptMessageRequest( + ctx context.Context, evt *event.Event, sender *UserLogin, origSender *OrigSender, caps *event.RoomFeatures, +) error { + if !portal.MessageRequest || caps.MessageRequest == nil || caps.MessageRequest.AcceptWithMessage == event.CapLevelFullySupported { + return nil + } + mran, ok := sender.Client.(MessageRequestAcceptingNetworkAPI) + if !ok { + return nil + } + err := mran.HandleMatrixAcceptMessageRequest(ctx, &MatrixAcceptMessageRequest{ + Event: evt, + Content: &event.BeeperAcceptMessageRequestEventContent{ + IsImplicit: true, + }, + Portal: portal, + OrigSender: origSender, + }) + if err != nil { + return err + } + if portal.MessageRequest { + portal.MessageRequest = false + portal.UpdateBridgeInfo(ctx) + err = portal.Save(ctx) + if err != nil { + zerolog.Ctx(ctx).Err(err).Msg("Failed to save portal after accepting message request") + } + } + return nil +} + func (portal *Portal) handleMatrixDeleteChat( ctx context.Context, sender *UserLogin, @@ -1693,6 +1993,7 @@ func (portal *Portal) handleMatrixMembership( sender *UserLogin, origSender *OrigSender, evt *event.Event, + isStateRequest bool, ) EventHandlingResult { if evt.StateKey == nil { return EventHandlingResultFailed.WithMSSError(ErrInvalidStateKey) @@ -1732,7 +2033,6 @@ func (portal *Portal) handleMatrixMembership( return EventHandlingResultIgnored //.WithMSSError(ErrIgnoringLeaveEvent) } targetGhost, _ := target.(*Ghost) - targetUserLogin, _ := target.(*UserLogin) membershipChange := &MatrixMembershipChange{ MatrixRoomMeta: MatrixRoomMeta[*event.MemberEventContent]{ MatrixEventBase: MatrixEventBase[*event.MemberEventContent]{ @@ -1743,19 +2043,60 @@ func (portal *Portal) handleMatrixMembership( InputTransactionID: portal.parseInputTransactionID(origSender, evt), }, - PrevContent: prevContent, + IsStateRequest: isStateRequest, + PrevContent: prevContent, }, - Target: target, - TargetGhost: targetGhost, - TargetUserLogin: targetUserLogin, - Type: membershipChangeType, + Target: target, + Type: membershipChangeType, } - _, err = api.HandleMatrixMembership(ctx, membershipChange) + res, err := api.HandleMatrixMembership(ctx, membershipChange) if err != nil { log.Err(err).Msg("Failed to handle Matrix membership change") return EventHandlingResultFailed.WithMSSError(err) } - return EventHandlingResultSuccess.WithMSS() + didRedirectInvite := membershipChangeType == Invite && + targetGhost != nil && + res != nil && + res.RedirectTo != "" && + res.RedirectTo != targetGhost.ID + if didRedirectInvite { + log.Debug(). + Str("orig_id", string(targetGhost.ID)). + Str("redirect_id", string(res.RedirectTo)). + Msg("Invite was redirected to different ghost") + var redirectGhost *Ghost + redirectGhost, err = portal.Bridge.GetGhostByID(ctx, res.RedirectTo) + if err != nil { + log.Err(err).Msg("Failed to get redirect target ghost") + return EventHandlingResultFailed.WithError(err) + } + if !isStateRequest { + portal.sendRoomMeta( + ctx, + sender.User.DoublePuppet(ctx), + time.UnixMilli(evt.Timestamp), + event.StateMember, + evt.GetStateKey(), + &event.MemberEventContent{ + Membership: event.MembershipLeave, + Reason: fmt.Sprintf("Invite redirected to %s", res.RedirectTo), + }, + true, + nil, + ) + } + portal.sendRoomMeta( + ctx, + sender.User.DoublePuppet(ctx), + time.UnixMilli(evt.Timestamp), + event.StateMember, + redirectGhost.Intent.GetMXID().String(), + content, + false, + nil, + ) + } + return EventHandlingResultSuccess.WithMSS().WithSkipStateEcho(didRedirectInvite) } func makePLChange(old, new int, newIsSet bool) *SinglePowerLevelChange { @@ -1780,6 +2121,7 @@ func (portal *Portal) handleMatrixPowerLevels( sender *UserLogin, origSender *OrigSender, evt *event.Event, + isStateRequest bool, ) EventHandlingResult { if evt.StateKey == nil || *evt.StateKey != "" { return EventHandlingResultFailed.WithMSSError(ErrInvalidStateKey) @@ -1821,7 +2163,8 @@ func (portal *Portal) handleMatrixPowerLevels( InputTransactionID: portal.parseInputTransactionID(origSender, evt), }, - PrevContent: prevContent, + IsStateRequest: isStateRequest, + PrevContent: prevContent, }, Users: make(map[id.UserID]*UserPowerLevelChange), Events: make(map[string]*SinglePowerLevelChange), @@ -2009,6 +2352,7 @@ func (portal *Portal) UpdateMatrixRoomID( } else if alreadyExists { log.Debug().Msg("Replacement room is already a portal, overwriting") existingPortal.MXID = "" + existingPortal.RoomCreated.Clear() err := existingPortal.Save(ctx) if err != nil { return fmt.Errorf("failed to clear mxid of existing portal: %w", err) @@ -2016,6 +2360,7 @@ func (portal *Portal) UpdateMatrixRoomID( delete(portal.Bridge.portalsByMXID, portal.MXID) } portal.MXID = newRoomID + portal.RoomCreated.Set() portal.Bridge.portalsByMXID[portal.MXID] = portal portal.NameSet = false portal.AvatarSet = false @@ -2273,7 +2618,7 @@ func (portal *Portal) handleRemoteEvent(ctx context.Context, source *UserLogin, } func (portal *Portal) ensureFunctionalMember(ctx context.Context, ghost *Ghost) { - if !ghost.IsBot || portal.RoomType != database.RoomTypeDM || portal.OtherUserID == ghost.ID { + if !ghost.IsBot || portal.RoomType != database.RoomTypeDM || portal.OtherUserID == ghost.ID || portal.MXID == "" { return } ars, ok := portal.Bridge.Matrix.(MatrixConnectorWithArbitraryRoomState) @@ -2447,7 +2792,7 @@ func (portal *Portal) getRelationMeta( log.Err(err).Msg("Failed to get last thread message from database") } if prevThreadEvent == nil { - prevThreadEvent = threadRoot + prevThreadEvent = ptr.Clone(threadRoot) } } return @@ -2558,6 +2903,7 @@ func (portal *Portal) sendConvertedMessage( portal.Bridge.DisappearLoop.Add(ctx, &database.DisappearingMessage{ RoomID: portal.MXID, EventID: dbMessage.MXID, + Timestamp: dbMessage.Timestamp, DisappearingSetting: converted.Disappear, }) } @@ -3344,11 +3690,15 @@ func (portal *Portal) handleRemoteReadReceipt(ctx context.Context, source *UserL return evt.Int64("target_stream_order", targetStreamOrder) } err = soIntent.MarkStreamOrderRead(ctx, portal.MXID, targetStreamOrder, getEventTS(evt)) + if readUpTo.IsZero() { + readUpTo = getEventTS(evt) + } } else { addTargetLog = func(evt *zerolog.Event) *zerolog.Event { return evt.Stringer("target_mxid", lastTarget.MXID) } err = intent.MarkRead(ctx, portal.MXID, lastTarget.MXID, getEventTS(evt)) + readUpTo = lastTarget.Timestamp } if err != nil { addTargetLog(log.Err(err)).Msg("Failed to bridge read receipt") @@ -3357,7 +3707,7 @@ func (portal *Portal) handleRemoteReadReceipt(ctx context.Context, source *UserL addTargetLog(log.Debug()).Msg("Bridged read receipt") } if sender.IsFromMe { - portal.Bridge.DisappearLoop.StartAll(ctx, portal.MXID) + portal.Bridge.DisappearLoop.StartAllBefore(ctx, portal.MXID, readUpTo) } return EventHandlingResultSuccess } @@ -3380,7 +3730,7 @@ func (portal *Portal) handleRemoteMarkUnread(ctx context.Context, source *UserLo } func (portal *Portal) handleRemoteDeliveryReceipt(ctx context.Context, source *UserLogin, evt RemoteDeliveryReceipt) EventHandlingResult { - if portal.RoomType != database.RoomTypeDM || evt.GetSender().Sender != portal.OtherUserID { + if portal.RoomType != database.RoomTypeDM || (evt.GetSender().Sender != portal.OtherUserID && portal.OtherUserID != "") { return EventHandlingResultIgnored } intent, ok := portal.GetIntentFor(ctx, evt.GetSender(), source, RemoteEventDeliveryReceipt) @@ -3494,6 +3844,20 @@ func (portal *Portal) findOtherLogins(ctx context.Context, source *UserLogin) (o return } +type childDeleteProxy struct { + RemoteChatDeleteWithChildren + child networkid.PortalKey + done func() +} + +func (cdp *childDeleteProxy) AddLogContext(c zerolog.Context) zerolog.Context { + return cdp.RemoteChatDeleteWithChildren.AddLogContext(c).Str("subaction", "delete children") +} +func (cdp *childDeleteProxy) GetPortalKey() networkid.PortalKey { return cdp.child } +func (cdp *childDeleteProxy) ShouldCreatePortal() bool { return false } +func (cdp *childDeleteProxy) PreHandle(ctx context.Context, portal *Portal) {} +func (cdp *childDeleteProxy) PostHandle(ctx context.Context, portal *Portal) { cdp.done() } + func (portal *Portal) handleRemoteChatDelete(ctx context.Context, source *UserLogin, evt RemoteChatDelete) EventHandlingResult { log := zerolog.Ctx(ctx) if portal.Receiver == "" && evt.DeleteOnlyForMe() { @@ -3529,6 +3893,31 @@ func (portal *Portal) handleRemoteChatDelete(ctx context.Context, source *UserLo } } } + if childDeleter, ok := evt.(RemoteChatDeleteWithChildren); ok && childDeleter.DeleteChildren() && portal.RoomType == database.RoomTypeSpace { + children, err := portal.Bridge.GetChildPortals(ctx, portal.PortalKey) + if err != nil { + log.Err(err).Msg("Failed to fetch children to delete") + return EventHandlingResultFailed.WithError(err) + } + log.Debug(). + Int("portal_count", len(children)). + Msg("Deleting child portals before remote chat delete") + var wg sync.WaitGroup + wg.Add(len(children)) + for _, child := range children { + child.queueEvent(ctx, &portalRemoteEvent{ + evt: &childDeleteProxy{ + RemoteChatDeleteWithChildren: childDeleter, + child: child.PortalKey, + done: wg.Done, + }, + source: source, + evtType: RemoteEventChatDelete, + }) + } + wg.Wait() + log.Debug().Msg("Finished deleting child portals") + } err := portal.Delete(ctx) if err != nil { log.Err(err).Msg("Failed to delete portal from database") @@ -3584,12 +3973,43 @@ type PortalInfo = ChatInfo type ChatMember struct { EventSender Membership event.Membership - Nickname *string + // Per-room nickname for the user. Not yet used. + Nickname *string + // The power level to set for the user when syncing power levels. PowerLevel *int - UserInfo *UserInfo - + // Optional user info to sync the ghost user while updating membership. + UserInfo *UserInfo + // The user who sent the membership change (user who invited/kicked/banned this user). + // Not yet used. Not applicable if Membership is join or knock. + MemberSender EventSender + // Extra fields to include in the member event. MemberEventExtra map[string]any - PrevMembership event.Membership + // The expected previous membership. If this doesn't match, the change is ignored. + PrevMembership event.Membership +} + +type ChatMemberMap map[networkid.UserID]ChatMember + +// Set adds the given entry to this map, overwriting any existing entry with the same Sender field. +func (cmm ChatMemberMap) Set(member ChatMember) ChatMemberMap { + if member.Sender == "" && member.SenderLogin == "" && !member.IsFromMe { + return cmm + } + cmm[member.Sender] = member + return cmm +} + +// Add adds the given entry to this map, but will ignore it if an entry with the same Sender field already exists. +// It returns true if the entry was added, false otherwise. +func (cmm ChatMemberMap) Add(member ChatMember) bool { + if member.Sender == "" && member.SenderLogin == "" && !member.IsFromMe { + return false + } + if _, exists := cmm[member.Sender]; exists { + return false + } + cmm[member.Sender] = member + return true } type ChatMemberList struct { @@ -3613,7 +4033,7 @@ type ChatMemberList struct { // Deprecated: Use MemberMap instead to avoid duplicate entries Members []ChatMember - MemberMap map[networkid.UserID]ChatMember + MemberMap ChatMemberMap PowerLevels *PowerLevelOverrides } @@ -3715,9 +4135,9 @@ type ChatInfo struct { Disappear *database.DisappearingSetting ParentID *networkid.PortalID - UserLocal *UserLocalPortalInfo - - CanBackfill bool + UserLocal *UserLocalPortalInfo + MessageRequest *bool + CanBackfill bool ExcludeChangesFromTimeline bool @@ -3761,7 +4181,7 @@ func (portal *Portal) updateName( } portal.Name = name portal.NameSet = portal.sendRoomMeta( - ctx, sender, ts, event.StateRoomName, "", &event.RoomNameEventContent{Name: name}, excludeFromTimeline, + ctx, sender, ts, event.StateRoomName, "", &event.RoomNameEventContent{Name: name}, excludeFromTimeline, nil, ) return true } @@ -3774,7 +4194,7 @@ func (portal *Portal) updateTopic( } portal.Topic = topic portal.TopicSet = portal.sendRoomMeta( - ctx, sender, ts, event.StateTopic, "", &event.TopicEventContent{Topic: topic}, excludeFromTimeline, + ctx, sender, ts, event.StateTopic, "", &event.TopicEventContent{Topic: topic}, excludeFromTimeline, nil, ) return true } @@ -3805,7 +4225,7 @@ func (portal *Portal) updateAvatar( portal.AvatarHash = newHash } portal.AvatarSet = portal.sendRoomMeta( - ctx, sender, ts, event.StateRoomAvatar, "", &event.RoomAvatarEventContent{URL: portal.AvatarMXC}, excludeFromTimeline, + ctx, sender, ts, event.StateRoomAvatar, "", &event.RoomAvatarEventContent{URL: portal.AvatarMXC}, excludeFromTimeline, nil, ) return true } @@ -3837,10 +4257,11 @@ func (portal *Portal) getBridgeInfo() (string, event.BridgeEventContent) { Creator: portal.Bridge.Bot.GetMXID(), Protocol: portal.Bridge.Network.GetName().AsBridgeInfoSection(), Channel: event.BridgeInfoSection{ - ID: string(portal.ID), - DisplayName: portal.Name, - AvatarURL: portal.AvatarMXC, - Receiver: string(portal.Receiver), + ID: string(portal.ID), + DisplayName: portal.Name, + AvatarURL: portal.AvatarMXC, + Receiver: string(portal.Receiver), + MessageRequest: portal.MessageRequest, // TODO external URL? }, BeeperRoomTypeV2: string(portal.RoomType), @@ -3850,6 +4271,7 @@ func (portal *Portal) getBridgeInfo() (string, event.BridgeEventContent) { } if bridgeInfo.Protocol.ID == "slackgo" { bridgeInfo.TempSlackRemoteIDMigratedFlag = true + bridgeInfo.TempSlackRemoteIDMigratedFlag2 = true } parent := portal.GetTopLevelParent() if parent != nil { @@ -3872,8 +4294,8 @@ func (portal *Portal) UpdateBridgeInfo(ctx context.Context) { return } stateKey, bridgeInfo := portal.getBridgeInfo() - portal.sendRoomMeta(ctx, nil, time.Now(), event.StateBridge, stateKey, &bridgeInfo, false) - portal.sendRoomMeta(ctx, nil, time.Now(), event.StateHalfShotBridge, stateKey, &bridgeInfo, false) + portal.sendRoomMeta(ctx, nil, time.Now(), event.StateBridge, stateKey, &bridgeInfo, false, nil) + portal.sendRoomMeta(ctx, nil, time.Now(), event.StateHalfShotBridge, stateKey, &bridgeInfo, false, nil) } func (portal *Portal) UpdateCapabilities(ctx context.Context, source *UserLogin, implicit bool) bool { @@ -3895,7 +4317,7 @@ func (portal *Portal) UpdateCapabilities(ctx context.Context, source *UserLogin, Str("old_id", portal.CapState.ID). Str("new_id", capID). Msg("Sending new room capability event") - success := portal.sendRoomMeta(ctx, nil, time.Now(), event.StateBeeperRoomFeatures, portal.getBridgeInfoStateKey(), caps, false) + success := portal.sendRoomMeta(ctx, nil, time.Now(), event.StateBeeperRoomFeatures, portal.getBridgeInfoStateKey(), caps, false, nil) if !success { return false } @@ -3906,7 +4328,7 @@ func (portal *Portal) UpdateCapabilities(ctx context.Context, source *UserLogin, } if caps.DisappearingTimer != nil && !portal.CapState.Flags.Has(database.CapStateFlagDisappearingTimerSet) { zerolog.Ctx(ctx).Debug().Msg("Disappearing timer capability was added, sending disappearing timer state event") - success = portal.sendRoomMeta(ctx, nil, time.Now(), event.StateBeeperDisappearingTimer, "", portal.Disappear.ToEventContent(), true) + success = portal.sendRoomMeta(ctx, nil, time.Now(), event.StateBeeperDisappearingTimer, "", portal.Disappear.ToEventContent(), true, nil) if !success { return false } @@ -3945,11 +4367,14 @@ func (portal *Portal) sendRoomMeta( stateKey string, content any, excludeFromTimeline bool, + extra map[string]any, ) bool { if portal.MXID == "" { return false } - extra := make(map[string]any) + if extra == nil { + extra = make(map[string]any) + } if excludeFromTimeline { extra["com.beeper.exclude_from_timeline"] = true } @@ -3966,9 +4391,55 @@ func (portal *Portal) sendRoomMeta( Msg("Failed to set room metadata") return false } + if eventType == event.StateBeeperDisappearingTimer { + // TODO remove this debug log at some point + zerolog.Ctx(ctx).Debug(). + Any("content", content). + Msg("Sent new disappearing timer event") + } return true } +func (portal *Portal) revertRoomMeta(ctx context.Context, evt *event.Event) { + if !portal.Bridge.Config.RevertFailedStateChanges { + return + } + if evt.GetStateKey() != "" && evt.Type != event.StateMember { + return + } + switch evt.Type { + case event.StateRoomName: + portal.sendRoomMeta(ctx, nil, time.Time{}, event.StateRoomName, "", &event.RoomNameEventContent{Name: portal.Name}, true, nil) + case event.StateRoomAvatar: + portal.sendRoomMeta(ctx, nil, time.Time{}, event.StateRoomAvatar, "", &event.RoomAvatarEventContent{URL: portal.AvatarMXC}, true, nil) + case event.StateTopic: + portal.sendRoomMeta(ctx, nil, time.Time{}, event.StateTopic, "", &event.TopicEventContent{Topic: portal.Topic}, true, nil) + case event.StateBeeperDisappearingTimer: + portal.sendRoomMeta(ctx, nil, time.Time{}, event.StateBeeperDisappearingTimer, "", portal.Disappear.ToEventContent(), true, nil) + case event.StateMember: + var prevContent *event.MemberEventContent + var extra map[string]any + if evt.Unsigned.PrevContent != nil { + _ = evt.Unsigned.PrevContent.ParseRaw(evt.Type) + prevContent = evt.Unsigned.PrevContent.AsMember() + newContent := evt.Content.AsMember() + if prevContent.Membership == newContent.Membership { + return + } + extra = evt.Unsigned.PrevContent.Raw + } else { + prevContent = &event.MemberEventContent{Membership: event.MembershipLeave} + } + if portal.Bridge.Matrix.GetCapabilities().ArbitraryMemberChange { + if extra == nil { + extra = make(map[string]any) + } + extra["com.beeper.member_rollback"] = true + portal.sendRoomMeta(ctx, nil, time.Time{}, event.StateMember, evt.GetStateKey(), prevContent, true, extra) + } + } +} + func (portal *Portal) getInitialMemberList(ctx context.Context, members *ChatMemberList, source *UserLogin, pl *event.PowerLevelsEventContent) (invite, functional []id.UserID, err error) { if members == nil { invite = []id.UserID{source.UserMXID} @@ -4052,6 +4523,39 @@ func (portal *Portal) updateOtherUser(ctx context.Context, members *ChatMemberLi return false } +func looksDirectlyJoinable(rule *event.JoinRulesEventContent) bool { + switch rule.JoinRule { + case event.JoinRulePublic: + return true + case event.JoinRuleKnockRestricted, event.JoinRuleRestricted: + for _, allow := range rule.Allow { + if allow.Type == "fi.mau.spam_checker" { + return true + } + } + } + return false +} + +func (portal *Portal) roomIsPublic(ctx context.Context) bool { + mx, ok := portal.Bridge.Matrix.(MatrixConnectorWithArbitraryRoomState) + if !ok { + return false + } + evt, err := mx.GetStateEvent(ctx, portal.MXID, event.StateJoinRules, "") + if err != nil { + zerolog.Ctx(ctx).Warn().Err(err).Msg("Failed to get join rules to check if room is public") + return false + } else if evt == nil { + return false + } + content, ok := evt.Content.Parsed.(*event.JoinRulesEventContent) + if !ok { + return false + } + return looksDirectlyJoinable(content) +} + func (portal *Portal) syncParticipants( ctx context.Context, members *ChatMemberList, @@ -4120,7 +4624,7 @@ func (portal *Portal) syncParticipants( wrappedContent := &event.Content{Parsed: content, Raw: exmaps.NonNilClone(member.MemberEventExtra)} addExcludeFromTimeline(wrappedContent.Raw) thisEvtSender := sender - if member.Membership == event.MembershipJoin { + if member.Membership == event.MembershipJoin && (intent == nil || !portal.roomIsPublic(ctx)) { content.Membership = event.MembershipInvite if intent != nil { wrappedContent.Raw["fi.mau.will_auto_accept"] = true @@ -4150,7 +4654,11 @@ func (portal *Portal) syncParticipants( currentMember.Membership = event.MembershipLeave } } - _, err = portal.sendStateWithIntentOrBot(ctx, thisEvtSender, event.StateMember, extraUserID.String(), wrappedContent, ts) + if content.Membership == event.MembershipJoin && intent != nil && intent.GetMXID() == extraUserID { + _, err = intent.SendState(ctx, portal.MXID, event.StateMember, extraUserID.String(), wrappedContent, ts) + } else { + _, err = portal.sendStateWithIntentOrBot(ctx, thisEvtSender, event.StateMember, extraUserID.String(), wrappedContent, ts) + } if err != nil { addLogContext(log.Err(err)). Str("new_membership", string(content.Membership)). @@ -4227,7 +4735,7 @@ func (portal *Portal) syncParticipants( if memberEvt.Membership == event.MembershipLeave || memberEvt.Membership == event.MembershipBan { continue } - if !portal.Bridge.IsGhostMXID(extraMember) && portal.Relay != nil { + if !portal.Bridge.IsGhostMXID(extraMember) && (portal.Relay != nil || !portal.Bridge.Config.KickMatrixUsers) { continue } _, err = portal.Bridge.Bot.SendState(ctx, portal.MXID, event.StateMember, extraMember.String(), &event.Content{ @@ -4353,6 +4861,7 @@ func (portal *Portal) UpdateDisappearingSetting( "", setting.ToEventContent(), opts.ExcludeFromTimeline, + nil, ) if !opts.SendNotice { @@ -4481,7 +4990,7 @@ func (portal *Portal) UpdateInfo(ctx context.Context, info *ChatInfo, source *Us } if info.JoinRule != nil { // TODO change detection instead of spamming this every time? - portal.sendRoomMeta(ctx, sender, ts, event.StateJoinRules, "", info.JoinRule, info.ExcludeChangesFromTimeline) + portal.sendRoomMeta(ctx, sender, ts, event.StateJoinRules, "", info.JoinRule, info.ExcludeChangesFromTimeline, nil) } if info.Type != nil && portal.RoomType != *info.Type { if portal.MXID != "" && (*info.Type == database.RoomTypeSpace || portal.RoomType == database.RoomTypeSpace) { @@ -4494,6 +5003,10 @@ func (portal *Portal) UpdateInfo(ctx context.Context, info *ChatInfo, source *Us portal.RoomType = *info.Type } } + if info.MessageRequest != nil && *info.MessageRequest != portal.MessageRequest { + changed = true + portal.MessageRequest = *info.MessageRequest + } if info.Members != nil && portal.MXID != "" && source != nil { err := portal.syncParticipants(ctx, info.Members, source, nil, time.Time{}) if err != nil { @@ -4535,6 +5048,9 @@ func (portal *Portal) CreateMatrixRoom(ctx context.Context, source *UserLogin, i } return nil } + if portal.deleted.IsSet() { + return ErrPortalIsDeleted + } waiter := make(chan struct{}) closed := false evt := &portalCreateEvent{ @@ -4552,7 +5068,11 @@ func (portal *Portal) CreateMatrixRoom(ctx context.Context, source *UserLogin, i if PortalEventBuffer == 0 { go portal.queueEvent(ctx, evt) } else { - portal.events <- evt + select { + case portal.events <- evt: + case <-portal.deleted.GetChan(): + return ErrPortalIsDeleted + } } select { case <-ctx.Done(): @@ -4563,7 +5083,11 @@ func (portal *Portal) CreateMatrixRoom(ctx context.Context, source *UserLogin, i } func (portal *Portal) createMatrixRoomInLoop(ctx context.Context, source *UserLogin, info *ChatInfo, backfillBundle any) error { + cancellableCtx, cancel := context.WithCancel(ctx) + defer cancel() + portal.cancelRoomCreate.CompareAndSwap(nil, &cancel) portal.roomCreateLock.Lock() + portal.cancelRoomCreate.Store(&cancel) defer portal.roomCreateLock.Unlock() if portal.MXID != "" { if source != nil { @@ -4574,6 +5098,7 @@ func (portal *Portal) createMatrixRoomInLoop(ctx context.Context, source *UserLo log := zerolog.Ctx(ctx).With(). Str("action", "create matrix room"). Logger() + cancellableCtx = log.WithContext(cancellableCtx) ctx = log.WithContext(ctx) log.Info().Msg("Creating Matrix room") @@ -4582,16 +5107,16 @@ func (portal *Portal) createMatrixRoomInLoop(ctx context.Context, source *UserLo if info != nil { log.Warn().Msg("CreateMatrixRoom got info without members. Refetching info") } - info, err = source.Client.GetChatInfo(ctx, portal) + info, err = source.Client.GetChatInfo(cancellableCtx, portal) if err != nil { log.Err(err).Msg("Failed to update portal info for creation") return err } } - portal.UpdateInfo(ctx, info, source, nil, time.Time{}) - if ctx.Err() != nil { - return ctx.Err() + portal.UpdateInfo(cancellableCtx, info, source, nil, time.Time{}) + if cancellableCtx.Err() != nil { + return cancellableCtx.Err() } powerLevels := &event.PowerLevelsEventContent{ @@ -4604,7 +5129,7 @@ func (portal *Portal) createMatrixRoomInLoop(ctx context.Context, source *UserLo portal.Bridge.Bot.GetMXID(): 9001, }, } - initialMembers, extraFunctionalMembers, err := portal.getInitialMemberList(ctx, info.Members, source, powerLevels) + initialMembers, extraFunctionalMembers, err := portal.getInitialMemberList(cancellableCtx, info.Members, source, powerLevels) if err != nil { log.Err(err).Msg("Failed to process participant list for portal creation") return err @@ -4619,7 +5144,6 @@ func (portal *Portal) createMatrixRoomInLoop(ctx context.Context, source *UserLo IsDirect: portal.RoomType == database.RoomTypeDM, PowerLevelOverride: powerLevels, BeeperLocalRoomID: portal.Bridge.Matrix.GenerateDeterministicRoomID(portal.PortalKey), - RoomVersion: id.RoomV11, } autoJoinInvites := portal.Bridge.Matrix.GetCapabilities().AutoJoinInvites if autoJoinInvites { @@ -4632,7 +5156,7 @@ func (portal *Portal) createMatrixRoomInLoop(ctx context.Context, source *UserLo req.CreationContent["type"] = event.RoomTypeSpace } bridgeInfoStateKey, bridgeInfo := portal.getBridgeInfo() - roomFeatures := source.Client.GetCapabilities(ctx, portal) + roomFeatures := source.Client.GetCapabilities(cancellableCtx, portal) portal.CapState = database.CapabilityState{ Source: source.ID, ID: roomFeatures.GetID(), @@ -4714,6 +5238,9 @@ func (portal *Portal) createMatrixRoomInLoop(ctx context.Context, source *UserLo Content: event.Content{Parsed: info.JoinRule}, }) } + if cancellableCtx.Err() != nil { + return cancellableCtx.Err() + } roomID, err := portal.Bridge.Bot.CreateRoom(ctx, &req) if err != nil { log.Err(err).Msg("Failed to create Matrix room") @@ -4724,6 +5251,7 @@ func (portal *Portal) createMatrixRoomInLoop(ctx context.Context, source *UserLo portal.TopicSet = true portal.NameSet = true portal.MXID = roomID + portal.RoomCreated.Set() portal.Bridge.cacheLock.Lock() portal.Bridge.portalsByMXID[roomID] = portal portal.Bridge.cacheLock.Unlock() @@ -4771,7 +5299,10 @@ func (portal *Portal) createMatrixRoomInLoop(ctx context.Context, source *UserLo } } portal.addToUserSpaces(ctx) - if portal.Bridge.Config.Backfill.Enabled && portal.RoomType != database.RoomTypeSpace && !portal.Bridge.Background { + if info.CanBackfill && + portal.Bridge.Config.Backfill.Enabled && + portal.RoomType != database.RoomTypeSpace && + !portal.Bridge.Background { portal.doForwardBackfill(ctx, source, nil, backfillBundle) } return nil @@ -4786,7 +5317,7 @@ func (portal *Portal) addToUserSpaces(ctx context.Context) { if portal.Receiver != "" { login := portal.Bridge.GetCachedUserLoginByID(portal.Receiver) if login != nil { - up, err := portal.Bridge.DB.UserPortal.Get(ctx, login.UserLogin, portal.PortalKey) + up, err := portal.Bridge.DB.UserPortal.GetOrCreate(ctx, login.UserLogin, portal.PortalKey) if err != nil { log.Err(err).Msg("Failed to get user portal to add portal to spaces") } else { @@ -4811,8 +5342,11 @@ func (portal *Portal) addToUserSpaces(ctx context.Context) { } func (portal *Portal) Delete(ctx context.Context) error { + if portal.deleted.IsSet() { + return nil + } portal.removeInPortalCache(ctx) - err := portal.Bridge.DB.Portal.Delete(ctx, portal.PortalKey) + err := portal.safeDBDelete(ctx) if err != nil { return err } @@ -4822,11 +5356,21 @@ func (portal *Portal) Delete(ctx context.Context) error { return nil } +func (portal *Portal) safeDBDelete(ctx context.Context) error { + err := portal.Bridge.DB.Message.DeleteInChunks(ctx, portal.PortalKey) + if err != nil { + return fmt.Errorf("failed to delete messages in portal: %w", err) + } + // TODO delete child portals? + return portal.Bridge.DB.Portal.Delete(ctx, portal.PortalKey) +} + func (portal *Portal) RemoveMXID(ctx context.Context) error { if portal.MXID == "" { return nil } portal.MXID = "" + portal.RoomCreated.Clear() err := portal.Save(ctx) if err != nil { return err @@ -4859,8 +5403,10 @@ func (portal *Portal) removeInPortalCache(ctx context.Context) { } func (portal *Portal) unlockedDelete(ctx context.Context) error { - // TODO delete child portals? - err := portal.Bridge.DB.Portal.Delete(ctx, portal.PortalKey) + if portal.deleted.IsSet() { + return nil + } + err := portal.safeDBDelete(ctx) if err != nil { return err } @@ -4869,15 +5415,18 @@ func (portal *Portal) unlockedDelete(ctx context.Context) error { } func (portal *Portal) unlockedDeleteCache() { + if portal.deleted.IsSet() { + return + } delete(portal.Bridge.portalsByKey, portal.PortalKey) if portal.MXID != "" { delete(portal.Bridge.portalsByMXID, portal.MXID) } + portal.deleted.Set() if portal.events != nil { // TODO there's a small risk of this racing with a queueEvent call close(portal.events) } - portal.deleted = true } func (portal *Portal) Save(ctx context.Context) error { @@ -4885,6 +5434,9 @@ func (portal *Portal) Save(ctx context.Context) error { } func (portal *Portal) SetRelay(ctx context.Context, relay *UserLogin) error { + if portal.Receiver != "" && relay.ID != portal.Receiver { + return fmt.Errorf("can't set non-receiver login as relay") + } portal.Relay = relay if relay == nil { portal.RelayLoginID = "" diff --git a/bridgev2/portalbackfill.go b/bridgev2/portalbackfill.go index f7819968..879f07ae 100644 --- a/bridgev2/portalbackfill.go +++ b/bridgev2/portalbackfill.go @@ -194,6 +194,9 @@ func (portal *Portal) doThreadBackfill(ctx context.Context, source *UserLogin, t if err != nil { log.Err(err).Msg("Failed to get last thread message") return + } else if anchorMessage == nil { + log.Warn().Msg("No messages found in thread?") + return } resp := portal.fetchThreadBackfill(ctx, source, anchorMessage) if resp != nil { @@ -387,12 +390,16 @@ func (portal *Portal) compileBatchMessage(ctx context.Context, source *UserLogin out.Disappear = append(out.Disappear, &database.DisappearingMessage{ RoomID: portal.MXID, EventID: evtID, + Timestamp: msg.Timestamp, DisappearingSetting: msg.Disappear, }) } } slices.Sort(partIDs) for _, reaction := range msg.Reactions { + if reaction == nil { + continue + } reactionIntent, ok := portal.GetIntentFor(ctx, reaction.Sender, source, RemoteEventReactionRemove) if !ok { continue @@ -403,6 +410,7 @@ func (portal *Portal) compileBatchMessage(ctx context.Context, source *UserLogin if reaction.Timestamp.IsZero() { reaction.Timestamp = msg.Timestamp.Add(10 * time.Millisecond) } + //lint:ignore SA4006 it's a todo targetPart, ok := partMap[*reaction.TargetPart] if !ok { // TODO warning log and/or skip reaction? diff --git a/bridgev2/portalinternal.go b/bridgev2/portalinternal.go index d9373eb6..4c7e2447 100644 --- a/bridgev2/portalinternal.go +++ b/bridgev2/portalinternal.go @@ -49,6 +49,10 @@ func (portal *PortalInternals) HandleSingleEvent(ctx context.Context, rawEvt any (*Portal)(portal).handleSingleEvent(ctx, rawEvt, doneCallback) } +func (portal *PortalInternals) UnwrapBeeperSendState(ctx context.Context, evt *event.Event) error { + return (*Portal)(portal).unwrapBeeperSendState(ctx, evt) +} + func (portal *PortalInternals) SendSuccessStatus(ctx context.Context, evt *event.Event, streamOrder int64, newEventID id.EventID) { (*Portal)(portal).sendSuccessStatus(ctx, evt, streamOrder, newEventID) } @@ -61,8 +65,8 @@ func (portal *PortalInternals) CheckConfusableName(ctx context.Context, userID i return (*Portal)(portal).checkConfusableName(ctx, userID, name) } -func (portal *PortalInternals) HandleMatrixEvent(ctx context.Context, sender *User, evt *event.Event) EventHandlingResult { - return (*Portal)(portal).handleMatrixEvent(ctx, sender, evt) +func (portal *PortalInternals) HandleMatrixEvent(ctx context.Context, sender *User, evt *event.Event, isStateRequest bool) EventHandlingResult { + return (*Portal)(portal).handleMatrixEvent(ctx, sender, evt, isStateRequest) } func (portal *PortalInternals) HandleMatrixReceipts(ctx context.Context, evt *event.Event) EventHandlingResult { @@ -125,12 +129,12 @@ func (portal *PortalInternals) HandleMatrixDeleteChat(ctx context.Context, sende return (*Portal)(portal).handleMatrixDeleteChat(ctx, sender, origSender, evt) } -func (portal *PortalInternals) HandleMatrixMembership(ctx context.Context, sender *UserLogin, origSender *OrigSender, evt *event.Event) EventHandlingResult { - return (*Portal)(portal).handleMatrixMembership(ctx, sender, origSender, evt) +func (portal *PortalInternals) HandleMatrixMembership(ctx context.Context, sender *UserLogin, origSender *OrigSender, evt *event.Event, isStateRequest bool) EventHandlingResult { + return (*Portal)(portal).handleMatrixMembership(ctx, sender, origSender, evt, isStateRequest) } -func (portal *PortalInternals) HandleMatrixPowerLevels(ctx context.Context, sender *UserLogin, origSender *OrigSender, evt *event.Event) EventHandlingResult { - return (*Portal)(portal).handleMatrixPowerLevels(ctx, sender, origSender, evt) +func (portal *PortalInternals) HandleMatrixPowerLevels(ctx context.Context, sender *UserLogin, origSender *OrigSender, evt *event.Event, isStateRequest bool) EventHandlingResult { + return (*Portal)(portal).handleMatrixPowerLevels(ctx, sender, origSender, evt, isStateRequest) } func (portal *PortalInternals) HandleMatrixTombstone(ctx context.Context, evt *event.Event) EventHandlingResult { @@ -289,8 +293,12 @@ func (portal *PortalInternals) SendStateWithIntentOrBot(ctx context.Context, sen return (*Portal)(portal).sendStateWithIntentOrBot(ctx, sender, eventType, stateKey, content, ts) } -func (portal *PortalInternals) SendRoomMeta(ctx context.Context, sender MatrixAPI, ts time.Time, eventType event.Type, stateKey string, content any, excludeFromTimeline bool) bool { - return (*Portal)(portal).sendRoomMeta(ctx, sender, ts, eventType, stateKey, content, excludeFromTimeline) +func (portal *PortalInternals) SendRoomMeta(ctx context.Context, sender MatrixAPI, ts time.Time, eventType event.Type, stateKey string, content any, excludeFromTimeline bool, extra map[string]any) bool { + return (*Portal)(portal).sendRoomMeta(ctx, sender, ts, eventType, stateKey, content, excludeFromTimeline, extra) +} + +func (portal *PortalInternals) RevertRoomMeta(ctx context.Context, evt *event.Event) { + (*Portal)(portal).revertRoomMeta(ctx, evt) } func (portal *PortalInternals) GetInitialMemberList(ctx context.Context, members *ChatMemberList, source *UserLogin, pl *event.PowerLevelsEventContent) (invite, functional []id.UserID, err error) { @@ -301,6 +309,10 @@ func (portal *PortalInternals) UpdateOtherUser(ctx context.Context, members *Cha return (*Portal)(portal).updateOtherUser(ctx, members) } +func (portal *PortalInternals) RoomIsPublic(ctx context.Context) bool { + return (*Portal)(portal).roomIsPublic(ctx) +} + func (portal *PortalInternals) SyncParticipants(ctx context.Context, members *ChatMemberList, source *UserLogin, sender MatrixAPI, ts time.Time) error { return (*Portal)(portal).syncParticipants(ctx, members, source, sender, ts) } diff --git a/bridgev2/portalreid.go b/bridgev2/portalreid.go index a25fe820..c976d97c 100644 --- a/bridgev2/portalreid.go +++ b/bridgev2/portalreid.go @@ -32,21 +32,40 @@ func (br *Bridge) ReIDPortal(ctx context.Context, source, target networkid.Porta if source == target { return ReIDResultError, nil, fmt.Errorf("illegal re-ID call: source and target are the same") } - log := zerolog.Ctx(ctx) - log.Debug().Msg("Re-ID'ing portal") + log := zerolog.Ctx(ctx).With(). + Str("action", "re-id portal"). + Stringer("source_portal_key", source). + Stringer("target_portal_key", target). + Logger() + ctx = log.WithContext(ctx) defer func() { log.Debug().Msg("Finished handling portal re-ID") }() - br.cacheLock.Lock() - defer br.cacheLock.Unlock() - sourcePortal, err := br.UnlockedGetPortalByKey(ctx, source, true) + acquireCacheLock := func() { + if !br.cacheLock.TryLock() { + log.Debug().Msg("Waiting for global cache lock") + br.cacheLock.Lock() + log.Debug().Msg("Acquired global cache lock after waiting") + } else { + log.Trace().Msg("Acquired global cache lock without waiting") + } + } + log.Debug().Msg("Re-ID'ing portal") + sourcePortal, err := br.GetExistingPortalByKey(ctx, source) if err != nil { return ReIDResultError, nil, fmt.Errorf("failed to get source portal: %w", err) } else if sourcePortal == nil { log.Debug().Msg("Source portal not found, re-ID is no-op") return ReIDResultNoOp, nil, nil } - sourcePortal.roomCreateLock.Lock() + if !sourcePortal.roomCreateLock.TryLock() { + if cancelCreate := sourcePortal.cancelRoomCreate.Swap(nil); cancelCreate != nil { + (*cancelCreate)() + } + log.Debug().Msg("Waiting for source portal room creation lock") + sourcePortal.roomCreateLock.Lock() + log.Debug().Msg("Acquired source portal room creation lock after waiting") + } defer sourcePortal.roomCreateLock.Unlock() if sourcePortal.MXID == "" { log.Info().Msg("Source portal doesn't have Matrix room, deleting row") @@ -59,22 +78,37 @@ func (br *Bridge) ReIDPortal(ctx context.Context, source, target networkid.Porta log.UpdateContext(func(c zerolog.Context) zerolog.Context { return c.Stringer("source_portal_mxid", sourcePortal.MXID) }) + + acquireCacheLock() targetPortal, err := br.UnlockedGetPortalByKey(ctx, target, true) if err != nil { + br.cacheLock.Unlock() return ReIDResultError, nil, fmt.Errorf("failed to get target portal: %w", err) } if targetPortal == nil { log.Info().Msg("Target portal doesn't exist, re-ID'ing source portal") err = sourcePortal.unlockedReID(ctx, target) + br.cacheLock.Unlock() if err != nil { return ReIDResultError, nil, fmt.Errorf("failed to re-ID source portal: %w", err) } return ReIDResultSourceReIDd, sourcePortal, nil } - targetPortal.roomCreateLock.Lock() + br.cacheLock.Unlock() + + if !targetPortal.roomCreateLock.TryLock() { + if cancelCreate := targetPortal.cancelRoomCreate.Swap(nil); cancelCreate != nil { + (*cancelCreate)() + } + log.Debug().Msg("Waiting for target portal room creation lock") + targetPortal.roomCreateLock.Lock() + log.Debug().Msg("Acquired target portal room creation lock after waiting") + } defer targetPortal.roomCreateLock.Unlock() if targetPortal.MXID == "" { log.Info().Msg("Target portal row exists, but doesn't have a Matrix room. Deleting target portal row and re-ID'ing source portal") + acquireCacheLock() + defer br.cacheLock.Unlock() err = targetPortal.unlockedDelete(ctx) if err != nil { return ReIDResultError, nil, fmt.Errorf("failed to delete target portal: %w", err) @@ -89,6 +123,9 @@ func (br *Bridge) ReIDPortal(ctx context.Context, source, target networkid.Porta return c.Stringer("target_portal_mxid", targetPortal.MXID) }) log.Info().Msg("Both target and source portals have Matrix rooms, tombstoning source portal") + sourcePortal.removeInPortalCache(ctx) + acquireCacheLock() + defer br.cacheLock.Unlock() err = sourcePortal.unlockedDelete(ctx) if err != nil { return ReIDResultError, nil, fmt.Errorf("failed to delete source portal row: %w", err) @@ -96,7 +133,7 @@ func (br *Bridge) ReIDPortal(ctx context.Context, source, target networkid.Porta go func() { _, err := br.Bot.SendState(ctx, sourcePortal.MXID, event.StateTombstone, "", &event.Content{ Parsed: &event.TombstoneEventContent{ - Body: fmt.Sprintf("This room has been merged"), + Body: "This room has been merged", ReplacementRoom: targetPortal.MXID, }, }, time.Now()) diff --git a/bridgev2/provisionutil/creategroup.go b/bridgev2/provisionutil/creategroup.go index f389ab42..72bacaff 100644 --- a/bridgev2/provisionutil/creategroup.go +++ b/bridgev2/provisionutil/creategroup.go @@ -15,6 +15,7 @@ import ( "maunium.net/go/mautrix" "maunium.net/go/mautrix/bridgev2" "maunium.net/go/mautrix/bridgev2/networkid" + "maunium.net/go/mautrix/event" "maunium.net/go/mautrix/id" ) @@ -22,6 +23,8 @@ type RespCreateGroup struct { ID networkid.PortalID `json:"id"` MXID id.RoomID `json:"mxid"` Portal *bridgev2.Portal `json:"-"` + + FailedParticipants map[networkid.UserID]*bridgev2.CreateChatFailedParticipant `json:"failed_participants,omitempty"` } func CreateGroup(ctx context.Context, login *bridgev2.UserLogin, params *bridgev2.GroupCreateParams) (*RespCreateGroup, error) { @@ -29,6 +32,9 @@ func CreateGroup(ctx context.Context, login *bridgev2.UserLogin, params *bridgev if !ok { return nil, bridgev2.RespError(mautrix.MUnrecognized.WithMessage("This bridge does not support creating groups")) } + zerolog.Ctx(ctx).Debug(). + Any("create_params", params). + Msg("Creating group chat on remote network") caps := login.Bridge.Network.GetCapabilities() typeSpec, validType := caps.Provisioning.GroupCreation[params.Type] if !validType { @@ -36,11 +42,20 @@ func CreateGroup(ctx context.Context, login *bridgev2.UserLogin, params *bridgev } if len(params.Participants) < typeSpec.Participants.MinLength { return nil, bridgev2.RespError(mautrix.MInvalidParam.WithMessage("Must have at least %d members", typeSpec.Participants.MinLength)) + } else if typeSpec.Participants.MaxLength > 0 && len(params.Participants) > typeSpec.Participants.MaxLength { + return nil, bridgev2.RespError(mautrix.MInvalidParam.WithMessage("Must have at most %d members", typeSpec.Participants.MaxLength)) } userIDValidatingNetwork, uidValOK := login.Bridge.Network.(bridgev2.IdentifierValidatingNetwork) - for _, participant := range params.Participants { - if uidValOK && !userIDValidatingNetwork.ValidateUserID(participant) { - return nil, bridgev2.RespError(mautrix.MInvalidParam.WithMessage("User ID %q is not valid on this network", participant)) + for i, participant := range params.Participants { + parsedParticipant, ok := login.Bridge.Matrix.ParseGhostMXID(id.UserID(participant)) + if ok { + participant = parsedParticipant + params.Participants[i] = participant + } + if !typeSpec.Participants.SkipIdentifierValidation { + if uidValOK && !userIDValidatingNetwork.ValidateUserID(participant) { + return nil, bridgev2.RespError(mautrix.MInvalidParam.WithMessage("User ID %q is not valid on this network", participant)) + } } if api.IsThisUser(ctx, participant) { return nil, bridgev2.RespError(mautrix.MInvalidParam.WithMessage("You can't include yourself in the participants list", participant)) @@ -50,7 +65,7 @@ func CreateGroup(ctx context.Context, login *bridgev2.UserLogin, params *bridgev return nil, bridgev2.RespError(mautrix.MInvalidParam.WithMessage("Name is required")) } else if nameLen := len(ptr.Val(params.Name).Name); nameLen > 0 && nameLen < typeSpec.Name.MinLength { return nil, bridgev2.RespError(mautrix.MInvalidParam.WithMessage("Name must be at least %d characters", typeSpec.Name.MinLength)) - } else if nameLen > typeSpec.Name.MaxLength { + } else if typeSpec.Name.MaxLength > 0 && nameLen > typeSpec.Name.MaxLength { return nil, bridgev2.RespError(mautrix.MInvalidParam.WithMessage("Name must be at most %d characters", typeSpec.Name.MaxLength)) } if (params.Avatar == nil || params.Avatar.URL == "") && typeSpec.Avatar.Required { @@ -60,7 +75,7 @@ func CreateGroup(ctx context.Context, login *bridgev2.UserLogin, params *bridgev return nil, bridgev2.RespError(mautrix.MInvalidParam.WithMessage("Topic is required")) } else if topicLen := len(ptr.Val(params.Topic).Topic); topicLen > 0 && topicLen < typeSpec.Topic.MinLength { return nil, bridgev2.RespError(mautrix.MInvalidParam.WithMessage("Topic must be at least %d characters", typeSpec.Topic.MinLength)) - } else if topicLen > typeSpec.Topic.MaxLength { + } else if typeSpec.Topic.MaxLength > 0 && topicLen > typeSpec.Topic.MaxLength { return nil, bridgev2.RespError(mautrix.MInvalidParam.WithMessage("Topic must be at most %d characters", typeSpec.Topic.MaxLength)) } if (params.Disappear == nil || params.Disappear.Timer.Duration == 0) && typeSpec.Disappear.Required { @@ -72,7 +87,7 @@ func CreateGroup(ctx context.Context, login *bridgev2.UserLogin, params *bridgev return nil, bridgev2.RespError(mautrix.MInvalidParam.WithMessage("Username is required")) } else if len(params.Username) > 0 && len(params.Username) < typeSpec.Username.MinLength { return nil, bridgev2.RespError(mautrix.MInvalidParam.WithMessage("Username must be at least %d characters", typeSpec.Username.MinLength)) - } else if len(params.Username) > typeSpec.Username.MaxLength { + } else if typeSpec.Username.MaxLength > 0 && len(params.Username) > typeSpec.Username.MaxLength { return nil, bridgev2.RespError(mautrix.MInvalidParam.WithMessage("Username must be at most %d characters", typeSpec.Username.MaxLength)) } if params.Parent == nil && typeSpec.Parent.Required { @@ -86,6 +101,9 @@ func CreateGroup(ctx context.Context, login *bridgev2.UserLogin, params *bridgev if resp.PortalKey.IsEmpty() { return nil, ErrNoPortalKey } + zerolog.Ctx(ctx).Debug(). + Object("portal_key", resp.PortalKey). + Msg("Successfully created group on remote network") if resp.Portal == nil { resp.Portal, err = login.Bridge.GetPortalByKey(ctx, resp.PortalKey) if err != nil { @@ -100,9 +118,32 @@ func CreateGroup(ctx context.Context, login *bridgev2.UserLogin, params *bridgev return nil, bridgev2.RespError(mautrix.MUnknown.WithMessage("Failed to create portal room")) } } + for key, fp := range resp.FailedParticipants { + if fp.InviteEventType == "" { + fp.InviteEventType = event.EventMessage.Type + } + if fp.UserMXID == "" { + ghost, err := login.Bridge.GetGhostByID(ctx, key) + if err != nil { + zerolog.Ctx(ctx).Err(err).Msg("Failed to get ghost for failed participant") + } else if ghost != nil { + fp.UserMXID = ghost.Intent.GetMXID() + } + } + if fp.DMRoomMXID == "" { + portal, err := login.Bridge.GetDMPortal(ctx, login.ID, key) + if err != nil { + zerolog.Ctx(ctx).Err(err).Msg("Failed to get DM portal for failed participant") + } else if portal != nil { + fp.DMRoomMXID = portal.MXID + } + } + } return &RespCreateGroup{ ID: resp.Portal.ID, MXID: resp.Portal.MXID, Portal: resp.Portal, + + FailedParticipants: resp.FailedParticipants, }, nil } diff --git a/bridgev2/provisionutil/resolveidentifier.go b/bridgev2/provisionutil/resolveidentifier.go index 5387347c..cfc388d0 100644 --- a/bridgev2/provisionutil/resolveidentifier.go +++ b/bridgev2/provisionutil/resolveidentifier.go @@ -109,6 +109,7 @@ func ResolveIdentifier( return nil, bridgev2.RespError(mautrix.MUnknown.WithMessage("Failed to get portal")) } } + resp.Chat.Portal.CleanupOrphanedDM(ctx, login.UserMXID) if createChat && resp.Chat.Portal.MXID == "" { apiResp.JustCreated = true err := resp.Chat.Portal.CreateMatrixRoom(ctx, login, resp.Chat.PortalInfo) diff --git a/bridgev2/queue.go b/bridgev2/queue.go index 95011cda..3775c825 100644 --- a/bridgev2/queue.go +++ b/bridgev2/queue.go @@ -63,6 +63,13 @@ func (br *Bridge) rejectInviteOnNoPermission(ctx context.Context, evt *event.Eve return true } +var ( + ErrEventSenderUserNotFound = WrapErrorInStatus(errors.New("sender not found for event")).WithIsCertain(true).WithErrorAsMessage() + ErrNoPermissionToInteract = WrapErrorInStatus(errors.New("you don't have permission to send messages")).WithIsCertain(true).WithSendNotice(false).WithErrorAsMessage() + ErrNoPermissionForCommands = WrapErrorInStatus(WrapErrorInStatus(errors.New("you don't have permission to use commands")).WithIsCertain(true).WithSendNotice(false).WithErrorAsMessage()) + ErrCantRelayStateRequest = WrapErrorInStatus(errors.New("relayed users can't use beeper state requests")).WithIsCertain(true).WithErrorAsMessage() +) + func (br *Bridge) QueueMatrixEvent(ctx context.Context, evt *event.Event) EventHandlingResult { // TODO maybe HandleMatrixEvent would be more appropriate as this also handles bot invites and commands @@ -78,13 +85,11 @@ func (br *Bridge) QueueMatrixEvent(ctx context.Context, evt *event.Event) EventH return EventHandlingResultFailed } else if sender == nil { log.Error().Msg("Couldn't get sender for incoming non-ephemeral Matrix event") - status := WrapErrorInStatus(errors.New("sender not found for event")).WithIsCertain(true).WithErrorAsMessage() - br.Matrix.SendMessageStatus(ctx, &status, StatusEventInfoFromEvent(evt)) + br.Matrix.SendMessageStatus(ctx, &ErrEventSenderUserNotFound, StatusEventInfoFromEvent(evt)) return EventHandlingResultFailed } else if !sender.Permissions.SendEvents { if !br.rejectInviteOnNoPermission(ctx, evt, "interact with") { - status := WrapErrorInStatus(errors.New("you don't have permission to send messages")).WithIsCertain(true).WithSendNotice(false).WithErrorAsMessage() - br.Matrix.SendMessageStatus(ctx, &status, StatusEventInfoFromEvent(evt)) + br.Matrix.SendMessageStatus(ctx, &ErrNoPermissionToInteract, StatusEventInfoFromEvent(evt)) } return EventHandlingResultIgnored } else if !sender.Permissions.Commands && br.rejectInviteOnNoPermission(ctx, evt, "send commands to") { @@ -92,8 +97,7 @@ func (br *Bridge) QueueMatrixEvent(ctx context.Context, evt *event.Event) EventH } } else if evt.Type.Class != event.EphemeralEventType { log.Error().Msg("Missing sender for incoming non-ephemeral Matrix event") - status := WrapErrorInStatus(errors.New("sender not found for event")).WithIsCertain(true).WithErrorAsMessage() - br.Matrix.SendMessageStatus(ctx, &status, StatusEventInfoFromEvent(evt)) + br.Matrix.SendMessageStatus(ctx, &ErrEventSenderUserNotFound, StatusEventInfoFromEvent(evt)) return EventHandlingResultIgnored } if evt.Type == event.EventMessage && sender != nil { @@ -102,8 +106,7 @@ func (br *Bridge) QueueMatrixEvent(ctx context.Context, evt *event.Event) EventH msg.RemovePerMessageProfileFallback() if strings.HasPrefix(msg.Body, br.Config.CommandPrefix) || evt.RoomID == sender.ManagementRoom { if !sender.Permissions.Commands { - status := WrapErrorInStatus(errors.New("you don't have permission to use commands")).WithIsCertain(true).WithSendNotice(false).WithErrorAsMessage() - br.Matrix.SendMessageStatus(ctx, &status, StatusEventInfoFromEvent(evt)) + br.Matrix.SendMessageStatus(ctx, &ErrNoPermissionForCommands, StatusEventInfoFromEvent(evt)) return EventHandlingResultIgnored } go br.Commands.Handle( @@ -157,10 +160,27 @@ type EventHandlingResult struct { Ignored bool Queued bool + SkipStateEcho bool + // Error is an optional reason for failure. It is not required, Success may be false even without a specific error. Error error // Whether the Error should be sent as a MSS event. SendMSS bool + + // EventID from the network + EventID id.EventID + // Stream order from the network + StreamOrder int64 +} + +func (ehr EventHandlingResult) WithEventID(id id.EventID) EventHandlingResult { + ehr.EventID = id + return ehr +} + +func (ehr EventHandlingResult) WithStreamOrder(order int64) EventHandlingResult { + ehr.StreamOrder = order + return ehr } func (ehr EventHandlingResult) WithError(err error) EventHandlingResult { @@ -177,6 +197,11 @@ func (ehr EventHandlingResult) WithMSS() EventHandlingResult { return ehr } +func (ehr EventHandlingResult) WithSkipStateEcho(skip bool) EventHandlingResult { + ehr.SkipStateEcho = skip + return ehr +} + func (ehr EventHandlingResult) WithMSSError(err error) EventHandlingResult { if err == nil { return ehr @@ -195,7 +220,7 @@ func (ul *UserLogin) QueueRemoteEvent(evt RemoteEvent) EventHandlingResult { return ul.Bridge.QueueRemoteEvent(ul, evt) } -func (br *Bridge) QueueRemoteEvent(login *UserLogin, evt RemoteEvent) (res EventHandlingResult) { +func (br *Bridge) QueueRemoteEvent(login *UserLogin, evt RemoteEvent) EventHandlingResult { log := login.Log ctx := log.WithContext(br.BackgroundCtx) maybeUncertain, ok := evt.(RemoteEventWithUncertainPortalReceiver) @@ -211,14 +236,14 @@ func (br *Bridge) QueueRemoteEvent(login *UserLogin, evt RemoteEvent) (res Event if err != nil { log.Err(err).Object("portal_key", key).Bool("uncertain_receiver", isUncertain). Msg("Failed to get portal to handle remote event") - return + return EventHandlingResultFailed.WithError(fmt.Errorf("failed to get portal: %w", err)) } else if portal == nil { log.Warn(). Stringer("event_type", evt.GetType()). Object("portal_key", key). Bool("uncertain_receiver", isUncertain). Msg("Portal not found to handle remote event") - return + return EventHandlingResultFailed.WithError(ErrPortalNotFoundInEventHandler) } // TODO put this in a better place, and maybe cache to avoid constant db queries login.MarkInPortal(ctx, portal) diff --git a/bridgev2/simplevent/chat.go b/bridgev2/simplevent/chat.go index c725141b..56e3a6b1 100644 --- a/bridgev2/simplevent/chat.go +++ b/bridgev2/simplevent/chat.go @@ -65,14 +65,19 @@ func (evt *ChatResync) GetChatInfo(ctx context.Context, portal *bridgev2.Portal) type ChatDelete struct { EventMeta OnlyForMe bool + Children bool } -var _ bridgev2.RemoteChatDelete = (*ChatDelete)(nil) +var _ bridgev2.RemoteChatDeleteWithChildren = (*ChatDelete)(nil) func (evt *ChatDelete) DeleteOnlyForMe() bool { return evt.OnlyForMe } +func (evt *ChatDelete) DeleteChildren() bool { + return evt.Children +} + // ChatInfoChange is a simple implementation of [bridgev2.RemoteChatInfoChange]. type ChatInfoChange struct { EventMeta diff --git a/bridgev2/simplevent/meta.go b/bridgev2/simplevent/meta.go index 8aa91866..96c8a9c5 100644 --- a/bridgev2/simplevent/meta.go +++ b/bridgev2/simplevent/meta.go @@ -27,8 +27,9 @@ type EventMeta struct { Timestamp time.Time StreamOrder int64 - PreHandleFunc func(context.Context, *bridgev2.Portal) - PostHandleFunc func(context.Context, *bridgev2.Portal) + PreHandleFunc func(context.Context, *bridgev2.Portal) + PostHandleFunc func(context.Context, *bridgev2.Portal) + MutateContextFunc func(context.Context) context.Context } var ( @@ -39,6 +40,7 @@ var ( _ bridgev2.RemoteEventWithStreamOrder = (*EventMeta)(nil) _ bridgev2.RemotePreHandler = (*EventMeta)(nil) _ bridgev2.RemotePostHandler = (*EventMeta)(nil) + _ bridgev2.RemoteEventWithContextMutation = (*EventMeta)(nil) ) func (evt *EventMeta) AddLogContext(c zerolog.Context) zerolog.Context { @@ -91,6 +93,13 @@ func (evt *EventMeta) PostHandle(ctx context.Context, portal *bridgev2.Portal) { } } +func (evt *EventMeta) MutateContext(ctx context.Context) context.Context { + if evt.MutateContextFunc == nil { + return ctx + } + return evt.MutateContextFunc(ctx) +} + func (evt EventMeta) WithType(t bridgev2.RemoteEventType) EventMeta { evt.Type = t return evt @@ -101,6 +110,18 @@ func (evt EventMeta) WithLogContext(f func(c zerolog.Context) zerolog.Context) E return evt } +func (evt EventMeta) WithMoreLogContext(f func(c zerolog.Context) zerolog.Context) EventMeta { + origFunc := evt.LogContext + if origFunc == nil { + evt.LogContext = f + return evt + } + evt.LogContext = func(c zerolog.Context) zerolog.Context { + return f(origFunc(c)) + } + return evt +} + func (evt EventMeta) WithPortalKey(p networkid.PortalKey) EventMeta { evt.PortalKey = p return evt diff --git a/bridgev2/space.go b/bridgev2/space.go index ae9013cb..2ca2bce3 100644 --- a/bridgev2/space.go +++ b/bridgev2/space.go @@ -164,14 +164,17 @@ func (ul *UserLogin) GetSpaceRoom(ctx context.Context) (id.RoomID, error) { ul.UserMXID: 50, }, }, - RoomVersion: id.RoomV11, - Invite: []id.UserID{ul.UserMXID}, + Invite: []id.UserID{ul.UserMXID}, } if autoJoin { req.BeeperInitialMembers = []id.UserID{ul.UserMXID} // TODO remove this after initial_members is supported in hungryserv req.BeeperAutoJoinInvites = true } + pfc, ok := ul.Client.(PersonalFilteringCustomizingNetworkAPI) + if ok { + pfc.CustomizePersonalFilteringSpace(req) + } ul.SpaceRoom, err = ul.Bridge.Bot.CreateRoom(ctx, req) if err != nil { return "", fmt.Errorf("failed to create space room: %w", err) diff --git a/bridgev2/status/bridgestate.go b/bridgev2/status/bridgestate.go index 430d4c7c..5925dd4f 100644 --- a/bridgev2/status/bridgestate.go +++ b/bridgev2/status/bridgestate.go @@ -19,7 +19,6 @@ import ( "github.com/tidwall/sjson" "go.mau.fi/util/jsontime" - "go.mau.fi/util/ptr" "maunium.net/go/mautrix" "maunium.net/go/mautrix/bridgev2/networkid" @@ -112,7 +111,7 @@ func (rp *RemoteProfile) Merge(other RemoteProfile) RemoteProfile { return other } -func (rp *RemoteProfile) IsEmpty() bool { +func (rp *RemoteProfile) IsZero() bool { return rp == nil || (rp.Phone == "" && rp.Email == "" && rp.Username == "" && rp.Name == "" && rp.Avatar == "" && rp.AvatarFile == nil) } @@ -130,7 +129,7 @@ type BridgeState struct { UserID id.UserID `json:"user_id,omitempty"` RemoteID networkid.UserLoginID `json:"remote_id,omitempty"` RemoteName string `json:"remote_name,omitempty"` - RemoteProfile *RemoteProfile `json:"remote_profile,omitempty"` + RemoteProfile RemoteProfile `json:"remote_profile,omitzero"` Reason string `json:"reason,omitempty"` Info map[string]interface{} `json:"info,omitempty"` @@ -210,7 +209,7 @@ func (pong *BridgeState) ShouldDeduplicate(newPong *BridgeState) bool { pong.StateEvent == newPong.StateEvent && pong.RemoteName == newPong.RemoteName && pong.UserAction == newPong.UserAction && - ptr.Val(pong.RemoteProfile) == ptr.Val(newPong.RemoteProfile) && + pong.RemoteProfile == newPong.RemoteProfile && pong.Error == newPong.Error && maps.EqualFunc(pong.Info, newPong.Info, reflect.DeepEqual) && pong.Timestamp.Add(time.Duration(pong.TTL)*time.Second).After(time.Now()) diff --git a/bridgev2/user.go b/bridgev2/user.go index 87ced1d7..9a7896d6 100644 --- a/bridgev2/user.go +++ b/bridgev2/user.go @@ -176,6 +176,10 @@ func (user *User) GetUserLogins() []*UserLogin { return maps.Values(user.logins) } +func (user *User) HasTooManyLogins() bool { + return user.Permissions.MaxLogins > 0 && len(user.GetUserLoginIDs()) >= user.Permissions.MaxLogins +} + func (user *User) GetFormattedUserLogins() string { user.Bridge.cacheLock.Lock() logins := make([]string, len(user.logins)) @@ -225,9 +229,8 @@ func (user *User) GetManagementRoom(ctx context.Context) (id.RoomID, error) { user.MXID: 50, }, }, - RoomVersion: id.RoomV11, - Invite: []id.UserID{user.MXID}, - IsDirect: true, + Invite: []id.UserID{user.MXID}, + IsDirect: true, } if autoJoin { req.BeeperInitialMembers = []id.UserID{user.MXID} diff --git a/bridgev2/userlogin.go b/bridgev2/userlogin.go index b5fcfcd0..d56dc4cc 100644 --- a/bridgev2/userlogin.go +++ b/bridgev2/userlogin.go @@ -10,6 +10,7 @@ import ( "cmp" "context" "fmt" + "maps" "slices" "sync" "time" @@ -50,6 +51,8 @@ func (br *Bridge) loadUserLogin(ctx context.Context, user *User, dbUserLogin *da if err != nil { return nil, fmt.Errorf("failed to get user: %w", err) } + // TODO if loading the user caused the provided userlogin to be loaded, cancel here? + // Currently this will double-load it } userLogin := &UserLogin{ UserLogin: dbUserLogin, @@ -140,6 +143,12 @@ func (br *Bridge) GetCachedUserLoginByID(id networkid.UserLoginID) *UserLogin { return br.userLoginsByID[id] } +func (br *Bridge) GetAllCachedUserLogins() (logins []*UserLogin) { + br.cacheLock.Lock() + defer br.cacheLock.Unlock() + return slices.Collect(maps.Values(br.userLoginsByID)) +} + func (br *Bridge) GetCurrentBridgeStates() (states []status.BridgeState) { br.cacheLock.Lock() defer br.cacheLock.Unlock() @@ -503,7 +512,7 @@ func (ul *UserLogin) FillBridgeState(state status.BridgeState) status.BridgeStat state.UserID = ul.UserMXID state.RemoteID = ul.ID state.RemoteName = ul.RemoteName - state.RemoteProfile = &ul.RemoteProfile + state.RemoteProfile = ul.RemoteProfile filler, ok := ul.Client.(status.BridgeStateFiller) if ok { return filler.FillBridgeState(state) diff --git a/client.go b/client.go index 95cbacb5..045d7b8e 100644 --- a/client.go +++ b/client.go @@ -111,6 +111,8 @@ type Client struct { // Set to true to disable automatically sleeping on 429 errors. IgnoreRateLimit bool + ResponseSizeLimit int64 + txnID int32 // Should the ?user_id= query parameter be set in requests? @@ -143,6 +145,8 @@ func DiscoverClientAPI(ctx context.Context, serverName string) (*ClientWellKnown return DiscoverClientAPIWithClient(ctx, &http.Client{Timeout: 30 * time.Second}, serverName) } +const WellKnownMaxSize = 64 * 1024 + func DiscoverClientAPIWithClient(ctx context.Context, client *http.Client, serverName string) (*ClientWellKnown, error) { wellKnownURL := url.URL{ Scheme: "https", @@ -168,11 +172,15 @@ func DiscoverClientAPIWithClient(ctx context.Context, client *http.Client, serve if resp.StatusCode == http.StatusNotFound { return nil, nil + } else if resp.ContentLength > WellKnownMaxSize { + return nil, errors.New(".well-known response too large") } - data, err := io.ReadAll(resp.Body) + data, err := io.ReadAll(io.LimitReader(resp.Body, WellKnownMaxSize)) if err != nil { return nil, err + } else if len(data) >= WellKnownMaxSize { + return nil, errors.New(".well-known response too large") } var wellKnown ClientWellKnown @@ -378,7 +386,14 @@ func (cli *Client) LogRequestDone(req *http.Request, resp *http.Response, err er } } if body := req.Context().Value(LogBodyContextKey); body != nil { - evt.Interface("req_body", body) + switch typedLogBody := body.(type) { + case json.RawMessage: + evt.RawJSON("req_body", typedLogBody) + case string: + evt.Str("req_body", typedLogBody) + default: + panic(fmt.Errorf("invalid type for LogBodyContextKey: %T", body)) + } } if errors.Is(err, context.Canceled) { evt.Msg("Request canceled") @@ -395,24 +410,25 @@ func (cli *Client) MakeRequest(ctx context.Context, method string, httpURL strin return cli.MakeFullRequest(ctx, FullRequest{Method: method, URL: httpURL, RequestJSON: reqBody, ResponseJSON: resBody}) } -type ClientResponseHandler = func(req *http.Request, res *http.Response, responseJSON interface{}) ([]byte, error) +type ClientResponseHandler = func(req *http.Request, res *http.Response, responseJSON any, sizeLimit int64) ([]byte, error) type FullRequest struct { - Method string - URL string - Headers http.Header - RequestJSON interface{} - RequestBytes []byte - RequestBody io.Reader - RequestLength int64 - ResponseJSON interface{} - MaxAttempts int - BackoffDuration time.Duration - SensitiveContent bool - Handler ClientResponseHandler - DontReadResponse bool - Logger *zerolog.Logger - Client *http.Client + Method string + URL string + Headers http.Header + RequestJSON interface{} + RequestBytes []byte + RequestBody io.Reader + RequestLength int64 + ResponseJSON interface{} + MaxAttempts int + BackoffDuration time.Duration + SensitiveContent bool + Handler ClientResponseHandler + DontReadResponse bool + ResponseSizeLimit int64 + Logger *zerolog.Logger + Client *http.Client } var requestID int32 @@ -441,8 +457,10 @@ func (params *FullRequest) compileRequest(ctx context.Context) (*http.Request, e } if params.SensitiveContent && !logSensitiveContent { logBody = "" + } else if len(jsonStr) > 32768 { + logBody = fmt.Sprintf("", len(jsonStr)) } else { - logBody = params.RequestJSON + logBody = json.RawMessage(jsonStr) } reqBody = bytes.NewReader(jsonStr) reqLen = int64(len(jsonStr)) @@ -467,7 +485,7 @@ func (params *FullRequest) compileRequest(ctx context.Context) (*http.Request, e } } else if params.Method != http.MethodGet && params.Method != http.MethodHead { params.RequestJSON = struct{}{} - logBody = params.RequestJSON + logBody = json.RawMessage("{}") reqBody = bytes.NewReader([]byte("{}")) reqLen = 2 } @@ -537,10 +555,25 @@ func (cli *Client) MakeFullRequestWithResp(ctx context.Context, params FullReque if len(cli.AccessToken) > 0 { req.Header.Set("Authorization", "Bearer "+cli.AccessToken) } + if params.ResponseSizeLimit == 0 { + params.ResponseSizeLimit = cli.ResponseSizeLimit + } + if params.ResponseSizeLimit == 0 { + params.ResponseSizeLimit = DefaultResponseSizeLimit + } if params.Client == nil { params.Client = cli.Client } - return cli.executeCompiledRequest(req, params.MaxAttempts-1, params.BackoffDuration, params.ResponseJSON, params.Handler, params.DontReadResponse, params.Client) + return cli.executeCompiledRequest( + req, + params.MaxAttempts-1, + params.BackoffDuration, + params.ResponseJSON, + params.Handler, + params.DontReadResponse, + params.ResponseSizeLimit, + params.Client, + ) } func (cli *Client) cliOrContextLog(ctx context.Context) *zerolog.Logger { @@ -551,7 +584,17 @@ func (cli *Client) cliOrContextLog(ctx context.Context) *zerolog.Logger { return log } -func (cli *Client) doRetry(req *http.Request, cause error, retries int, backoff time.Duration, responseJSON any, handler ClientResponseHandler, dontReadResponse bool, client *http.Client) ([]byte, *http.Response, error) { +func (cli *Client) doRetry( + req *http.Request, + cause error, + retries int, + backoff time.Duration, + responseJSON any, + handler ClientResponseHandler, + dontReadResponse bool, + sizeLimit int64, + client *http.Client, +) ([]byte, *http.Response, error) { log := zerolog.Ctx(req.Context()) if req.Body != nil { var err error @@ -580,16 +623,30 @@ func (cli *Client) doRetry(req *http.Request, cause error, retries int, backoff select { case <-time.After(backoff): case <-req.Context().Done(): - return nil, nil, req.Context().Err() + if !errors.Is(context.Cause(req.Context()), ErrContextCancelRetry) { + return nil, nil, req.Context().Err() + } } if cli.UpdateRequestOnRetry != nil { req = cli.UpdateRequestOnRetry(req, cause) } - return cli.executeCompiledRequest(req, retries-1, backoff*2, responseJSON, handler, dontReadResponse, client) + return cli.executeCompiledRequest(req, retries-1, backoff*2, responseJSON, handler, dontReadResponse, sizeLimit, client) } -func readResponseBody(req *http.Request, res *http.Response) ([]byte, error) { - contents, err := io.ReadAll(res.Body) +func readResponseBody(req *http.Request, res *http.Response, limit int64) ([]byte, error) { + if res.ContentLength > limit { + return nil, HTTPError{ + Request: req, + Response: res, + + Message: "not reading response", + WrappedError: fmt.Errorf("%w (%.2f MiB)", ErrResponseTooLong, float64(res.ContentLength)/1024/1024), + } + } + contents, err := io.ReadAll(io.LimitReader(res.Body, limit+1)) + if err == nil && len(contents) > int(limit) { + err = ErrBodyReadReachedLimit + } if err != nil { return nil, HTTPError{ Request: req, @@ -610,17 +667,20 @@ func closeTemp(log *zerolog.Logger, file *os.File) { } } -func streamResponse(req *http.Request, res *http.Response, responseJSON interface{}) ([]byte, error) { +func streamResponse(req *http.Request, res *http.Response, responseJSON any, limit int64) ([]byte, error) { log := zerolog.Ctx(req.Context()) file, err := os.CreateTemp("", "mautrix-response-") if err != nil { log.Warn().Err(err).Msg("Failed to create temporary file for streaming response") - _, err = handleNormalResponse(req, res, responseJSON) + _, err = handleNormalResponse(req, res, responseJSON, limit) return nil, err } defer closeTemp(log, file) - if _, err = io.Copy(file, res.Body); err != nil { + var n int64 + if n, err = io.Copy(file, io.LimitReader(res.Body, limit+1)); err != nil { return nil, fmt.Errorf("failed to copy response to file: %w", err) + } else if n > limit { + return nil, ErrBodyReadReachedLimit } else if _, err = file.Seek(0, 0); err != nil { return nil, fmt.Errorf("failed to seek to beginning of response file: %w", err) } else if err = json.NewDecoder(file).Decode(responseJSON); err != nil { @@ -630,12 +690,12 @@ func streamResponse(req *http.Request, res *http.Response, responseJSON interfac } } -func noopHandleResponse(req *http.Request, res *http.Response, responseJSON interface{}) ([]byte, error) { +func noopHandleResponse(req *http.Request, res *http.Response, responseJSON any, limit int64) ([]byte, error) { return nil, nil } -func handleNormalResponse(req *http.Request, res *http.Response, responseJSON interface{}) ([]byte, error) { - if contents, err := readResponseBody(req, res); err != nil { +func handleNormalResponse(req *http.Request, res *http.Response, responseJSON any, limit int64) ([]byte, error) { + if contents, err := readResponseBody(req, res, limit); err != nil { return nil, err } else if responseJSON == nil { return contents, nil @@ -653,8 +713,13 @@ func handleNormalResponse(req *http.Request, res *http.Response, responseJSON in } } +const ErrorResponseSizeLimit = 512 * 1024 + +var DefaultResponseSizeLimit int64 = 512 * 1024 * 1024 + func ParseErrorResponse(req *http.Request, res *http.Response) ([]byte, error) { - contents, err := readResponseBody(req, res) + defer res.Body.Close() + contents, err := readResponseBody(req, res, ErrorResponseSizeLimit) if err != nil { return contents, err } @@ -673,17 +738,31 @@ func ParseErrorResponse(req *http.Request, res *http.Response) ([]byte, error) { } } -func (cli *Client) executeCompiledRequest(req *http.Request, retries int, backoff time.Duration, responseJSON any, handler ClientResponseHandler, dontReadResponse bool, client *http.Client) ([]byte, *http.Response, error) { +func (cli *Client) executeCompiledRequest( + req *http.Request, + retries int, + backoff time.Duration, + responseJSON any, + handler ClientResponseHandler, + dontReadResponse bool, + sizeLimit int64, + client *http.Client, +) ([]byte, *http.Response, error) { cli.RequestStart(req) startTime := time.Now() res, err := client.Do(req) - duration := time.Now().Sub(startTime) + duration := time.Since(startTime) if res != nil && !dontReadResponse { defer res.Body.Close() } if err != nil { - if retries > 0 && !errors.Is(err, context.Canceled) { - return cli.doRetry(req, err, retries, backoff, responseJSON, handler, dontReadResponse, client) + // Either error is *not* canceled or the underlying cause of cancelation explicitly asks to retry + canRetry := !errors.Is(err, context.Canceled) || + errors.Is(context.Cause(req.Context()), ErrContextCancelRetry) + if retries > 0 && canRetry { + return cli.doRetry( + req, err, retries, backoff, responseJSON, handler, dontReadResponse, sizeLimit, client, + ) } err = HTTPError{ Request: req, @@ -698,7 +777,9 @@ func (cli *Client) executeCompiledRequest(req *http.Request, retries int, backof if retries > 0 && retryafter.Should(res.StatusCode, !cli.IgnoreRateLimit) { backoff = retryafter.Parse(res.Header.Get("Retry-After"), backoff) - return cli.doRetry(req, fmt.Errorf("HTTP %d", res.StatusCode), retries, backoff, responseJSON, handler, dontReadResponse, client) + return cli.doRetry( + req, fmt.Errorf("HTTP %d", res.StatusCode), retries, backoff, responseJSON, handler, dontReadResponse, sizeLimit, client, + ) } var body []byte @@ -706,7 +787,7 @@ func (cli *Client) executeCompiledRequest(req *http.Request, retries int, backof body, err = ParseErrorResponse(req, res) cli.LogRequestDone(req, res, nil, nil, len(body), duration) } else { - body, err = handler(req, res, responseJSON) + body, err = handler(req, res, responseJSON, sizeLimit) cli.LogRequestDone(req, res, nil, err, len(body), duration) } return body, res, err @@ -790,7 +871,7 @@ func (cli *Client) FullSyncRequest(ctx context.Context, req ReqSync) (resp *Resp } start := time.Now() _, err = cli.MakeFullRequest(ctx, fullReq) - duration := time.Now().Sub(start) + duration := time.Since(start) timeout := time.Duration(req.Timeout) * time.Millisecond buffer := 10 * time.Second if req.Since == "" { @@ -837,7 +918,7 @@ func (cli *Client) RegisterAvailable(ctx context.Context, username string) (resp return } -func (cli *Client) register(ctx context.Context, url string, req *ReqRegister) (resp *RespRegister, uiaResp *RespUserInteractive, err error) { +func (cli *Client) register(ctx context.Context, url string, req *ReqRegister[any]) (resp *RespRegister, uiaResp *RespUserInteractive, err error) { var bodyBytes []byte bodyBytes, err = cli.MakeFullRequest(ctx, FullRequest{ Method: http.MethodPost, @@ -861,7 +942,7 @@ func (cli *Client) register(ctx context.Context, url string, req *ReqRegister) ( // Register makes an HTTP request according to https://spec.matrix.org/v1.2/client-server-api/#post_matrixclientv3register // // Registers with kind=user. For kind=guest, see RegisterGuest. -func (cli *Client) Register(ctx context.Context, req *ReqRegister) (*RespRegister, *RespUserInteractive, error) { +func (cli *Client) Register(ctx context.Context, req *ReqRegister[any]) (*RespRegister, *RespUserInteractive, error) { u := cli.BuildClientURL("v3", "register") return cli.register(ctx, u, req) } @@ -870,7 +951,7 @@ func (cli *Client) Register(ctx context.Context, req *ReqRegister) (*RespRegiste // with kind=guest. // // For kind=user, see Register. -func (cli *Client) RegisterGuest(ctx context.Context, req *ReqRegister) (*RespRegister, *RespUserInteractive, error) { +func (cli *Client) RegisterGuest(ctx context.Context, req *ReqRegister[any]) (*RespRegister, *RespUserInteractive, error) { query := map[string]string{ "kind": "guest", } @@ -893,8 +974,8 @@ func (cli *Client) RegisterGuest(ctx context.Context, req *ReqRegister) (*RespRe // panic(err) // } // token := res.AccessToken -func (cli *Client) RegisterDummy(ctx context.Context, req *ReqRegister) (*RespRegister, error) { - res, uia, err := cli.Register(ctx, req) +func (cli *Client) RegisterDummy(ctx context.Context, req *ReqRegister[any]) (*RespRegister, error) { + _, uia, err := cli.Register(ctx, req) if err != nil && uia == nil { return nil, err } else if uia == nil { @@ -903,7 +984,7 @@ func (cli *Client) RegisterDummy(ctx context.Context, req *ReqRegister) (*RespRe return nil, errors.New("server does not support m.login.dummy") } req.Auth = BaseAuthData{Type: AuthTypeDummy, Session: uia.Session} - res, _, err = cli.Register(ctx, req) + res, _, err := cli.Register(ctx, req) if err != nil { return nil, err } @@ -1077,7 +1158,9 @@ func (cli *Client) SearchUserDirectory(ctx context.Context, query string, limit } func (cli *Client) GetMutualRooms(ctx context.Context, otherUserID id.UserID, extras ...ReqMutualRooms) (resp *RespMutualRooms, err error) { - if cli.SpecVersions != nil && !cli.SpecVersions.Supports(FeatureMutualRooms) { + supportsStable := cli.SpecVersions.Supports(FeatureStableMutualRooms) + supportsUnstable := cli.SpecVersions.Supports(FeatureUnstableMutualRooms) + if cli.SpecVersions != nil && !supportsUnstable && !supportsStable { err = fmt.Errorf("server does not support fetching mutual rooms") return } @@ -1087,7 +1170,10 @@ func (cli *Client) GetMutualRooms(ctx context.Context, otherUserID id.UserID, ex if len(extras) > 0 { query["from"] = extras[0].From } - urlPath := cli.BuildURLWithQuery(ClientURLPath{"unstable", "uk.half-shot.msc2666", "user", "mutual_rooms"}, query) + urlPath := cli.BuildURLWithQuery(ClientURLPath{"v1", "mutual_rooms"}, query) + if !supportsStable && supportsUnstable { + urlPath = cli.BuildURLWithQuery(ClientURLPath{"unstable", "uk.half-shot.msc2666", "user", "mutual_rooms"}, query) + } _, err = cli.MakeRequest(ctx, http.MethodGet, urlPath, nil, &resp) return } @@ -1252,6 +1338,9 @@ func (cli *Client) SendMessageEvent(ctx context.Context, roomID id.RoomID, event if req.UnstableDelay > 0 { queryParams["org.matrix.msc4140.delay"] = strconv.FormatInt(req.UnstableDelay.Milliseconds(), 10) } + if req.UnstableStickyDuration > 0 { + queryParams["org.matrix.msc4354.sticky_duration_ms"] = strconv.FormatInt(req.UnstableStickyDuration.Milliseconds(), 10) + } if !req.DontEncrypt && cli != nil && cli.Crypto != nil && eventType != event.EventReaction && eventType != event.EventEncrypted { var isEncrypted bool @@ -1275,9 +1364,51 @@ func (cli *Client) SendMessageEvent(ctx context.Context, roomID id.RoomID, event return } -// SendStateEvent sends a state event into a room. See https://spec.matrix.org/v1.2/client-server-api/#put_matrixclientv3roomsroomidstateeventtypestatekey +// BeeperSendEphemeralEvent sends an ephemeral event into a room using Beeper's unstable endpoint. +// contentJSON should be a value that can be encoded as JSON using json.Marshal. +func (cli *Client) BeeperSendEphemeralEvent(ctx context.Context, roomID id.RoomID, eventType event.Type, contentJSON any, extra ...ReqSendEvent) (resp *RespSendEvent, err error) { + var req ReqSendEvent + if len(extra) > 0 { + req = extra[0] + } + + var txnID string + if len(req.TransactionID) > 0 { + txnID = req.TransactionID + } else { + txnID = cli.TxnID() + } + + queryParams := map[string]string{} + if req.Timestamp > 0 { + queryParams["ts"] = strconv.FormatInt(req.Timestamp, 10) + } + + if !req.DontEncrypt && cli != nil && cli.Crypto != nil && eventType != event.EventEncrypted { + var isEncrypted bool + isEncrypted, err = cli.StateStore.IsEncrypted(ctx, roomID) + if err != nil { + err = fmt.Errorf("failed to check if room is encrypted: %w", err) + return + } + if isEncrypted { + if contentJSON, err = cli.Crypto.Encrypt(ctx, roomID, eventType, contentJSON); err != nil { + err = fmt.Errorf("failed to encrypt event: %w", err) + return + } + eventType = event.EventEncrypted + } + } + + urlData := ClientURLPath{"unstable", "com.beeper.ephemeral", "rooms", roomID, "ephemeral", eventType.String(), txnID} + urlPath := cli.BuildURLWithQuery(urlData, queryParams) + _, err = cli.MakeRequest(ctx, http.MethodPut, urlPath, contentJSON, &resp) + return +} + +// SendStateEvent sends a state event into a room. See https://spec.matrix.org/v1.16/client-server-api/#put_matrixclientv3roomsroomidstateeventtypestatekey // contentJSON should be a pointer to something that can be encoded as JSON using json.Marshal. -func (cli *Client) SendStateEvent(ctx context.Context, roomID id.RoomID, eventType event.Type, stateKey string, contentJSON interface{}, extra ...ReqSendEvent) (resp *RespSendEvent, err error) { +func (cli *Client) SendStateEvent(ctx context.Context, roomID id.RoomID, eventType event.Type, stateKey string, contentJSON any, extra ...ReqSendEvent) (resp *RespSendEvent, err error) { var req ReqSendEvent if len(extra) > 0 { req = extra[0] @@ -1287,9 +1418,18 @@ func (cli *Client) SendStateEvent(ctx context.Context, roomID id.RoomID, eventTy if req.MeowEventID != "" { queryParams["fi.mau.event_id"] = req.MeowEventID.String() } + if req.TransactionID != "" { + queryParams["fi.mau.transaction_id"] = req.TransactionID + } if req.UnstableDelay > 0 { queryParams["org.matrix.msc4140.delay"] = strconv.FormatInt(req.UnstableDelay.Milliseconds(), 10) } + if req.UnstableStickyDuration > 0 { + queryParams["org.matrix.msc4354.sticky_duration_ms"] = strconv.FormatInt(req.UnstableStickyDuration.Milliseconds(), 10) + } + if req.Timestamp > 0 { + queryParams["ts"] = strconv.FormatInt(req.Timestamp, 10) + } urlData := ClientURLPath{"v3", "rooms", roomID, "state", eventType.String(), stateKey} urlPath := cli.BuildURLWithQuery(urlData, queryParams) @@ -1302,14 +1442,12 @@ func (cli *Client) SendStateEvent(ctx context.Context, roomID id.RoomID, eventTy // SendMassagedStateEvent sends a state event into a room with a custom timestamp. See https://spec.matrix.org/v1.2/client-server-api/#put_matrixclientv3roomsroomidstateeventtypestatekey // contentJSON should be a pointer to something that can be encoded as JSON using json.Marshal. +// +// Deprecated: SendStateEvent accepts a timestamp via ReqSendEvent and should be used instead. func (cli *Client) SendMassagedStateEvent(ctx context.Context, roomID id.RoomID, eventType event.Type, stateKey string, contentJSON interface{}, ts int64) (resp *RespSendEvent, err error) { - urlPath := cli.BuildURLWithQuery(ClientURLPath{"v3", "rooms", roomID, "state", eventType.String(), stateKey}, map[string]string{ - "ts": strconv.FormatInt(ts, 10), + resp, err = cli.SendStateEvent(ctx, roomID, eventType, stateKey, contentJSON, ReqSendEvent{ + Timestamp: ts, }) - _, err = cli.MakeRequest(ctx, http.MethodPut, urlPath, contentJSON, &resp) - if err == nil && cli.StateStore != nil { - cli.updateStoreWithOutgoingEvent(ctx, roomID, eventType, stateKey, contentJSON) - } return } @@ -1628,11 +1766,20 @@ func (cli *Client) FullStateEvent(ctx context.Context, roomID id.RoomID, eventTy } // parseRoomStateArray parses a JSON array as a stream and stores the events inside it in a room state map. -func parseRoomStateArray(_ *http.Request, res *http.Response, responseJSON interface{}) ([]byte, error) { +func parseRoomStateArray(req *http.Request, res *http.Response, responseJSON any, limit int64) ([]byte, error) { + if res.ContentLength > limit { + return nil, HTTPError{ + Request: req, + Response: res, + + Message: "not reading response", + WrappedError: fmt.Errorf("%w (%.2f MiB)", ErrResponseTooLong, float64(res.ContentLength)/1024/1024), + } + } response := make(RoomStateMap) responsePtr := responseJSON.(*map[event.Type]map[string]*event.Event) *responsePtr = response - dec := json.NewDecoder(res.Body) + dec := json.NewDecoder(io.LimitReader(res.Body, limit)) arrayStart, err := dec.Token() if err != nil { @@ -1666,6 +1813,8 @@ func parseRoomStateArray(_ *http.Request, res *http.Response, responseJSON inter return nil, nil } +type RoomStateMap = map[event.Type]map[string]*event.Event + // State gets all state in a room. // See https://spec.matrix.org/v1.2/client-server-api/#get_matrixclientv3roomsroomidstate func (cli *Client) State(ctx context.Context, roomID id.RoomID) (stateMap RoomStateMap, err error) { @@ -1748,6 +1897,9 @@ func (cli *Client) UploadLink(ctx context.Context, link string) (*RespMediaUploa } func (cli *Client) Download(ctx context.Context, mxcURL id.ContentURI) (*http.Response, error) { + if mxcURL.IsEmpty() { + return nil, fmt.Errorf("empty mxc uri provided to Download") + } _, resp, err := cli.MakeFullRequestWithResp(ctx, FullRequest{ Method: http.MethodGet, URL: cli.BuildClientURL("v1", "media", "download", mxcURL.Homeserver, mxcURL.FileID), @@ -1762,6 +1914,9 @@ type DownloadThumbnailExtra struct { } func (cli *Client) DownloadThumbnail(ctx context.Context, mxcURL id.ContentURI, height, width int, extras ...DownloadThumbnailExtra) (*http.Response, error) { + if mxcURL.IsEmpty() { + return nil, fmt.Errorf("empty mxc uri provided to DownloadThumbnail") + } if len(extras) > 1 { panic(fmt.Errorf("invalid number of arguments to DownloadThumbnail: %d", len(extras))) } @@ -1834,10 +1989,15 @@ func (cli *Client) UploadAsync(ctx context.Context, req ReqUploadMedia) (*RespCr } req.MXC = resp.ContentURI req.UnstableUploadURL = resp.UnstableUploadURL + if req.AsyncContext == nil { + req.AsyncContext = cli.cliOrContextLog(ctx).WithContext(context.Background()) + } go func() { - _, err = cli.UploadMedia(ctx, req) + _, err = cli.UploadMedia(req.AsyncContext, req) if err != nil { - cli.Log.Error().Stringer("mxc", req.MXC).Err(err).Msg("Async upload of media failed") + zerolog.Ctx(req.AsyncContext).Err(err). + Stringer("mxc", req.MXC). + Msg("Async upload of media failed") } }() return resp, nil @@ -1873,6 +2033,7 @@ type ReqUploadMedia struct { ContentType string FileName string + AsyncContext context.Context DoneCallback func() // MXC specifies an existing MXC URI which doesn't have content yet to upload into. @@ -1885,7 +2046,10 @@ type ReqUploadMedia struct { } func (cli *Client) tryUploadMediaToURL(ctx context.Context, url, contentType string, content io.Reader, contentLength int64) (*http.Response, error) { - cli.Log.Debug().Str("url", url).Msg("Uploading media to external URL") + cli.Log.Debug(). + Str("url", url). + Int64("content_length", contentLength). + Msg("Uploading media to external URL") req, err := http.NewRequestWithContext(ctx, http.MethodPut, url, content) if err != nil { return nil, err @@ -1934,8 +2098,16 @@ func (cli *Client) uploadMediaToURL(ctx context.Context, data ReqUploadMedia) (* Msg("Error uploading media to external URL, not retrying") return nil, err } - cli.Log.Warn().Str("url", data.UnstableUploadURL).Err(err). + backoff := time.Second * time.Duration(cli.DefaultHTTPRetries-retries) + cli.Log.Warn().Err(err). + Str("url", data.UnstableUploadURL). + Int("retry_in_seconds", int(backoff.Seconds())). Msg("Error uploading media to external URL, retrying") + select { + case <-time.After(backoff): + case <-ctx.Done(): + return nil, ctx.Err() + } retries-- _, err = readerSeeker.Seek(0, io.SeekStart) if err != nil { @@ -2515,13 +2687,13 @@ func (cli *Client) SetDeviceInfo(ctx context.Context, deviceID id.DeviceID, req return err } -func (cli *Client) DeleteDevice(ctx context.Context, deviceID id.DeviceID, req *ReqDeleteDevice) error { +func (cli *Client) DeleteDevice(ctx context.Context, deviceID id.DeviceID, req *ReqDeleteDevice[any]) error { urlPath := cli.BuildClientURL("v3", "devices", deviceID) _, err := cli.MakeRequest(ctx, http.MethodDelete, urlPath, req, nil) return err } -func (cli *Client) DeleteDevices(ctx context.Context, req *ReqDeleteDevices) error { +func (cli *Client) DeleteDevices(ctx context.Context, req *ReqDeleteDevices[any]) error { urlPath := cli.BuildClientURL("v3", "delete_devices") _, err := cli.MakeRequest(ctx, http.MethodPost, urlPath, req, nil) return err @@ -2532,7 +2704,7 @@ type UIACallback = func(*RespUserInteractive) interface{} // UploadCrossSigningKeys uploads the given cross-signing keys to the server. // Because the endpoint requires user-interactive authentication a callback must be provided that, // given the UI auth parameters, produces the required result (or nil to end the flow). -func (cli *Client) UploadCrossSigningKeys(ctx context.Context, keys *UploadCrossSigningKeysReq, uiaCallback UIACallback) error { +func (cli *Client) UploadCrossSigningKeys(ctx context.Context, keys *UploadCrossSigningKeysReq[any], uiaCallback UIACallback) error { content, err := cli.MakeFullRequest(ctx, FullRequest{ Method: http.MethodPost, URL: cli.BuildClientURL("v3", "keys", "device_signing", "upload"), @@ -2614,30 +2786,60 @@ func (cli *Client) ReportRoom(ctx context.Context, roomID id.RoomID, reason stri return err } -// UnstableGetSuspendedStatus uses MSC4323 to check if a user is suspended. -func (cli *Client) UnstableGetSuspendedStatus(ctx context.Context, userID id.UserID) (res *RespSuspended, err error) { - urlPath := cli.BuildClientURL("unstable", "uk.timedout.msc4323", "admin", "suspend", userID) +// AdminWhoIs fetches session information belonging to a specific user. Typically requires being a server admin. +// +// https://spec.matrix.org/v1.15/client-server-api/#get_matrixclientv3adminwhoisuserid +func (cli *Client) AdminWhoIs(ctx context.Context, userID id.UserID) (resp RespWhoIs, err error) { + urlPath := cli.BuildClientURL("v3", "admin", "whois", userID) + _, err = cli.MakeRequest(ctx, http.MethodGet, urlPath, nil, &resp) + return +} + +func (cli *Client) makeMSC4323URL(action string, target id.UserID) string { + if cli.SpecVersions.Supports(FeatureUnstableAccountModeration) { + return cli.BuildClientURL("unstable", "uk.timedout.msc4323", "admin", action, target) + } else if cli.SpecVersions.Supports(FeatureStableAccountModeration) { + return cli.BuildClientURL("v1", "admin", action, target) + } + return "" +} + +// GetSuspendedStatus uses MSC4323 to check if a user is suspended. +func (cli *Client) GetSuspendedStatus(ctx context.Context, userID id.UserID) (res *RespSuspended, err error) { + urlPath := cli.makeMSC4323URL("suspend", userID) + if urlPath == "" { + return nil, MUnrecognized.WithMessage("Homeserver does not advertise MSC4323 support") + } _, err = cli.MakeRequest(ctx, http.MethodGet, urlPath, nil, res) return } -// UnstableGetLockStatus uses MSC4323 to check if a user is locked. -func (cli *Client) UnstableGetLockStatus(ctx context.Context, userID id.UserID) (res *RespLocked, err error) { - urlPath := cli.BuildClientURL("unstable", "uk.timedout.msc4323", "admin", "lock", userID) +// GetLockStatus uses MSC4323 to check if a user is locked. +func (cli *Client) GetLockStatus(ctx context.Context, userID id.UserID) (res *RespLocked, err error) { + urlPath := cli.makeMSC4323URL("lock", userID) + if urlPath == "" { + return nil, MUnrecognized.WithMessage("Homeserver does not advertise MSC4323 support") + } _, err = cli.MakeRequest(ctx, http.MethodGet, urlPath, nil, res) return } -// UnstableSetSuspendedStatus uses MSC4323 to set whether a user account is suspended. -func (cli *Client) UnstableSetSuspendedStatus(ctx context.Context, userID id.UserID, suspended bool) (res *RespSuspended, err error) { - urlPath := cli.BuildClientURL("unstable", "uk.timedout.msc4323", "admin", "suspend", userID) +// SetSuspendedStatus uses MSC4323 to set whether a user account is suspended. +func (cli *Client) SetSuspendedStatus(ctx context.Context, userID id.UserID, suspended bool) (res *RespSuspended, err error) { + urlPath := cli.makeMSC4323URL("suspend", userID) + if urlPath == "" { + return nil, MUnrecognized.WithMessage("Homeserver does not advertise MSC4323 support") + } _, err = cli.MakeRequest(ctx, http.MethodPut, urlPath, &ReqSuspend{Suspended: suspended}, res) return } -// UnstableSetLockStatus uses MSC4323 to set whether a user account is locked. -func (cli *Client) UnstableSetLockStatus(ctx context.Context, userID id.UserID, locked bool) (res *RespLocked, err error) { - urlPath := cli.BuildClientURL("unstable", "uk.timedout.msc4323", "admin", "lock", userID) +// SetLockStatus uses MSC4323 to set whether a user account is locked. +func (cli *Client) SetLockStatus(ctx context.Context, userID id.UserID, locked bool) (res *RespLocked, err error) { + urlPath := cli.makeMSC4323URL("lock", userID) + if urlPath == "" { + return nil, MUnrecognized.WithMessage("Homeserver does not advertise MSC4323 support") + } _, err = cli.MakeRequest(ctx, http.MethodPut, urlPath, &ReqLocked{Locked: locked}, res) return } diff --git a/client_ephemeral_test.go b/client_ephemeral_test.go new file mode 100644 index 00000000..c2846427 --- /dev/null +++ b/client_ephemeral_test.go @@ -0,0 +1,158 @@ +// Copyright (c) 2026 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package mautrix_test + +import ( + "context" + "encoding/json" + "errors" + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "maunium.net/go/mautrix" + "maunium.net/go/mautrix/event" + "maunium.net/go/mautrix/id" +) + +func TestClient_SendEphemeralEvent_UsesUnstablePathTxnAndTS(t *testing.T) { + roomID := id.RoomID("!room:example.com") + evtType := event.Type{Type: "com.example.ephemeral", Class: event.EphemeralEventType} + txnID := "txn-123" + + var gotPath string + var gotQueryTS string + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + gotPath = r.URL.Path + gotQueryTS = r.URL.Query().Get("ts") + assert.Equal(t, http.MethodPut, r.Method) + _, _ = w.Write([]byte(`{"event_id":"$evt"}`)) + })) + defer ts.Close() + + cli, err := mautrix.NewClient(ts.URL, "", "") + require.NoError(t, err) + + _, err = cli.BeeperSendEphemeralEvent( + context.Background(), + roomID, + evtType, + map[string]any{"foo": "bar"}, + mautrix.ReqSendEvent{TransactionID: txnID, Timestamp: 1234}, + ) + require.NoError(t, err) + + assert.True(t, strings.Contains(gotPath, "/_matrix/client/unstable/com.beeper.ephemeral/rooms/")) + assert.True(t, strings.HasSuffix(gotPath, "/ephemeral/com.example.ephemeral/"+txnID)) + assert.Equal(t, "1234", gotQueryTS) +} + +func TestClient_SendEphemeralEvent_UnsupportedReturnsMUnrecognized(t *testing.T) { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusNotFound) + _, _ = w.Write([]byte(`{"errcode":"M_UNRECOGNIZED","error":"Unrecognized endpoint"}`)) + })) + defer ts.Close() + + cli, err := mautrix.NewClient(ts.URL, "", "") + require.NoError(t, err) + + _, err = cli.BeeperSendEphemeralEvent( + context.Background(), + id.RoomID("!room:example.com"), + event.Type{Type: "com.example.ephemeral", Class: event.EphemeralEventType}, + map[string]any{"foo": "bar"}, + ) + require.Error(t, err) + assert.True(t, errors.Is(err, mautrix.MUnrecognized)) +} + +func TestClient_SendEphemeralEvent_EncryptsInEncryptedRooms(t *testing.T) { + roomID := id.RoomID("!room:example.com") + evtType := event.Type{Type: "com.example.ephemeral", Class: event.EphemeralEventType} + txnID := "txn-encrypted" + + stateStore := mautrix.NewMemoryStateStore() + err := stateStore.SetEncryptionEvent(context.Background(), roomID, &event.EncryptionEventContent{ + Algorithm: id.AlgorithmMegolmV1, + }) + require.NoError(t, err) + + fakeCrypto := &fakeCryptoHelper{ + encryptedContent: &event.EncryptedEventContent{ + Algorithm: id.AlgorithmMegolmV1, + MegolmCiphertext: []byte("ciphertext"), + }, + } + + var gotPath string + var gotBody map[string]any + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + gotPath = r.URL.Path + assert.Equal(t, http.MethodPut, r.Method) + err := json.NewDecoder(r.Body).Decode(&gotBody) + require.NoError(t, err) + _, _ = w.Write([]byte(`{"event_id":"$evt"}`)) + })) + defer ts.Close() + + cli, err := mautrix.NewClient(ts.URL, "", "") + require.NoError(t, err) + cli.StateStore = stateStore + cli.Crypto = fakeCrypto + + _, err = cli.BeeperSendEphemeralEvent( + context.Background(), + roomID, + evtType, + map[string]any{"foo": "bar"}, + mautrix.ReqSendEvent{TransactionID: txnID}, + ) + require.NoError(t, err) + + assert.True(t, strings.HasSuffix(gotPath, "/ephemeral/m.room.encrypted/"+txnID)) + assert.Equal(t, string(id.AlgorithmMegolmV1), gotBody["algorithm"]) + assert.Equal(t, 1, fakeCrypto.encryptCalls) + assert.Equal(t, roomID, fakeCrypto.lastRoomID) + assert.Equal(t, evtType, fakeCrypto.lastEventType) +} + +type fakeCryptoHelper struct { + encryptCalls int + lastRoomID id.RoomID + lastEventType event.Type + lastEncryptInput any + encryptedContent *event.EncryptedEventContent +} + +func (f *fakeCryptoHelper) Encrypt(_ context.Context, roomID id.RoomID, eventType event.Type, content any) (*event.EncryptedEventContent, error) { + f.encryptCalls++ + f.lastRoomID = roomID + f.lastEventType = eventType + f.lastEncryptInput = content + return f.encryptedContent, nil +} + +func (f *fakeCryptoHelper) Decrypt(context.Context, *event.Event) (*event.Event, error) { + return nil, nil +} + +func (f *fakeCryptoHelper) WaitForSession(context.Context, id.RoomID, id.SenderKey, id.SessionID, time.Duration) bool { + return false +} + +func (f *fakeCryptoHelper) RequestSession(context.Context, id.RoomID, id.SenderKey, id.SessionID, id.UserID, id.DeviceID) { +} + +func (f *fakeCryptoHelper) Init(context.Context) error { + return nil +} diff --git a/commands/container.go b/commands/container.go index bc685b7b..9b909b75 100644 --- a/commands/container.go +++ b/commands/container.go @@ -1,4 +1,4 @@ -// Copyright (c) 2025 Tulir Asokan +// Copyright (c) 2026 Tulir Asokan // // This Source Code Form is subject to the terms of the Mozilla Public // License, v. 2.0. If a copy of the MPL was not distributed with this @@ -8,14 +8,20 @@ package commands import ( "fmt" + "slices" "strings" "sync" + + "go.mau.fi/util/exmaps" + + "maunium.net/go/mautrix/event/cmdschema" ) type CommandContainer[MetaType any] struct { commands map[string]*Handler[MetaType] aliases map[string]string lock sync.RWMutex + parent *Handler[MetaType] } func NewCommandContainer[MetaType any]() *CommandContainer[MetaType] { @@ -25,6 +31,29 @@ func NewCommandContainer[MetaType any]() *CommandContainer[MetaType] { } } +func (cont *CommandContainer[MetaType]) AllSpecs() []*cmdschema.EventContent { + data := make(exmaps.Set[*Handler[MetaType]]) + cont.collectHandlers(data) + specs := make([]*cmdschema.EventContent, 0, data.Size()) + for handler := range data.Iter() { + if handler.Parameters != nil { + specs = append(specs, handler.Spec()) + } + } + return specs +} + +func (cont *CommandContainer[MetaType]) collectHandlers(into exmaps.Set[*Handler[MetaType]]) { + cont.lock.RLock() + defer cont.lock.RUnlock() + for _, handler := range cont.commands { + into.Add(handler) + if handler.subcommandContainer != nil { + handler.subcommandContainer.collectHandlers(into) + } + } +} + // Register registers the given command handlers. func (cont *CommandContainer[MetaType]) Register(handlers ...*Handler[MetaType]) { if cont == nil { @@ -32,7 +61,10 @@ func (cont *CommandContainer[MetaType]) Register(handlers ...*Handler[MetaType]) } cont.lock.Lock() defer cont.lock.Unlock() - for _, handler := range handlers { + for i, handler := range handlers { + if handler == nil { + panic(fmt.Errorf("handler #%d is nil", i+1)) + } cont.registerOne(handler) } } @@ -45,6 +77,10 @@ func (cont *CommandContainer[MetaType]) registerOne(handler *Handler[MetaType]) } else if aliasTarget, alreadyExists := cont.aliases[handler.Name]; alreadyExists { panic(fmt.Errorf("tried to register command %q, but it's already registered as an alias for %q", handler.Name, aliasTarget)) } + if !slices.Contains(handler.parents, cont.parent) { + handler.parents = append(handler.parents, cont.parent) + handler.nestedNameCache = nil + } cont.commands[handler.Name] = handler for _, alias := range handler.Aliases { if strings.ToLower(alias) != alias { diff --git a/commands/event.go b/commands/event.go index 77a3c0d2..76d6c9f0 100644 --- a/commands/event.go +++ b/commands/event.go @@ -1,4 +1,4 @@ -// Copyright (c) 2025 Tulir Asokan +// Copyright (c) 2026 Tulir Asokan // // This Source Code Form is subject to the terms of the Mozilla Public // License, v. 2.0. If a copy of the MPL was not distributed with this @@ -8,6 +8,7 @@ package commands import ( "context" + "encoding/json" "fmt" "strings" @@ -35,6 +36,8 @@ type Event[MetaType any] struct { // RawArgs is the same as args, but without the splitting by whitespace. RawArgs string + StructuredArgs json.RawMessage + Ctx context.Context Log *zerolog.Logger Proc *Processor[MetaType] @@ -61,7 +64,7 @@ var IDHTMLParser = &format.HTMLParser{ } // ParseEvent parses a message into a command event struct. -func ParseEvent[MetaType any](ctx context.Context, evt *event.Event) *Event[MetaType] { +func (proc *Processor[MetaType]) ParseEvent(ctx context.Context, evt *event.Event) *Event[MetaType] { content, ok := evt.Content.Parsed.(*event.MessageEventContent) if !ok || content.MsgType == event.MsgNotice || content.RelatesTo.GetReplaceID() != "" { return nil @@ -70,12 +73,34 @@ func ParseEvent[MetaType any](ctx context.Context, evt *event.Event) *Event[Meta if content.Format == event.FormatHTML { text = IDHTMLParser.Parse(content.FormattedBody, format.NewContext(ctx)) } + if content.MSC4391BotCommand != nil { + if !content.Mentions.Has(proc.Client.UserID) || len(content.Mentions.UserIDs) != 1 { + return nil + } + wrapped := StructuredCommandToEvent[MetaType](ctx, evt, content.MSC4391BotCommand) + wrapped.RawInput = text + return wrapped + } if len(text) == 0 { return nil } return RawTextToEvent[MetaType](ctx, evt, text) } +func StructuredCommandToEvent[MetaType any](ctx context.Context, evt *event.Event, content *event.MSC4391BotCommandInput) *Event[MetaType] { + commandParts := strings.Split(content.Command, " ") + return &Event[MetaType]{ + Event: evt, + // Fake a command and args to let the subcommand finder in Process work. + Command: commandParts[0], + Args: commandParts[1:], + Ctx: ctx, + Log: zerolog.Ctx(ctx), + + StructuredArgs: content.Arguments, + } +} + func RawTextToEvent[MetaType any](ctx context.Context, evt *event.Event, text string) *Event[MetaType] { parts := strings.Fields(text) if len(parts) == 0 { @@ -188,3 +213,25 @@ func (evt *Event[MetaType]) UnshiftArg(arg string) { evt.RawArgs = arg + " " + evt.RawArgs evt.Args = append([]string{arg}, evt.Args...) } + +func (evt *Event[MetaType]) ParseArgs(into any) error { + return json.Unmarshal(evt.StructuredArgs, into) +} + +func ParseArgs[T, MetaType any](evt *Event[MetaType]) (into T, err error) { + err = evt.ParseArgs(&into) + return +} + +func WithParsedArgs[T, MetaType any](fn func(*Event[MetaType], T)) func(*Event[MetaType]) { + return func(evt *Event[MetaType]) { + parsed, err := ParseArgs[T, MetaType](evt) + if err != nil { + evt.Log.Debug().Err(err).Msg("Failed to parse structured args into struct") + // TODO better error, usage info? deduplicate with Process + evt.Reply("Failed to parse arguments: %v", err) + return + } + fn(evt, parsed) + } +} diff --git a/commands/handler.go b/commands/handler.go index b01d594f..56f27f06 100644 --- a/commands/handler.go +++ b/commands/handler.go @@ -1,4 +1,4 @@ -// Copyright (c) 2025 Tulir Asokan +// Copyright (c) 2026 Tulir Asokan // // This Source Code Form is subject to the terms of the Mozilla Public // License, v. 2.0. If a copy of the MPL was not distributed with this @@ -8,6 +8,9 @@ package commands import ( "strings" + + "maunium.net/go/mautrix/event" + "maunium.net/go/mautrix/event/cmdschema" ) type Handler[MetaType any] struct { @@ -25,12 +28,63 @@ type Handler[MetaType any] struct { // Event.ShiftArg will likely be useful for implementing such parameters. PreFunc func(ce *Event[MetaType]) + // Description is a short description of the command. + Description *event.ExtensibleTextContainer + // Parameters is a description of structured command parameters. + // If set, the StructuredArgs field of Event will be populated. + Parameters []*cmdschema.Parameter + TailParam string + + parents []*Handler[MetaType] + nestedNameCache []string subcommandContainer *CommandContainer[MetaType] } +func (h *Handler[MetaType]) NestedNames() []string { + if h.nestedNameCache != nil { + return h.nestedNameCache + } + nestedNames := make([]string, 0, (1+len(h.Aliases))*len(h.parents)) + for _, parent := range h.parents { + if parent == nil { + nestedNames = append(nestedNames, h.Name) + nestedNames = append(nestedNames, h.Aliases...) + } else { + for _, parentName := range parent.NestedNames() { + nestedNames = append(nestedNames, parentName+" "+h.Name) + for _, alias := range h.Aliases { + nestedNames = append(nestedNames, parentName+" "+alias) + } + } + } + } + h.nestedNameCache = nestedNames + return nestedNames +} + +func (h *Handler[MetaType]) Spec() *cmdschema.EventContent { + names := h.NestedNames() + return &cmdschema.EventContent{ + Command: names[0], + Aliases: names[1:], + Parameters: h.Parameters, + Description: h.Description, + TailParam: h.TailParam, + } +} + +func (h *Handler[MetaType]) CopyFrom(other *Handler[MetaType]) { + if h.Parameters == nil { + h.Parameters = other.Parameters + h.TailParam = other.TailParam + } + h.Func = other.Func +} + func (h *Handler[MetaType]) initSubcommandContainer() { if len(h.Subcommands) > 0 { h.subcommandContainer = NewCommandContainer[MetaType]() + h.subcommandContainer.parent = h h.subcommandContainer.Register(h.Subcommands...) } else { h.subcommandContainer = nil diff --git a/commands/processor.go b/commands/processor.go index 9341329b..80f6745d 100644 --- a/commands/processor.go +++ b/commands/processor.go @@ -1,4 +1,4 @@ -// Copyright (c) 2025 Tulir Asokan +// Copyright (c) 2026 Tulir Asokan // // This Source Code Form is subject to the terms of the Mozilla Public // License, v. 2.0. If a copy of the MPL was not distributed with this @@ -72,9 +72,9 @@ func (proc *Processor[MetaType]) Process(ctx context.Context, evt *event.Event) case event.EventReaction: parsed = proc.ParseReaction(ctx, evt) case event.EventMessage: - parsed = ParseEvent[MetaType](ctx, evt) + parsed = proc.ParseEvent(ctx, evt) } - if parsed == nil || !proc.PreValidator.Validate(parsed) { + if parsed == nil || (!proc.PreValidator.Validate(parsed) && parsed.StructuredArgs == nil) { return } parsed.Proc = proc @@ -107,6 +107,12 @@ func (proc *Processor[MetaType]) Process(ctx context.Context, evt *event.Event) break } } + if parsed.StructuredArgs != nil && len(parsed.Args) > 0 { + // TODO allow unknown command handlers to be called? + // The client sent MSC4391 data, but the target command wasn't found + log.Debug().Msg("Didn't find handler for MSC4391 command") + return + } logWith := log.With(). Str("command", parsed.Command). @@ -116,11 +122,31 @@ func (proc *Processor[MetaType]) Process(ctx context.Context, evt *event.Event) } if proc.LogArgs { logWith = logWith.Strs("args", parsed.Args) + if parsed.StructuredArgs != nil { + logWith = logWith.RawJSON("structured_args", parsed.StructuredArgs) + } } log = logWith.Logger() parsed.Ctx = log.WithContext(ctx) parsed.Log = &log + if handler.Parameters != nil && parsed.StructuredArgs == nil { + // The handler wants structured parameters, but the client didn't send MSC4391 data + var err error + parsed.StructuredArgs, err = handler.Spec().ParseArguments(parsed.RawArgs) + if err != nil { + log.Debug().Err(err).Msg("Failed to parse structured arguments") + // TODO better error, usage info? deduplicate with WithParsedArgs + parsed.Reply("Failed to parse arguments: %v", err) + return + } + if proc.LogArgs { + log.UpdateContext(func(c zerolog.Context) zerolog.Context { + return c.RawJSON("structured_args", parsed.StructuredArgs) + }) + } + } + log.Debug().Msg("Processing command") handler.Func(parsed) } diff --git a/commands/reactions.go b/commands/reactions.go index 0df372e5..0d316219 100644 --- a/commands/reactions.go +++ b/commands/reactions.go @@ -1,4 +1,4 @@ -// Copyright (c) 2025 Tulir Asokan +// Copyright (c) 2026 Tulir Asokan // // This Source Code Form is subject to the terms of the Mozilla Public // License, v. 2.0. If a copy of the MPL was not distributed with this @@ -8,6 +8,7 @@ package commands import ( "context" + "encoding/json" "strings" "github.com/rs/zerolog" @@ -19,6 +20,11 @@ import ( const ReactionCommandsKey = "fi.mau.reaction_commands" const ReactionMultiUseKey = "fi.mau.reaction_multi_use" +type ReactionCommandData struct { + Command string `json:"command"` + Args any `json:"args,omitempty"` +} + func (proc *Processor[MetaType]) ParseReaction(ctx context.Context, evt *event.Event) *Event[MetaType] { content, ok := evt.Content.Parsed.(*event.ReactionEventContent) if !ok { @@ -67,21 +73,33 @@ func (proc *Processor[MetaType]) ParseReaction(ctx context.Context, evt *event.E Msg("Reaction command not found in target event") return nil } - cmdString, ok := rawCmd.(string) - if !ok { + var wrappedEvt *Event[MetaType] + switch typedCmd := rawCmd.(type) { + case string: + wrappedEvt = RawTextToEvent[MetaType](ctx, evt, typedCmd) + case map[string]any: + var input event.MSC4391BotCommandInput + if marshaled, err := json.Marshal(typedCmd); err != nil { + + } else if err = json.Unmarshal(marshaled, &input); err != nil { + + } else { + wrappedEvt = StructuredCommandToEvent[MetaType](ctx, evt, &input) + } + } + if wrappedEvt == nil { zerolog.Ctx(ctx).Debug(). Stringer("target_event_id", evtID). Str("reaction_key", content.RelatesTo.Key). Msg("Reaction command data is invalid") return nil } - wrappedEvt := RawTextToEvent[MetaType](ctx, evt, cmdString) wrappedEvt.Proc = proc wrappedEvt.Redact() if !isMultiUse { DeleteAllReactions(ctx, proc.Client, evt) } - if cmdString == "" { + if wrappedEvt.Command == "" { return nil } return wrappedEvt diff --git a/crypto/attachment/attachments.go b/crypto/attachment/attachments.go index 155cca5c..727aacbf 100644 --- a/crypto/attachment/attachments.go +++ b/crypto/attachment/attachments.go @@ -21,13 +21,24 @@ import ( ) var ( - HashMismatch = errors.New("mismatching SHA-256 digest") - UnsupportedVersion = errors.New("unsupported Matrix file encryption version") - UnsupportedAlgorithm = errors.New("unsupported JWK encryption algorithm") - InvalidKey = errors.New("failed to decode key") - InvalidInitVector = errors.New("failed to decode initialization vector") - InvalidHash = errors.New("failed to decode SHA-256 hash") - ReaderClosed = errors.New("encrypting reader was already closed") + ErrHashMismatch = errors.New("mismatching SHA-256 digest") + ErrUnsupportedVersion = errors.New("unsupported Matrix file encryption version") + ErrUnsupportedAlgorithm = errors.New("unsupported JWK encryption algorithm") + ErrInvalidKey = errors.New("failed to decode key") + ErrInvalidInitVector = errors.New("failed to decode initialization vector") + ErrInvalidHash = errors.New("failed to decode SHA-256 hash") + ErrReaderClosed = errors.New("encrypting reader was already closed") +) + +// Deprecated: use variables prefixed with Err +var ( + HashMismatch = ErrHashMismatch + UnsupportedVersion = ErrUnsupportedVersion + UnsupportedAlgorithm = ErrUnsupportedAlgorithm + InvalidKey = ErrInvalidKey + InvalidInitVector = ErrInvalidInitVector + InvalidHash = ErrInvalidHash + ReaderClosed = ErrReaderClosed ) var ( @@ -85,25 +96,25 @@ func (ef *EncryptedFile) decodeKeys(includeHash bool) error { if ef.decoded != nil { return nil } else if len(ef.Key.Key) != keyBase64Length { - return InvalidKey + return ErrInvalidKey } else if len(ef.InitVector) != ivBase64Length { - return InvalidInitVector + return ErrInvalidInitVector } else if includeHash && len(ef.Hashes.SHA256) != hashBase64Length { - return InvalidHash + return ErrInvalidHash } ef.decoded = &decodedKeys{} _, err := base64.RawURLEncoding.Decode(ef.decoded.key[:], []byte(ef.Key.Key)) if err != nil { - return InvalidKey + return ErrInvalidKey } _, err = base64.RawStdEncoding.Decode(ef.decoded.iv[:], []byte(ef.InitVector)) if err != nil { - return InvalidInitVector + return ErrInvalidInitVector } if includeHash { _, err = base64.RawStdEncoding.Decode(ef.decoded.sha256[:], []byte(ef.Hashes.SHA256)) if err != nil { - return InvalidHash + return ErrInvalidHash } } return nil @@ -179,7 +190,7 @@ var _ io.ReadSeekCloser = (*encryptingReader)(nil) func (r *encryptingReader) Seek(offset int64, whence int) (int64, error) { if r.closed { - return 0, ReaderClosed + return 0, ErrReaderClosed } if offset != 0 || whence != io.SeekStart { return 0, fmt.Errorf("attachments.EncryptStream: only seeking to the beginning is supported") @@ -200,7 +211,7 @@ func (r *encryptingReader) Seek(offset int64, whence int) (int64, error) { func (r *encryptingReader) Read(dst []byte) (n int, err error) { if r.closed { - return 0, ReaderClosed + return 0, ErrReaderClosed } else if r.isDecrypting && r.file.decoded == nil { if err = r.file.PrepareForDecryption(); err != nil { return @@ -224,7 +235,7 @@ func (r *encryptingReader) Close() (err error) { } if r.isDecrypting { if !hmac.Equal(r.hash.Sum(nil), r.file.decoded.sha256[:]) { - return HashMismatch + return ErrHashMismatch } } else { r.file.Hashes.SHA256 = base64.RawStdEncoding.EncodeToString(r.hash.Sum(nil)) @@ -265,9 +276,9 @@ func (ef *EncryptedFile) Decrypt(ciphertext []byte) ([]byte, error) { // DecryptInPlace will always call this automatically, so calling this manually is not necessary when using that function. func (ef *EncryptedFile) PrepareForDecryption() error { if ef.Version != "v2" { - return UnsupportedVersion + return ErrUnsupportedVersion } else if ef.Key.Algorithm != "A256CTR" { - return UnsupportedAlgorithm + return ErrUnsupportedAlgorithm } else if err := ef.decodeKeys(true); err != nil { return err } @@ -281,7 +292,7 @@ func (ef *EncryptedFile) DecryptInPlace(data []byte) error { } dataHash := sha256.Sum256(data) if !hmac.Equal(ef.decoded.sha256[:], dataHash[:]) { - return HashMismatch + return ErrHashMismatch } utils.XorA256CTR(data, ef.decoded.key, ef.decoded.iv) return nil diff --git a/crypto/attachment/attachments_test.go b/crypto/attachment/attachments_test.go index d7f1394a..9fe929ab 100644 --- a/crypto/attachment/attachments_test.go +++ b/crypto/attachment/attachments_test.go @@ -53,33 +53,33 @@ func TestUnsupportedVersion(t *testing.T) { file := parseHelloWorld() file.Version = "foo" err := file.DecryptInPlace([]byte(helloWorldCiphertext)) - assert.ErrorIs(t, err, UnsupportedVersion) + assert.ErrorIs(t, err, ErrUnsupportedVersion) } func TestUnsupportedAlgorithm(t *testing.T) { file := parseHelloWorld() file.Key.Algorithm = "bar" err := file.DecryptInPlace([]byte(helloWorldCiphertext)) - assert.ErrorIs(t, err, UnsupportedAlgorithm) + assert.ErrorIs(t, err, ErrUnsupportedAlgorithm) } func TestHashMismatch(t *testing.T) { file := parseHelloWorld() file.Hashes.SHA256 = base64.RawStdEncoding.EncodeToString([]byte(random32Bytes)) err := file.DecryptInPlace([]byte(helloWorldCiphertext)) - assert.ErrorIs(t, err, HashMismatch) + assert.ErrorIs(t, err, ErrHashMismatch) } func TestTooLongHash(t *testing.T) { file := parseHelloWorld() file.Hashes.SHA256 = "TG9yZW0gaXBzdW0gZG9sb3Igc2l0IGFtZXQsIGNvbnNlY3RldHVlciBhZGlwaXNjaW5nIGVsaXQuIFNlZCBwb3N1ZXJlIGludGVyZHVtIHNlbS4gUXVpc3F1ZSBsaWd1bGEgZXJvcyB1bGxhbWNvcnBlciBxdWlzLCBsYWNpbmlhIHF1aXMgZmFjaWxpc2lzIHNlZCBzYXBpZW4uCg" err := file.DecryptInPlace([]byte(helloWorldCiphertext)) - assert.ErrorIs(t, err, InvalidHash) + assert.ErrorIs(t, err, ErrInvalidHash) } func TestTooShortHash(t *testing.T) { file := parseHelloWorld() file.Hashes.SHA256 = "5/Gy1JftyyQ" err := file.DecryptInPlace([]byte(helloWorldCiphertext)) - assert.ErrorIs(t, err, InvalidHash) + assert.ErrorIs(t, err, ErrInvalidHash) } diff --git a/crypto/cross_sign_key.go b/crypto/cross_sign_key.go index 4094f695..5d9bf5b3 100644 --- a/crypto/cross_sign_key.go +++ b/crypto/cross_sign_key.go @@ -135,7 +135,7 @@ func (mach *OlmMachine) PublishCrossSigningKeys(ctx context.Context, keys *Cross } userKey.Signatures = signatures.NewSingleSignature(userID, id.KeyAlgorithmEd25519, keys.MasterKey.PublicKey().String(), userSig) - err = mach.Client.UploadCrossSigningKeys(ctx, &mautrix.UploadCrossSigningKeysReq{ + err = mach.Client.UploadCrossSigningKeys(ctx, &mautrix.UploadCrossSigningKeysReq[any]{ Master: masterKey, SelfSigning: selfKey, UserSigning: userKey, diff --git a/crypto/cross_sign_pubkey.go b/crypto/cross_sign_pubkey.go index f85d1ea3..223fc7b5 100644 --- a/crypto/cross_sign_pubkey.go +++ b/crypto/cross_sign_pubkey.go @@ -63,8 +63,8 @@ func (mach *OlmMachine) GetCrossSigningPublicKeys(ctx context.Context, userID id if len(dbKeys) > 0 { masterKey, ok := dbKeys[id.XSUsageMaster] if ok { - selfSigning, _ := dbKeys[id.XSUsageSelfSigning] - userSigning, _ := dbKeys[id.XSUsageUserSigning] + selfSigning := dbKeys[id.XSUsageSelfSigning] + userSigning := dbKeys[id.XSUsageUserSigning] return &CrossSigningPublicKeysCache{ MasterKey: masterKey.Key, SelfSigningKey: selfSigning.Key, diff --git a/crypto/cross_sign_ssss.go b/crypto/cross_sign_ssss.go index 50b58ea0..fd42880d 100644 --- a/crypto/cross_sign_ssss.go +++ b/crypto/cross_sign_ssss.go @@ -8,6 +8,7 @@ package crypto import ( "context" + "errors" "fmt" "maunium.net/go/mautrix" @@ -77,7 +78,11 @@ func (mach *OlmMachine) VerifyWithRecoveryKey(ctx context.Context, recoveryKey s return fmt.Errorf("failed to get default SSSS key data: %w", err) } key, err := keyData.VerifyRecoveryKey(keyID, recoveryKey) - if err != nil { + if errors.Is(err, ssss.ErrUnverifiableKey) { + mach.machOrContextLog(ctx).Warn(). + Str("key_id", keyID). + Msg("SSSS key is unverifiable, trying to use without verifying") + } else if err != nil { return err } err = mach.FetchCrossSigningKeysFromSSSS(ctx, key) diff --git a/crypto/cross_sign_store.go b/crypto/cross_sign_store.go index d30b7e32..57406b11 100644 --- a/crypto/cross_sign_store.go +++ b/crypto/cross_sign_store.go @@ -26,24 +26,22 @@ func (mach *OlmMachine) storeCrossSigningKeys(ctx context.Context, crossSigningK log.Error().Err(err). Msg("Error fetching current cross-signing keys of user") } - if currentKeys != nil { - for curKeyUsage, curKey := range currentKeys { - log := log.With().Stringer("old_key", curKey.Key).Str("old_key_usage", string(curKeyUsage)).Logger() - // got a new key with the same usage as an existing key - for _, newKeyUsage := range userKeys.Usage { - if newKeyUsage == curKeyUsage { - if _, ok := userKeys.Keys[id.NewKeyID(id.KeyAlgorithmEd25519, curKey.Key.String())]; !ok { - // old key is not in the new key map, so we drop signatures made by it - if count, err := mach.CryptoStore.DropSignaturesByKey(ctx, userID, curKey.Key); err != nil { - log.Error().Err(err).Msg("Error deleting old signatures made by user") - } else { - log.Debug(). - Int64("signature_count", count). - Msg("Dropped signatures made by old key as it has been replaced") - } + for curKeyUsage, curKey := range currentKeys { + log := log.With().Stringer("old_key", curKey.Key).Str("old_key_usage", string(curKeyUsage)).Logger() + // got a new key with the same usage as an existing key + for _, newKeyUsage := range userKeys.Usage { + if newKeyUsage == curKeyUsage { + if _, ok := userKeys.Keys[id.NewKeyID(id.KeyAlgorithmEd25519, curKey.Key.String())]; !ok { + // old key is not in the new key map, so we drop signatures made by it + if count, err := mach.CryptoStore.DropSignaturesByKey(ctx, userID, curKey.Key); err != nil { + log.Error().Err(err).Msg("Error deleting old signatures made by user") + } else { + log.Debug(). + Int64("signature_count", count). + Msg("Dropped signatures made by old key as it has been replaced") } - break } + break } } } diff --git a/crypto/cryptohelper/cryptohelper.go b/crypto/cryptohelper/cryptohelper.go index 1939ea79..b62dc128 100644 --- a/crypto/cryptohelper/cryptohelper.go +++ b/crypto/cryptohelper/cryptohelper.go @@ -278,7 +278,7 @@ func (helper *CryptoHelper) verifyDeviceKeysOnServer(ctx context.Context) error } } -var NoSessionFound = crypto.NoSessionFound +var NoSessionFound = crypto.ErrNoSessionFound const initialSessionWaitTimeout = 3 * time.Second const extendedSessionWaitTimeout = 22 * time.Second @@ -371,6 +371,7 @@ func (helper *CryptoHelper) waitLongerForSession(ctx context.Context, evt *event content := evt.Content.AsEncrypted() log.Debug().Int("wait_seconds", int(extendedSessionWaitTimeout.Seconds())).Msg("Couldn't find session, requesting keys and waiting longer...") + //lint:ignore SA1019 RequestSession will gracefully request from all devices if DeviceID is blank go helper.RequestSession(context.TODO(), evt.RoomID, content.SenderKey, content.SessionID, evt.Sender, content.DeviceID) if !helper.mach.WaitForSession(ctx, evt.RoomID, content.SenderKey, content.SessionID, extendedSessionWaitTimeout) { @@ -418,7 +419,7 @@ func (helper *CryptoHelper) EncryptWithStateKey(ctx context.Context, roomID id.R defer helper.lock.RUnlock() encrypted, err = helper.mach.EncryptMegolmEventWithStateKey(ctx, roomID, evtType, stateKey, content) if err != nil { - if !errors.Is(err, crypto.SessionExpired) && err != crypto.NoGroupSession && !errors.Is(err, crypto.SessionNotShared) { + if !errors.Is(err, crypto.ErrSessionExpired) && err != crypto.ErrNoGroupSession && !errors.Is(err, crypto.ErrSessionNotShared) { return } helper.log.Debug(). diff --git a/crypto/decryptmegolm.go b/crypto/decryptmegolm.go index 47279474..457d5a0c 100644 --- a/crypto/decryptmegolm.go +++ b/crypto/decryptmegolm.go @@ -24,13 +24,23 @@ import ( ) var ( - IncorrectEncryptedContentType = errors.New("event content is not instance of *event.EncryptedEventContent") - NoSessionFound = errors.New("failed to decrypt megolm event: no session with given ID found") - DuplicateMessageIndex = errors.New("duplicate megolm message index") - WrongRoom = errors.New("encrypted megolm event is not intended for this room") - DeviceKeyMismatch = errors.New("device keys in event and verified device info do not match") - SenderKeyMismatch = errors.New("sender keys in content and megolm session do not match") - RatchetError = errors.New("failed to ratchet session after use") + ErrIncorrectEncryptedContentType = errors.New("event content is not instance of *event.EncryptedEventContent") + ErrNoSessionFound = errors.New("failed to decrypt megolm event: no session with given ID found") + ErrDuplicateMessageIndex = errors.New("duplicate megolm message index") + ErrWrongRoom = errors.New("encrypted megolm event is not intended for this room") + ErrDeviceKeyMismatch = errors.New("device keys in event and verified device info do not match") + ErrRatchetError = errors.New("failed to ratchet session after use") + ErrCorruptedMegolmPayload = errors.New("corrupted megolm payload") +) + +// Deprecated: use variables prefixed with Err +var ( + IncorrectEncryptedContentType = ErrIncorrectEncryptedContentType + NoSessionFound = ErrNoSessionFound + DuplicateMessageIndex = ErrDuplicateMessageIndex + WrongRoom = ErrWrongRoom + DeviceKeyMismatch = ErrDeviceKeyMismatch + RatchetError = ErrRatchetError ) type megolmEvent struct { @@ -45,13 +55,30 @@ var ( relatesToTopLevelPath = exgjson.Path("content", "m.relates_to") ) +const sessionIDLength = 43 + +func validateCiphertextCharacters(ciphertext []byte) bool { + for _, b := range ciphertext { + if (b < 'a' || b > 'z') && (b < 'A' || b > 'Z') && (b < '0' || b > '9') && b != '+' && b != '/' { + return false + } + } + return true +} + // DecryptMegolmEvent decrypts an m.room.encrypted event where the algorithm is m.megolm.v1.aes-sha2 func (mach *OlmMachine) DecryptMegolmEvent(ctx context.Context, evt *event.Event) (*event.Event, error) { content, ok := evt.Content.Parsed.(*event.EncryptedEventContent) if !ok { - return nil, IncorrectEncryptedContentType + return nil, ErrIncorrectEncryptedContentType } else if content.Algorithm != id.AlgorithmMegolmV1 { - return nil, UnsupportedAlgorithm + return nil, ErrUnsupportedAlgorithm + } else if len(content.MegolmCiphertext) < 74 { + return nil, fmt.Errorf("%w: ciphertext too short (%d bytes)", ErrCorruptedMegolmPayload, len(content.MegolmCiphertext)) + } else if len(content.SessionID) != sessionIDLength { + return nil, fmt.Errorf("%w: invalid session ID length %d", ErrCorruptedMegolmPayload, len(content.SessionID)) + } else if !validateCiphertextCharacters(content.MegolmCiphertext) { + return nil, fmt.Errorf("%w: invalid characters in ciphertext", ErrCorruptedMegolmPayload) } log := mach.machOrContextLog(ctx).With(). Str("action", "decrypt megolm event"). @@ -97,7 +124,13 @@ func (mach *OlmMachine) DecryptMegolmEvent(ctx context.Context, evt *event.Event Msg("Couldn't resolve trust level of session: sent by unknown device") trustLevel = id.TrustStateUnknownDevice } else if device.SigningKey != sess.SigningKey || device.IdentityKey != sess.SenderKey { - return nil, DeviceKeyMismatch + log.Debug(). + Stringer("session_sender_key", sess.SenderKey). + Stringer("device_sender_key", device.IdentityKey). + Stringer("session_signing_key", sess.SigningKey). + Stringer("device_signing_key", device.SigningKey). + Msg("Device keys don't match keys in session, marking as untrusted") + trustLevel = id.TrustStateDeviceKeyMismatch } else { trustLevel, err = mach.ResolveTrustContext(ctx, device) if err != nil { @@ -147,7 +180,7 @@ func (mach *OlmMachine) DecryptMegolmEvent(ctx context.Context, evt *event.Event if err != nil { return nil, fmt.Errorf("failed to parse megolm payload: %w", err) } else if megolmEvt.RoomID != encryptionRoomID { - return nil, WrongRoom + return nil, ErrWrongRoom } if evt.StateKey != nil && megolmEvt.StateKey != nil && mach.AllowEncryptedState { megolmEvt.Type.Class = event.StateEventType @@ -180,6 +213,7 @@ func (mach *OlmMachine) DecryptMegolmEvent(ctx context.Context, evt *event.Event TrustSource: device, ForwardedKeys: forwardedKeys, WasEncrypted: true, + EventSource: evt.Mautrix.EventSource | event.SourceDecrypted, ReceivedAt: evt.Mautrix.ReceivedAt, }, }, nil @@ -201,19 +235,19 @@ func (mach *OlmMachine) checkUndecryptableMessageIndexDuplication(ctx context.Co messageIndex, decodeErr := ParseMegolmMessageIndex(content.MegolmCiphertext) if decodeErr != nil { log.Warn().Err(decodeErr).Msg("Failed to parse message index to check if it's a duplicate for message that failed to decrypt") - return 0, fmt.Errorf("%w (also failed to parse message index)", olm.UnknownMessageIndex) + return 0, fmt.Errorf("%w (also failed to parse message index)", olm.ErrUnknownMessageIndex) } firstKnown := sess.Internal.FirstKnownIndex() log = log.With().Uint("message_index", messageIndex).Uint32("first_known_index", firstKnown).Logger() if ok, err := mach.CryptoStore.ValidateMessageIndex(ctx, sess.SenderKey, content.SessionID, evt.ID, messageIndex, evt.Timestamp); err != nil { log.Debug().Err(err).Msg("Failed to check if message index is duplicate") - return messageIndex, fmt.Errorf("%w (failed to check if index is duplicate; received: %d, earliest known: %d)", olm.UnknownMessageIndex, messageIndex, firstKnown) + return messageIndex, fmt.Errorf("%w (failed to check if index is duplicate; received: %d, earliest known: %d)", olm.ErrUnknownMessageIndex, messageIndex, firstKnown) } else if !ok { log.Debug().Msg("Failed to decrypt message due to unknown index and found duplicate") - return messageIndex, fmt.Errorf("%w %d (also failed to decrypt because earliest known index is %d)", DuplicateMessageIndex, messageIndex, firstKnown) + return messageIndex, fmt.Errorf("%w %d (also failed to decrypt because earliest known index is %d)", ErrDuplicateMessageIndex, messageIndex, firstKnown) } log.Debug().Msg("Failed to decrypt message due to unknown index, but index is not duplicate") - return messageIndex, fmt.Errorf("%w (not duplicate index; received: %d, earliest known: %d)", olm.UnknownMessageIndex, messageIndex, firstKnown) + return messageIndex, fmt.Errorf("%w (not duplicate index; received: %d, earliest known: %d)", olm.ErrUnknownMessageIndex, messageIndex, firstKnown) } func (mach *OlmMachine) actuallyDecryptMegolmEvent(ctx context.Context, evt *event.Event, encryptionRoomID id.RoomID, content *event.EncryptedEventContent) (*InboundGroupSession, []byte, uint, error) { @@ -224,13 +258,11 @@ func (mach *OlmMachine) actuallyDecryptMegolmEvent(ctx context.Context, evt *eve if err != nil { return nil, nil, 0, fmt.Errorf("failed to get group session: %w", err) } else if sess == nil { - return nil, nil, 0, fmt.Errorf("%w (ID %s)", NoSessionFound, content.SessionID) - } else if content.SenderKey != "" && content.SenderKey != sess.SenderKey { - return sess, nil, 0, SenderKeyMismatch + return nil, nil, 0, fmt.Errorf("%w (ID %s)", ErrNoSessionFound, content.SessionID) } plaintext, messageIndex, err := sess.Internal.Decrypt(content.MegolmCiphertext) if err != nil { - if errors.Is(err, olm.UnknownMessageIndex) && mach.RatchetKeysOnDecrypt { + if errors.Is(err, olm.ErrUnknownMessageIndex) && mach.RatchetKeysOnDecrypt { messageIndex, err = mach.checkUndecryptableMessageIndexDuplication(ctx, sess, evt, content) return sess, nil, messageIndex, fmt.Errorf("failed to decrypt megolm event: %w", err) } @@ -238,7 +270,7 @@ func (mach *OlmMachine) actuallyDecryptMegolmEvent(ctx context.Context, evt *eve } else if ok, err := mach.CryptoStore.ValidateMessageIndex(ctx, sess.SenderKey, content.SessionID, evt.ID, messageIndex, evt.Timestamp); err != nil { return sess, nil, messageIndex, fmt.Errorf("failed to check if message index is duplicate: %w", err) } else if !ok { - return sess, nil, messageIndex, fmt.Errorf("%w %d", DuplicateMessageIndex, messageIndex) + return sess, nil, messageIndex, fmt.Errorf("%w %d", ErrDuplicateMessageIndex, messageIndex) } // Normal clients don't care about tracking the ratchet state, so let them bypass the rest of the function @@ -290,24 +322,24 @@ func (mach *OlmMachine) actuallyDecryptMegolmEvent(ctx context.Context, evt *eve err = mach.CryptoStore.RedactGroupSession(ctx, sess.RoomID, sess.ID(), "maximum messages reached") if err != nil { log.Err(err).Msg("Failed to delete fully used session") - return sess, plaintext, messageIndex, RatchetError + return sess, plaintext, messageIndex, ErrRatchetError } else { log.Info().Msg("Deleted fully used session") } } else if ratchetCurrentIndex < ratchetTargetIndex && mach.RatchetKeysOnDecrypt { if err = sess.RatchetTo(ratchetTargetIndex); err != nil { log.Err(err).Msg("Failed to ratchet session") - return sess, plaintext, messageIndex, RatchetError + return sess, plaintext, messageIndex, ErrRatchetError } else if err = mach.CryptoStore.PutGroupSession(ctx, sess); err != nil { log.Err(err).Msg("Failed to store ratcheted session") - return sess, plaintext, messageIndex, RatchetError + return sess, plaintext, messageIndex, ErrRatchetError } else { log.Info().Msg("Ratcheted session forward") } } else if didModify { if err = mach.CryptoStore.PutGroupSession(ctx, sess); err != nil { log.Err(err).Msg("Failed to store updated ratchet safety data") - return sess, plaintext, messageIndex, RatchetError + return sess, plaintext, messageIndex, ErrRatchetError } else { log.Debug().Msg("Ratchet safety data changed (ratchet state didn't change)") } diff --git a/crypto/decryptolm.go b/crypto/decryptolm.go index 30cc4cfe..aea5e6dc 100644 --- a/crypto/decryptolm.go +++ b/crypto/decryptolm.go @@ -26,15 +26,27 @@ import ( ) var ( - UnsupportedAlgorithm = errors.New("unsupported event encryption algorithm") - NotEncryptedForMe = errors.New("olm event doesn't contain ciphertext for this device") - UnsupportedOlmMessageType = errors.New("unsupported olm message type") - DecryptionFailedWithMatchingSession = errors.New("decryption failed with matching session") - DecryptionFailedForNormalMessage = errors.New("decryption failed for normal message") - SenderMismatch = errors.New("mismatched sender in olm payload") - RecipientMismatch = errors.New("mismatched recipient in olm payload") - RecipientKeyMismatch = errors.New("mismatched recipient key in olm payload") - ErrDuplicateMessage = errors.New("duplicate olm message") + ErrUnsupportedAlgorithm = errors.New("unsupported event encryption algorithm") + ErrNotEncryptedForMe = errors.New("olm event doesn't contain ciphertext for this device") + ErrUnsupportedOlmMessageType = errors.New("unsupported olm message type") + ErrDecryptionFailedWithMatchingSession = errors.New("decryption failed with matching session") + ErrDecryptionFailedForNormalMessage = errors.New("decryption failed for normal message") + ErrSenderMismatch = errors.New("mismatched sender in olm payload") + ErrRecipientMismatch = errors.New("mismatched recipient in olm payload") + ErrRecipientKeyMismatch = errors.New("mismatched recipient key in olm payload") + ErrDuplicateMessage = errors.New("duplicate olm message") +) + +// Deprecated: use variables prefixed with Err +var ( + UnsupportedAlgorithm = ErrUnsupportedAlgorithm + NotEncryptedForMe = ErrNotEncryptedForMe + UnsupportedOlmMessageType = ErrUnsupportedOlmMessageType + DecryptionFailedWithMatchingSession = ErrDecryptionFailedWithMatchingSession + DecryptionFailedForNormalMessage = ErrDecryptionFailedForNormalMessage + SenderMismatch = ErrSenderMismatch + RecipientMismatch = ErrRecipientMismatch + RecipientKeyMismatch = ErrRecipientKeyMismatch ) // DecryptedOlmEvent represents an event that was decrypted from an event encrypted with the m.olm.v1.curve25519-aes-sha2 algorithm. @@ -56,13 +68,13 @@ type DecryptedOlmEvent struct { func (mach *OlmMachine) decryptOlmEvent(ctx context.Context, evt *event.Event) (*DecryptedOlmEvent, error) { content, ok := evt.Content.Parsed.(*event.EncryptedEventContent) if !ok { - return nil, IncorrectEncryptedContentType + return nil, ErrIncorrectEncryptedContentType } else if content.Algorithm != id.AlgorithmOlmV1 { - return nil, UnsupportedAlgorithm + return nil, ErrUnsupportedAlgorithm } ownContent, ok := content.OlmCiphertext[mach.account.IdentityKey()] if !ok { - return nil, NotEncryptedForMe + return nil, ErrNotEncryptedForMe } decrypted, err := mach.decryptAndParseOlmCiphertext(ctx, evt, content.SenderKey, ownContent.Type, ownContent.Body) if err != nil { @@ -78,7 +90,7 @@ type OlmEventKeys struct { func (mach *OlmMachine) decryptAndParseOlmCiphertext(ctx context.Context, evt *event.Event, senderKey id.SenderKey, olmType id.OlmMsgType, ciphertext string) (*DecryptedOlmEvent, error) { if olmType != id.OlmMsgTypePreKey && olmType != id.OlmMsgTypeMsg { - return nil, UnsupportedOlmMessageType + return nil, ErrUnsupportedOlmMessageType } log := mach.machOrContextLog(ctx).With(). @@ -102,11 +114,11 @@ func (mach *OlmMachine) decryptAndParseOlmCiphertext(ctx context.Context, evt *e } olmEvt.Type.Class = evt.Type.Class if evt.Sender != olmEvt.Sender { - return nil, SenderMismatch + return nil, ErrSenderMismatch } else if mach.Client.UserID != olmEvt.Recipient { - return nil, RecipientMismatch + return nil, ErrRecipientMismatch } else if mach.account.SigningKey() != olmEvt.RecipientKeys.Ed25519 { - return nil, RecipientKeyMismatch + return nil, ErrRecipientKeyMismatch } if len(olmEvt.Content.VeryRaw) > 0 { @@ -122,6 +134,9 @@ func (mach *OlmMachine) decryptAndParseOlmCiphertext(ctx context.Context, evt *e } func olmMessageHash(ciphertext string) ([32]byte, error) { + if ciphertext == "" { + return [32]byte{}, fmt.Errorf("empty ciphertext") + } ciphertextBytes, err := base64.RawStdEncoding.DecodeString(ciphertext) return sha256.Sum256(ciphertextBytes), err } @@ -151,7 +166,7 @@ func (mach *OlmMachine) tryDecryptOlmCiphertext(ctx context.Context, sender id.U plaintext, err := mach.tryDecryptOlmCiphertextWithExistingSession(ctx, senderKey, olmType, ciphertext, ciphertextHash) if err != nil { - if err == DecryptionFailedWithMatchingSession { + if err == ErrDecryptionFailedWithMatchingSession { log.Warn().Msg("Found matching session, but decryption failed") go mach.unwedgeDevice(log, sender, senderKey) } @@ -169,10 +184,10 @@ func (mach *OlmMachine) tryDecryptOlmCiphertext(ctx context.Context, sender id.U // if it isn't one at this point in time anymore, so return early. if olmType != id.OlmMsgTypePreKey { go mach.unwedgeDevice(log, sender, senderKey) - return nil, DecryptionFailedForNormalMessage + return nil, ErrDecryptionFailedForNormalMessage } - accountBackup, err := mach.account.Internal.Pickle([]byte("tmp")) + accountBackup, _ := mach.account.Internal.Pickle([]byte("tmp")) log.Trace().Msg("Trying to create inbound session") endTimeTrace = mach.timeTrace(ctx, "creating inbound olm session", time.Second) session, err := mach.createInboundSession(ctx, senderKey, ciphertext) @@ -302,7 +317,7 @@ func (mach *OlmMachine) tryDecryptOlmCiphertextWithExistingSession( Str("session_description", session.Describe()). Msg("Failed to decrypt olm message") if olmType == id.OlmMsgTypePreKey { - return nil, DecryptionFailedWithMatchingSession + return nil, ErrDecryptionFailedWithMatchingSession } } else { endTimeTrace = mach.timeTrace(ctx, "updating session in database", time.Second) @@ -345,7 +360,7 @@ func (mach *OlmMachine) unwedgeDevice(log zerolog.Logger, sender id.UserID, send ctx := log.WithContext(mach.backgroundCtx) mach.recentlyUnwedgedLock.Lock() prevUnwedge, ok := mach.recentlyUnwedged[senderKey] - delta := time.Now().Sub(prevUnwedge) + delta := time.Since(prevUnwedge) if ok && delta < MinUnwedgeInterval { log.Debug(). Str("previous_recreation", delta.String()). diff --git a/crypto/devicelist.go b/crypto/devicelist.go index 61a22522..f0d2b129 100644 --- a/crypto/devicelist.go +++ b/crypto/devicelist.go @@ -22,14 +22,23 @@ import ( ) var ( - MismatchingDeviceID = errors.New("mismatching device ID in parameter and keys object") - MismatchingUserID = errors.New("mismatching user ID in parameter and keys object") - MismatchingSigningKey = errors.New("received update for device with different signing key") - NoSigningKeyFound = errors.New("didn't find ed25519 signing key") - NoIdentityKeyFound = errors.New("didn't find curve25519 identity key") - InvalidKeySignature = errors.New("invalid signature on device keys") + ErrMismatchingDeviceID = errors.New("mismatching device ID in parameter and keys object") + ErrMismatchingUserID = errors.New("mismatching user ID in parameter and keys object") + ErrMismatchingSigningKey = errors.New("received update for device with different signing key") + ErrNoSigningKeyFound = errors.New("didn't find ed25519 signing key") + ErrNoIdentityKeyFound = errors.New("didn't find curve25519 identity key") + ErrInvalidKeySignature = errors.New("invalid signature on device keys") + ErrUserNotTracked = errors.New("user is not tracked") +) - ErrUserNotTracked = errors.New("user is not tracked") +// Deprecated: use variables prefixed with Err +var ( + MismatchingDeviceID = ErrMismatchingDeviceID + MismatchingUserID = ErrMismatchingUserID + MismatchingSigningKey = ErrMismatchingSigningKey + NoSigningKeyFound = ErrNoSigningKeyFound + NoIdentityKeyFound = ErrNoIdentityKeyFound + InvalidKeySignature = ErrInvalidKeySignature ) func (mach *OlmMachine) LoadDevices(ctx context.Context, user id.UserID) (keys map[id.DeviceID]*id.Device) { @@ -312,28 +321,28 @@ func (mach *OlmMachine) OnDevicesChanged(ctx context.Context, userID id.UserID) func (mach *OlmMachine) validateDevice(userID id.UserID, deviceID id.DeviceID, deviceKeys mautrix.DeviceKeys, existing *id.Device) (*id.Device, error) { if deviceID != deviceKeys.DeviceID { - return nil, fmt.Errorf("%w (expected %s, got %s)", MismatchingDeviceID, deviceID, deviceKeys.DeviceID) + return nil, fmt.Errorf("%w (expected %s, got %s)", ErrMismatchingDeviceID, deviceID, deviceKeys.DeviceID) } else if userID != deviceKeys.UserID { - return nil, fmt.Errorf("%w (expected %s, got %s)", MismatchingUserID, userID, deviceKeys.UserID) + return nil, fmt.Errorf("%w (expected %s, got %s)", ErrMismatchingUserID, userID, deviceKeys.UserID) } signingKey := deviceKeys.Keys.GetEd25519(deviceID) identityKey := deviceKeys.Keys.GetCurve25519(deviceID) if signingKey == "" { - return nil, NoSigningKeyFound + return nil, ErrNoSigningKeyFound } else if identityKey == "" { - return nil, NoIdentityKeyFound + return nil, ErrNoIdentityKeyFound } if existing != nil && existing.SigningKey != signingKey { - return existing, fmt.Errorf("%w (expected %s, got %s)", MismatchingSigningKey, existing.SigningKey, signingKey) + return existing, fmt.Errorf("%w (expected %s, got %s)", ErrMismatchingSigningKey, existing.SigningKey, signingKey) } ok, err := signatures.VerifySignatureJSON(deviceKeys, userID, deviceID.String(), signingKey) if err != nil { return existing, fmt.Errorf("failed to verify signature: %w", err) } else if !ok { - return existing, InvalidKeySignature + return existing, ErrInvalidKeySignature } name, ok := deviceKeys.Unsigned["device_display_name"].(string) diff --git a/crypto/encryptmegolm.go b/crypto/encryptmegolm.go index b3d19618..88f9c8d4 100644 --- a/crypto/encryptmegolm.go +++ b/crypto/encryptmegolm.go @@ -25,7 +25,12 @@ import ( ) var ( - NoGroupSession = errors.New("no group session created") + ErrNoGroupSession = errors.New("no group session created") +) + +// Deprecated: use variables prefixed with Err +var ( + NoGroupSession = ErrNoGroupSession ) func getRawJSON[T any](content json.RawMessage, path ...string) *T { @@ -82,15 +87,20 @@ type rawMegolmEvent struct { // IsShareError returns true if the error is caused by the lack of an outgoing megolm session and can be solved with OlmMachine.ShareGroupSession func IsShareError(err error) bool { - return err == SessionExpired || err == SessionNotShared || err == NoGroupSession + return err == ErrSessionExpired || err == ErrSessionNotShared || err == ErrNoGroupSession } func ParseMegolmMessageIndex(ciphertext []byte) (uint, error) { + if len(ciphertext) == 0 { + return 0, fmt.Errorf("empty ciphertext") + } decoded := make([]byte, base64.RawStdEncoding.DecodedLen(len(ciphertext))) var err error _, err = base64.RawStdEncoding.Decode(decoded, ciphertext) if err != nil { return 0, err + } else if len(decoded) < 2+binary.MaxVarintLen64 { + return 0, fmt.Errorf("decoded ciphertext too short: %d bytes", len(decoded)) } else if decoded[0] != 3 || decoded[1] != 8 { return 0, fmt.Errorf("unexpected initial bytes %d and %d", decoded[0], decoded[1]) } @@ -120,7 +130,7 @@ func (mach *OlmMachine) EncryptMegolmEventWithStateKey(ctx context.Context, room if err != nil { return nil, fmt.Errorf("failed to get outbound group session: %w", err) } else if session == nil { - return nil, NoGroupSession + return nil, ErrNoGroupSession } plaintext, err := json.Marshal(&rawMegolmEvent{ RoomID: roomID, @@ -164,6 +174,15 @@ func (mach *OlmMachine) EncryptMegolmEventWithStateKey(ctx context.Context, room SenderKey: mach.account.IdentityKey(), DeviceID: mach.Client.DeviceID, } + if mach.MSC4392Relations && encrypted.RelatesTo != nil { + // When MSC4392 mode is enabled, reply and reaction metadata is stripped from the unencrypted content. + // Other relations like threads are still left unencrypted. + encrypted.RelatesTo.InReplyTo = nil + encrypted.RelatesTo.IsFallingBack = false + if evtType == event.EventReaction || encrypted.RelatesTo.Type == "" { + encrypted.RelatesTo = nil + } + } if mach.PlaintextMentions { encrypted.Mentions = getMentions(content) } @@ -351,26 +370,19 @@ func (mach *OlmMachine) encryptAndSendGroupSession(ctx context.Context, session log.Trace().Msg("Encrypting group session for all found devices") deviceCount := 0 toDevice := &mautrix.ReqSendToDevice{Messages: make(map[id.UserID]map[id.DeviceID]*event.Content)} + logUsers := zerolog.Dict() for userID, sessions := range olmSessions { if len(sessions) == 0 { continue } + logDevices := zerolog.Dict() output := make(map[id.DeviceID]*event.Content) toDevice.Messages[userID] = output for deviceID, device := range sessions { - log.Trace(). - Stringer("target_user_id", userID). - Stringer("target_device_id", deviceID). - Stringer("target_identity_key", device.identity.IdentityKey). - Msg("Encrypting group session for device") content := mach.encryptOlmEvent(ctx, device.session, device.identity, event.ToDeviceRoomKey, session.ShareContent()) output[deviceID] = &event.Content{Parsed: content} + logDevices.Str(string(deviceID), string(device.identity.IdentityKey)) deviceCount++ - log.Debug(). - Stringer("target_user_id", userID). - Stringer("target_device_id", deviceID). - Stringer("target_identity_key", device.identity.IdentityKey). - Msg("Encrypted group session for device") if !mach.DisableSharedGroupSessionTracking { err := mach.CryptoStore.MarkOutboundGroupSessionShared(ctx, userID, device.identity.IdentityKey, session.id) if err != nil { @@ -384,11 +396,13 @@ func (mach *OlmMachine) encryptAndSendGroupSession(ctx context.Context, session } } } + logUsers.Dict(string(userID), logDevices) } log.Debug(). Int("device_count", deviceCount). Int("user_count", len(toDevice.Messages)). + Dict("destination_map", logUsers). Msg("Sending to-device messages to share group session") _, err := mach.Client.SendToDevice(ctx, event.ToDeviceEncrypted, toDevice) return err diff --git a/crypto/encryptolm.go b/crypto/encryptolm.go index 80b76dc5..765307af 100644 --- a/crypto/encryptolm.go +++ b/crypto/encryptolm.go @@ -96,15 +96,19 @@ func (mach *OlmMachine) encryptOlmEvent(ctx context.Context, session *OlmSession panic(err) } log := mach.machOrContextLog(ctx) - log.Debug(). - Str("recipient_identity_key", recipient.IdentityKey.String()). - Str("olm_session_id", session.ID().String()). - Str("olm_session_description", session.Describe()). - Msg("Encrypting olm message") msgType, ciphertext, err := session.Encrypt(plaintext) if err != nil { panic(err) } + ciphertextStr := string(ciphertext) + ciphertextHash, _ := olmMessageHash(ciphertextStr) + log.Debug(). + Stringer("event_type", evtType). + Str("recipient_identity_key", recipient.IdentityKey.String()). + Str("olm_session_id", session.ID().String()). + Str("olm_session_description", session.Describe()). + Hex("ciphertext_hash", ciphertextHash[:]). + Msg("Encrypted olm message") err = mach.CryptoStore.UpdateSession(ctx, recipient.IdentityKey, session) if err != nil { log.Error().Err(err).Msg("Failed to update olm session in crypto store after encrypting") @@ -115,7 +119,7 @@ func (mach *OlmMachine) encryptOlmEvent(ctx context.Context, session *OlmSession OlmCiphertext: event.OlmCiphertexts{ recipient.IdentityKey: { Type: msgType, - Body: string(ciphertext), + Body: ciphertextStr, }, }, } diff --git a/crypto/goolm/account/account.go b/crypto/goolm/account/account.go index 4da08a73..b48843a4 100644 --- a/crypto/goolm/account/account.go +++ b/crypto/goolm/account/account.go @@ -334,7 +334,7 @@ func (a *Account) UnpickleLibOlm(buf []byte) error { if err != nil { return err } else if pickledVersion != accountPickleVersionLibOLM && pickledVersion != 3 && pickledVersion != 2 { - return fmt.Errorf("unpickle account: %w (found version %d)", olm.ErrBadVersion, pickledVersion) + return fmt.Errorf("unpickle account: %w (found version %d)", olm.ErrUnknownOlmPickleVersion, pickledVersion) } else if err = a.IdKeys.Ed25519.UnpickleLibOlm(decoder); err != nil { // read the ed25519 key pair return err } else if err = a.IdKeys.Curve25519.UnpickleLibOlm(decoder); err != nil { // read curve25519 key pair diff --git a/crypto/goolm/account/account_test.go b/crypto/goolm/account/account_test.go index e1c9b452..d0dec5f0 100644 --- a/crypto/goolm/account/account_test.go +++ b/crypto/goolm/account/account_test.go @@ -124,7 +124,7 @@ func TestOldAccountPickle(t *testing.T) { account, err := account.NewAccount() assert.NoError(t, err) err = account.Unpickle(pickled, pickleKey) - assert.ErrorIs(t, err, olm.ErrBadVersion) + assert.ErrorIs(t, err, olm.ErrUnknownOlmPickleVersion) } func TestLoopback(t *testing.T) { diff --git a/crypto/goolm/crypto/curve25519.go b/crypto/goolm/crypto/curve25519.go index e9759501..6e42d886 100644 --- a/crypto/goolm/crypto/curve25519.go +++ b/crypto/goolm/crypto/curve25519.go @@ -53,6 +53,7 @@ func (c Curve25519KeyPair) B64Encoded() id.Curve25519 { // SharedSecret returns the shared secret between the key pair and the given public key. func (c Curve25519KeyPair) SharedSecret(pubKey Curve25519PublicKey) ([]byte, error) { + // Note: the standard library checks that the output is non-zero return c.PrivateKey.SharedSecret(pubKey) } diff --git a/crypto/goolm/crypto/curve25519_test.go b/crypto/goolm/crypto/curve25519_test.go index 9039c126..2550f15e 100644 --- a/crypto/goolm/crypto/curve25519_test.go +++ b/crypto/goolm/crypto/curve25519_test.go @@ -25,6 +25,8 @@ func TestCurve25519(t *testing.T) { fromPrivate, err := crypto.Curve25519GenerateFromPrivate(firstKeypair.PrivateKey) assert.NoError(t, err) assert.Equal(t, fromPrivate, firstKeypair) + _, err = secondKeypair.SharedSecret(make([]byte, crypto.Curve25519PublicKeyLength)) + assert.Error(t, err) } func TestCurve25519Case1(t *testing.T) { diff --git a/crypto/goolm/goolmbase64/base64.go b/crypto/goolm/goolmbase64/base64.go index 061a052a..58ee26f7 100644 --- a/crypto/goolm/goolmbase64/base64.go +++ b/crypto/goolm/goolmbase64/base64.go @@ -4,7 +4,8 @@ import ( "encoding/base64" ) -// Deprecated: base64.RawStdEncoding should be used directly +// These methods should only be used for raw byte operations, never with string conversion + func Decode(input []byte) ([]byte, error) { decoded := make([]byte, base64.RawStdEncoding.DecodedLen(len(input))) writtenBytes, err := base64.RawStdEncoding.Decode(decoded, input) @@ -14,7 +15,6 @@ func Decode(input []byte) ([]byte, error) { return decoded[:writtenBytes], nil } -// Deprecated: base64.RawStdEncoding should be used directly func Encode(input []byte) []byte { encoded := make([]byte, base64.RawStdEncoding.EncodedLen(len(input))) base64.RawStdEncoding.Encode(encoded, input) diff --git a/crypto/goolm/libolmpickle/picklejson.go b/crypto/goolm/libolmpickle/picklejson.go index 308e472c..f765391f 100644 --- a/crypto/goolm/libolmpickle/picklejson.go +++ b/crypto/goolm/libolmpickle/picklejson.go @@ -50,7 +50,7 @@ func UnpickleAsJSON(object any, pickled, key []byte, pickleVersion byte) error { } } if decrypted[0] != pickleVersion { - return fmt.Errorf("unpickle: %w", olm.ErrWrongPickleVersion) + return fmt.Errorf("unpickle: %w", olm.ErrUnknownJSONPickleVersion) } err = json.Unmarshal(decrypted[1:], object) if err != nil { diff --git a/crypto/goolm/message/decoder.go b/crypto/goolm/message/decoder.go index a71cf302..b06756a9 100644 --- a/crypto/goolm/message/decoder.go +++ b/crypto/goolm/message/decoder.go @@ -3,6 +3,9 @@ package message import ( "bytes" "encoding/binary" + "fmt" + + "maunium.net/go/mautrix/crypto/olm" ) type Decoder struct { @@ -20,6 +23,8 @@ func (d *Decoder) ReadVarInt() (uint64, error) { func (d *Decoder) ReadVarBytes() ([]byte, error) { if n, err := d.ReadVarInt(); err != nil { return nil, err + } else if n > uint64(d.Len()) { + return nil, fmt.Errorf("%w: var bytes length says %d, but only %d bytes left", olm.ErrInputToSmall, n, d.Available()) } else { out := make([]byte, n) _, err = d.Read(out) diff --git a/crypto/goolm/message/group_message.go b/crypto/goolm/message/group_message.go index c2a43b1f..c83540c1 100644 --- a/crypto/goolm/message/group_message.go +++ b/crypto/goolm/message/group_message.go @@ -2,10 +2,12 @@ package message import ( "bytes" + "fmt" "io" "maunium.net/go/mautrix/crypto/goolm/aessha2" "maunium.net/go/mautrix/crypto/goolm/crypto" + "maunium.net/go/mautrix/crypto/olm" ) const ( @@ -36,6 +38,9 @@ func (r *GroupMessage) Decode(input []byte) (err error) { if err != nil { return } + if r.Version != protocolVersion { + return fmt.Errorf("GroupMessage.Decode: %w (got %d, expected %d)", olm.ErrWrongProtocolVersion, r.Version, protocolVersion) + } for { // Read Key diff --git a/crypto/goolm/message/message.go b/crypto/goolm/message/message.go index 8bb6e0cd..b161a2d1 100644 --- a/crypto/goolm/message/message.go +++ b/crypto/goolm/message/message.go @@ -2,10 +2,12 @@ package message import ( "bytes" + "fmt" "io" "maunium.net/go/mautrix/crypto/goolm/aessha2" "maunium.net/go/mautrix/crypto/goolm/crypto" + "maunium.net/go/mautrix/crypto/olm" ) const ( @@ -40,6 +42,9 @@ func (r *Message) Decode(input []byte) (err error) { if err != nil { return } + if r.Version != protocolVersion { + return fmt.Errorf("Message.Decode: %w (got %d, expected %d)", olm.ErrWrongProtocolVersion, r.Version, protocolVersion) + } for { // Read Key diff --git a/crypto/goolm/message/prekey_message.go b/crypto/goolm/message/prekey_message.go index 22ebf9c3..4e3d495d 100644 --- a/crypto/goolm/message/prekey_message.go +++ b/crypto/goolm/message/prekey_message.go @@ -1,6 +1,7 @@ package message import ( + "fmt" "io" "maunium.net/go/mautrix/crypto/goolm/crypto" @@ -22,6 +23,11 @@ type PreKeyMessage struct { Message []byte `json:"message"` } +// TODO deduplicate constant with one in session/olm_session.go +const ( + protocolVersion = 0x3 +) + // Decodes decodes the input and populates the corresponding fileds. func (r *PreKeyMessage) Decode(input []byte) (err error) { r.Version = 0 @@ -41,6 +47,9 @@ func (r *PreKeyMessage) Decode(input []byte) (err error) { } return } + if r.Version != protocolVersion { + return fmt.Errorf("PreKeyMessage.Decode: %w (got %d, expected %d)", olm.ErrWrongProtocolVersion, r.Version, protocolVersion) + } for { // Read Key diff --git a/crypto/goolm/message/session_export.go b/crypto/goolm/message/session_export.go index 956868b2..d58dbb21 100644 --- a/crypto/goolm/message/session_export.go +++ b/crypto/goolm/message/session_export.go @@ -35,7 +35,7 @@ func (s *MegolmSessionExport) Decode(input []byte) error { return fmt.Errorf("decrypt: %w", olm.ErrBadInput) } if input[0] != sessionExportVersion { - return fmt.Errorf("decrypt: %w", olm.ErrBadVersion) + return fmt.Errorf("decrypt: %w", olm.ErrUnknownOlmPickleVersion) } s.Counter = binary.BigEndian.Uint32(input[1:5]) copy(s.RatchetData[:], input[5:133]) diff --git a/crypto/goolm/message/session_sharing.go b/crypto/goolm/message/session_sharing.go index 16240945..d04ef15a 100644 --- a/crypto/goolm/message/session_sharing.go +++ b/crypto/goolm/message/session_sharing.go @@ -42,7 +42,7 @@ func (s *MegolmSessionSharing) VerifyAndDecode(input []byte) error { } s.PublicKey = publicKey if input[0] != sessionSharingVersion { - return fmt.Errorf("verify: %w", olm.ErrBadVersion) + return fmt.Errorf("verify: %w", olm.ErrUnknownOlmPickleVersion) } s.Counter = binary.BigEndian.Uint32(input[1:5]) copy(s.RatchetData[:], input[5:133]) diff --git a/crypto/goolm/pk/decryption.go b/crypto/goolm/pk/decryption.go index afb01f74..cdb20eb1 100644 --- a/crypto/goolm/pk/decryption.go +++ b/crypto/goolm/pk/decryption.go @@ -103,7 +103,7 @@ func (a *Decryption) UnpickleLibOlm(unpickled []byte) error { if pickledVersion == decryptionPickleVersionLibOlm { return a.KeyPair.UnpickleLibOlm(decoder) } else { - return fmt.Errorf("unpickle olmSession: %w (found %d, expected %d)", olm.ErrBadVersion, pickledVersion, decryptionPickleVersionLibOlm) + return fmt.Errorf("unpickle olmSession: %w (found %d, expected %d)", olm.ErrUnknownOlmPickleVersion, pickledVersion, decryptionPickleVersionLibOlm) } } diff --git a/crypto/goolm/pk/encryption.go b/crypto/goolm/pk/encryption.go index 23f67ddf..2897d9b0 100644 --- a/crypto/goolm/pk/encryption.go +++ b/crypto/goolm/pk/encryption.go @@ -37,6 +37,9 @@ func (e Encryption) Encrypt(plaintext []byte, privateKey crypto.Curve25519Privat return nil, nil, err } cipher, err := aessha2.NewAESSHA2(sharedSecret, nil) + if err != nil { + return nil, nil, err + } ciphertext, err = cipher.Encrypt(plaintext) if err != nil { return nil, nil, err diff --git a/crypto/goolm/ratchet/olm.go b/crypto/goolm/ratchet/olm.go index 229c9bd2..9901ada8 100644 --- a/crypto/goolm/ratchet/olm.go +++ b/crypto/goolm/ratchet/olm.go @@ -142,7 +142,7 @@ func (r *Ratchet) Decrypt(input []byte) ([]byte, error) { return nil, err } if message.Version != protocolVersion { - return nil, fmt.Errorf("decrypt: %w", olm.ErrWrongProtocolVersion) + return nil, fmt.Errorf("decrypt: %w (got %d, expected %d)", olm.ErrWrongProtocolVersion, message.Version, protocolVersion) } if !message.HasCounter || len(message.RatchetKey) == 0 || len(message.Ciphertext) == 0 { return nil, fmt.Errorf("decrypt: %w", olm.ErrBadMessageFormat) diff --git a/crypto/goolm/session/megolm_inbound_session.go b/crypto/goolm/session/megolm_inbound_session.go index 80dd71cc..7ccbd26d 100644 --- a/crypto/goolm/session/megolm_inbound_session.go +++ b/crypto/goolm/session/megolm_inbound_session.go @@ -99,7 +99,7 @@ func (o *MegolmInboundSession) getRatchet(messageIndex uint32) (*megolm.Ratchet, } if (messageIndex - o.InitialRatchet.Counter) >= uint32(1<<31) { // the counter is before our initial ratchet - we can't decode this - return nil, fmt.Errorf("decrypt: %w", olm.ErrRatchetNotAvailable) + return nil, fmt.Errorf("decrypt: %w", olm.ErrUnknownMessageIndex) } // otherwise, start from the initial ratchet. Take a copy so that we don't overwrite the initial ratchet copiedRatchet := o.InitialRatchet @@ -126,7 +126,7 @@ func (o *MegolmInboundSession) Decrypt(ciphertext []byte) ([]byte, uint, error) return nil, 0, err } if msg.Version != protocolVersion { - return nil, 0, fmt.Errorf("decrypt: %w", olm.ErrWrongProtocolVersion) + return nil, 0, fmt.Errorf("decrypt: %w (got %d, expected %d)", olm.ErrWrongProtocolVersion, msg.Version, protocolVersion) } if msg.Ciphertext == nil || !msg.HasMessageIndex { return nil, 0, fmt.Errorf("decrypt: %w", olm.ErrBadMessageFormat) @@ -206,7 +206,7 @@ func (o *MegolmInboundSession) UnpickleLibOlm(value []byte) error { return err } if pickledVersion != megolmInboundSessionPickleVersionLibOlm && pickledVersion != 1 { - return fmt.Errorf("unpickle MegolmInboundSession: %w (found version %d)", olm.ErrBadVersion, pickledVersion) + return fmt.Errorf("unpickle MegolmInboundSession: %w (found version %d)", olm.ErrUnknownOlmPickleVersion, pickledVersion) } if err = o.InitialRatchet.UnpickleLibOlm(decoder); err != nil { diff --git a/crypto/goolm/session/megolm_outbound_session.go b/crypto/goolm/session/megolm_outbound_session.go index 2b8e1c84..7f923534 100644 --- a/crypto/goolm/session/megolm_outbound_session.go +++ b/crypto/goolm/session/megolm_outbound_session.go @@ -101,8 +101,10 @@ func (o *MegolmOutboundSession) Unpickle(pickled, key []byte) error { func (o *MegolmOutboundSession) UnpickleLibOlm(buf []byte) error { decoder := libolmpickle.NewDecoder(buf) pickledVersion, err := decoder.ReadUInt32() - if pickledVersion != megolmOutboundSessionPickleVersionLibOlm { - return fmt.Errorf("unpickle MegolmInboundSession: %w (found version %d)", olm.ErrBadVersion, pickledVersion) + if err != nil { + return fmt.Errorf("unpickle MegolmOutboundSession: failed to read version: %w", err) + } else if pickledVersion != megolmOutboundSessionPickleVersionLibOlm { + return fmt.Errorf("unpickle MegolmInboundSession: %w (found version %d)", olm.ErrUnknownOlmPickleVersion, pickledVersion) } if err = o.Ratchet.UnpickleLibOlm(decoder); err != nil { return err diff --git a/crypto/goolm/session/olm_session.go b/crypto/goolm/session/olm_session.go index b99ab630..a1cb8d66 100644 --- a/crypto/goolm/session/olm_session.go +++ b/crypto/goolm/session/olm_session.go @@ -168,11 +168,11 @@ func NewInboundOlmSession(identityKeyAlice *crypto.Curve25519PublicKey, received msg := message.Message{} err = msg.Decode(oneTimeMsg.Message) if err != nil { - return nil, fmt.Errorf("Message decode: %w", err) + return nil, fmt.Errorf("message decode: %w", err) } if len(msg.RatchetKey) == 0 { - return nil, fmt.Errorf("Message missing ratchet key: %w", olm.ErrBadMessageFormat) + return nil, fmt.Errorf("message missing ratchet key: %w", olm.ErrBadMessageFormat) } //Init Ratchet s.Ratchet.InitializeAsBob(secret, msg.RatchetKey) @@ -203,7 +203,7 @@ func (s *OlmSession) ID() id.SessionID { copy(message[crypto.Curve25519PrivateKeyLength:], s.AliceBaseKey) copy(message[2*crypto.Curve25519PrivateKeyLength:], s.BobOneTimeKey) hash := sha256.Sum256(message) - res := id.SessionID(goolmbase64.Encode(hash[:])) + res := id.SessionID(base64.RawStdEncoding.EncodeToString(hash[:])) return res } @@ -325,7 +325,7 @@ func (s *OlmSession) Decrypt(crypttext string, msgType id.OlmMsgType) ([]byte, e if len(crypttext) == 0 { return nil, fmt.Errorf("decrypt: %w", olm.ErrEmptyInput) } - decodedCrypttext, err := goolmbase64.Decode([]byte(crypttext)) + decodedCrypttext, err := base64.RawStdEncoding.DecodeString(crypttext) if err != nil { return nil, err } @@ -365,6 +365,9 @@ func (o *OlmSession) Unpickle(pickled, key []byte) error { func (o *OlmSession) UnpickleLibOlm(buf []byte) error { decoder := libolmpickle.NewDecoder(buf) pickledVersion, err := decoder.ReadUInt32() + if err != nil { + return fmt.Errorf("unpickle olmSession: failed to read version: %w", err) + } var includesChainIndex bool switch pickledVersion { @@ -373,7 +376,7 @@ func (o *OlmSession) UnpickleLibOlm(buf []byte) error { case uint32(0x80000001): includesChainIndex = true default: - return fmt.Errorf("unpickle olmSession: %w (found version %d)", olm.ErrBadVersion, pickledVersion) + return fmt.Errorf("unpickle olmSession: %w (found version %d)", olm.ErrUnknownOlmPickleVersion, pickledVersion) } if o.ReceivedMessage, err = decoder.ReadBool(); err != nil { diff --git a/crypto/goolm/session/register.go b/crypto/goolm/session/register.go index a88d12f6..b95a44ac 100644 --- a/crypto/goolm/session/register.go +++ b/crypto/goolm/session/register.go @@ -14,7 +14,7 @@ func Register() { // Inbound Session olm.InitInboundGroupSessionFromPickled = func(pickled, key []byte) (olm.InboundGroupSession, error) { if len(pickled) == 0 { - return nil, olm.EmptyInput + return nil, olm.ErrEmptyInput } if len(key) == 0 { key = []byte(" ") @@ -23,13 +23,13 @@ func Register() { } olm.InitNewInboundGroupSession = func(sessionKey []byte) (olm.InboundGroupSession, error) { if len(sessionKey) == 0 { - return nil, olm.EmptyInput + return nil, olm.ErrEmptyInput } return NewMegolmInboundSession(sessionKey) } olm.InitInboundGroupSessionImport = func(sessionKey []byte) (olm.InboundGroupSession, error) { if len(sessionKey) == 0 { - return nil, olm.EmptyInput + return nil, olm.ErrEmptyInput } return NewMegolmInboundSessionFromExport(sessionKey) } @@ -40,7 +40,7 @@ func Register() { // Outbound Session olm.InitNewOutboundGroupSessionFromPickled = func(pickled, key []byte) (olm.OutboundGroupSession, error) { if len(pickled) == 0 { - return nil, olm.EmptyInput + return nil, olm.ErrEmptyInput } lenKey := len(key) if lenKey == 0 { diff --git a/crypto/keybackup.go b/crypto/keybackup.go index d8b3d715..7b3c30db 100644 --- a/crypto/keybackup.go +++ b/crypto/keybackup.go @@ -56,11 +56,12 @@ func (mach *OlmMachine) GetAndVerifyLatestKeyBackupVersion(ctx context.Context, // ...by deriving the public key from a private key that it obtained from a trusted source. Trusted sources for the private // key include the user entering the key, retrieving the key stored in secret storage, or obtaining the key via secret sharing // from a verified device belonging to the same user." - megolmBackupDerivedPublicKey := id.Ed25519(base64.RawStdEncoding.EncodeToString(megolmBackupKey.PublicKey().Bytes())) - if megolmBackupKey != nil && versionInfo.AuthData.PublicKey == megolmBackupDerivedPublicKey { - log.Debug().Msg("key backup is trusted based on derived public key") - return versionInfo, nil - } else { + if megolmBackupKey != nil { + megolmBackupDerivedPublicKey := id.Ed25519(base64.RawStdEncoding.EncodeToString(megolmBackupKey.PublicKey().Bytes())) + if versionInfo.AuthData.PublicKey == megolmBackupDerivedPublicKey { + log.Debug().Msg("Key backup is trusted based on derived public key") + return versionInfo, nil + } log.Debug(). Stringer("expected_key", megolmBackupDerivedPublicKey). Stringer("actual_key", versionInfo.AuthData.PublicKey). @@ -199,13 +200,14 @@ func (mach *OlmMachine) ImportRoomKeyFromBackupWithoutSaving( SigningKey: keyBackupData.SenderClaimedKeys.Ed25519, SenderKey: keyBackupData.SenderKey, RoomID: roomID, - ForwardingChains: append(keyBackupData.ForwardingKeyChain, keyBackupData.SenderKey.String()), + ForwardingChains: keyBackupData.ForwardingKeyChain, id: sessionID, ReceivedAt: time.Now().UTC(), MaxAge: maxAge.Milliseconds(), MaxMessages: maxMessages, KeyBackupVersion: version, + KeySource: id.KeySourceBackup, }, nil } diff --git a/crypto/keyexport_test.go b/crypto/keyexport_test.go index 47616a20..fd6f105d 100644 --- a/crypto/keyexport_test.go +++ b/crypto/keyexport_test.go @@ -31,5 +31,5 @@ func TestExportKeys(t *testing.T) { )) data, err := crypto.ExportKeys("meow", []*crypto.InboundGroupSession{sess}) assert.NoError(t, err) - assert.Len(t, data, 836) + assert.Len(t, data, 893) } diff --git a/crypto/keyimport.go b/crypto/keyimport.go index 36ad6b9c..3ffc74a5 100644 --- a/crypto/keyimport.go +++ b/crypto/keyimport.go @@ -108,19 +108,20 @@ func (mach *OlmMachine) importExportedRoomKey(ctx context.Context, session Expor return false, ErrMismatchingExportedSessionID } igs := &InboundGroupSession{ - Internal: igsInternal, - SigningKey: session.SenderClaimedKeys.Ed25519, - SenderKey: session.SenderKey, - RoomID: session.RoomID, - // TODO should we add something here to mark the signing key as unverified like key requests do? + Internal: igsInternal, + SigningKey: session.SenderClaimedKeys.Ed25519, + SenderKey: session.SenderKey, + RoomID: session.RoomID, ForwardingChains: session.ForwardingChains, - - ReceivedAt: time.Now().UTC(), + KeySource: id.KeySourceImport, + ReceivedAt: time.Now().UTC(), } existingIGS, _ := mach.CryptoStore.GetGroupSession(ctx, igs.RoomID, igs.ID()) firstKnownIndex := igs.Internal.FirstKnownIndex() if existingIGS != nil && existingIGS.Internal.FirstKnownIndex() <= firstKnownIndex { - // We already have an equivalent or better session in the store, so don't override it. + // We already have an equivalent or better session in the store, so don't override it, + // but do notify the session received callback just in case. + mach.MarkSessionReceived(ctx, session.RoomID, igs.ID(), existingIGS.Internal.FirstKnownIndex()) return false, nil } err = mach.CryptoStore.PutGroupSession(ctx, igs) diff --git a/crypto/keysharing.go b/crypto/keysharing.go index f1d427af..19a68c87 100644 --- a/crypto/keysharing.go +++ b/crypto/keysharing.go @@ -189,6 +189,7 @@ func (mach *OlmMachine) importForwardedRoomKey(ctx context.Context, evt *Decrypt MaxAge: maxAge.Milliseconds(), MaxMessages: maxMessages, IsScheduled: content.IsScheduled, + KeySource: id.KeySourceForward, } existingIGS, _ := mach.CryptoStore.GetGroupSession(ctx, igs.RoomID, igs.ID()) if existingIGS != nil && existingIGS.Internal.FirstKnownIndex() <= igs.Internal.FirstKnownIndex() { @@ -214,6 +215,7 @@ func (mach *OlmMachine) rejectKeyRequest(ctx context.Context, rejection KeyShare RoomID: request.RoomID, Algorithm: request.Algorithm, SessionID: request.SessionID, + //lint:ignore SA1019 This is just echoing back the deprecated field SenderKey: request.SenderKey, Code: rejection.Code, Reason: rejection.Reason, @@ -263,9 +265,14 @@ func (mach *OlmMachine) defaultAllowKeyShare(ctx context.Context, device *id.Dev log.Err(err).Msg("Rejecting key request due to internal error when checking session sharing") return &KeyShareRejectNoResponse } else if !isShared { - // TODO differentiate session not shared with requester vs session not created by this device? - log.Debug().Msg("Rejecting key request for unshared session") - return &KeyShareRejectNotRecipient + igs, _ := mach.CryptoStore.GetGroupSession(ctx, evt.RoomID, evt.SessionID) + if igs != nil && igs.SenderKey == mach.OwnIdentity().IdentityKey { + log.Debug().Msg("Rejecting key request for unshared session") + return &KeyShareRejectNotRecipient + } + // Note: this case will also happen for redacted sessions and database errors + log.Debug().Msg("Rejecting key request for session created by another device") + return &KeyShareRejectNoResponse } log.Debug().Msg("Accepting key request for shared session") return nil @@ -323,7 +330,9 @@ func (mach *OlmMachine) HandleRoomKeyRequest(ctx context.Context, sender id.User if err != nil { if errors.Is(err, ErrGroupSessionWithheld) { log.Debug().Err(err).Msg("Requested group session not available") - mach.rejectKeyRequest(ctx, KeyShareRejectUnavailable, device, content.Body) + if sender != mach.Client.UserID { + mach.rejectKeyRequest(ctx, KeyShareRejectUnavailable, device, content.Body) + } } else { log.Error().Err(err).Msg("Failed to get group session to forward") mach.rejectKeyRequest(ctx, KeyShareRejectInternalError, device, content.Body) @@ -331,7 +340,9 @@ func (mach *OlmMachine) HandleRoomKeyRequest(ctx context.Context, sender id.User return } else if igs == nil { log.Error().Msg("Didn't find group session to forward") - mach.rejectKeyRequest(ctx, KeyShareRejectUnavailable, device, content.Body) + if sender != mach.Client.UserID { + mach.rejectKeyRequest(ctx, KeyShareRejectUnavailable, device, content.Body) + } return } if internalID := igs.ID(); internalID != content.Body.SessionID { @@ -356,7 +367,7 @@ func (mach *OlmMachine) HandleRoomKeyRequest(ctx context.Context, sender id.User SessionID: igs.ID(), SessionKey: string(exportedKey), }, - SenderKey: content.Body.SenderKey, + SenderKey: igs.SenderKey, ForwardingKeyChain: igs.ForwardingChains, SenderClaimedKey: igs.SigningKey, }, diff --git a/crypto/libolm/account.go b/crypto/libolm/account.go index f6f916e7..0350f083 100644 --- a/crypto/libolm/account.go +++ b/crypto/libolm/account.go @@ -33,7 +33,7 @@ var _ olm.Account = (*Account)(nil) // "INVALID_BASE64". func AccountFromPickled(pickled, key []byte) (*Account, error) { if len(pickled) == 0 { - return nil, olm.EmptyInput + return nil, olm.ErrEmptyInput } a := NewBlankAccount() return a, a.Unpickle(pickled, key) @@ -53,7 +53,7 @@ func NewAccount() (*Account, error) { random := make([]byte, a.createRandomLen()+1) _, err := rand.Read(random) if err != nil { - panic(olm.NotEnoughGoRandom) + panic(olm.ErrNotEnoughGoRandom) } ret := C.olm_create_account( (*C.OlmAccount)(a.int), @@ -128,7 +128,7 @@ func (a *Account) genOneTimeKeysRandomLen(num uint) uint { // supplied key. func (a *Account) Pickle(key []byte) ([]byte, error) { if len(key) == 0 { - return nil, olm.NoKeyProvided + return nil, olm.ErrNoKeyProvided } pickled := make([]byte, a.pickleLen()) r := C.olm_pickle_account( @@ -145,7 +145,7 @@ func (a *Account) Pickle(key []byte) ([]byte, error) { func (a *Account) Unpickle(pickled, key []byte) error { if len(key) == 0 { - return olm.NoKeyProvided + return olm.ErrNoKeyProvided } r := C.olm_unpickle_account( (*C.OlmAccount)(a.int), @@ -198,7 +198,7 @@ func (a *Account) MarshalJSON() ([]byte, error) { // Deprecated func (a *Account) UnmarshalJSON(data []byte) error { if len(data) == 0 || data[0] != '"' || data[len(data)-1] != '"' { - return olm.InputNotJSONString + return olm.ErrInputNotJSONString } if a.int == nil { *a = *NewBlankAccount() @@ -235,7 +235,7 @@ func (a *Account) IdentityKeys() (id.Ed25519, id.Curve25519, error) { // Account. func (a *Account) Sign(message []byte) ([]byte, error) { if len(message) == 0 { - panic(olm.EmptyInput) + panic(olm.ErrEmptyInput) } signature := make([]byte, a.signatureLen()) r := C.olm_account_sign( @@ -299,7 +299,7 @@ func (a *Account) GenOneTimeKeys(num uint) error { random := make([]byte, a.genOneTimeKeysRandomLen(num)+1) _, err := rand.Read(random) if err != nil { - return olm.NotEnoughGoRandom + return olm.ErrNotEnoughGoRandom } r := C.olm_account_generate_one_time_keys( (*C.OlmAccount)(a.int), @@ -319,13 +319,13 @@ func (a *Account) GenOneTimeKeys(num uint) error { // keys couldn't be decoded as base64 then the error will be "INVALID_BASE64" func (a *Account) NewOutboundSession(theirIdentityKey, theirOneTimeKey id.Curve25519) (olm.Session, error) { if len(theirIdentityKey) == 0 || len(theirOneTimeKey) == 0 { - return nil, olm.EmptyInput + return nil, olm.ErrEmptyInput } s := NewBlankSession() random := make([]byte, s.createOutboundRandomLen()+1) _, err := rand.Read(random) if err != nil { - panic(olm.NotEnoughGoRandom) + panic(olm.ErrNotEnoughGoRandom) } theirIdentityKeyCopy := []byte(theirIdentityKey) theirOneTimeKeyCopy := []byte(theirOneTimeKey) @@ -357,7 +357,7 @@ func (a *Account) NewOutboundSession(theirIdentityKey, theirOneTimeKey id.Curve2 // time key then the error will be "BAD_MESSAGE_KEY_ID". func (a *Account) NewInboundSession(oneTimeKeyMsg string) (olm.Session, error) { if len(oneTimeKeyMsg) == 0 { - return nil, olm.EmptyInput + return nil, olm.ErrEmptyInput } s := NewBlankSession() oneTimeKeyMsgCopy := []byte(oneTimeKeyMsg) @@ -383,7 +383,7 @@ func (a *Account) NewInboundSession(oneTimeKeyMsg string) (olm.Session, error) { // time key then the error will be "BAD_MESSAGE_KEY_ID". func (a *Account) NewInboundSessionFrom(theirIdentityKey *id.Curve25519, oneTimeKeyMsg string) (olm.Session, error) { if theirIdentityKey == nil || len(oneTimeKeyMsg) == 0 { - return nil, olm.EmptyInput + return nil, olm.ErrEmptyInput } theirIdentityKeyCopy := []byte(*theirIdentityKey) oneTimeKeyMsgCopy := []byte(oneTimeKeyMsg) diff --git a/crypto/libolm/error.go b/crypto/libolm/error.go index 9ca415ee..6fb5512b 100644 --- a/crypto/libolm/error.go +++ b/crypto/libolm/error.go @@ -11,21 +11,21 @@ import ( ) var errorMap = map[string]error{ - "NOT_ENOUGH_RANDOM": olm.NotEnoughRandom, - "OUTPUT_BUFFER_TOO_SMALL": olm.OutputBufferTooSmall, - "BAD_MESSAGE_VERSION": olm.BadMessageVersion, - "BAD_MESSAGE_FORMAT": olm.BadMessageFormat, - "BAD_MESSAGE_MAC": olm.BadMessageMAC, - "BAD_MESSAGE_KEY_ID": olm.BadMessageKeyID, - "INVALID_BASE64": olm.InvalidBase64, - "BAD_ACCOUNT_KEY": olm.BadAccountKey, - "UNKNOWN_PICKLE_VERSION": olm.UnknownPickleVersion, - "CORRUPTED_PICKLE": olm.CorruptedPickle, - "BAD_SESSION_KEY": olm.BadSessionKey, - "UNKNOWN_MESSAGE_INDEX": olm.UnknownMessageIndex, - "BAD_LEGACY_ACCOUNT_PICKLE": olm.BadLegacyAccountPickle, - "BAD_SIGNATURE": olm.BadSignature, - "INPUT_BUFFER_TOO_SMALL": olm.InputBufferTooSmall, + "NOT_ENOUGH_RANDOM": olm.ErrLibolmNotEnoughRandom, + "OUTPUT_BUFFER_TOO_SMALL": olm.ErrLibolmOutputBufferTooSmall, + "BAD_MESSAGE_VERSION": olm.ErrWrongProtocolVersion, + "BAD_MESSAGE_FORMAT": olm.ErrBadMessageFormat, + "BAD_MESSAGE_MAC": olm.ErrBadMAC, + "BAD_MESSAGE_KEY_ID": olm.ErrBadMessageKeyID, + "INVALID_BASE64": olm.ErrLibolmInvalidBase64, + "BAD_ACCOUNT_KEY": olm.ErrLibolmBadAccountKey, + "UNKNOWN_PICKLE_VERSION": olm.ErrUnknownOlmPickleVersion, + "CORRUPTED_PICKLE": olm.ErrLibolmCorruptedPickle, + "BAD_SESSION_KEY": olm.ErrLibolmBadSessionKey, + "UNKNOWN_MESSAGE_INDEX": olm.ErrUnknownMessageIndex, + "BAD_LEGACY_ACCOUNT_PICKLE": olm.ErrLibolmBadLegacyAccountPickle, + "BAD_SIGNATURE": olm.ErrBadSignature, + "INPUT_BUFFER_TOO_SMALL": olm.ErrInputToSmall, } func convertError(errCode string) error { diff --git a/crypto/libolm/inboundgroupsession.go b/crypto/libolm/inboundgroupsession.go index 5606475d..8815ac32 100644 --- a/crypto/libolm/inboundgroupsession.go +++ b/crypto/libolm/inboundgroupsession.go @@ -31,7 +31,7 @@ var _ olm.InboundGroupSession = (*InboundGroupSession)(nil) // base64 couldn't be decoded then the error will be "INVALID_BASE64". func InboundGroupSessionFromPickled(pickled, key []byte) (*InboundGroupSession, error) { if len(pickled) == 0 { - return nil, olm.EmptyInput + return nil, olm.ErrEmptyInput } lenKey := len(key) if lenKey == 0 { @@ -48,7 +48,7 @@ func InboundGroupSessionFromPickled(pickled, key []byte) (*InboundGroupSession, // "OLM_BAD_SESSION_KEY". func NewInboundGroupSession(sessionKey []byte) (*InboundGroupSession, error) { if len(sessionKey) == 0 { - return nil, olm.EmptyInput + return nil, olm.ErrEmptyInput } s := NewBlankInboundGroupSession() r := C.olm_init_inbound_group_session( @@ -69,7 +69,7 @@ func NewInboundGroupSession(sessionKey []byte) (*InboundGroupSession, error) { // error will be "OLM_BAD_SESSION_KEY". func InboundGroupSessionImport(sessionKey []byte) (*InboundGroupSession, error) { if len(sessionKey) == 0 { - return nil, olm.EmptyInput + return nil, olm.ErrEmptyInput } s := NewBlankInboundGroupSession() r := C.olm_import_inbound_group_session( @@ -124,7 +124,7 @@ func (s *InboundGroupSession) pickleLen() uint { // InboundGroupSession using the supplied key. func (s *InboundGroupSession) Pickle(key []byte) ([]byte, error) { if len(key) == 0 { - return nil, olm.NoKeyProvided + return nil, olm.ErrNoKeyProvided } pickled := make([]byte, s.pickleLen()) r := C.olm_pickle_inbound_group_session( @@ -143,9 +143,9 @@ func (s *InboundGroupSession) Pickle(key []byte) ([]byte, error) { func (s *InboundGroupSession) Unpickle(pickled, key []byte) error { if len(key) == 0 { - return olm.NoKeyProvided + return olm.ErrNoKeyProvided } else if len(pickled) == 0 { - return olm.EmptyInput + return olm.ErrEmptyInput } r := C.olm_unpickle_inbound_group_session( (*C.OlmInboundGroupSession)(s.int), @@ -200,7 +200,7 @@ func (s *InboundGroupSession) MarshalJSON() ([]byte, error) { // Deprecated func (s *InboundGroupSession) UnmarshalJSON(data []byte) error { if len(data) == 0 || data[0] != '"' || data[len(data)-1] != '"' { - return olm.InputNotJSONString + return olm.ErrInputNotJSONString } if s == nil || s.int == nil { *s = *NewBlankInboundGroupSession() @@ -217,7 +217,7 @@ func (s *InboundGroupSession) UnmarshalJSON(data []byte) error { // will be "BAD_MESSAGE_FORMAT". func (s *InboundGroupSession) decryptMaxPlaintextLen(message []byte) (uint, error) { if len(message) == 0 { - return 0, olm.EmptyInput + return 0, olm.ErrEmptyInput } // olm_group_decrypt_max_plaintext_length destroys the input, so we have to clone it messageCopy := bytes.Clone(message) @@ -244,7 +244,7 @@ func (s *InboundGroupSession) decryptMaxPlaintextLen(message []byte) (uint, erro // was shared with us) the error will be "OLM_UNKNOWN_MESSAGE_INDEX". func (s *InboundGroupSession) Decrypt(message []byte) ([]byte, uint, error) { if len(message) == 0 { - return nil, 0, olm.EmptyInput + return nil, 0, olm.ErrEmptyInput } decryptMaxPlaintextLen, err := s.decryptMaxPlaintextLen(message) if err != nil { diff --git a/crypto/libolm/outboundgroupsession.go b/crypto/libolm/outboundgroupsession.go index 646929eb..ca5b68f7 100644 --- a/crypto/libolm/outboundgroupsession.go +++ b/crypto/libolm/outboundgroupsession.go @@ -84,7 +84,7 @@ func (s *OutboundGroupSession) pickleLen() uint { // OutboundGroupSession using the supplied key. func (s *OutboundGroupSession) Pickle(key []byte) ([]byte, error) { if len(key) == 0 { - return nil, olm.NoKeyProvided + return nil, olm.ErrNoKeyProvided } pickled := make([]byte, s.pickleLen()) r := C.olm_pickle_outbound_group_session( @@ -103,7 +103,7 @@ func (s *OutboundGroupSession) Pickle(key []byte) ([]byte, error) { func (s *OutboundGroupSession) Unpickle(pickled, key []byte) error { if len(key) == 0 { - return olm.NoKeyProvided + return olm.ErrNoKeyProvided } r := C.olm_unpickle_outbound_group_session( (*C.OlmOutboundGroupSession)(s.int), @@ -159,7 +159,7 @@ func (s *OutboundGroupSession) MarshalJSON() ([]byte, error) { // Deprecated func (s *OutboundGroupSession) UnmarshalJSON(data []byte) error { if len(data) == 0 || data[0] != '"' || data[len(data)-1] != '"' { - return olm.InputNotJSONString + return olm.ErrInputNotJSONString } if s == nil || s.int == nil { *s = *NewBlankOutboundGroupSession() @@ -183,7 +183,7 @@ func (s *OutboundGroupSession) encryptMsgLen(plainTextLen int) uint { // as base64. func (s *OutboundGroupSession) Encrypt(plaintext []byte) ([]byte, error) { if len(plaintext) == 0 { - return nil, olm.EmptyInput + return nil, olm.ErrEmptyInput } message := make([]byte, s.encryptMsgLen(len(plaintext))) r := C.olm_group_encrypt( diff --git a/crypto/libolm/pk.go b/crypto/libolm/pk.go index 35532140..2683cf15 100644 --- a/crypto/libolm/pk.go +++ b/crypto/libolm/pk.go @@ -86,7 +86,7 @@ func NewPKSigning() (*PKSigning, error) { seed := make([]byte, pkSigningSeedLength()) _, err := rand.Read(seed) if err != nil { - panic(olm.NotEnoughGoRandom) + panic(olm.ErrNotEnoughGoRandom) } pk, err := NewPKSigningFromSeed(seed) return pk, err diff --git a/crypto/libolm/register.go b/crypto/libolm/register.go index f091d822..ddf84613 100644 --- a/crypto/libolm/register.go +++ b/crypto/libolm/register.go @@ -65,7 +65,7 @@ func Register() { olm.InitNewOutboundGroupSessionFromPickled = func(pickled, key []byte) (olm.OutboundGroupSession, error) { if len(pickled) == 0 { - return nil, olm.EmptyInput + return nil, olm.ErrEmptyInput } s := NewBlankOutboundGroupSession() return s, s.Unpickle(pickled, key) diff --git a/crypto/libolm/session.go b/crypto/libolm/session.go index 57e631c3..1441df26 100644 --- a/crypto/libolm/session.go +++ b/crypto/libolm/session.go @@ -51,7 +51,7 @@ func sessionSize() uint { // "INVALID_BASE64". func SessionFromPickled(pickled, key []byte) (*Session, error) { if len(pickled) == 0 { - return nil, olm.EmptyInput + return nil, olm.ErrEmptyInput } s := NewBlankSession() return s, s.Unpickle(pickled, key) @@ -118,7 +118,7 @@ func (s *Session) encryptMsgLen(plainTextLen int) uint { // will be "BAD_MESSAGE_FORMAT". func (s *Session) decryptMaxPlaintextLen(message string, msgType id.OlmMsgType) (uint, error) { if len(message) == 0 { - return 0, olm.EmptyInput + return 0, olm.ErrEmptyInput } messageCopy := []byte(message) r := C.olm_decrypt_max_plaintext_length( @@ -138,7 +138,7 @@ func (s *Session) decryptMaxPlaintextLen(message string, msgType id.OlmMsgType) // supplied key. func (s *Session) Pickle(key []byte) ([]byte, error) { if len(key) == 0 { - return nil, olm.NoKeyProvided + return nil, olm.ErrNoKeyProvided } pickled := make([]byte, s.pickleLen()) r := C.olm_pickle_session( @@ -158,7 +158,7 @@ func (s *Session) Pickle(key []byte) ([]byte, error) { // provided key. This function mutates the input pickled data slice. func (s *Session) Unpickle(pickled, key []byte) error { if len(key) == 0 { - return olm.NoKeyProvided + return olm.ErrNoKeyProvided } r := C.olm_unpickle_session( (*C.OlmSession)(s.int), @@ -213,7 +213,7 @@ func (s *Session) MarshalJSON() ([]byte, error) { // Deprecated func (s *Session) UnmarshalJSON(data []byte) error { if len(data) == 0 || data[0] != '"' || data[len(data)-1] != '"' { - return olm.InputNotJSONString + return olm.ErrInputNotJSONString } if s == nil || s.int == nil { *s = *NewBlankSession() @@ -256,7 +256,7 @@ func (s *Session) HasReceivedMessage() bool { // decoded then then the error will be "BAD_MESSAGE_FORMAT". func (s *Session) MatchesInboundSession(oneTimeKeyMsg string) (bool, error) { if len(oneTimeKeyMsg) == 0 { - return false, olm.EmptyInput + return false, olm.ErrEmptyInput } oneTimeKeyMsgCopy := []byte(oneTimeKeyMsg) r := C.olm_matches_inbound_session( @@ -284,7 +284,7 @@ func (s *Session) MatchesInboundSession(oneTimeKeyMsg string) (bool, error) { // decoded then then the error will be "BAD_MESSAGE_FORMAT". func (s *Session) MatchesInboundSessionFrom(theirIdentityKey, oneTimeKeyMsg string) (bool, error) { if len(theirIdentityKey) == 0 || len(oneTimeKeyMsg) == 0 { - return false, olm.EmptyInput + return false, olm.ErrEmptyInput } theirIdentityKeyCopy := []byte(theirIdentityKey) oneTimeKeyMsgCopy := []byte(oneTimeKeyMsg) @@ -325,14 +325,14 @@ func (s *Session) EncryptMsgType() id.OlmMsgType { // as base64. func (s *Session) Encrypt(plaintext []byte) (id.OlmMsgType, []byte, error) { if len(plaintext) == 0 { - return 0, nil, olm.EmptyInput + return 0, nil, olm.ErrEmptyInput } // Make the slice be at least length 1 random := make([]byte, s.encryptRandomLen()+1) _, err := rand.Read(random) if err != nil { // TODO can we just return err here? - return 0, nil, olm.NotEnoughGoRandom + return 0, nil, olm.ErrNotEnoughGoRandom } messageType := s.EncryptMsgType() message := make([]byte, s.encryptMsgLen(len(plaintext))) @@ -362,7 +362,7 @@ func (s *Session) Encrypt(plaintext []byte) (id.OlmMsgType, []byte, error) { // "BAD_MESSAGE_MAC". func (s *Session) Decrypt(message string, msgType id.OlmMsgType) ([]byte, error) { if len(message) == 0 { - return nil, olm.EmptyInput + return nil, olm.ErrEmptyInput } decryptMaxPlaintextLen, err := s.decryptMaxPlaintextLen(message, msgType) if err != nil { diff --git a/crypto/machine.go b/crypto/machine.go index 4d2e3880..fa051f94 100644 --- a/crypto/machine.go +++ b/crypto/machine.go @@ -39,6 +39,7 @@ type OlmMachine struct { cancelBackgroundCtx context.CancelFunc PlaintextMentions bool + MSC4392Relations bool AllowEncryptedState bool // Never ask the server for keys automatically as a side effect during Megolm decryption. @@ -205,7 +206,7 @@ func (mach *OlmMachine) FlushStore(ctx context.Context) error { func (mach *OlmMachine) timeTrace(ctx context.Context, thing string, expectedDuration time.Duration) func() { start := time.Now() return func() { - duration := time.Now().Sub(start) + duration := time.Since(start) if duration > expectedDuration { zerolog.Ctx(ctx).Warn(). Str("action", thing). diff --git a/crypto/olm/errors.go b/crypto/olm/errors.go index 957d7928..9e522b2a 100644 --- a/crypto/olm/errors.go +++ b/crypto/olm/errors.go @@ -10,50 +10,67 @@ import "errors" // Those are the most common used errors var ( - ErrBadSignature = errors.New("bad signature") - ErrBadMAC = errors.New("bad mac") - ErrBadMessageFormat = errors.New("bad message format") - ErrBadVerification = errors.New("bad verification") - ErrWrongProtocolVersion = errors.New("wrong protocol version") - ErrEmptyInput = errors.New("empty input") - ErrNoKeyProvided = errors.New("no key") - ErrBadMessageKeyID = errors.New("bad message key id") - ErrRatchetNotAvailable = errors.New("ratchet not available: attempt to decode a message whose index is earlier than our earliest known session key") - ErrMsgIndexTooHigh = errors.New("message index too high") - ErrProtocolViolation = errors.New("not protocol message order") - ErrMessageKeyNotFound = errors.New("message key not found") - ErrChainTooHigh = errors.New("chain index too high") - ErrBadInput = errors.New("bad input") - ErrBadVersion = errors.New("wrong version") - ErrWrongPickleVersion = errors.New("wrong pickle version") - ErrInputToSmall = errors.New("input too small (truncated?)") - ErrOverflow = errors.New("overflow") + ErrBadSignature = errors.New("bad signature") + ErrBadMAC = errors.New("the message couldn't be decrypted (bad mac)") + ErrBadMessageFormat = errors.New("the message couldn't be decoded") + ErrBadVerification = errors.New("bad verification") + ErrWrongProtocolVersion = errors.New("wrong protocol version") + ErrEmptyInput = errors.New("empty input") + ErrNoKeyProvided = errors.New("no key provided") + ErrBadMessageKeyID = errors.New("the message references an unknown key ID") + ErrUnknownMessageIndex = errors.New("attempt to decode a message whose index is earlier than our earliest known session key") + ErrMsgIndexTooHigh = errors.New("message index too high") + ErrProtocolViolation = errors.New("not protocol message order") + ErrMessageKeyNotFound = errors.New("message key not found") + ErrChainTooHigh = errors.New("chain index too high") + ErrBadInput = errors.New("bad input") + ErrUnknownOlmPickleVersion = errors.New("unknown olm pickle version") + ErrUnknownJSONPickleVersion = errors.New("unknown JSON pickle version") + ErrInputToSmall = errors.New("input too small (truncated?)") ) // Error codes from go-olm var ( - EmptyInput = errors.New("empty input") - NoKeyProvided = errors.New("no pickle key provided") - NotEnoughGoRandom = errors.New("couldn't get enough randomness from crypto/rand") - SignatureNotFound = errors.New("input JSON doesn't contain signature from specified device") - InputNotJSONString = errors.New("input doesn't look like a JSON string") + ErrNotEnoughGoRandom = errors.New("couldn't get enough randomness from crypto/rand") + ErrInputNotJSONString = errors.New("input doesn't look like a JSON string") ) // Error codes from olm code var ( - NotEnoughRandom = errors.New("not enough entropy was supplied") - OutputBufferTooSmall = errors.New("supplied output buffer is too small") - BadMessageVersion = errors.New("the message version is unsupported") - BadMessageFormat = errors.New("the message couldn't be decoded") - BadMessageMAC = errors.New("the message couldn't be decrypted") - BadMessageKeyID = errors.New("the message references an unknown key ID") - InvalidBase64 = errors.New("the input base64 was invalid") - BadAccountKey = errors.New("the supplied account key is invalid") - UnknownPickleVersion = errors.New("the pickled object is too new") - CorruptedPickle = errors.New("the pickled object couldn't be decoded") - BadSessionKey = errors.New("attempt to initialise an inbound group session from an invalid session key") - UnknownMessageIndex = errors.New("attempt to decode a message whose index is earlier than our earliest known session key") - BadLegacyAccountPickle = errors.New("attempt to unpickle an account which uses pickle version 1") - BadSignature = errors.New("received message had a bad signature") - InputBufferTooSmall = errors.New("the input data was too small to be valid") + ErrLibolmInvalidBase64 = errors.New("the input base64 was invalid") + + ErrLibolmNotEnoughRandom = errors.New("not enough entropy was supplied") + ErrLibolmOutputBufferTooSmall = errors.New("supplied output buffer is too small") + ErrLibolmBadAccountKey = errors.New("the supplied account key is invalid") + ErrLibolmCorruptedPickle = errors.New("the pickled object couldn't be decoded") + ErrLibolmBadSessionKey = errors.New("attempt to initialise an inbound group session from an invalid session key") + ErrLibolmBadLegacyAccountPickle = errors.New("attempt to unpickle an account which uses pickle version 1") +) + +// Deprecated: use variables prefixed with Err +var ( + EmptyInput = ErrEmptyInput + BadSignature = ErrBadSignature + InvalidBase64 = ErrLibolmInvalidBase64 + BadMessageKeyID = ErrBadMessageKeyID + BadMessageFormat = ErrBadMessageFormat + BadMessageVersion = ErrWrongProtocolVersion + BadMessageMAC = ErrBadMAC + UnknownPickleVersion = ErrUnknownOlmPickleVersion + NotEnoughRandom = ErrLibolmNotEnoughRandom + OutputBufferTooSmall = ErrLibolmOutputBufferTooSmall + BadAccountKey = ErrLibolmBadAccountKey + CorruptedPickle = ErrLibolmCorruptedPickle + BadSessionKey = ErrLibolmBadSessionKey + UnknownMessageIndex = ErrUnknownMessageIndex + BadLegacyAccountPickle = ErrLibolmBadLegacyAccountPickle + InputBufferTooSmall = ErrInputToSmall + NoKeyProvided = ErrNoKeyProvided + + NotEnoughGoRandom = ErrNotEnoughGoRandom + InputNotJSONString = ErrInputNotJSONString + + ErrBadVersion = ErrUnknownJSONPickleVersion + ErrWrongPickleVersion = ErrUnknownJSONPickleVersion + ErrRatchetNotAvailable = ErrUnknownMessageIndex ) diff --git a/crypto/sessions.go b/crypto/sessions.go index aecb0416..ccc7b784 100644 --- a/crypto/sessions.go +++ b/crypto/sessions.go @@ -18,8 +18,14 @@ import ( ) var ( - SessionNotShared = errors.New("session has not been shared") - SessionExpired = errors.New("session has expired") + ErrSessionNotShared = errors.New("session has not been shared") + ErrSessionExpired = errors.New("session has expired") +) + +// Deprecated: use variables prefixed with Err +var ( + SessionNotShared = ErrSessionNotShared + SessionExpired = ErrSessionExpired ) // OlmSessionList is a list of OlmSessions. @@ -111,6 +117,7 @@ type InboundGroupSession struct { MaxMessages int IsScheduled bool KeyBackupVersion id.KeyBackupVersion + KeySource id.KeySource id id.SessionID } @@ -130,6 +137,7 @@ func NewInboundGroupSession(senderKey id.SenderKey, signingKey id.Ed25519, roomI MaxAge: maxAge.Milliseconds(), MaxMessages: maxMessages, IsScheduled: isScheduled, + KeySource: id.KeySourceDirect, }, nil } @@ -163,7 +171,7 @@ func (igs *InboundGroupSession) export() (*ExportedSession, error) { ForwardingChains: igs.ForwardingChains, RoomID: igs.RoomID, SenderKey: igs.SenderKey, - SenderClaimedKeys: SenderClaimedKeys{}, + SenderClaimedKeys: SenderClaimedKeys{Ed25519: igs.SigningKey}, SessionID: igs.ID(), SessionKey: string(key), }, nil @@ -255,9 +263,9 @@ func (ogs *OutboundGroupSession) Expired() bool { func (ogs *OutboundGroupSession) Encrypt(plaintext []byte) ([]byte, error) { if !ogs.Shared { - return nil, SessionNotShared + return nil, ErrSessionNotShared } else if ogs.Expired() { - return nil, SessionExpired + return nil, ErrSessionExpired } ogs.MessageCount++ ogs.LastEncryptedTime = time.Now() diff --git a/crypto/sql_store.go b/crypto/sql_store.go index ca75b3f6..138cc557 100644 --- a/crypto/sql_store.go +++ b/crypto/sql_store.go @@ -346,22 +346,23 @@ func (store *SQLCryptoStore) PutGroupSession(ctx context.Context, session *Inbou Int("max_messages", session.MaxMessages). Bool("is_scheduled", session.IsScheduled). Stringer("key_backup_version", session.KeyBackupVersion). + Stringer("key_source", session.KeySource). Msg("Upserting megolm inbound group session") _, err = store.DB.Exec(ctx, ` INSERT INTO crypto_megolm_inbound_session ( session_id, sender_key, signing_key, room_id, session, forwarding_chains, - ratchet_safety, received_at, max_age, max_messages, is_scheduled, key_backup_version, account_id - ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13) + ratchet_safety, received_at, max_age, max_messages, is_scheduled, key_backup_version, key_source, account_id + ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14) ON CONFLICT (session_id, account_id) DO UPDATE SET withheld_code=NULL, withheld_reason=NULL, sender_key=excluded.sender_key, signing_key=excluded.signing_key, room_id=excluded.room_id, session=excluded.session, forwarding_chains=excluded.forwarding_chains, ratchet_safety=excluded.ratchet_safety, received_at=excluded.received_at, max_age=excluded.max_age, max_messages=excluded.max_messages, is_scheduled=excluded.is_scheduled, - key_backup_version=excluded.key_backup_version + key_backup_version=excluded.key_backup_version, key_source=excluded.key_source `, session.ID(), session.SenderKey, session.SigningKey, session.RoomID, sessionBytes, forwardingChains, ratchetSafety, datePtr(session.ReceivedAt), dbutil.NumPtr(session.MaxAge), dbutil.NumPtr(session.MaxMessages), - session.IsScheduled, session.KeyBackupVersion, store.AccountID, + session.IsScheduled, session.KeyBackupVersion, session.KeySource, store.AccountID, ) return err } @@ -374,12 +375,13 @@ func (store *SQLCryptoStore) GetGroupSession(ctx context.Context, roomID id.Room var maxAge, maxMessages sql.NullInt64 var isScheduled bool var version id.KeyBackupVersion + var keySource id.KeySource err := store.DB.QueryRow(ctx, ` - SELECT sender_key, signing_key, session, forwarding_chains, withheld_code, withheld_reason, ratchet_safety, received_at, max_age, max_messages, is_scheduled, key_backup_version + SELECT sender_key, signing_key, session, forwarding_chains, withheld_code, withheld_reason, ratchet_safety, received_at, max_age, max_messages, is_scheduled, key_backup_version, key_source FROM crypto_megolm_inbound_session WHERE room_id=$1 AND session_id=$2 AND account_id=$3`, roomID, sessionID, store.AccountID, - ).Scan(&senderKey, &signingKey, &sessionBytes, &forwardingChains, &withheldCode, &withheldReason, &ratchetSafetyBytes, &receivedAt, &maxAge, &maxMessages, &isScheduled, &version) + ).Scan(&senderKey, &signingKey, &sessionBytes, &forwardingChains, &withheldCode, &withheldReason, &ratchetSafetyBytes, &receivedAt, &maxAge, &maxMessages, &isScheduled, &version, &keySource) if errors.Is(err, sql.ErrNoRows) { return nil, nil } else if err != nil { @@ -410,6 +412,7 @@ func (store *SQLCryptoStore) GetGroupSession(ctx context.Context, roomID id.Room MaxMessages: int(maxMessages.Int64), IsScheduled: isScheduled, KeyBackupVersion: version, + KeySource: keySource, }, nil } @@ -534,7 +537,8 @@ func (store *SQLCryptoStore) scanInboundGroupSession(rows dbutil.Scannable) (*In var maxAge, maxMessages sql.NullInt64 var isScheduled bool var version id.KeyBackupVersion - err := rows.Scan(&roomID, &senderKey, &signingKey, &sessionBytes, &forwardingChains, &ratchetSafetyBytes, &receivedAt, &maxAge, &maxMessages, &isScheduled, &version) + var keySource id.KeySource + err := rows.Scan(&roomID, &senderKey, &signingKey, &sessionBytes, &forwardingChains, &ratchetSafetyBytes, &receivedAt, &maxAge, &maxMessages, &isScheduled, &version, &keySource) if err != nil { return nil, err } @@ -554,12 +558,13 @@ func (store *SQLCryptoStore) scanInboundGroupSession(rows dbutil.Scannable) (*In MaxMessages: int(maxMessages.Int64), IsScheduled: isScheduled, KeyBackupVersion: version, + KeySource: keySource, }, nil } func (store *SQLCryptoStore) GetGroupSessionsForRoom(ctx context.Context, roomID id.RoomID) dbutil.RowIter[*InboundGroupSession] { rows, err := store.DB.Query(ctx, ` - SELECT room_id, sender_key, signing_key, session, forwarding_chains, ratchet_safety, received_at, max_age, max_messages, is_scheduled, key_backup_version + SELECT room_id, sender_key, signing_key, session, forwarding_chains, ratchet_safety, received_at, max_age, max_messages, is_scheduled, key_backup_version, key_source FROM crypto_megolm_inbound_session WHERE room_id=$1 AND account_id=$2 AND session IS NOT NULL`, roomID, store.AccountID, ) @@ -568,7 +573,7 @@ func (store *SQLCryptoStore) GetGroupSessionsForRoom(ctx context.Context, roomID func (store *SQLCryptoStore) GetAllGroupSessions(ctx context.Context) dbutil.RowIter[*InboundGroupSession] { rows, err := store.DB.Query(ctx, ` - SELECT room_id, sender_key, signing_key, session, forwarding_chains, ratchet_safety, received_at, max_age, max_messages, is_scheduled, key_backup_version + SELECT room_id, sender_key, signing_key, session, forwarding_chains, ratchet_safety, received_at, max_age, max_messages, is_scheduled, key_backup_version, key_source FROM crypto_megolm_inbound_session WHERE account_id=$1 AND session IS NOT NULL`, store.AccountID, ) @@ -577,7 +582,7 @@ func (store *SQLCryptoStore) GetAllGroupSessions(ctx context.Context) dbutil.Row func (store *SQLCryptoStore) GetGroupSessionsWithoutKeyBackupVersion(ctx context.Context, version id.KeyBackupVersion) dbutil.RowIter[*InboundGroupSession] { rows, err := store.DB.Query(ctx, ` - SELECT room_id, sender_key, signing_key, session, forwarding_chains, ratchet_safety, received_at, max_age, max_messages, is_scheduled, key_backup_version + SELECT room_id, sender_key, signing_key, session, forwarding_chains, ratchet_safety, received_at, max_age, max_messages, is_scheduled, key_backup_version, key_source FROM crypto_megolm_inbound_session WHERE account_id=$1 AND session IS NOT NULL AND key_backup_version != $2`, store.AccountID, version, ) diff --git a/crypto/sql_store_upgrade/00-latest-revision.sql b/crypto/sql_store_upgrade/00-latest-revision.sql index af8ab5cc..3709f1e5 100644 --- a/crypto/sql_store_upgrade/00-latest-revision.sql +++ b/crypto/sql_store_upgrade/00-latest-revision.sql @@ -1,4 +1,4 @@ --- v0 -> v18 (compatible with v15+): Latest revision +-- v0 -> v19 (compatible with v15+): Latest revision CREATE TABLE IF NOT EXISTS crypto_account ( account_id TEXT PRIMARY KEY, device_id TEXT NOT NULL, @@ -71,6 +71,7 @@ CREATE TABLE IF NOT EXISTS crypto_megolm_inbound_session ( max_messages INTEGER, is_scheduled BOOLEAN NOT NULL DEFAULT false, key_backup_version TEXT NOT NULL DEFAULT '', + key_source TEXT NOT NULL DEFAULT '', PRIMARY KEY (account_id, session_id) ); -- Useful index to find keys that need backing up diff --git a/crypto/sql_store_upgrade/19-megolm-session-source.sql b/crypto/sql_store_upgrade/19-megolm-session-source.sql new file mode 100644 index 00000000..f624222f --- /dev/null +++ b/crypto/sql_store_upgrade/19-megolm-session-source.sql @@ -0,0 +1,2 @@ +-- v19 (compatible with v15+): Store megolm session source +ALTER TABLE crypto_megolm_inbound_session ADD COLUMN key_source TEXT NOT NULL DEFAULT ''; diff --git a/crypto/ssss/client.go b/crypto/ssss/client.go index e30925d9..8691d032 100644 --- a/crypto/ssss/client.go +++ b/crypto/ssss/client.go @@ -95,6 +95,22 @@ func (mach *Machine) SetEncryptedAccountData(ctx context.Context, eventType even return mach.Client.SetAccountData(ctx, eventType.Type, &EncryptedAccountDataEventContent{Encrypted: encrypted}) } +// SetEncryptedAccountDataWithMetadata encrypts the given data with the given keys and stores it, +// alongside the unencrypted metadata, on the server. +func (mach *Machine) SetEncryptedAccountDataWithMetadata(ctx context.Context, eventType event.Type, data []byte, metadata map[string]any, keys ...*Key) error { + if len(keys) == 0 { + return ErrNoKeyGiven + } + encrypted := make(map[string]EncryptedKeyData, len(keys)) + for _, key := range keys { + encrypted[key.ID] = key.Encrypt(eventType.Type, data) + } + return mach.Client.SetAccountData(ctx, eventType.Type, &EncryptedAccountDataEventContent{ + Encrypted: encrypted, + Metadata: metadata, + }) +} + // GenerateAndUploadKey generates a new SSSS key and stores the metadata on the server. func (mach *Machine) GenerateAndUploadKey(ctx context.Context, passphrase string) (key *Key, err error) { key, err = NewKey(passphrase) diff --git a/crypto/ssss/key.go b/crypto/ssss/key.go index cd8e3fce..78ebd8f3 100644 --- a/crypto/ssss/key.go +++ b/crypto/ssss/key.go @@ -59,12 +59,12 @@ func NewKey(passphrase string) (*Key, error) { // We store a certain hash in the key metadata so that clients can check if the user entered the correct key. ivBytes := random.Bytes(utils.AESCTRIVLength) keyData.IV = base64.RawStdEncoding.EncodeToString(ivBytes) - var err error - keyData.MAC, err = keyData.calculateHash(ssssKey) + macBytes, err := keyData.calculateHash(ssssKey) if err != nil { // This should never happen because we just generated the IV and key. return nil, fmt.Errorf("failed to calculate hash: %w", err) } + keyData.MAC = base64.RawStdEncoding.EncodeToString(macBytes) return &Key{ Key: ssssKey, diff --git a/crypto/ssss/meta.go b/crypto/ssss/meta.go index 474c85d8..34775fa7 100644 --- a/crypto/ssss/meta.go +++ b/crypto/ssss/meta.go @@ -7,7 +7,10 @@ package ssss import ( + "crypto/hmac" + "crypto/sha256" "encoding/base64" + "errors" "fmt" "strings" @@ -33,7 +36,9 @@ func (kd *KeyMetadata) VerifyPassphrase(keyID, passphrase string) (*Key, error) ssssKey, err := kd.Passphrase.GetKey(passphrase) if err != nil { return nil, err - } else if err = kd.verifyKey(ssssKey); err != nil { + } + err = kd.verifyKey(ssssKey) + if err != nil && !errors.Is(err, ErrUnverifiableKey) { return nil, err } @@ -49,7 +54,9 @@ func (kd *KeyMetadata) VerifyRecoveryKey(keyID, recoveryKey string) (*Key, error ssssKey := utils.DecodeBase58RecoveryKey(recoveryKey) if ssssKey == nil { return nil, ErrInvalidRecoveryKey - } else if err := kd.verifyKey(ssssKey); err != nil { + } + err := kd.verifyKey(ssssKey) + if err != nil && !errors.Is(err, ErrUnverifiableKey) { return nil, err } @@ -57,20 +64,28 @@ func (kd *KeyMetadata) VerifyRecoveryKey(keyID, recoveryKey string) (*Key, error ID: keyID, Key: ssssKey, Metadata: kd, - }, nil + }, err } func (kd *KeyMetadata) verifyKey(key []byte) error { + if kd.MAC == "" || kd.IV == "" { + return ErrUnverifiableKey + } unpaddedMAC := strings.TrimRight(kd.MAC, "=") expectedMACLength := base64.RawStdEncoding.EncodedLen(utils.SHAHashLength) if len(unpaddedMAC) != expectedMACLength { return fmt.Errorf("%w: invalid mac length %d (expected %d)", ErrCorruptedKeyMetadata, len(unpaddedMAC), expectedMACLength) } - hash, err := kd.calculateHash(key) + expectedMAC, err := base64.RawStdEncoding.DecodeString(unpaddedMAC) + if err != nil { + return fmt.Errorf("%w: failed to decode mac: %w", ErrCorruptedKeyMetadata, err) + } + calculatedMAC, err := kd.calculateHash(key) if err != nil { return err } - if unpaddedMAC != hash { + // This doesn't really need to be constant time since it's fully local, but might as well be. + if !hmac.Equal(expectedMAC, calculatedMAC) { return ErrIncorrectSSSSKey } return nil @@ -83,23 +98,26 @@ func (kd *KeyMetadata) VerifyKey(key []byte) bool { // calculateHash calculates the hash used for checking if the key is entered correctly as described // in the spec: https://matrix.org/docs/spec/client_server/unstable#m-secret-storage-v1-aes-hmac-sha2 -func (kd *KeyMetadata) calculateHash(key []byte) (string, error) { +func (kd *KeyMetadata) calculateHash(key []byte) ([]byte, error) { aesKey, hmacKey := utils.DeriveKeysSHA256(key, "") unpaddedIV := strings.TrimRight(kd.IV, "=") expectedIVLength := base64.RawStdEncoding.EncodedLen(utils.AESCTRIVLength) - if len(unpaddedIV) != expectedIVLength { - return "", fmt.Errorf("%w: invalid iv length %d (expected %d)", ErrCorruptedKeyMetadata, len(unpaddedIV), expectedIVLength) + if len(unpaddedIV) < expectedIVLength || len(unpaddedIV) > expectedIVLength*3 { + return nil, fmt.Errorf("%w: invalid iv length %d (expected %d)", ErrCorruptedKeyMetadata, len(unpaddedIV), expectedIVLength) } - - var ivBytes [utils.AESCTRIVLength]byte - _, err := base64.RawStdEncoding.Decode(ivBytes[:], []byte(unpaddedIV)) + rawIVBytes, err := base64.RawStdEncoding.DecodeString(unpaddedIV) if err != nil { - return "", fmt.Errorf("%w: failed to decode iv: %w", ErrCorruptedKeyMetadata, err) + return nil, fmt.Errorf("%w: failed to decode iv: %w", ErrCorruptedKeyMetadata, err) } + // TODO log a warning for non-16 byte IVs? + // Certain broken clients like nheko generated 32-byte IVs where only the first 16 bytes were used. + ivBytes := *(*[utils.AESCTRIVLength]byte)(rawIVBytes[:utils.AESCTRIVLength]) - cipher := utils.XorA256CTR(make([]byte, utils.AESCTRKeyLength), aesKey, ivBytes) - - return utils.HMACSHA256B64(cipher, hmacKey), nil + zeroes := make([]byte, utils.AESCTRKeyLength) + encryptedZeroes := utils.XorA256CTR(zeroes, aesKey, ivBytes) + h := hmac.New(sha256.New, hmacKey[:]) + h.Write(encryptedZeroes) + return h.Sum(nil), nil } // PassphraseMetadata represents server-side metadata about a SSSS key passphrase. diff --git a/crypto/ssss/meta_test.go b/crypto/ssss/meta_test.go index 7a5ef8b9..d59809c7 100644 --- a/crypto/ssss/meta_test.go +++ b/crypto/ssss/meta_test.go @@ -8,7 +8,6 @@ package ssss_test import ( "encoding/json" - "errors" "testing" "github.com/stretchr/testify/assert" @@ -42,10 +41,24 @@ const key2Meta = ` } ` +const key2MetaUnverified = ` +{ + "algorithm": "m.secret_storage.v1.aes-hmac-sha2" +} +` + +const key2MetaLongIV = ` +{ + "algorithm": "m.secret_storage.v1.aes-hmac-sha2", + "iv": "O0BOvTqiIAYjC+RMcyHfW2f/gdxjceTxoYtNlpPduJ8=", + "mac": "7k6OruQlWg0UmQjxGZ0ad4Q6DdwkgnoI7G6X3IjBYtI=" +} +` + const key2MetaBrokenIV = ` { "algorithm": "m.secret_storage.v1.aes-hmac-sha2", - "iv": "O0BOvTqiIAYjC+RMcyHfWwMeowMeowMeow", + "iv": "MeowMeowMeow", "mac": "7k6OruQlWg0UmQjxGZ0ad4Q6DdwkgnoI7G6X3IjBYtI=" } ` @@ -94,17 +107,33 @@ func TestKeyMetadata_VerifyRecoveryKey_Correct2(t *testing.T) { assert.Equal(t, key2RecoveryKey, key.RecoveryKey()) } +func TestKeyMetadata_VerifyRecoveryKey_NonCompliant_LongIV(t *testing.T) { + km := getKeyMeta(key2MetaLongIV) + key, err := km.VerifyRecoveryKey(key2ID, key2RecoveryKey) + assert.NoError(t, err) + assert.NotNil(t, key) + assert.Equal(t, key2RecoveryKey, key.RecoveryKey()) +} + +func TestKeyMetadata_VerifyRecoveryKey_Unverified(t *testing.T) { + km := getKeyMeta(key2MetaUnverified) + key, err := km.VerifyRecoveryKey(key2ID, key2RecoveryKey) + assert.ErrorIs(t, err, ssss.ErrUnverifiableKey) + assert.NotNil(t, key) + assert.Equal(t, key2RecoveryKey, key.RecoveryKey()) +} + func TestKeyMetadata_VerifyRecoveryKey_Invalid(t *testing.T) { km := getKeyMeta(key1Meta) key, err := km.VerifyRecoveryKey(key1ID, "foo") - assert.True(t, errors.Is(err, ssss.ErrInvalidRecoveryKey), "unexpected error: %v", err) + assert.ErrorIs(t, err, ssss.ErrInvalidRecoveryKey) assert.Nil(t, key) } func TestKeyMetadata_VerifyRecoveryKey_Incorrect(t *testing.T) { km := getKeyMeta(key1Meta) key, err := km.VerifyRecoveryKey(key2ID, key2RecoveryKey) - assert.True(t, errors.Is(err, ssss.ErrIncorrectSSSSKey), "unexpected error: %v", err) + assert.ErrorIs(t, err, ssss.ErrIncorrectSSSSKey) assert.Nil(t, key) } @@ -119,27 +148,27 @@ func TestKeyMetadata_VerifyPassphrase_Correct(t *testing.T) { func TestKeyMetadata_VerifyPassphrase_Incorrect(t *testing.T) { km := getKeyMeta(key1Meta) key, err := km.VerifyPassphrase(key1ID, "incorrect horse battery staple") - assert.True(t, errors.Is(err, ssss.ErrIncorrectSSSSKey), "unexpected error %v", err) + assert.ErrorIs(t, err, ssss.ErrIncorrectSSSSKey) assert.Nil(t, key) } func TestKeyMetadata_VerifyPassphrase_NotSet(t *testing.T) { km := getKeyMeta(key2Meta) key, err := km.VerifyPassphrase(key2ID, "hmm") - assert.True(t, errors.Is(err, ssss.ErrNoPassphrase), "unexpected error %v", err) + assert.ErrorIs(t, err, ssss.ErrNoPassphrase) assert.Nil(t, key) } func TestKeyMetadata_VerifyRecoveryKey_CorruptedIV(t *testing.T) { km := getKeyMeta(key2MetaBrokenIV) key, err := km.VerifyRecoveryKey(key2ID, key2RecoveryKey) - assert.True(t, errors.Is(err, ssss.ErrCorruptedKeyMetadata), "unexpected error %v", err) + assert.ErrorIs(t, err, ssss.ErrCorruptedKeyMetadata) assert.Nil(t, key) } func TestKeyMetadata_VerifyRecoveryKey_CorruptedMAC(t *testing.T) { km := getKeyMeta(key2MetaBrokenMAC) key, err := km.VerifyRecoveryKey(key2ID, key2RecoveryKey) - assert.True(t, errors.Is(err, ssss.ErrCorruptedKeyMetadata), "unexpected error %v", err) + assert.ErrorIs(t, err, ssss.ErrCorruptedKeyMetadata) assert.Nil(t, key) } diff --git a/crypto/ssss/types.go b/crypto/ssss/types.go index 345393b0..b7465d3e 100644 --- a/crypto/ssss/types.go +++ b/crypto/ssss/types.go @@ -26,7 +26,8 @@ var ( ErrUnsupportedPassphraseAlgorithm = errors.New("unsupported passphrase KDF algorithm") ErrIncorrectSSSSKey = errors.New("incorrect SSSS key") ErrInvalidRecoveryKey = errors.New("invalid recovery key") - ErrCorruptedKeyMetadata = errors.New("corrupted key metadata") + ErrCorruptedKeyMetadata = errors.New("corrupted recovery key metadata") + ErrUnverifiableKey = errors.New("cannot verify recovery key: missing MAC or IV in metadata") ) // Algorithm is the identifier for an SSSS encryption algorithm. @@ -57,6 +58,7 @@ type EncryptedKeyData struct { type EncryptedAccountDataEventContent struct { Encrypted map[string]EncryptedKeyData `json:"encrypted"` + Metadata map[string]any `json:"com.beeper.metadata,omitzero"` } func (ed *EncryptedAccountDataEventContent) Decrypt(eventType string, key *Key) ([]byte, error) { diff --git a/error.go b/error.go index b7c92a5f..4711b3dc 100644 --- a/error.go +++ b/error.go @@ -67,6 +67,8 @@ var ( MIncompatibleRoomVersion = RespError{ErrCode: "M_INCOMPATIBLE_ROOM_VERSION"} // The client specified a parameter that has the wrong value. MInvalidParam = RespError{ErrCode: "M_INVALID_PARAM", StatusCode: http.StatusBadRequest} + // The client specified a room key backup version that is not the current room key backup version for the user. + MWrongRoomKeysVersion = RespError{ErrCode: "M_WRONG_ROOM_KEYS_VERSION", StatusCode: http.StatusForbidden} MURLNotSet = RespError{ErrCode: "M_URL_NOT_SET"} MBadStatus = RespError{ErrCode: "M_BAD_STATUS"} @@ -80,6 +82,13 @@ var ( var ( ErrClientIsNil = errors.New("client is nil") ErrClientHasNoHomeserver = errors.New("client has no homeserver set") + + ErrResponseTooLong = errors.New("response content length too long") + ErrBodyReadReachedLimit = errors.New("reached response size limit while reading body") + + // Special error that indicates we should retry canceled contexts. Note that on it's own this + // is useless, the context itself must also be replaced. + ErrContextCancelRetry = errors.New("retry canceled context") ) // HTTPError An HTTP Error response, which may wrap an underlying native Go Error. @@ -131,7 +140,10 @@ type RespError struct { Err string ExtraData map[string]any - StatusCode int + StatusCode int + ExtraHeader map[string]string + + CanRetry bool } func (e *RespError) UnmarshalJSON(data []byte) error { @@ -141,6 +153,7 @@ func (e *RespError) UnmarshalJSON(data []byte) error { } e.ErrCode, _ = e.ExtraData["errcode"].(string) e.Err, _ = e.ExtraData["error"].(string) + e.CanRetry, _ = e.ExtraData["com.beeper.can_retry"].(bool) return nil } @@ -148,6 +161,9 @@ func (e *RespError) MarshalJSON() ([]byte, error) { data := exmaps.NonNilClone(e.ExtraData) data["errcode"] = e.ErrCode data["error"] = e.Err + if e.CanRetry { + data["com.beeper.can_retry"] = e.CanRetry + } return json.Marshal(data) } @@ -159,6 +175,9 @@ func (e RespError) Write(w http.ResponseWriter) { if statusCode == 0 { statusCode = http.StatusInternalServerError } + for key, value := range e.ExtraHeader { + w.Header().Set(key, value) + } exhttp.WriteJSONResponse(w, statusCode, &e) } @@ -175,12 +194,29 @@ func (e RespError) WithStatus(status int) RespError { return e } +func (e RespError) WithCanRetry(canRetry bool) RespError { + e.CanRetry = canRetry + return e +} + func (e RespError) WithExtraData(extraData map[string]any) RespError { e.ExtraData = exmaps.NonNilClone(e.ExtraData) maps.Copy(e.ExtraData, extraData) return e } +func (e RespError) WithExtraHeader(key, value string) RespError { + e.ExtraHeader = exmaps.NonNilClone(e.ExtraHeader) + e.ExtraHeader[key] = value + return e +} + +func (e RespError) WithExtraHeaders(headers map[string]string) RespError { + e.ExtraHeader = exmaps.NonNilClone(e.ExtraHeader) + maps.Copy(e.ExtraHeader, headers) + return e +} + // Error returns the errcode and error message. func (e RespError) Error() string { return e.ErrCode + ": " + e.Err diff --git a/event/beeper.go b/event/beeper.go index 95b4a571..a1a60b35 100644 --- a/event/beeper.go +++ b/event/beeper.go @@ -53,6 +53,8 @@ type BeeperMessageStatusEventContent struct { LastRetry id.EventID `json:"last_retry,omitempty"` + TargetTxnID string `json:"relates_to_txn_id,omitempty"` + MutateEventKey string `json:"mutate_event_key,omitempty"` // Indicates the set of users to whom the event was delivered. If nil, then @@ -87,7 +89,19 @@ type BeeperRoomKeyAckEventContent struct { } type BeeperChatDeleteEventContent struct { - DeleteForEveryone bool `json:"delete_for_everyone,omitempty"` + DeleteForEveryone bool `json:"delete_for_everyone,omitempty"` + FromMessageRequest bool `json:"from_message_request,omitempty"` +} + +type BeeperAcceptMessageRequestEventContent struct { + // Whether this was triggered by a message rather than an explicit event + IsImplicit bool `json:"-"` +} + +type BeeperSendStateEventContent struct { + Type string `json:"type"` + StateKey string `json:"state_key"` + Content Content `json:"content"` } type IntOrString int @@ -132,6 +146,7 @@ type BeeperLinkPreview struct { MatchedURL string `json:"matched_url,omitempty"` ImageEncryption *EncryptedFileInfo `json:"beeper:image:encryption,omitempty"` + ImageBlurhash string `json:"matrix:image:blurhash,omitempty"` } type BeeperProfileExtra struct { @@ -151,6 +166,24 @@ type BeeperPerMessageProfile struct { HasFallback bool `json:"has_fallback,omitempty"` } +type BeeperActionMessageType string + +const ( + BeeperActionMessageCall BeeperActionMessageType = "call" +) + +type BeeperActionMessageCallType string + +const ( + BeeperActionMessageCallTypeVoice BeeperActionMessageCallType = "voice" + BeeperActionMessageCallTypeVideo BeeperActionMessageCallType = "video" +) + +type BeeperActionMessage struct { + Type BeeperActionMessageType `json:"type"` + CallType BeeperActionMessageCallType `json:"call_type,omitempty"` +} + func (content *MessageEventContent) AddPerMessageProfileFallback() { if content.BeeperPerMessageProfile == nil || content.BeeperPerMessageProfile.HasFallback || content.BeeperPerMessageProfile.Displayname == "" { return @@ -181,6 +214,15 @@ func (content *MessageEventContent) RemovePerMessageProfileFallback() { } } +type BeeperAIStreamEventContent struct { + TurnID string `json:"turn_id"` + Seq int `json:"seq"` + Part map[string]any `json:"part"` + TargetEvent id.EventID `json:"target_event,omitempty"` + AgentID string `json:"agent_id,omitempty"` + RelatesTo *RelatesTo `json:"m.relates_to,omitempty"` +} + type BeeperEncodedOrder struct { order int64 suborder int16 diff --git a/event/botcommand.go b/event/botcommand.go deleted file mode 100644 index 2b208656..00000000 --- a/event/botcommand.go +++ /dev/null @@ -1,49 +0,0 @@ -// Copyright (c) 2025 Tulir Asokan -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. - -package event - -import ( - "encoding/json" -) - -type BotCommandsEventContent struct { - Sigil string `json:"sigil,omitempty"` - Commands []*BotCommand `json:"commands,omitempty"` -} - -type BotCommand struct { - Syntax string `json:"syntax"` - Aliases []string `json:"fi.mau.aliases,omitempty"` // Not in MSC (yet) - Arguments []*BotCommandArgument `json:"arguments,omitempty"` - Description *ExtensibleTextContainer `json:"description,omitempty"` -} - -type BotArgumentType string - -const ( - BotArgumentTypeString BotArgumentType = "string" - BotArgumentTypeEnum BotArgumentType = "enum" - BotArgumentTypeInteger BotArgumentType = "integer" - BotArgumentTypeBoolean BotArgumentType = "boolean" - BotArgumentTypeUserID BotArgumentType = "user_id" - BotArgumentTypeRoomID BotArgumentType = "room_id" - BotArgumentTypeRoomAlias BotArgumentType = "room_alias" - BotArgumentTypeEventID BotArgumentType = "event_id" -) - -type BotCommandArgument struct { - Type BotArgumentType `json:"type"` - DefaultValue any `json:"fi.mau.default_value,omitempty"` // Not in MSC (yet) - Description *ExtensibleTextContainer `json:"description,omitempty"` - Enum []string `json:"enum,omitempty"` - Variadic bool `json:"variadic,omitempty"` -} - -type BotCommandInput struct { - Syntax string `json:"syntax"` - Arguments json.RawMessage `json:"arguments,omitempty"` -} diff --git a/event/capabilities.d.ts b/event/capabilities.d.ts index 37848575..26aeb347 100644 --- a/event/capabilities.d.ts +++ b/event/capabilities.d.ts @@ -16,6 +16,23 @@ export interface RoomFeatures { * If a message type isn't listed here, it should be treated as support level -2 (will be rejected). */ file?: Record + /** + * Supported state event types and their parameters. Currently, there are no parameters, + * but it is likely there will be some in the future (like max name/topic length, avatar mime types, etc.). + * + * Events that are not listed or have a support level of zero or below should be treated as unsupported. + * + * Clients should at least check `m.room.name`, `m.room.topic`, and `m.room.avatar` here. + * `m.room.member` will not be listed here, as it's controlled by the member_actions field. + * `com.beeper.disappearing_timer` should be listed here, but the parameters are in the disappearing_timer field for now. + */ + state?: Record + /** + * Supported member actions and their support levels. + * + * Actions that are not listed or have a support level of zero or below should be treated as unsupported. + */ + member_actions?: Record /** Maximum length of normal text messages. */ max_text_length?: integer @@ -60,6 +77,11 @@ export interface RoomFeatures { delete_chat?: boolean /** Whether deleting the chat for all participants is supported. */ delete_chat_for_everyone?: boolean + /** What can be done with message requests? */ + message_request?: { + accept_with_message?: CapabilitySupportLevel + accept_with_button?: CapabilitySupportLevel + } } declare type integer = number @@ -72,6 +94,21 @@ declare type MIMETypeOrPattern = | `${MIMEClass}/${string}` | `${MIMEClass}/${string}; ${string}` +export enum MemberAction { + Ban = "ban", + Kick = "kick", + Leave = "leave", + RevokeInvite = "revoke_invite", + Invite = "invite", +} + +declare type EventType = string + +// This is an object for future extensibility (e.g. max name/topic length) +export interface StateFeatures { + level: CapabilitySupportLevel +} + export enum CapabilityMsgType { // Real message types used in the `msgtype` field Image = "m.image", diff --git a/event/capabilities.go b/event/capabilities.go index 42afe5b6..a86c726b 100644 --- a/event/capabilities.go +++ b/event/capabilities.go @@ -28,8 +28,10 @@ type RoomFeatures struct { // N.B. New fields need to be added to the Hash function to be included in the deduplication hash. - Formatting FormattingFeatureMap `json:"formatting,omitempty"` - File FileFeatureMap `json:"file,omitempty"` + Formatting FormattingFeatureMap `json:"formatting,omitempty"` + File FileFeatureMap `json:"file,omitempty"` + State StateFeatureMap `json:"state,omitempty"` + MemberActions MemberFeatureMap `json:"member_actions,omitempty"` MaxTextLength int `json:"max_text_length,omitempty"` @@ -58,6 +60,10 @@ type RoomFeatures struct { MarkAsUnread bool `json:"mark_as_unread,omitempty"` DeleteChat bool `json:"delete_chat,omitempty"` DeleteChatForEveryone bool `json:"delete_chat_for_everyone,omitempty"` + + MessageRequest *MessageRequestFeatures `json:"message_request,omitempty"` + + PerMessageProfileRelay bool `json:"-"` } func (rf *RoomFeatures) GetID() string { @@ -74,13 +80,58 @@ func (rf *RoomFeatures) Clone() *RoomFeatures { clone := *rf clone.File = clone.File.Clone() clone.Formatting = maps.Clone(clone.Formatting) + clone.State = clone.State.Clone() + clone.MemberActions = clone.MemberActions.Clone() clone.EditMaxAge = ptr.Clone(clone.EditMaxAge) clone.DeleteMaxAge = ptr.Clone(clone.DeleteMaxAge) clone.DisappearingTimer = clone.DisappearingTimer.Clone() clone.AllowedReactions = slices.Clone(clone.AllowedReactions) + clone.MessageRequest = clone.MessageRequest.Clone() return &clone } +type MemberFeatureMap map[MemberAction]CapabilitySupportLevel + +func (mfm MemberFeatureMap) Clone() MemberFeatureMap { + return maps.Clone(mfm) +} + +type MemberAction string + +const ( + MemberActionBan MemberAction = "ban" + MemberActionKick MemberAction = "kick" + MemberActionLeave MemberAction = "leave" + MemberActionRevokeInvite MemberAction = "revoke_invite" + MemberActionInvite MemberAction = "invite" +) + +type StateFeatureMap map[string]*StateFeatures + +func (sfm StateFeatureMap) Clone() StateFeatureMap { + dup := maps.Clone(sfm) + for key, value := range dup { + dup[key] = value.Clone() + } + return dup +} + +type StateFeatures struct { + Level CapabilitySupportLevel `json:"level"` +} + +func (sf *StateFeatures) Clone() *StateFeatures { + if sf == nil { + return nil + } + clone := *sf + return &clone +} + +func (sf *StateFeatures) Hash() []byte { + return sf.Level.Hash() +} + type FormattingFeatureMap map[FormattingFeature]CapabilitySupportLevel type FileFeatureMap map[CapabilityMsgType]*FileFeatures @@ -117,6 +168,25 @@ func (dtc *DisappearingTimerCapability) Supports(content *BeeperDisappearingTime return slices.Contains(dtc.Types, content.Type) && (dtc.Timers == nil || slices.Contains(dtc.Timers, content.Timer)) } +type MessageRequestFeatures struct { + AcceptWithMessage CapabilitySupportLevel `json:"accept_with_message,omitempty"` + AcceptWithButton CapabilitySupportLevel `json:"accept_with_button,omitempty"` +} + +func (mrf *MessageRequestFeatures) Clone() *MessageRequestFeatures { + return ptr.Clone(mrf) +} + +func (mrf *MessageRequestFeatures) Hash() []byte { + if mrf == nil { + return nil + } + hasher := sha256.New() + hashValue(hasher, "accept_with_message", mrf.AcceptWithMessage) + hashValue(hasher, "accept_with_button", mrf.AcceptWithButton) + return hasher.Sum(nil) +} + type CapabilityMsgType = MessageType // Message types which are used for event capability signaling, but aren't real values for the msgtype field. @@ -266,6 +336,8 @@ func (rf *RoomFeatures) Hash() []byte { hashMap(hasher, "formatting", rf.Formatting) hashMap(hasher, "file", rf.File) + hashMap(hasher, "state", rf.State) + hashMap(hasher, "member_actions", rf.MemberActions) hashInt(hasher, "max_text_length", rf.MaxTextLength) @@ -297,6 +369,7 @@ func (rf *RoomFeatures) Hash() []byte { hashBool(hasher, "mark_as_unread", rf.MarkAsUnread) hashBool(hasher, "delete_chat", rf.DeleteChat) hashBool(hasher, "delete_chat_for_everyone", rf.DeleteChatForEveryone) + hashValue(hasher, "message_request", rf.MessageRequest) return hasher.Sum(nil) } diff --git a/event/cmdschema/content.go b/event/cmdschema/content.go new file mode 100644 index 00000000..ce07c4c0 --- /dev/null +++ b/event/cmdschema/content.go @@ -0,0 +1,78 @@ +// Copyright (c) 2026 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package cmdschema + +import ( + "crypto/sha256" + "encoding/base64" + "fmt" + "reflect" + "slices" + + "go.mau.fi/util/exsync" + "go.mau.fi/util/ptr" + + "maunium.net/go/mautrix/event" + "maunium.net/go/mautrix/id" +) + +type EventContent struct { + Command string `json:"command"` + Aliases []string `json:"aliases,omitempty"` + Parameters []*Parameter `json:"parameters,omitempty"` + Description *event.ExtensibleTextContainer `json:"description,omitempty"` + TailParam string `json:"fi.mau.tail_parameter,omitempty"` +} + +func (ec *EventContent) Validate() error { + if ec == nil { + return fmt.Errorf("event content is nil") + } else if ec.Command == "" { + return fmt.Errorf("command is empty") + } + var tailFound bool + dupMap := exsync.NewSet[string]() + for i, p := range ec.Parameters { + if err := p.Validate(); err != nil { + return fmt.Errorf("parameter %q (#%d) is invalid: %w", ptr.Val(p).Key, i+1, err) + } else if !dupMap.Add(p.Key) { + return fmt.Errorf("duplicate parameter key %q at #%d", p.Key, i+1) + } else if p.Key == ec.TailParam { + tailFound = true + } else if tailFound && !p.Optional { + return fmt.Errorf("required parameter %q (#%d) is after tail parameter %q", p.Key, i+1, ec.TailParam) + } + } + if ec.TailParam != "" && !tailFound { + return fmt.Errorf("tail parameter %q not found in parameters", ec.TailParam) + } + return nil +} + +func (ec *EventContent) IsValid() bool { + return ec.Validate() == nil +} + +func (ec *EventContent) StateKey(owner id.UserID) string { + hash := sha256.Sum256([]byte(ec.Command + owner.String())) + return base64.StdEncoding.EncodeToString(hash[:]) +} + +func (ec *EventContent) Equals(other *EventContent) bool { + if ec == nil || other == nil { + return ec == other + } + return ec.Command == other.Command && + slices.Equal(ec.Aliases, other.Aliases) && + slices.EqualFunc(ec.Parameters, other.Parameters, (*Parameter).Equals) && + ec.Description.Equals(other.Description) && + ec.TailParam == other.TailParam +} + +func init() { + event.TypeMap[event.StateMSC4391BotCommand] = reflect.TypeOf(EventContent{}) +} diff --git a/event/cmdschema/parameter.go b/event/cmdschema/parameter.go new file mode 100644 index 00000000..4193b297 --- /dev/null +++ b/event/cmdschema/parameter.go @@ -0,0 +1,286 @@ +// Copyright (c) 2026 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package cmdschema + +import ( + "encoding/json" + "fmt" + "slices" + + "go.mau.fi/util/exslices" + + "maunium.net/go/mautrix/event" +) + +type Parameter struct { + Key string `json:"key"` + Schema *ParameterSchema `json:"schema"` + Optional bool `json:"optional,omitempty"` + Description *event.ExtensibleTextContainer `json:"description,omitempty"` + DefaultValue any `json:"fi.mau.default_value,omitempty"` +} + +func (p *Parameter) Equals(other *Parameter) bool { + if p == nil || other == nil { + return p == other + } + return p.Key == other.Key && + p.Schema.Equals(other.Schema) && + p.Optional == other.Optional && + p.Description.Equals(other.Description) && + p.DefaultValue == other.DefaultValue // TODO this won't work for room/event ID values +} + +func (p *Parameter) Validate() error { + if p == nil { + return fmt.Errorf("parameter is nil") + } else if p.Key == "" { + return fmt.Errorf("key is empty") + } + return p.Schema.Validate() +} + +func (p *Parameter) IsValid() bool { + return p.Validate() == nil +} + +func (p *Parameter) GetDefaultValue() any { + if p != nil && p.DefaultValue != nil { + return p.DefaultValue + } else if p == nil || p.Optional { + return nil + } + return p.Schema.GetDefaultValue() +} + +type PrimitiveType string + +const ( + PrimitiveTypeString PrimitiveType = "string" + PrimitiveTypeInteger PrimitiveType = "integer" + PrimitiveTypeBoolean PrimitiveType = "boolean" + PrimitiveTypeServerName PrimitiveType = "server_name" + PrimitiveTypeUserID PrimitiveType = "user_id" + PrimitiveTypeRoomID PrimitiveType = "room_id" + PrimitiveTypeRoomAlias PrimitiveType = "room_alias" + PrimitiveTypeEventID PrimitiveType = "event_id" +) + +func (pt PrimitiveType) Schema() *ParameterSchema { + return &ParameterSchema{ + SchemaType: SchemaTypePrimitive, + Type: pt, + } +} + +func (pt PrimitiveType) IsValid() bool { + switch pt { + case PrimitiveTypeString, + PrimitiveTypeInteger, + PrimitiveTypeBoolean, + PrimitiveTypeServerName, + PrimitiveTypeUserID, + PrimitiveTypeRoomID, + PrimitiveTypeRoomAlias, + PrimitiveTypeEventID: + return true + default: + return false + } +} + +type SchemaType string + +const ( + SchemaTypePrimitive SchemaType = "primitive" + SchemaTypeArray SchemaType = "array" + SchemaTypeUnion SchemaType = "union" + SchemaTypeLiteral SchemaType = "literal" +) + +type ParameterSchema struct { + SchemaType SchemaType `json:"schema_type"` + Type PrimitiveType `json:"type,omitempty"` // Only for primitive + Items *ParameterSchema `json:"items,omitempty"` // Only for array + Variants []*ParameterSchema `json:"variants,omitempty"` // Only for union + Value any `json:"value,omitempty"` // Only for literal +} + +func Literal(value any) *ParameterSchema { + return &ParameterSchema{ + SchemaType: SchemaTypeLiteral, + Value: value, + } +} + +func Enum(values ...any) *ParameterSchema { + return Union(exslices.CastFunc(values, Literal)...) +} + +func flattenUnion(variants []*ParameterSchema) []*ParameterSchema { + var flattened []*ParameterSchema + for _, variant := range variants { + switch variant.SchemaType { + case SchemaTypeArray: + panic(fmt.Errorf("illegal array schema in union")) + case SchemaTypeUnion: + flattened = append(flattened, flattenUnion(variant.Variants)...) + default: + flattened = append(flattened, variant) + } + } + return flattened +} + +func Union(variants ...*ParameterSchema) *ParameterSchema { + needsFlattening := false + for _, variant := range variants { + if variant.SchemaType == SchemaTypeArray { + panic(fmt.Errorf("illegal array schema in union")) + } else if variant.SchemaType == SchemaTypeUnion { + needsFlattening = true + } + } + if needsFlattening { + variants = flattenUnion(variants) + } + return &ParameterSchema{ + SchemaType: SchemaTypeUnion, + Variants: variants, + } +} + +func Array(items *ParameterSchema) *ParameterSchema { + if items.SchemaType == SchemaTypeArray { + panic(fmt.Errorf("illegal array schema in array")) + } + return &ParameterSchema{ + SchemaType: SchemaTypeArray, + Items: items, + } +} + +func (ps *ParameterSchema) GetDefaultValue() any { + if ps == nil { + return nil + } + switch ps.SchemaType { + case SchemaTypePrimitive: + switch ps.Type { + case PrimitiveTypeInteger: + return 0 + case PrimitiveTypeBoolean: + return false + default: + return "" + } + case SchemaTypeArray: + return []any{} + case SchemaTypeUnion: + if len(ps.Variants) > 0 { + return ps.Variants[0].GetDefaultValue() + } + return nil + case SchemaTypeLiteral: + return ps.Value + default: + return nil + } +} + +func (ps *ParameterSchema) IsValid() bool { + return ps.validate("") == nil +} + +func (ps *ParameterSchema) Validate() error { + return ps.validate("") +} + +func (ps *ParameterSchema) validate(parent SchemaType) error { + if ps == nil { + return fmt.Errorf("schema is nil") + } + switch ps.SchemaType { + case SchemaTypePrimitive: + if !ps.Type.IsValid() { + return fmt.Errorf("invalid primitive type %s", ps.Type) + } else if ps.Items != nil || ps.Variants != nil || ps.Value != nil { + return fmt.Errorf("primitive schema has extra fields") + } + return nil + case SchemaTypeArray: + if parent != "" { + return fmt.Errorf("arrays can't be nested in other types") + } else if err := ps.Items.validate(ps.SchemaType); err != nil { + return fmt.Errorf("item schema is invalid: %w", err) + } else if ps.Type != "" || ps.Variants != nil || ps.Value != nil { + return fmt.Errorf("array schema has extra fields") + } + return nil + case SchemaTypeUnion: + if len(ps.Variants) == 0 { + return fmt.Errorf("no variants specified for union") + } else if parent != "" && parent != SchemaTypeArray { + return fmt.Errorf("unions can't be nested in anything other than arrays") + } + for i, v := range ps.Variants { + if err := v.validate(ps.SchemaType); err != nil { + return fmt.Errorf("variant #%d is invalid: %w", i+1, err) + } + } + if ps.Type != "" || ps.Items != nil || ps.Value != nil { + return fmt.Errorf("union schema has extra fields") + } + return nil + case SchemaTypeLiteral: + switch typedVal := ps.Value.(type) { + case string, float64, int, int64, json.Number, bool, RoomIDValue, *RoomIDValue: + // ok + case map[string]any: + if typedVal["type"] != "event_id" && typedVal["type"] != "room_id" { + return fmt.Errorf("literal value has invalid map data") + } + default: + return fmt.Errorf("literal value has unsupported type %T", ps.Value) + } + if ps.Type != "" || ps.Items != nil || ps.Variants != nil { + return fmt.Errorf("literal schema has extra fields") + } + return nil + default: + return fmt.Errorf("invalid schema type %s", ps.SchemaType) + } +} + +func (ps *ParameterSchema) Equals(other *ParameterSchema) bool { + if ps == nil || other == nil { + return ps == other + } + return ps.SchemaType == other.SchemaType && + ps.Type == other.Type && + ps.Items.Equals(other.Items) && + slices.EqualFunc(ps.Variants, other.Variants, (*ParameterSchema).Equals) && + ps.Value == other.Value // TODO this won't work for room/event ID values +} + +func (ps *ParameterSchema) AllowsPrimitive(prim PrimitiveType) bool { + switch ps.SchemaType { + case SchemaTypePrimitive: + return ps.Type == prim + case SchemaTypeUnion: + for _, variant := range ps.Variants { + if variant.AllowsPrimitive(prim) { + return true + } + } + return false + case SchemaTypeArray: + return ps.Items.AllowsPrimitive(prim) + default: + return false + } +} diff --git a/event/cmdschema/parse.go b/event/cmdschema/parse.go new file mode 100644 index 00000000..92e69b60 --- /dev/null +++ b/event/cmdschema/parse.go @@ -0,0 +1,478 @@ +// Copyright (c) 2026 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package cmdschema + +import ( + "encoding/json" + "errors" + "fmt" + "regexp" + "strconv" + "strings" + + "maunium.net/go/mautrix/event" + "maunium.net/go/mautrix/id" +) + +const botArrayOpener = "<" +const botArrayCloser = ">" + +func parseQuoted(val string) (parsed, remaining string, quoted bool) { + if len(val) == 0 { + return + } + if !strings.HasPrefix(val, `"`) { + spaceIdx := strings.IndexByte(val, ' ') + if spaceIdx == -1 { + parsed = val + } else { + parsed = val[:spaceIdx] + remaining = strings.TrimLeft(val[spaceIdx+1:], " ") + } + return + } + val = val[1:] + var buf strings.Builder + for { + quoteIdx := strings.IndexByte(val, '"') + var valUntilQuote string + if quoteIdx == -1 { + valUntilQuote = val + } else { + valUntilQuote = val[:quoteIdx] + } + escapeIdx := strings.IndexByte(valUntilQuote, '\\') + if escapeIdx >= 0 { + buf.WriteString(val[:escapeIdx]) + if len(val) > escapeIdx+1 { + buf.WriteByte(val[escapeIdx+1]) + } + val = val[min(escapeIdx+2, len(val)):] + } else if quoteIdx >= 0 { + buf.WriteString(val[:quoteIdx]) + val = val[quoteIdx+1:] + break + } else if buf.Len() == 0 { + // Unterminated quote, no escape characters, val is the whole input + return val, "", true + } else { + // Unterminated quote, but there were escape characters previously + buf.WriteString(val) + val = "" + break + } + } + return buf.String(), strings.TrimLeft(val, " "), true +} + +// ParseInput tries to parse the given text into a bot command event matching this command definition. +// +// If the prefix doesn't match, this will return a nil content and nil error. +// If the prefix does match, some content is always returned, but there may still be an error if parsing failed. +func (ec *EventContent) ParseInput(owner id.UserID, sigils []string, input string) (content *event.MessageEventContent, err error) { + prefix := ec.parsePrefix(input, sigils, owner.String()) + if prefix == "" { + return nil, nil + } + content = &event.MessageEventContent{ + MsgType: event.MsgText, + Body: input, + Mentions: &event.Mentions{UserIDs: []id.UserID{owner}}, + MSC4391BotCommand: &event.MSC4391BotCommandInput{ + Command: ec.Command, + }, + } + content.MSC4391BotCommand.Arguments, err = ec.ParseArguments(input[len(prefix):]) + return content, err +} + +func (ec *EventContent) ParseArguments(input string) (json.RawMessage, error) { + args := make(map[string]any) + var retErr error + setError := func(err error) { + if err != nil && retErr == nil { + retErr = err + } + } + processParameter := func(param *Parameter, isLast, isTail, isNamed bool) { + origInput := input + var nextVal string + var wasQuoted bool + if param.Schema.SchemaType == SchemaTypeArray { + hasOpener := strings.HasPrefix(input, botArrayOpener) + arrayClosed := false + if hasOpener { + input = input[len(botArrayOpener):] + if strings.HasPrefix(input, botArrayCloser) { + input = strings.TrimLeft(input[len(botArrayCloser):], " ") + arrayClosed = true + } + } + var collector []any + for len(input) > 0 && !arrayClosed { + //origInput = input + nextVal, input, wasQuoted = parseQuoted(input) + if !wasQuoted && hasOpener && strings.HasSuffix(nextVal, botArrayCloser) { + // The value wasn't quoted and has the array delimiter at the end, close the array + nextVal = strings.TrimRight(nextVal, botArrayCloser) + arrayClosed = true + } else if hasOpener && strings.HasPrefix(input, botArrayCloser) { + // The value was quoted or there was a space, and the next character is the + // array delimiter, close the array + input = strings.TrimLeft(input[len(botArrayCloser):], " ") + arrayClosed = true + } else if !hasOpener && !isLast { + // For array arguments in the middle without the <> delimiters, stop after the first item + arrayClosed = true + } + parsedVal, err := param.Schema.Items.ParseString(nextVal) + if err == nil { + collector = append(collector, parsedVal) + } else if hasOpener || isLast { + setError(fmt.Errorf("failed to parse item #%d of array %s: %w", len(collector)+1, param.Key, err)) + } else { + //input = origInput + } + } + args[param.Key] = collector + } else { + nextVal, input, wasQuoted = parseQuoted(input) + if (isLast || isTail) && !wasQuoted && len(input) > 0 { + // If the last argument is not quoted, just treat the rest of the string + // as the argument without escapes (arguments with escapes should be quoted). + nextVal += " " + input + input = "" + } + // Special case for named boolean parameters: if no value is given, treat it as true + if nextVal == "" && !wasQuoted && isNamed && param.Schema.AllowsPrimitive(PrimitiveTypeBoolean) { + args[param.Key] = true + return + } + if nextVal == "" && !wasQuoted && !isNamed && !param.Optional { + setError(fmt.Errorf("missing value for required parameter %s", param.Key)) + } + parsedVal, err := param.Schema.ParseString(nextVal) + if err != nil { + args[param.Key] = param.GetDefaultValue() + // For optional parameters that fail to parse, restore the input and try passing it as the next parameter + if param.Optional && !isLast && !isNamed { + input = strings.TrimLeft(origInput, " ") + } else if !param.Optional || isNamed { + setError(fmt.Errorf("failed to parse %s: %w", param.Key, err)) + } + } else { + args[param.Key] = parsedVal + } + } + } + skipParams := make([]bool, len(ec.Parameters)) + for i, param := range ec.Parameters { + for strings.HasPrefix(input, "--") { + nameEndIdx := strings.IndexAny(input, " =") + if nameEndIdx == -1 { + nameEndIdx = len(input) + } + overrideParam, paramIdx := ec.parameterByName(input[2:nameEndIdx]) + if overrideParam != nil { + // Trim the equals sign, but leave spaces alone to let parseQuoted treat it as empty input + input = strings.TrimPrefix(input[nameEndIdx:], "=") + skipParams[paramIdx] = true + processParameter(overrideParam, false, false, true) + } else { + break + } + } + isTail := param.Key == ec.TailParam + if skipParams[i] || (param.Optional && !isTail) { + continue + } + processParameter(param, i == len(ec.Parameters)-1, isTail, false) + } + jsonArgs, marshalErr := json.Marshal(args) + if marshalErr != nil { + return nil, fmt.Errorf("failed to marshal arguments: %w", marshalErr) + } + return jsonArgs, retErr +} + +func (ec *EventContent) parameterByName(name string) (*Parameter, int) { + for i, param := range ec.Parameters { + if strings.EqualFold(param.Key, name) { + return param, i + } + } + return nil, -1 +} + +func (ec *EventContent) parsePrefix(origInput string, sigils []string, owner string) (prefix string) { + input := origInput + var chosenSigil string + for _, sigil := range sigils { + if strings.HasPrefix(input, sigil) { + chosenSigil = sigil + break + } + } + if chosenSigil == "" { + return "" + } + input = input[len(chosenSigil):] + var chosenAlias string + if !strings.HasPrefix(input, ec.Command) { + for _, alias := range ec.Aliases { + if strings.HasPrefix(input, alias) { + chosenAlias = alias + break + } + } + if chosenAlias == "" { + return "" + } + } else { + chosenAlias = ec.Command + } + input = strings.TrimPrefix(input[len(chosenAlias):], owner) + if input == "" || input[0] == ' ' { + input = strings.TrimLeft(input, " ") + return origInput[:len(origInput)-len(input)] + } + return "" +} + +func (pt PrimitiveType) ValidateValue(value any) bool { + _, err := pt.NormalizeValue(value) + return err == nil +} + +func normalizeNumber(value any) (int, error) { + switch typedValue := value.(type) { + case int: + return typedValue, nil + case int64: + return int(typedValue), nil + case float64: + return int(typedValue), nil + case json.Number: + if i, err := typedValue.Int64(); err != nil { + return 0, fmt.Errorf("failed to parse json.Number: %w", err) + } else { + return int(i), nil + } + default: + return 0, fmt.Errorf("unsupported type %T for integer", value) + } +} + +func (pt PrimitiveType) NormalizeValue(value any) (any, error) { + switch pt { + case PrimitiveTypeInteger: + return normalizeNumber(value) + case PrimitiveTypeBoolean: + bv, ok := value.(bool) + if !ok { + return nil, fmt.Errorf("unsupported type %T for boolean", value) + } + return bv, nil + case PrimitiveTypeString, PrimitiveTypeServerName: + str, ok := value.(string) + if !ok { + return nil, fmt.Errorf("unsupported type %T for string", value) + } + return str, pt.validateStringValue(str) + case PrimitiveTypeUserID, PrimitiveTypeRoomAlias: + str, ok := value.(string) + if !ok { + return nil, fmt.Errorf("unsupported type %T for user ID or room alias", value) + } else if plainErr := pt.validateStringValue(str); plainErr == nil { + return str, nil + } else if parsed, err := id.ParseMatrixURIOrMatrixToURL(str); err != nil { + return nil, fmt.Errorf("couldn't parse %q as plain ID nor matrix URI: %w / %w", value, plainErr, err) + } else if parsed.Sigil1 == '@' && pt == PrimitiveTypeUserID { + return parsed.UserID(), nil + } else if parsed.Sigil1 == '#' && pt == PrimitiveTypeRoomAlias { + return parsed.RoomAlias(), nil + } else { + return nil, fmt.Errorf("unexpected sigil %c for user ID or room alias", parsed.Sigil1) + } + case PrimitiveTypeRoomID, PrimitiveTypeEventID: + riv, err := NormalizeRoomIDValue(value) + if err != nil { + return nil, err + } + return riv, riv.Validate() + default: + return nil, fmt.Errorf("cannot normalize value for argument type %s", pt) + } +} + +func (pt PrimitiveType) validateStringValue(value string) error { + switch pt { + case PrimitiveTypeString: + return nil + case PrimitiveTypeServerName: + if !id.ValidateServerName(value) { + return fmt.Errorf("invalid server name: %q", value) + } + return nil + case PrimitiveTypeUserID: + _, _, err := id.UserID(value).ParseAndValidateRelaxed() + return err + case PrimitiveTypeRoomAlias: + sigil, localpart, serverName := id.ParseCommonIdentifier(value) + if sigil != '#' || localpart == "" || serverName == "" { + return fmt.Errorf("invalid room alias: %q", value) + } else if !id.ValidateServerName(serverName) { + return fmt.Errorf("invalid server name in room alias: %q", serverName) + } + return nil + default: + panic(fmt.Errorf("validateStringValue called with invalid type %s", pt)) + } +} + +func parseBoolean(val string) (bool, error) { + if len(val) == 0 { + return false, fmt.Errorf("cannot parse empty string as boolean") + } + switch strings.ToLower(val) { + case "t", "true", "y", "yes", "1": + return true, nil + case "f", "false", "n", "no", "0": + return false, nil + default: + return false, fmt.Errorf("invalid boolean string: %q", val) + } +} + +var markdownLinkRegex = regexp.MustCompile(`^\[.+]\(([^)]+)\)$`) + +func parseRoomOrEventID(value string) (*RoomIDValue, error) { + if strings.HasPrefix(value, "[") && strings.Contains(value, "](") && strings.HasSuffix(value, ")") { + matches := markdownLinkRegex.FindStringSubmatch(value) + if len(matches) == 2 { + value = matches[1] + } + } + parsed, err := id.ParseMatrixURIOrMatrixToURL(value) + if err != nil && strings.HasPrefix(value, "!") { + return &RoomIDValue{ + Type: PrimitiveTypeRoomID, + RoomID: id.RoomID(value), + }, nil + } + if err != nil { + return nil, err + } else if parsed.Sigil1 != '!' { + return nil, fmt.Errorf("unexpected sigil %c for room ID", parsed.Sigil1) + } else if parsed.MXID2 != "" && parsed.Sigil2 != '$' { + return nil, fmt.Errorf("unexpected sigil %c for event ID", parsed.Sigil2) + } + valType := PrimitiveTypeRoomID + if parsed.MXID2 != "" { + valType = PrimitiveTypeEventID + } + return &RoomIDValue{ + Type: valType, + RoomID: parsed.RoomID(), + Via: parsed.Via, + EventID: parsed.EventID(), + }, nil +} + +func (pt PrimitiveType) ParseString(value string) (any, error) { + switch pt { + case PrimitiveTypeInteger: + return strconv.Atoi(value) + case PrimitiveTypeBoolean: + return parseBoolean(value) + case PrimitiveTypeString, PrimitiveTypeServerName, PrimitiveTypeUserID: + return value, pt.validateStringValue(value) + case PrimitiveTypeRoomAlias: + plainErr := pt.validateStringValue(value) + if plainErr == nil { + return value, nil + } + parsed, err := id.ParseMatrixURIOrMatrixToURL(value) + if err != nil { + return nil, fmt.Errorf("couldn't parse %q as plain room alias nor matrix URI: %w / %w", value, plainErr, err) + } else if parsed.Sigil1 != '#' { + return nil, fmt.Errorf("unexpected sigil %c for room alias", parsed.Sigil1) + } + return parsed.RoomAlias(), nil + case PrimitiveTypeRoomID, PrimitiveTypeEventID: + parsed, err := parseRoomOrEventID(value) + if err != nil { + return nil, err + } else if pt != parsed.Type { + return nil, fmt.Errorf("mismatching argument type: expected %s but got %s", pt, parsed.Type) + } + return parsed, nil + default: + return nil, fmt.Errorf("cannot parse string for argument type %s", pt) + } +} + +func (ps *ParameterSchema) ParseString(value string) (any, error) { + if ps == nil { + return nil, fmt.Errorf("parameter schema is nil") + } + switch ps.SchemaType { + case SchemaTypePrimitive: + return ps.Type.ParseString(value) + case SchemaTypeLiteral: + switch typedValue := ps.Value.(type) { + case string: + if value == typedValue { + return typedValue, nil + } else { + return nil, fmt.Errorf("literal value %q does not match %q", typedValue, value) + } + case int, int64, float64, json.Number: + expectedVal, _ := normalizeNumber(typedValue) + intVal, err := strconv.Atoi(value) + if err != nil { + return nil, fmt.Errorf("failed to parse integer literal: %w", err) + } else if intVal != expectedVal { + return nil, fmt.Errorf("literal value %d does not match %d", expectedVal, intVal) + } + return intVal, nil + case bool: + boolVal, err := parseBoolean(value) + if err != nil { + return nil, fmt.Errorf("failed to parse boolean literal: %w", err) + } else if boolVal != typedValue { + return nil, fmt.Errorf("literal value %t does not match %t", typedValue, boolVal) + } + return boolVal, nil + case RoomIDValue, *RoomIDValue, map[string]any, json.RawMessage: + expectedVal, _ := NormalizeRoomIDValue(typedValue) + parsed, err := parseRoomOrEventID(value) + if err != nil { + return nil, fmt.Errorf("failed to parse room or event ID literal: %w", err) + } else if !parsed.Equals(expectedVal) { + return nil, fmt.Errorf("literal value %s does not match %s", expectedVal, parsed) + } + return parsed, nil + default: + return nil, fmt.Errorf("unsupported literal type %T", ps.Value) + } + case SchemaTypeUnion: + var errs []error + for _, variant := range ps.Variants { + if parsed, err := variant.ParseString(value); err == nil { + return parsed, nil + } else { + errs = append(errs, err) + } + } + return nil, fmt.Errorf("no union variant matched: %w", errors.Join(errs...)) + case SchemaTypeArray: + return nil, fmt.Errorf("cannot parse string for array schema type") + default: + return nil, fmt.Errorf("unknown schema type %s", ps.SchemaType) + } +} diff --git a/event/cmdschema/parse_test.go b/event/cmdschema/parse_test.go new file mode 100644 index 00000000..1e0d1817 --- /dev/null +++ b/event/cmdschema/parse_test.go @@ -0,0 +1,118 @@ +// Copyright (c) 2026 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package cmdschema + +import ( + "bytes" + "encoding/json" + "strings" + "testing" + + "github.com/stretchr/testify/assert" + "go.mau.fi/util/exbytes" + "go.mau.fi/util/exerrors" + + "maunium.net/go/mautrix/event/cmdschema/testdata" +) + +type QuoteParseOutput struct { + Parsed string + Remaining string + Quoted bool +} + +func (qpo *QuoteParseOutput) UnmarshalJSON(data []byte) error { + var arr []any + if err := json.Unmarshal(data, &arr); err != nil { + return err + } + qpo.Parsed = arr[0].(string) + qpo.Remaining = arr[1].(string) + qpo.Quoted = arr[2].(bool) + return nil +} + +type QuoteParseTestData struct { + Name string `json:"name"` + Input string `json:"input"` + Output QuoteParseOutput `json:"output"` +} + +func loadFile[T any](name string) (into T) { + quoteData := exerrors.Must(testdata.FS.ReadFile(name)) + exerrors.PanicIfNotNil(json.Unmarshal(quoteData, &into)) + return +} + +func TestParseQuoted(t *testing.T) { + qptd := loadFile[[]QuoteParseTestData]("parse_quote.json") + for _, test := range qptd { + t.Run(test.Name, func(t *testing.T) { + parsed, remaining, quoted := parseQuoted(test.Input) + assert.Equalf(t, test.Output, QuoteParseOutput{ + Parsed: parsed, + Remaining: remaining, + Quoted: quoted, + }, "Failed with input `%s`", test.Input) + // Note: can't just test that requoted == input, because some inputs + // have unnecessary escapes which won't survive roundtripping + t.Run("roundtrip", func(t *testing.T) { + requoted := quoteString(parsed) + " " + remaining + reparsed, newRemaining, _ := parseQuoted(requoted) + assert.Equal(t, parsed, reparsed) + assert.Equal(t, remaining, newRemaining) + }) + }) + } +} + +type CommandTestData struct { + Spec *EventContent + Tests []*CommandTestUnit +} + +type CommandTestUnit struct { + Name string `json:"name"` + Input string `json:"input"` + Broken string `json:"broken,omitempty"` + Error bool `json:"error"` + Output json.RawMessage `json:"output"` +} + +func compactJSON(input json.RawMessage) json.RawMessage { + var buf bytes.Buffer + exerrors.PanicIfNotNil(json.Compact(&buf, input)) + return buf.Bytes() +} + +func TestMSC4391BotCommandEventContent_ParseInput(t *testing.T) { + for _, cmd := range exerrors.Must(testdata.FS.ReadDir("commands")) { + t.Run(strings.TrimSuffix(cmd.Name(), ".json"), func(t *testing.T) { + ctd := loadFile[CommandTestData]("commands/" + cmd.Name()) + for _, test := range ctd.Tests { + outputStr := exbytes.UnsafeString(compactJSON(test.Output)) + t.Run(test.Name, func(t *testing.T) { + if test.Broken != "" { + t.Skip(test.Broken) + } + output, err := ctd.Spec.ParseInput("@testbot", []string{"/"}, test.Input) + if test.Error { + assert.Error(t, err) + } else { + assert.NoError(t, err) + } + if outputStr == "null" { + assert.Nil(t, output) + } else { + assert.Equal(t, ctd.Spec.Command, output.MSC4391BotCommand.Command) + assert.Equalf(t, outputStr, exbytes.UnsafeString(output.MSC4391BotCommand.Arguments), "Input: %s", test.Input) + } + }) + } + }) + } +} diff --git a/event/cmdschema/roomid.go b/event/cmdschema/roomid.go new file mode 100644 index 00000000..98c421fc --- /dev/null +++ b/event/cmdschema/roomid.go @@ -0,0 +1,135 @@ +// Copyright (c) 2026 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package cmdschema + +import ( + "encoding/json" + "fmt" + "slices" + "strings" + + "maunium.net/go/mautrix/id" +) + +var ParameterSchemaJoinableRoom = Union( + PrimitiveTypeRoomID.Schema(), + PrimitiveTypeRoomAlias.Schema(), +) + +type RoomIDValue struct { + Type PrimitiveType `json:"type"` + RoomID id.RoomID `json:"id"` + Via []string `json:"via,omitempty"` + EventID id.EventID `json:"event_id,omitempty"` +} + +func NormalizeRoomIDValue(input any) (riv *RoomIDValue, err error) { + switch typedValue := input.(type) { + case map[string]any, json.RawMessage: + var raw json.RawMessage + if raw, err = json.Marshal(input); err != nil { + err = fmt.Errorf("failed to roundtrip room ID value: %w", err) + } else if err = json.Unmarshal(raw, &riv); err != nil { + err = fmt.Errorf("failed to roundtrip room ID value: %w", err) + } + case *RoomIDValue: + riv = typedValue + case RoomIDValue: + riv = &typedValue + default: + err = fmt.Errorf("unsupported type %T for room or event ID", input) + } + return +} + +func (riv *RoomIDValue) String() string { + return riv.URI().String() +} + +func (riv *RoomIDValue) URI() *id.MatrixURI { + if riv == nil { + return nil + } + switch riv.Type { + case PrimitiveTypeRoomID: + return riv.RoomID.URI(riv.Via...) + case PrimitiveTypeEventID: + return riv.RoomID.EventURI(riv.EventID, riv.Via...) + default: + return nil + } +} + +func (riv *RoomIDValue) Equals(other *RoomIDValue) bool { + if riv == nil || other == nil { + return riv == other + } + return riv.Type == other.Type && + riv.RoomID == other.RoomID && + riv.EventID == other.EventID && + slices.Equal(riv.Via, other.Via) +} + +func (riv *RoomIDValue) Validate() error { + if riv == nil { + return fmt.Errorf("value is nil") + } + switch riv.Type { + case PrimitiveTypeRoomID: + if riv.EventID != "" { + return fmt.Errorf("event ID must be empty for room ID type") + } + case PrimitiveTypeEventID: + if !strings.HasPrefix(riv.EventID.String(), "$") { + return fmt.Errorf("event ID not valid: %q", riv.EventID) + } + default: + return fmt.Errorf("unexpected type %s for room/event ID value", riv.Type) + } + for _, via := range riv.Via { + if !id.ValidateServerName(via) { + return fmt.Errorf("invalid server name %q in vias", via) + } + } + sigil, localpart, serverName := id.ParseCommonIdentifier(riv.RoomID) + if sigil != '!' { + return fmt.Errorf("room ID does not start with !: %q", riv.RoomID) + } else if localpart == "" && serverName == "" { + return fmt.Errorf("room ID has empty localpart and server name: %q", riv.RoomID) + } else if serverName != "" && !id.ValidateServerName(serverName) { + return fmt.Errorf("invalid server name %q in room ID", serverName) + } + return nil +} + +func (riv *RoomIDValue) IsValid() bool { + return riv.Validate() == nil +} + +type RoomIDOrString string + +func (ros *RoomIDOrString) UnmarshalJSON(data []byte) error { + if len(data) == 0 { + return fmt.Errorf("empty data for room ID or string") + } + if data[0] == '"' { + var str string + if err := json.Unmarshal(data, &str); err != nil { + return err + } + *ros = RoomIDOrString(str) + return nil + } + var riv RoomIDValue + if err := json.Unmarshal(data, &riv); err != nil { + return err + } else if err = riv.Validate(); err != nil { + return err + } + *ros = RoomIDOrString(riv.String()) + return nil +} diff --git a/event/cmdschema/stringify.go b/event/cmdschema/stringify.go new file mode 100644 index 00000000..c5c57c53 --- /dev/null +++ b/event/cmdschema/stringify.go @@ -0,0 +1,122 @@ +// Copyright (c) 2026 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package cmdschema + +import ( + "encoding/json" + "strconv" + "strings" +) + +var quoteEscaper = strings.NewReplacer( + `"`, `\"`, + `\`, `\\`, +) + +const charsToQuote = ` \` + botArrayOpener + botArrayCloser + +func quoteString(val string) string { + if val == "" { + return `""` + } + val = quoteEscaper.Replace(val) + if strings.ContainsAny(val, charsToQuote) { + return `"` + val + `"` + } + return val +} + +func (ec *EventContent) StringifyArgs(args any) string { + var argMap map[string]any + switch typedArgs := args.(type) { + case json.RawMessage: + err := json.Unmarshal(typedArgs, &argMap) + if err != nil { + return "" + } + case map[string]any: + argMap = typedArgs + default: + if b, err := json.Marshal(args); err != nil { + return "" + } else if err = json.Unmarshal(b, &argMap); err != nil { + return "" + } + } + parts := make([]string, 0, len(ec.Parameters)) + for i, param := range ec.Parameters { + isLast := i == len(ec.Parameters)-1 + val := argMap[param.Key] + if val == nil { + val = param.DefaultValue + if val == nil && !param.Optional { + val = param.Schema.GetDefaultValue() + } + } + if val == nil { + continue + } + var stringified string + if param.Schema.SchemaType == SchemaTypeArray { + stringified = arrayArgumentToString(val, isLast) + } else { + stringified = singleArgumentToString(val) + } + if stringified != "" { + parts = append(parts, stringified) + } + } + return strings.Join(parts, " ") +} + +func arrayArgumentToString(val any, isLast bool) string { + valArr, ok := val.([]any) + if !ok { + return "" + } + parts := make([]string, 0, len(valArr)) + for _, elem := range valArr { + stringified := singleArgumentToString(elem) + if stringified != "" { + parts = append(parts, stringified) + } + } + joinedParts := strings.Join(parts, " ") + if isLast && len(parts) > 0 { + return joinedParts + } + return botArrayOpener + joinedParts + botArrayCloser +} + +func singleArgumentToString(val any) string { + switch typedVal := val.(type) { + case string: + return quoteString(typedVal) + case json.Number: + return typedVal.String() + case bool: + return strconv.FormatBool(typedVal) + case int: + return strconv.Itoa(typedVal) + case int64: + return strconv.FormatInt(typedVal, 10) + case float64: + return strconv.FormatInt(int64(typedVal), 10) + case map[string]any, json.RawMessage, RoomIDValue, *RoomIDValue: + normalized, err := NormalizeRoomIDValue(typedVal) + if err != nil { + return "" + } + uri := normalized.URI() + if uri == nil { + return "" + } + return quoteString(uri.String()) + default: + return "" + } +} diff --git a/event/cmdschema/testdata/commands.schema.json b/event/cmdschema/testdata/commands.schema.json new file mode 100644 index 00000000..e53382db --- /dev/null +++ b/event/cmdschema/testdata/commands.schema.json @@ -0,0 +1,281 @@ +{ + "$schema": "https://json-schema.org/draft/2020-12/schema#", + "$id": "commands.schema.json", + "title": "ParseInput test cases", + "description": "JSON schema for test case files containing command specifications and test cases", + "type": "object", + "required": [ + "spec", + "tests" + ], + "additionalProperties": false, + "properties": { + "spec": { + "title": "MSC4391 Command Description", + "description": "JSON schema defining the structure of a bot command event content", + "type": "object", + "required": [ + "command" + ], + "additionalProperties": false, + "properties": { + "command": { + "type": "string", + "description": "The command name that triggers this bot command" + }, + "aliases": { + "type": "array", + "description": "Alternative names/aliases for this command", + "items": { + "type": "string" + } + }, + "parameters": { + "type": "array", + "description": "List of parameters accepted by this command", + "items": { + "$ref": "#/$defs/Parameter" + } + }, + "description": { + "$ref": "#/$defs/ExtensibleTextContainer", + "description": "Human-readable description of the command" + }, + "fi.mau.tail_parameter": { + "type": "string", + "description": "The key of the parameter that accepts remaining arguments as tail text" + }, + "source": { + "type": "string", + "description": "The user ID of the bot that responds to this command" + } + } + }, + "tests": { + "type": "array", + "description": "Array of test cases for the command", + "items": { + "type": "object", + "description": "A single test case for command parsing", + "required": [ + "name", + "input" + ], + "additionalProperties": false, + "properties": { + "name": { + "type": "string", + "description": "The name of the test case" + }, + "input": { + "type": "string", + "description": "The command input string to parse" + }, + "output": { + "description": "The expected parsed parameter values, or null if the parsing is expected to fail", + "oneOf": [ + { + "type": "object", + "additionalProperties": true + }, + { + "type": "null" + } + ] + }, + "error": { + "type": "boolean", + "description": "Whether parsing should result in an error. May still produce output.", + "default": false + } + } + } + } + }, + "$defs": { + "ExtensibleTextContainer": { + "type": "object", + "description": "Container for text that can have multiple representations", + "required": [ + "m.text" + ], + "properties": { + "m.text": { + "type": "array", + "description": "Array of text representations in different formats", + "items": { + "$ref": "#/$defs/ExtensibleText" + } + } + } + }, + "ExtensibleText": { + "type": "object", + "description": "A text representation with a specific MIME type", + "required": [ + "body" + ], + "properties": { + "body": { + "type": "string", + "description": "The text content" + }, + "mimetype": { + "type": "string", + "description": "The MIME type of the text (e.g., text/plain, text/html)", + "default": "text/plain", + "examples": [ + "text/plain", + "text/html" + ] + } + } + }, + "Parameter": { + "type": "object", + "description": "A parameter definition for a command", + "required": [ + "key", + "schema" + ], + "additionalProperties": false, + "properties": { + "key": { + "type": "string", + "description": "The identifier for this parameter" + }, + "schema": { + "$ref": "#/$defs/ParameterSchema", + "description": "The schema defining the type and structure of this parameter" + }, + "optional": { + "type": "boolean", + "description": "Whether this parameter is optional", + "default": false + }, + "description": { + "$ref": "#/$defs/ExtensibleTextContainer", + "description": "Human-readable description of this parameter" + }, + "fi.mau.default_value": { + "description": "Default value for this parameter if not provided" + } + } + }, + "ParameterSchema": { + "type": "object", + "description": "Schema definition for a parameter value", + "required": [ + "schema_type" + ], + "additionalProperties": false, + "properties": { + "schema_type": { + "type": "string", + "enum": [ + "primitive", + "array", + "union", + "literal" + ], + "description": "The type of schema" + } + }, + "allOf": [ + { + "if": { + "properties": { + "schema_type": { + "const": "primitive" + } + } + }, + "then": { + "required": [ + "type" + ], + "properties": { + "type": { + "type": "string", + "enum": [ + "string", + "integer", + "boolean", + "server_name", + "user_id", + "room_id", + "room_alias", + "event_id" + ], + "description": "The primitive type (only for schema_type: primitive)" + } + } + } + }, + { + "if": { + "properties": { + "schema_type": { + "const": "array" + } + } + }, + "then": { + "required": [ + "items" + ], + "properties": { + "items": { + "$ref": "#/$defs/ParameterSchema", + "description": "The schema for array items (only for schema_type: array)" + } + } + } + }, + { + "if": { + "properties": { + "schema_type": { + "const": "union" + } + } + }, + "then": { + "required": [ + "variants" + ], + "properties": { + "variants": { + "type": "array", + "description": "The possible variants (only for schema_type: union)", + "items": { + "$ref": "#/$defs/ParameterSchema" + }, + "minItems": 1 + } + } + } + }, + { + "if": { + "properties": { + "schema_type": { + "const": "literal" + } + } + }, + "then": { + "required": [ + "value" + ], + "properties": { + "value": { + "description": "The literal value (only for schema_type: literal)" + } + } + } + } + ] + } + } +} diff --git a/event/cmdschema/testdata/commands/flags.json b/event/cmdschema/testdata/commands/flags.json new file mode 100644 index 00000000..6ce1f4da --- /dev/null +++ b/event/cmdschema/testdata/commands/flags.json @@ -0,0 +1,126 @@ +{ + "$schema": "../commands.schema.json#", + "spec": { + "command": "flag", + "source": "@testbot", + "parameters": [ + { + "key": "meow", + "schema": { + "schema_type": "primitive", + "type": "string" + } + }, + { + "key": "user", + "schema": { + "schema_type": "primitive", + "type": "user_id" + }, + "optional": true + }, + { + "key": "woof", + "schema": { + "schema_type": "primitive", + "type": "boolean" + }, + "optional": true, + "fi.mau.default_value": false + } + ], + "fi.mau.tail_parameter": "user" + }, + "tests": [ + { + "name": "no flags", + "input": "/flag mrrp", + "output": { + "meow": "mrrp", + "user": null + } + }, + { + "name": "no flags, has tail", + "input": "/flag mrrp @user:example.com", + "output": { + "meow": "mrrp", + "user": "@user:example.com" + } + }, + { + "name": "named flag at start", + "input": "/flag --woof=yes mrrp @user:example.com", + "output": { + "meow": "mrrp", + "user": "@user:example.com", + "woof": true + } + }, + { + "name": "boolean flag without value", + "input": "/flag --woof mrrp @user:example.com", + "output": { + "meow": "mrrp", + "user": "@user:example.com", + "woof": true + } + }, + { + "name": "user id flag without value", + "input": "/flag --user --woof mrrp", + "error": true, + "output": { + "meow": "mrrp", + "user": null, + "woof": true + } + }, + { + "name": "named flag in the middle", + "input": "/flag mrrp --woof=yes @user:example.com", + "output": { + "meow": "mrrp", + "user": "@user:example.com", + "woof": true + } + }, + { + "name": "named flag in the middle with different value", + "input": "/flag mrrp --woof=no @user:example.com", + "output": { + "meow": "mrrp", + "user": "@user:example.com", + "woof": false + } + }, + { + "name": "all variables named", + "input": "/flag --woof=no --meow=mrrp --user=@user:example.com", + "output": { + "meow": "mrrp", + "user": "@user:example.com", + "woof": false + } + }, + { + "name": "all variables named with quotes", + "input": "/flag --woof --meow=\"meow meow mrrp\" --user=\"@user:example.com\"", + "output": { + "meow": "meow meow mrrp", + "user": "@user:example.com", + "woof": true + } + }, + { + "name": "invalid value for named parameter", + "input": "/flag --user=meowings mrrp --woof", + "error": true, + "output": { + "meow": "mrrp", + "user": null, + "woof": true + } + } + ] +} diff --git a/event/cmdschema/testdata/commands/room_id_or_alias.json b/event/cmdschema/testdata/commands/room_id_or_alias.json new file mode 100644 index 00000000..1351c292 --- /dev/null +++ b/event/cmdschema/testdata/commands/room_id_or_alias.json @@ -0,0 +1,85 @@ +{ + "$schema": "../commands.schema.json#", + "spec": { + "command": "test room reference", + "source": "@testbot", + "parameters": [ + { + "key": "room", + "schema": { + "schema_type": "union", + "variants": [ + { + "schema_type": "primitive", + "type": "room_id" + }, + { + "schema_type": "primitive", + "type": "room_alias" + } + ] + } + } + ] + }, + "tests": [ + { + "name": "room alias", + "input": "/test room reference #test:matrix.org", + "output": { + "room": "#test:matrix.org" + } + }, + { + "name": "room id", + "input": "/test room reference !aiwVrNhPwbGBNjqlNu:matrix.org", + "output": { + "room": { + "type": "room_id", + "id": "!aiwVrNhPwbGBNjqlNu:matrix.org" + } + } + }, + { + "name": "room id matrix.to link", + "input": "/test room reference https://matrix.to/#/!aiwVrNhPwbGBNjqlNu:matrix.org?via=example.com", + "output": { + "room": { + "type": "room_id", + "id": "!aiwVrNhPwbGBNjqlNu:matrix.org", + "via": [ + "example.com" + ] + } + } + }, + { + "name": "room id matrix.to link with url encoding", + "input": "/test room reference https://matrix.to/#/!%23test%2Froom%0Aversion%20%3Cu%3E11%3C%2Fu%3E%2C%20with%20%40%F0%9F%90%88%EF%B8%8F%3Amaunium.net?via=maunium.net", + "broken": "Go's url.URL does url decoding on the fragment, which breaks splitting the path segments properly", + "output": { + "room": { + "type": "room_id", + "id": "!#test/room\nversion 11, with @🐈️:maunium.net", + "via": [ + "maunium.net" + ] + } + } + }, + { + "name": "room id matrix: URI", + "input": "/test room reference matrix:roomid/mauT12AzsoqxV7Abvy_ApA-HNPK1LcT4GbP70_AOPyQ?via=maunium.net&via=matrix.org", + "output": { + "room": { + "type": "room_id", + "id": "!mauT12AzsoqxV7Abvy_ApA-HNPK1LcT4GbP70_AOPyQ", + "via": [ + "maunium.net", + "matrix.org" + ] + } + } + } + ] +} diff --git a/event/cmdschema/testdata/commands/room_reference_list.json b/event/cmdschema/testdata/commands/room_reference_list.json new file mode 100644 index 00000000..aa266054 --- /dev/null +++ b/event/cmdschema/testdata/commands/room_reference_list.json @@ -0,0 +1,106 @@ +{ + "$schema": "../commands.schema.json#", + "spec": { + "command": "test room reference", + "source": "@testbot", + "parameters": [ + { + "key": "rooms", + "schema": { + "schema_type": "array", + "items": { + "schema_type": "union", + "variants": [ + { + "schema_type": "primitive", + "type": "room_id" + }, + { + "schema_type": "primitive", + "type": "room_alias" + } + ] + } + } + } + ] + }, + "tests": [ + { + "name": "room alias", + "input": "/test room reference #test:matrix.org", + "output": { + "rooms": [ + "#test:matrix.org" + ] + } + }, + { + "name": "room id", + "input": "/test room reference !aiwVrNhPwbGBNjqlNu:matrix.org", + "output": { + "rooms": [ + { + "type": "room_id", + "id": "!aiwVrNhPwbGBNjqlNu:matrix.org" + } + ] + } + }, + { + "name": "two room ids", + "input": "/test room reference !mauT12AzsoqxV7Abvy_ApA-HNPK1LcT4GbP70_AOPyQ !aiwVrNhPwbGBNjqlNu:matrix.org", + "output": { + "rooms": [ + { + "type": "room_id", + "id": "!mauT12AzsoqxV7Abvy_ApA-HNPK1LcT4GbP70_AOPyQ" + }, + { + "type": "room_id", + "id": "!aiwVrNhPwbGBNjqlNu:matrix.org" + } + ] + } + }, + { + "name": "room id matrix: URI", + "input": "/test room reference matrix:roomid/mauT12AzsoqxV7Abvy_ApA-HNPK1LcT4GbP70_AOPyQ?via=maunium.net&via=matrix.org", + "output": { + "rooms": [ + { + "type": "room_id", + "id": "!mauT12AzsoqxV7Abvy_ApA-HNPK1LcT4GbP70_AOPyQ", + "via": [ + "maunium.net", + "matrix.org" + ] + } + ] + } + }, + { + "name": "room id matrix: URI and matrix.to URL", + "input": "/test room reference https://matrix.to/#/!aiwVrNhPwbGBNjqlNu:matrix.org?via=example.com matrix:roomid/mauT12AzsoqxV7Abvy_ApA-HNPK1LcT4GbP70_AOPyQ?via=maunium.net&via=matrix.org", + "output": { + "rooms": [ + { + "type": "room_id", + "id": "!aiwVrNhPwbGBNjqlNu:matrix.org", + "via": [ + "example.com" + ] + }, + { + "type": "room_id", + "id": "!mauT12AzsoqxV7Abvy_ApA-HNPK1LcT4GbP70_AOPyQ", + "via": [ + "maunium.net", + "matrix.org" + ] + } + ] + } + } + ] +} diff --git a/event/cmdschema/testdata/commands/simple.json b/event/cmdschema/testdata/commands/simple.json new file mode 100644 index 00000000..94667323 --- /dev/null +++ b/event/cmdschema/testdata/commands/simple.json @@ -0,0 +1,46 @@ +{ + "$schema": "../commands.schema.json#", + "spec": { + "command": "test simple", + "source": "@testbot", + "parameters": [ + { + "key": "meow", + "schema": { + "schema_type": "primitive", + "type": "string" + } + } + ] + }, + "tests": [ + { + "name": "success", + "input": "/test simple mrrp", + "output": { + "meow": "mrrp" + } + }, + { + "name": "directed success", + "input": "/test simple@testbot mrrp", + "output": { + "meow": "mrrp" + } + }, + { + "name": "missing parameter", + "input": "/test simple", + "error": true, + "output": { + "meow": "" + } + }, + { + "name": "directed at another bot", + "input": "/test simple@anotherbot mrrp", + "error": false, + "output": null + } + ] +} diff --git a/event/cmdschema/testdata/commands/tail.json b/event/cmdschema/testdata/commands/tail.json new file mode 100644 index 00000000..9782f8ec --- /dev/null +++ b/event/cmdschema/testdata/commands/tail.json @@ -0,0 +1,60 @@ +{ + "$schema": "../commands.schema.json#", + "spec": { + "command": "tail", + "source": "@testbot", + "parameters": [ + { + "key": "meow", + "schema": { + "schema_type": "primitive", + "type": "string" + } + }, + { + "key": "reason", + "schema": { + "schema_type": "primitive", + "type": "string" + }, + "optional": true + }, + { + "key": "woof", + "schema": { + "schema_type": "primitive", + "type": "boolean" + }, + "optional": true + } + ], + "fi.mau.tail_parameter": "reason" + }, + "tests": [ + { + "name": "no tail or flag", + "input": "/tail mrrp", + "output": { + "meow": "mrrp", + "reason": "" + } + }, + { + "name": "tail, no flag", + "input": "/tail mrrp meow meow", + "output": { + "meow": "mrrp", + "reason": "meow meow" + } + }, + { + "name": "flag before tail", + "input": "/tail mrrp --woof meow meow", + "output": { + "meow": "mrrp", + "reason": "meow meow", + "woof": true + } + } + ] +} diff --git a/event/cmdschema/testdata/data.go b/event/cmdschema/testdata/data.go new file mode 100644 index 00000000..eceea3d2 --- /dev/null +++ b/event/cmdschema/testdata/data.go @@ -0,0 +1,14 @@ +// Copyright (c) 2026 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package testdata + +import ( + "embed" +) + +//go:embed * +var FS embed.FS diff --git a/event/cmdschema/testdata/parse_quote.json b/event/cmdschema/testdata/parse_quote.json new file mode 100644 index 00000000..8f52b7f5 --- /dev/null +++ b/event/cmdschema/testdata/parse_quote.json @@ -0,0 +1,30 @@ +[ + {"name": "empty string", "input": "", "output": ["", "", false]}, + {"name": "single word", "input": "meow", "output": ["meow", "", false]}, + {"name": "two words", "input": "meow woof", "output": ["meow", "woof", false]}, + {"name": "many words", "input": "meow meow mrrp", "output": ["meow", "meow mrrp", false]}, + {"name": "extra spaces", "input": "meow meow mrrp", "output": ["meow", "meow mrrp", false]}, + {"name": "trailing space", "input": "meow ", "output": ["meow", "", false]}, + {"name": "only spaces", "input": " ", "output": ["", "", false]}, + {"name": "leading spaces", "input": " meow woof", "output": ["", "meow woof", false]}, + {"name": "backslash at end unquoted", "input": "meow\\ woof", "output": ["meow\\", "woof", false]}, + {"name": "quoted word", "input": "\"meow\" meow mrrp", "output": ["meow", "meow mrrp", true]}, + {"name": "quoted words", "input": "\"meow meow\" mrrp", "output": ["meow meow", "mrrp", true]}, + {"name": "spaces in quotes", "input": "\" meow meow \" mrrp", "output": [" meow meow ", "mrrp", true]}, + {"name": "empty quoted string", "input": "\"\"", "output": ["", "", true]}, + {"name": "empty quoted with trailing", "input": "\"\" meow", "output": ["", "meow", true]}, + {"name": "quote no space before next", "input": "\"meow\"woof", "output": ["meow", "woof", true]}, + {"name": "just opening quote", "input": "\"", "output": ["", "", true]}, + {"name": "quote then space then text", "input": "\" meow", "output": [" meow", "", true]}, + {"name": "quotes after word", "input": "meow \" meow mrrp \"", "output": ["meow", "\" meow mrrp \"", false]}, + {"name": "escaped quote", "input": "\"meow\\\" meow\" mrrp", "output": ["meow\" meow", "mrrp", true]}, + {"name": "missing end quote", "input": "\"meow meow mrrp", "output": ["meow meow mrrp", "", true]}, + {"name": "missing end quote with escaped quote", "input": "\"meow\\\" meow mrrp", "output": ["meow\" meow mrrp", "", true]}, + {"name": "quote in the middle", "input": "me\"ow meow mrrp", "output": ["me\"ow", "meow mrrp", false]}, + {"name": "backslash in the middle", "input": "me\\ow meow mrrp", "output": ["me\\ow", "meow mrrp", false]}, + {"name": "other escaped character", "input": "\"m\\eow\" meow mrrp", "output": ["meow", "meow mrrp", true]}, + {"name": "escaped backslashes", "input": "\"m\\\\e\\\"ow\\\\\" meow mrrp", "output": ["m\\e\"ow\\", "meow mrrp", true]}, + {"name": "just quotes", "input": "\"\\\"\\\"\\\\\\\"\" meow", "output": ["\"\"\\\"", "meow", true]}, + {"name": "escape at eof", "input": "\"meow\\", "output": ["meow", "", true]}, + {"name": "escaped backslash at eof", "input": "\"meow\\\\", "output": ["meow\\", "", true]} +] diff --git a/event/cmdschema/testdata/parse_quote.schema.json b/event/cmdschema/testdata/parse_quote.schema.json new file mode 100644 index 00000000..9f249116 --- /dev/null +++ b/event/cmdschema/testdata/parse_quote.schema.json @@ -0,0 +1,46 @@ +{ + "$schema": "https://json-schema.org/draft/2020-12/schema#", + "$id": "parse_quote.schema.json", + "title": "parseQuote test cases", + "description": "Test cases for the parseQuoted function", + "type": "array", + "items": { + "type": "object", + "required": [ + "name", + "input", + "output" + ], + "properties": { + "name": { + "type": "string", + "description": "Name of the test case" + }, + "input": { + "type": "string", + "description": "Input string to be parsed" + }, + "output": { + "type": "array", + "description": "Expected output of parsing: [first word, remaining text, was quoted]", + "minItems": 3, + "maxItems": 3, + "prefixItems": [ + { + "type": "string", + "description": "First parsed word" + }, + { + "type": "string", + "description": "Remaining text after the first word" + }, + { + "type": "boolean", + "description": "Whether the first word was quoted" + } + ] + } + }, + "additionalProperties": false + } +} diff --git a/event/content.go b/event/content.go index c0ff51ad..814aeec4 100644 --- a/event/content.go +++ b/event/content.go @@ -40,6 +40,9 @@ var TypeMap = map[Type]reflect.Type{ StateSpaceParent: reflect.TypeOf(SpaceParentEventContent{}), StateSpaceChild: reflect.TypeOf(SpaceChildEventContent{}), + StateRoomPolicy: reflect.TypeOf(RoomPolicyEventContent{}), + StateUnstableRoomPolicy: reflect.TypeOf(RoomPolicyEventContent{}), + StateLegacyPolicyRoom: reflect.TypeOf(ModPolicyContent{}), StateLegacyPolicyServer: reflect.TypeOf(ModPolicyContent{}), StateLegacyPolicyUser: reflect.TypeOf(ModPolicyContent{}), @@ -50,7 +53,6 @@ var TypeMap = map[Type]reflect.Type{ StateElementFunctionalMembers: reflect.TypeOf(ElementFunctionalMembersContent{}), StateBeeperRoomFeatures: reflect.TypeOf(RoomFeatures{}), StateBeeperDisappearingTimer: reflect.TypeOf(BeeperDisappearingTimer{}), - StateBotCommands: reflect.TypeOf(BotCommandsEventContent{}), EventMessage: reflect.TypeOf(MessageEventContent{}), EventSticker: reflect.TypeOf(MessageEventContent{}), @@ -61,9 +63,11 @@ var TypeMap = map[Type]reflect.Type{ EventUnstablePollStart: reflect.TypeOf(PollStartEventContent{}), EventUnstablePollResponse: reflect.TypeOf(PollResponseEventContent{}), - BeeperMessageStatus: reflect.TypeOf(BeeperMessageStatusEventContent{}), - BeeperTranscription: reflect.TypeOf(BeeperTranscriptionEventContent{}), - BeeperDeleteChat: reflect.TypeOf(BeeperChatDeleteEventContent{}), + BeeperMessageStatus: reflect.TypeOf(BeeperMessageStatusEventContent{}), + BeeperTranscription: reflect.TypeOf(BeeperTranscriptionEventContent{}), + BeeperDeleteChat: reflect.TypeOf(BeeperChatDeleteEventContent{}), + BeeperAcceptMessageRequest: reflect.TypeOf(BeeperAcceptMessageRequestEventContent{}), + BeeperSendState: reflect.TypeOf(BeeperSendStateEventContent{}), AccountDataRoomTags: reflect.TypeOf(TagEventContent{}), AccountDataDirectChats: reflect.TypeOf(DirectChatsEventContent{}), @@ -72,9 +76,11 @@ var TypeMap = map[Type]reflect.Type{ AccountDataMarkedUnread: reflect.TypeOf(MarkedUnreadEventContent{}), AccountDataBeeperMute: reflect.TypeOf(BeeperMuteEventContent{}), - EphemeralEventTyping: reflect.TypeOf(TypingEventContent{}), - EphemeralEventReceipt: reflect.TypeOf(ReceiptEventContent{}), - EphemeralEventPresence: reflect.TypeOf(PresenceEventContent{}), + EphemeralEventTyping: reflect.TypeOf(TypingEventContent{}), + EphemeralEventReceipt: reflect.TypeOf(ReceiptEventContent{}), + EphemeralEventPresence: reflect.TypeOf(PresenceEventContent{}), + EphemeralEventEncrypted: reflect.TypeOf(EncryptedEventContent{}), + BeeperEphemeralEventAIStream: reflect.TypeOf(BeeperAIStreamEventContent{}), InRoomVerificationReady: reflect.TypeOf(VerificationReadyEventContent{}), InRoomVerificationStart: reflect.TypeOf(VerificationStartEventContent{}), diff --git a/event/encryption.go b/event/encryption.go index cf9c2814..c60cb91a 100644 --- a/event/encryption.go +++ b/event/encryption.go @@ -63,7 +63,7 @@ func (content *EncryptedEventContent) UnmarshalJSON(data []byte) error { return json.Unmarshal(content.Ciphertext, &content.OlmCiphertext) case id.AlgorithmMegolmV1: if len(content.Ciphertext) == 0 || content.Ciphertext[0] != '"' || content.Ciphertext[len(content.Ciphertext)-1] != '"' { - return id.InputNotJSONString + return fmt.Errorf("ciphertext %w", id.ErrInputNotJSONString) } content.MegolmCiphertext = content.Ciphertext[1 : len(content.Ciphertext)-1] } @@ -132,8 +132,9 @@ type RoomKeyRequestEventContent struct { type RequestedKeyInfo struct { Algorithm id.Algorithm `json:"algorithm"` RoomID id.RoomID `json:"room_id"` - SenderKey id.SenderKey `json:"sender_key"` SessionID id.SessionID `json:"session_id"` + // Deprecated: Matrix v1.3 + SenderKey id.SenderKey `json:"sender_key"` } type RoomKeyWithheldCode string diff --git a/event/message.go b/event/message.go index 692382cf..3fb3dc82 100644 --- a/event/message.go +++ b/event/message.go @@ -135,6 +135,7 @@ type MessageEventContent struct { BeeperGalleryCaption string `json:"com.beeper.gallery.caption,omitempty"` BeeperGalleryCaptionHTML string `json:"com.beeper.gallery.caption_html,omitempty"` BeeperPerMessageProfile *BeeperPerMessageProfile `json:"com.beeper.per_message_profile,omitempty"` + BeeperActionMessage *BeeperActionMessage `json:"com.beeper.action_message,omitempty"` BeeperLinkPreviews []*BeeperLinkPreview `json:"com.beeper.linkpreviews,omitempty"` @@ -143,7 +144,7 @@ type MessageEventContent struct { MSC1767Audio *MSC1767Audio `json:"org.matrix.msc1767.audio,omitempty"` MSC3245Voice *MSC3245Voice `json:"org.matrix.msc3245.voice,omitempty"` - MSC4332BotCommand *BotCommandInput `json:"org.matrix.msc4332.command,omitempty"` + MSC4391BotCommand *MSC4391BotCommandInput `json:"org.matrix.msc4391.command,omitempty"` } func (content *MessageEventContent) GetCapMsgType() CapabilityMsgType { @@ -287,6 +288,13 @@ func (m *Mentions) Merge(other *Mentions) *Mentions { } } +type MSC4391BotCommandInputCustom[T any] struct { + Command string `json:"command"` + Arguments T `json:"arguments,omitempty"` +} + +type MSC4391BotCommandInput = MSC4391BotCommandInputCustom[json.RawMessage] + type EncryptedFileInfo struct { attachment.EncryptedFile URL id.ContentURIString `json:"url"` diff --git a/event/powerlevels.go b/event/powerlevels.go index 50df2c1f..668eb6d3 100644 --- a/event/powerlevels.go +++ b/event/powerlevels.go @@ -28,6 +28,9 @@ type PowerLevelsEventContent struct { Events map[string]int `json:"events,omitempty"` EventsDefault int `json:"events_default,omitempty"` + beeperEphemeralLock sync.RWMutex + BeeperEphemeral map[string]int `json:"com.beeper.ephemeral,omitempty"` + Notifications *NotificationPowerLevels `json:"notifications,omitempty"` StateDefaultPtr *int `json:"state_default,omitempty"` @@ -37,6 +40,8 @@ type PowerLevelsEventContent struct { BanPtr *int `json:"ban,omitempty"` RedactPtr *int `json:"redact,omitempty"` + BeeperEphemeralDefaultPtr *int `json:"com.beeper.ephemeral_default,omitempty"` + // This is not a part of power levels, it's added by mautrix-go internally in certain places // in order to detect creator power accurately. CreateEvent *Event `json:"-"` @@ -51,6 +56,7 @@ func (pl *PowerLevelsEventContent) Clone() *PowerLevelsEventContent { UsersDefault: pl.UsersDefault, Events: maps.Clone(pl.Events), EventsDefault: pl.EventsDefault, + BeeperEphemeral: maps.Clone(pl.BeeperEphemeral), StateDefaultPtr: ptr.Clone(pl.StateDefaultPtr), Notifications: pl.Notifications.Clone(), @@ -60,6 +66,8 @@ func (pl *PowerLevelsEventContent) Clone() *PowerLevelsEventContent { BanPtr: ptr.Clone(pl.BanPtr), RedactPtr: ptr.Clone(pl.RedactPtr), + BeeperEphemeralDefaultPtr: ptr.Clone(pl.BeeperEphemeralDefaultPtr), + CreateEvent: pl.CreateEvent, } } @@ -119,6 +127,13 @@ func (pl *PowerLevelsEventContent) StateDefault() int { return 50 } +func (pl *PowerLevelsEventContent) BeeperEphemeralDefault() int { + if pl.BeeperEphemeralDefaultPtr != nil { + return *pl.BeeperEphemeralDefaultPtr + } + return pl.EventsDefault +} + func (pl *PowerLevelsEventContent) GetUserLevel(userID id.UserID) int { if pl.isCreator(userID) { return math.MaxInt @@ -132,9 +147,19 @@ func (pl *PowerLevelsEventContent) GetUserLevel(userID id.UserID) int { return level } +const maxPL = 1<<53 - 1 + func (pl *PowerLevelsEventContent) SetUserLevel(userID id.UserID, level int) { pl.usersLock.Lock() defer pl.usersLock.Unlock() + if pl.isCreator(userID) { + return + } + if level == math.MaxInt && maxPL < math.MaxInt { + // Hack to avoid breaking on 32-bit systems (they're only slightly supported) + x := int64(maxPL) + level = int(x) + } if level == pl.UsersDefault { delete(pl.Users, userID) } else { @@ -192,6 +217,29 @@ func (pl *PowerLevelsEventContent) GetEventLevel(eventType Type) int { return level } +func (pl *PowerLevelsEventContent) GetBeeperEphemeralLevel(eventType Type) int { + pl.beeperEphemeralLock.RLock() + defer pl.beeperEphemeralLock.RUnlock() + level, ok := pl.BeeperEphemeral[eventType.String()] + if !ok { + return pl.BeeperEphemeralDefault() + } + return level +} + +func (pl *PowerLevelsEventContent) SetBeeperEphemeralLevel(eventType Type, level int) { + pl.beeperEphemeralLock.Lock() + defer pl.beeperEphemeralLock.Unlock() + if level == pl.BeeperEphemeralDefault() { + delete(pl.BeeperEphemeral, eventType.String()) + } else { + if pl.BeeperEphemeral == nil { + pl.BeeperEphemeral = make(map[string]int) + } + pl.BeeperEphemeral[eventType.String()] = level + } +} + func (pl *PowerLevelsEventContent) SetEventLevel(eventType Type, level int) { pl.eventsLock.Lock() defer pl.eventsLock.Unlock() diff --git a/event/powerlevels_ephemeral_test.go b/event/powerlevels_ephemeral_test.go new file mode 100644 index 00000000..f5861583 --- /dev/null +++ b/event/powerlevels_ephemeral_test.go @@ -0,0 +1,67 @@ +// Copyright (c) 2026 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package event_test + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "maunium.net/go/mautrix/event" +) + +func TestPowerLevelsEventContent_BeeperEphemeralDefaultFallsBackToEventsDefault(t *testing.T) { + pl := &event.PowerLevelsEventContent{ + EventsDefault: 45, + } + + assert.Equal(t, 45, pl.BeeperEphemeralDefault()) + + override := 60 + pl.BeeperEphemeralDefaultPtr = &override + assert.Equal(t, 60, pl.BeeperEphemeralDefault()) +} + +func TestPowerLevelsEventContent_GetSetBeeperEphemeralLevel(t *testing.T) { + pl := &event.PowerLevelsEventContent{ + EventsDefault: 25, + } + evtType := event.Type{Type: "com.example.ephemeral", Class: event.EphemeralEventType} + + assert.Equal(t, 25, pl.GetBeeperEphemeralLevel(evtType)) + + pl.SetBeeperEphemeralLevel(evtType, 50) + assert.Equal(t, 50, pl.GetBeeperEphemeralLevel(evtType)) + require.NotNil(t, pl.BeeperEphemeral) + assert.Equal(t, 50, pl.BeeperEphemeral[evtType.String()]) + + pl.SetBeeperEphemeralLevel(evtType, 25) + _, exists := pl.BeeperEphemeral[evtType.String()] + assert.False(t, exists) +} + +func TestPowerLevelsEventContent_CloneCopiesBeeperEphemeralFields(t *testing.T) { + override := 70 + pl := &event.PowerLevelsEventContent{ + EventsDefault: 35, + BeeperEphemeral: map[string]int{"com.example.ephemeral": 90}, + BeeperEphemeralDefaultPtr: &override, + } + + cloned := pl.Clone() + require.NotNil(t, cloned) + require.NotNil(t, cloned.BeeperEphemeralDefaultPtr) + assert.Equal(t, 70, *cloned.BeeperEphemeralDefaultPtr) + assert.Equal(t, 90, cloned.BeeperEphemeral["com.example.ephemeral"]) + + cloned.BeeperEphemeral["com.example.ephemeral"] = 99 + *cloned.BeeperEphemeralDefaultPtr = 71 + + assert.Equal(t, 90, pl.BeeperEphemeral["com.example.ephemeral"]) + assert.Equal(t, 70, *pl.BeeperEphemeralDefaultPtr) +} diff --git a/event/reply.go b/event/reply.go index 9ae1c110..5f55bb80 100644 --- a/event/reply.go +++ b/event/reply.go @@ -32,12 +32,13 @@ func TrimReplyFallbackText(text string) string { } func (content *MessageEventContent) RemoveReplyFallback() { - if len(content.RelatesTo.GetReplyTo()) > 0 && !content.replyFallbackRemoved { - if content.Format == FormatHTML { - content.FormattedBody = TrimReplyFallbackHTML(content.FormattedBody) + if len(content.RelatesTo.GetReplyTo()) > 0 && !content.replyFallbackRemoved && content.Format == FormatHTML { + origHTML := content.FormattedBody + content.FormattedBody = TrimReplyFallbackHTML(content.FormattedBody) + if content.FormattedBody != origHTML { + content.Body = TrimReplyFallbackText(content.Body) + content.replyFallbackRemoved = true } - content.Body = TrimReplyFallbackText(content.Body) - content.replyFallbackRemoved = true } } diff --git a/event/state.go b/event/state.go index ed5434c9..ace170a5 100644 --- a/event/state.go +++ b/event/state.go @@ -62,6 +62,13 @@ type ExtensibleTextContainer struct { Text []ExtensibleText `json:"m.text"` } +func (c *ExtensibleTextContainer) Equals(description *ExtensibleTextContainer) bool { + if c == nil || description == nil { + return c == description + } + return slices.Equal(c.Text, description.Text) +} + func MakeExtensibleText(text string) *ExtensibleTextContainer { return &ExtensibleTextContainer{ Text: []ExtensibleText{{ @@ -96,6 +103,13 @@ type TombstoneEventContent struct { ReplacementRoom id.RoomID `json:"replacement_room"` } +func (tec *TombstoneEventContent) GetReplacementRoom() id.RoomID { + if tec == nil { + return "" + } + return tec.ReplacementRoom +} + type Predecessor struct { RoomID id.RoomID `json:"room_id"` EventID id.EventID `json:"event_id"` @@ -135,6 +149,13 @@ type CreateEventContent struct { Creator id.UserID `json:"creator,omitempty"` } +func (cec *CreateEventContent) GetPredecessor() (p Predecessor) { + if cec != nil && cec.Predecessor != nil { + p = *cec.Predecessor + } + return +} + func (cec *CreateEventContent) SupportsCreatorPower() bool { if cec == nil { return false @@ -217,7 +238,8 @@ type BridgeInfoSection struct { AvatarURL id.ContentURIString `json:"avatar_url,omitempty"` ExternalURL string `json:"external_url,omitempty"` - Receiver string `json:"fi.mau.receiver,omitempty"` + Receiver string `json:"fi.mau.receiver,omitempty"` + MessageRequest bool `json:"com.beeper.message_request,omitempty"` } // BridgeEventContent represents the content of a m.bridge state event. @@ -232,7 +254,8 @@ type BridgeEventContent struct { BeeperRoomType string `json:"com.beeper.room_type,omitempty"` BeeperRoomTypeV2 string `json:"com.beeper.room_type.v2,omitempty"` - TempSlackRemoteIDMigratedFlag bool `json:"com.beeper.slack_remote_id_migrated,omitempty"` + TempSlackRemoteIDMigratedFlag bool `json:"com.beeper.slack_remote_id_migrated,omitempty"` + TempSlackRemoteIDMigratedFlag2 bool `json:"com.beeper.slack_remote_id_really_migrated,omitempty"` } // DisappearingType represents the type of a disappearing message timer. @@ -320,3 +343,15 @@ func (efmc *ElementFunctionalMembersContent) Add(mxid id.UserID) bool { efmc.ServiceMembers = append(efmc.ServiceMembers, mxid) return true } + +type PolicyServerPublicKeys struct { + Ed25519 id.Ed25519 `json:"ed25519,omitempty"` +} + +type RoomPolicyEventContent struct { + Via string `json:"via,omitempty"` + PublicKeys *PolicyServerPublicKeys `json:"public_keys,omitempty"` + + // Deprecated, only for legacy use + PublicKey id.Ed25519 `json:"public_key,omitempty"` +} diff --git a/event/type.go b/event/type.go index 56ea82f6..80b86728 100644 --- a/event/type.go +++ b/event/type.go @@ -113,9 +113,9 @@ func (et *Type) GuessClass() TypeClass { StatePinnedEvents.Type, StateTombstone.Type, StateEncryption.Type, StateBridge.Type, StateHalfShotBridge.Type, StateSpaceParent.Type, StateSpaceChild.Type, StatePolicyRoom.Type, StatePolicyServer.Type, StatePolicyUser.Type, StateElementFunctionalMembers.Type, StateBeeperRoomFeatures.Type, StateBeeperDisappearingTimer.Type, - StateBotCommands.Type: + StateMSC4391BotCommand.Type, StateRoomPolicy.Type, StateUnstableRoomPolicy.Type: return StateEventType - case EphemeralEventReceipt.Type, EphemeralEventTyping.Type, EphemeralEventPresence.Type: + case EphemeralEventReceipt.Type, EphemeralEventTyping.Type, EphemeralEventPresence.Type, BeeperEphemeralEventAIStream.Type: return EphemeralEventType case AccountDataDirectChats.Type, AccountDataPushRules.Type, AccountDataRoomTags.Type, AccountDataFullyRead.Type, AccountDataIgnoredUserList.Type, AccountDataMarkedUnread.Type, @@ -128,7 +128,7 @@ func (et *Type) GuessClass() TypeClass { InRoomVerificationKey.Type, InRoomVerificationMAC.Type, InRoomVerificationCancel.Type, CallInvite.Type, CallCandidates.Type, CallAnswer.Type, CallReject.Type, CallSelectAnswer.Type, CallNegotiate.Type, CallHangup.Type, BeeperMessageStatus.Type, EventUnstablePollStart.Type, EventUnstablePollResponse.Type, - EventUnstablePollEnd.Type, BeeperTranscription.Type, BeeperDeleteChat.Type: + EventUnstablePollEnd.Type, BeeperTranscription.Type, BeeperDeleteChat.Type, BeeperAcceptMessageRequest.Type: return MessageEventType case ToDeviceRoomKey.Type, ToDeviceRoomKeyRequest.Type, ToDeviceForwardedRoomKey.Type, ToDeviceRoomKeyWithheld.Type, ToDeviceBeeperRoomKeyAck.Type: @@ -195,6 +195,9 @@ var ( StateSpaceChild = Type{"m.space.child", StateEventType} StateSpaceParent = Type{"m.space.parent", StateEventType} + StateRoomPolicy = Type{"m.room.policy", StateEventType} + StateUnstableRoomPolicy = Type{"org.matrix.msc4284.policy", StateEventType} + StateLegacyPolicyRoom = Type{"m.room.rule.room", StateEventType} StateLegacyPolicyServer = Type{"m.room.rule.server", StateEventType} StateLegacyPolicyUser = Type{"m.room.rule.user", StateEventType} @@ -205,7 +208,7 @@ var ( StateElementFunctionalMembers = Type{"io.element.functional_members", StateEventType} StateBeeperRoomFeatures = Type{"com.beeper.room_features", StateEventType} StateBeeperDisappearingTimer = Type{"com.beeper.disappearing_timer", StateEventType} - StateBotCommands = Type{"org.matrix.msc4332.commands", StateEventType} + StateMSC4391BotCommand = Type{"org.matrix.msc4391.command_description", StateEventType} ) // Message events @@ -234,9 +237,11 @@ var ( CallNegotiate = Type{"m.call.negotiate", MessageEventType} CallHangup = Type{"m.call.hangup", MessageEventType} - BeeperMessageStatus = Type{"com.beeper.message_send_status", MessageEventType} - BeeperTranscription = Type{"com.beeper.transcription", MessageEventType} - BeeperDeleteChat = Type{"com.beeper.delete_chat", MessageEventType} + BeeperMessageStatus = Type{"com.beeper.message_send_status", MessageEventType} + BeeperTranscription = Type{"com.beeper.transcription", MessageEventType} + BeeperDeleteChat = Type{"com.beeper.delete_chat", MessageEventType} + BeeperAcceptMessageRequest = Type{"com.beeper.accept_message_request", MessageEventType} + BeeperSendState = Type{"com.beeper.send_state", MessageEventType} EventUnstablePollStart = Type{Type: "org.matrix.msc3381.poll.start", Class: MessageEventType} EventUnstablePollResponse = Type{Type: "org.matrix.msc3381.poll.response", Class: MessageEventType} @@ -245,9 +250,11 @@ var ( // Ephemeral events var ( - EphemeralEventReceipt = Type{"m.receipt", EphemeralEventType} - EphemeralEventTyping = Type{"m.typing", EphemeralEventType} - EphemeralEventPresence = Type{"m.presence", EphemeralEventType} + EphemeralEventReceipt = Type{"m.receipt", EphemeralEventType} + EphemeralEventTyping = Type{"m.typing", EphemeralEventType} + EphemeralEventPresence = Type{"m.presence", EphemeralEventType} + EphemeralEventEncrypted = Type{"m.room.encrypted", EphemeralEventType} + BeeperEphemeralEventAIStream = Type{"com.beeper.ai.stream_event", EphemeralEventType} ) // Account data events diff --git a/federation/client.go b/federation/client.go index 8f454516..183fb5d1 100644 --- a/federation/client.go +++ b/federation/client.go @@ -30,6 +30,8 @@ type Client struct { ServerName string UserAgent string Key *SigningKey + + ResponseSizeLimit int64 } func NewClient(serverName string, key *SigningKey, cache ResolutionCache) *Client { @@ -37,10 +39,16 @@ func NewClient(serverName string, key *SigningKey, cache ResolutionCache) *Clien HTTP: &http.Client{ Transport: NewServerResolvingTransport(cache), Timeout: 120 * time.Second, + CheckRedirect: func(req *http.Request, via []*http.Request) error { + // Federation requests do not allow redirects. + return http.ErrUseLastResponse + }, }, UserAgent: mautrix.DefaultUserAgent, ServerName: serverName, Key: key, + + ResponseSizeLimit: mautrix.DefaultResponseSizeLimit, } } @@ -81,7 +89,7 @@ type RespSendTransaction struct { } func (c *Client) SendTransaction(ctx context.Context, req *ReqSendTransaction) (resp *RespSendTransaction, err error) { - err = c.MakeRequest(ctx, req.Destination, true, http.MethodPost, URLPath{"v1", "send", req.TxnID}, req, &resp) + err = c.MakeRequest(ctx, req.Destination, true, http.MethodPut, URLPath{"v1", "send", req.TxnID}, req, &resp) return } @@ -255,6 +263,169 @@ func (c *Client) GetOpenIDUserInfo(ctx context.Context, serverName, accessToken return } +type ReqMakeJoin struct { + RoomID id.RoomID + UserID id.UserID + Via string + SupportedVersions []id.RoomVersion +} + +type RespMakeJoin struct { + RoomVersion id.RoomVersion `json:"room_version"` + Event PDU `json:"event"` +} + +type ReqSendJoin struct { + RoomID id.RoomID + EventID id.EventID + OmitMembers bool + Event PDU + Via string +} + +type ReqSendKnock struct { + RoomID id.RoomID + EventID id.EventID + Event PDU + Via string +} + +type RespSendJoin struct { + AuthChain []PDU `json:"auth_chain"` + Event PDU `json:"event"` + MembersOmitted bool `json:"members_omitted"` + ServersInRoom []string `json:"servers_in_room"` + State []PDU `json:"state"` +} + +type RespSendKnock struct { + KnockRoomState []PDU `json:"knock_room_state"` +} + +type ReqSendInvite struct { + RoomID id.RoomID `json:"-"` + UserID id.UserID `json:"-"` + Event PDU `json:"event"` + InviteRoomState []PDU `json:"invite_room_state"` + RoomVersion id.RoomVersion `json:"room_version"` +} + +type RespSendInvite struct { + Event PDU `json:"event"` +} + +type ReqMakeLeave struct { + RoomID id.RoomID + UserID id.UserID + Via string +} + +type ReqSendLeave struct { + RoomID id.RoomID + EventID id.EventID + Event PDU + Via string +} + +type ( + ReqMakeKnock = ReqMakeJoin + RespMakeKnock = RespMakeJoin + RespMakeLeave = RespMakeJoin +) + +func (c *Client) MakeJoin(ctx context.Context, req *ReqMakeJoin) (resp *RespMakeJoin, err error) { + versions := make([]string, len(req.SupportedVersions)) + for i, v := range req.SupportedVersions { + versions[i] = string(v) + } + _, _, err = c.MakeFullRequest(ctx, RequestParams{ + ServerName: req.Via, + Method: http.MethodGet, + Path: URLPath{"v1", "make_join", req.RoomID, req.UserID}, + Query: url.Values{"ver": versions}, + Authenticate: true, + ResponseJSON: &resp, + }) + return +} + +func (c *Client) MakeKnock(ctx context.Context, req *ReqMakeKnock) (resp *RespMakeKnock, err error) { + versions := make([]string, len(req.SupportedVersions)) + for i, v := range req.SupportedVersions { + versions[i] = string(v) + } + _, _, err = c.MakeFullRequest(ctx, RequestParams{ + ServerName: req.Via, + Method: http.MethodGet, + Path: URLPath{"v1", "make_knock", req.RoomID, req.UserID}, + Query: url.Values{"ver": versions}, + Authenticate: true, + ResponseJSON: &resp, + }) + return +} + +func (c *Client) SendJoin(ctx context.Context, req *ReqSendJoin) (resp *RespSendJoin, err error) { + _, _, err = c.MakeFullRequest(ctx, RequestParams{ + ServerName: req.Via, + Method: http.MethodPut, + Path: URLPath{"v2", "send_join", req.RoomID, req.EventID}, + Query: url.Values{ + "omit_members": {strconv.FormatBool(req.OmitMembers)}, + }, + Authenticate: true, + RequestJSON: req.Event, + ResponseJSON: &resp, + }) + return +} + +func (c *Client) SendKnock(ctx context.Context, req *ReqSendKnock) (resp *RespSendKnock, err error) { + _, _, err = c.MakeFullRequest(ctx, RequestParams{ + ServerName: req.Via, + Method: http.MethodPut, + Path: URLPath{"v1", "send_knock", req.RoomID, req.EventID}, + Authenticate: true, + RequestJSON: req.Event, + ResponseJSON: &resp, + }) + return +} + +func (c *Client) SendInvite(ctx context.Context, req *ReqSendInvite) (resp *RespSendInvite, err error) { + _, _, err = c.MakeFullRequest(ctx, RequestParams{ + ServerName: req.UserID.Homeserver(), + Method: http.MethodPut, + Path: URLPath{"v2", "invite", req.RoomID, req.UserID}, + Authenticate: true, + RequestJSON: req, + ResponseJSON: &resp, + }) + return +} + +func (c *Client) MakeLeave(ctx context.Context, req *ReqMakeLeave) (resp *RespMakeLeave, err error) { + _, _, err = c.MakeFullRequest(ctx, RequestParams{ + ServerName: req.Via, + Method: http.MethodGet, + Path: URLPath{"v1", "make_leave", req.RoomID, req.UserID}, + Authenticate: true, + ResponseJSON: &resp, + }) + return +} + +func (c *Client) SendLeave(ctx context.Context, req *ReqSendLeave) (err error) { + _, _, err = c.MakeFullRequest(ctx, RequestParams{ + ServerName: req.Via, + Method: http.MethodPut, + Path: URLPath{"v2", "send_leave", req.RoomID, req.EventID}, + Authenticate: true, + RequestJSON: req.Event, + }) + return +} + type URLPath []any func (fup URLPath) FullPath() []any { @@ -306,15 +477,27 @@ func (c *Client) MakeFullRequest(ctx context.Context, params RequestParams) ([]b WrappedError: err, } } - defer func() { - _ = resp.Body.Close() - }() + if !params.DontReadBody { + defer resp.Body.Close() + } var body []byte - if resp.StatusCode >= 400 { + if resp.StatusCode >= 300 { body, err = mautrix.ParseErrorResponse(req, resp) return body, resp, err } else if params.ResponseJSON != nil || !params.DontReadBody { - body, err = io.ReadAll(resp.Body) + if resp.ContentLength > c.ResponseSizeLimit { + return body, resp, mautrix.HTTPError{ + Request: req, + Response: resp, + + Message: "not reading response", + WrappedError: fmt.Errorf("%w (%.2f MiB)", mautrix.ErrResponseTooLong, float64(resp.ContentLength)/1024/1024), + } + } + body, err = io.ReadAll(io.LimitReader(resp.Body, c.ResponseSizeLimit+1)) + if err == nil && len(body) > int(c.ResponseSizeLimit) { + err = mautrix.ErrBodyReadReachedLimit + } if err != nil { return body, resp, mautrix.HTTPError{ Request: req, diff --git a/federation/eventauth/eventauth.go b/federation/eventauth/eventauth.go index bd102213..c72933c2 100644 --- a/federation/eventauth/eventauth.go +++ b/federation/eventauth/eventauth.go @@ -310,7 +310,7 @@ func authorizeMember(roomVersion id.RoomVersion, evt, createEvt *pdu.PDU, authEv // 5.1. If there is no state_key property, or no membership property in content, reject. return ErrMemberNotState } - authorizedVia := id.UserID(gjson.GetBytes(evt.Content, "authorized_via_users_server").Str) + authorizedVia := id.UserID(gjson.GetBytes(evt.Content, "authorised_via_users_server").Str) if authorizedVia != "" { homeserver := authorizedVia.Homeserver() err := evt.VerifySignature(roomVersion, homeserver, getKey) @@ -484,7 +484,7 @@ func authorizeMember(roomVersion id.RoomVersion, evt, createEvt *pdu.PDU, authEv } return ErrCantLeaveWithoutBeingInRoom } - if senderMembership != event.MembershipLeave { + if senderMembership != event.MembershipJoin { // 5.5.2. If the sender’s current membership state is not join, reject. return ErrCantKickWithoutBeingInRoom } @@ -505,7 +505,7 @@ func authorizeMember(roomVersion id.RoomVersion, evt, createEvt *pdu.PDU, authEv // 5.5.5. Otherwise, reject. return ErrInsufficientPermissionForKick case event.MembershipBan: - if senderMembership != event.MembershipLeave { + if senderMembership != event.MembershipJoin { // 5.6.1. If the sender’s current membership state is not join, reject. return ErrCantBanWithoutBeingInRoom } diff --git a/federation/eventauth/eventauth_internal_test.go b/federation/eventauth/eventauth_internal_test.go new file mode 100644 index 00000000..d316f3c8 --- /dev/null +++ b/federation/eventauth/eventauth_internal_test.go @@ -0,0 +1,66 @@ +// Copyright (c) 2026 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +//go:build goexperiment.jsonv2 + +package eventauth + +import ( + "strings" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/tidwall/gjson" +) + +type pythonIntTest struct { + Name string + Input string + Expected int64 +} + +var pythonIntTests = []pythonIntTest{ + {"True", `true`, 1}, + {"False", `false`, 0}, + {"SmallFloat", `3.1415`, 3}, + {"SmallFloatRoundDown", `10.999999999999999`, 10}, + {"SmallFloatRoundUp", `10.9999999999999999`, 11}, + {"BigFloatRoundDown", `1000000.9999999999`, 1000000}, + {"BigFloatRoundUp", `1000000.99999999999`, 1000001}, + {"BigFloatPrecisionError", `9007199254740993.0`, 9007199254740992}, + {"BigFloatPrecisionError2", `9007199254740993.123`, 9007199254740994}, + {"Int64", `9223372036854775807`, 9223372036854775807}, + {"Int64String", `"9223372036854775807"`, 9223372036854775807}, + {"String", `"123"`, 123}, + {"InvalidFloatInString", `"123.456"`, 0}, + {"StringWithPlusSign", `"+123"`, 123}, + {"StringWithMinusSign", `"-123"`, -123}, + {"StringWithSpaces", `" 123 "`, 123}, + {"StringWithSpacesAndSign", `" -123 "`, -123}, + //{"StringWithUnderscores", `"123_456"`, 123456}, + //{"StringWithUnderscores", `"123_456"`, 123456}, + {"InvalidStringWithTrailingUnderscore", `"123_456_"`, 0}, + {"InvalidStringWithMultipleUnderscores", `"123__456"`, 0}, + {"InvalidStringWithLeadingUnderscore", `"_123_456"`, 0}, + {"InvalidStringWithUnderscoreAfterSign", `"+_123_456"`, 0}, + {"InvalidStringWithUnderscoreAfterSpace", `" _123_456"`, 0}, + //{"StringWithUnderscoresAndSpaces", `" +1_2_3_4_5_6 "`, 123456}, +} + +func TestParsePythonInt(t *testing.T) { + for _, test := range pythonIntTests { + t.Run(test.Name, func(t *testing.T) { + output := parsePythonInt(gjson.Parse(test.Input)) + if strings.HasPrefix(test.Name, "Invalid") { + assert.Nil(t, output) + } else { + require.NotNil(t, output) + assert.Equal(t, int(test.Expected), *output) + } + }) + } +} diff --git a/federation/pdu/auth.go b/federation/pdu/auth.go index 1f98de06..16706fe5 100644 --- a/federation/pdu/auth.go +++ b/federation/pdu/auth.go @@ -61,7 +61,7 @@ func (pdu *PDU) AuthEventSelection(roomVersion id.RoomVersion) (keys AuthEventSe } } if membership == event.MembershipJoin && roomVersion.RestrictedJoins() { - authorizedVia := gjson.GetBytes(pdu.Content, "authorized_via_users_server").Str + authorizedVia := gjson.GetBytes(pdu.Content, "authorised_via_users_server").Str if authorizedVia != "" { keys.Add(event.StateMember.Type, authorizedVia) } diff --git a/federation/pdu/pdu.go b/federation/pdu/pdu.go index cecee5b9..17db6995 100644 --- a/federation/pdu/pdu.go +++ b/federation/pdu/pdu.go @@ -123,6 +123,19 @@ func (pdu *PDU) ToClientEvent(roomVersion id.RoomVersion) (*event.Event, error) return evt, nil } +func (pdu *PDU) AddSignature(serverName string, keyID id.KeyID, signature string) { + if signature == "" { + return + } + if pdu.Signatures == nil { + pdu.Signatures = make(map[string]map[id.KeyID]string) + } + if _, ok := pdu.Signatures[serverName]; !ok { + pdu.Signatures[serverName] = make(map[id.KeyID]string) + } + pdu.Signatures[serverName][keyID] = signature +} + func marshalCanonical(data any) (jsontext.Value, error) { marshaledBytes, err := json.Marshal(data) if err != nil { diff --git a/federation/pdu/signature.go b/federation/pdu/signature.go index a7685cc6..04e7c5ef 100644 --- a/federation/pdu/signature.go +++ b/federation/pdu/signature.go @@ -28,13 +28,7 @@ func (pdu *PDU) Sign(roomVersion id.RoomVersion, serverName string, keyID id.Key return fmt.Errorf("failed to marshal redacted PDU to sign: %w", err) } signature := ed25519.Sign(privateKey, rawJSON) - if pdu.Signatures == nil { - pdu.Signatures = make(map[string]map[id.KeyID]string) - } - if _, ok := pdu.Signatures[serverName]; !ok { - pdu.Signatures[serverName] = make(map[id.KeyID]string) - } - pdu.Signatures[serverName][keyID] = base64.RawStdEncoding.EncodeToString(signature) + pdu.AddSignature(serverName, keyID, base64.RawStdEncoding.EncodeToString(signature)) return nil } diff --git a/federation/resolution.go b/federation/resolution.go index 69d4d3bf..a3188266 100644 --- a/federation/resolution.go +++ b/federation/resolution.go @@ -20,6 +20,8 @@ import ( "time" "github.com/rs/zerolog" + + "maunium.net/go/mautrix" ) type ResolvedServerName struct { @@ -78,7 +80,10 @@ func ResolveServerName(ctx context.Context, serverName string, opts ...*ResolveS } else if wellKnown != nil { output.Expires = expiry output.HostHeader = wellKnown.Server - hostname, port, ok = ParseServerName(wellKnown.Server) + wkHost, wkPort, ok := ParseServerName(wellKnown.Server) + if ok { + hostname, port = wkHost, wkPort + } // Step 3.1 and 3.2: IP literals and hostnames with port inside .well-known if net.ParseIP(hostname) != nil || port != 0 { if port == 0 { @@ -171,9 +176,11 @@ func RequestWellKnown(ctx context.Context, cli *http.Client, hostname string) (* defer resp.Body.Close() if resp.StatusCode != http.StatusOK { return nil, time.Time{}, fmt.Errorf("unexpected status code %d", resp.StatusCode) + } else if resp.ContentLength > mautrix.WellKnownMaxSize { + return nil, time.Time{}, fmt.Errorf("response too large: %d bytes", resp.ContentLength) } var respData RespWellKnown - err = json.NewDecoder(io.LimitReader(resp.Body, 50*1024)).Decode(&respData) + err = json.NewDecoder(io.LimitReader(resp.Body, mautrix.WellKnownMaxSize)).Decode(&respData) if err != nil { return nil, time.Time{}, fmt.Errorf("failed to decode response: %w", err) } else if respData.Server == "" { diff --git a/federation/serverauth.go b/federation/serverauth.go index f46c7991..cd300341 100644 --- a/federation/serverauth.go +++ b/federation/serverauth.go @@ -231,7 +231,7 @@ func (sa *ServerAuth) Authenticate(r *http.Request) (*http.Request, *mautrix.Res } err = (&signableRequest{ Method: r.Method, - URI: r.URL.EscapedPath(), + URI: r.URL.RequestURI(), Origin: parsed.Origin, Destination: destination, Content: reqBody, diff --git a/federation/serverauth_test.go b/federation/serverauth_test.go index 9fa15459..f99fc6cf 100644 --- a/federation/serverauth_test.go +++ b/federation/serverauth_test.go @@ -19,9 +19,9 @@ import ( func TestServerKeyResponse_VerifySelfSignature(t *testing.T) { cli := federation.NewClient("", nil, nil) ctx := context.Background() - for _, name := range []string{"matrix.org", "maunium.net", "continuwuity.org"} { + for _, name := range []string{"matrix.org", "maunium.net", "cd.mau.dev", "uwu.mau.dev"} { t.Run(name, func(t *testing.T) { - resp, err := cli.ServerKeys(ctx, "matrix.org") + resp, err := cli.ServerKeys(ctx, name) require.NoError(t, err) assert.NoError(t, resp.VerifySelfSignature()) }) diff --git a/filter.go b/filter.go index c6c8211b..54973dab 100644 --- a/filter.go +++ b/filter.go @@ -57,7 +57,7 @@ type FilterPart struct { // Validate checks if the filter contains valid property values func (filter *Filter) Validate() error { if filter.EventFormat != EventFormatClient && filter.EventFormat != EventFormatFederation { - return errors.New("Bad event_format value. Must be one of [\"client\", \"federation\"]") + return errors.New("bad event_format value") } return nil } diff --git a/format/htmlparser.go b/format/htmlparser.go index e5f92896..e0507d93 100644 --- a/format/htmlparser.go +++ b/format/htmlparser.go @@ -93,6 +93,30 @@ func DefaultPillConverter(displayname, mxid, eventID string, ctx Context) string } } +func onlyBacktickCount(line string) (count int) { + for i := 0; i < len(line); i++ { + if line[i] != '`' { + return -1 + } + count++ + } + return +} + +func DefaultMonospaceBlockConverter(code, language string, ctx Context) string { + if len(code) == 0 || code[len(code)-1] != '\n' { + code += "\n" + } + fence := "```" + for line := range strings.SplitSeq(code, "\n") { + count := onlyBacktickCount(strings.TrimSpace(line)) + if count >= len(fence) { + fence = strings.Repeat("`", count+1) + } + } + return fmt.Sprintf("%s%s\n%s%s", fence, language, code, fence) +} + // HTMLParser is a somewhat customizable Matrix HTML parser. type HTMLParser struct { PillConverter PillConverter @@ -348,10 +372,7 @@ func (parser *HTMLParser) tagToString(node *html.Node, ctx Context) string { if parser.MonospaceBlockConverter != nil { return parser.MonospaceBlockConverter(preStr, language, ctx) } - if len(preStr) == 0 || preStr[len(preStr)-1] != '\n' { - preStr += "\n" - } - return fmt.Sprintf("```%s\n%s```", language, preStr) + return DefaultMonospaceBlockConverter(preStr, language, ctx) default: return parser.nodeToTagAwareString(node.FirstChild, ctx) } diff --git a/go.mod b/go.mod index fb63cf59..49a1d4e4 100644 --- a/go.mod +++ b/go.mod @@ -1,42 +1,42 @@ module maunium.net/go/mautrix -go 1.24.0 +go 1.25.0 -toolchain go1.25.3 +toolchain go1.26.0 require ( - filippo.io/edwards25519 v1.1.0 + filippo.io/edwards25519 v1.2.0 github.com/chzyer/readline v1.5.1 github.com/coder/websocket v1.8.14 - github.com/lib/pq v1.10.9 - github.com/mattn/go-sqlite3 v1.14.32 + github.com/lib/pq v1.11.2 + github.com/mattn/go-sqlite3 v1.14.34 github.com/rs/xid v1.6.0 github.com/rs/zerolog v1.34.0 github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e github.com/stretchr/testify v1.11.1 github.com/tidwall/gjson v1.18.0 github.com/tidwall/sjson v1.2.5 - github.com/yuin/goldmark v1.7.13 - go.mau.fi/util v0.9.2 + github.com/yuin/goldmark v1.7.16 + go.mau.fi/util v0.9.6 go.mau.fi/zeroconfig v0.2.0 - golang.org/x/crypto v0.43.0 - golang.org/x/exp v0.0.0-20251009144603-d2f985daa21b - golang.org/x/net v0.46.0 - golang.org/x/sync v0.17.0 + golang.org/x/crypto v0.48.0 + golang.org/x/exp v0.0.0-20260218203240-3dfff04db8fa + golang.org/x/net v0.50.0 + golang.org/x/sync v0.19.0 gopkg.in/yaml.v3 v3.0.1 maunium.net/go/mauflag v1.0.0 ) require ( - github.com/coreos/go-systemd/v22 v22.5.0 // indirect + github.com/coreos/go-systemd/v22 v22.6.0 // indirect github.com/davecgh/go-spew v1.1.1 // indirect github.com/mattn/go-colorable v0.1.14 // indirect github.com/mattn/go-isatty v0.0.20 // indirect - github.com/petermattis/goid v0.0.0-20250904145737-900bdf8bb490 // indirect + github.com/petermattis/goid v0.0.0-20260113132338-7c7de50cc741 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect github.com/tidwall/match v1.1.1 // indirect github.com/tidwall/pretty v1.2.1 // indirect - golang.org/x/sys v0.37.0 // indirect - golang.org/x/text v0.30.0 // indirect + golang.org/x/sys v0.41.0 // indirect + golang.org/x/text v0.34.0 // indirect gopkg.in/natefinch/lumberjack.v2 v2.2.1 // indirect ) diff --git a/go.sum b/go.sum index faa4ef4c..871a5156 100644 --- a/go.sum +++ b/go.sum @@ -1,5 +1,5 @@ -filippo.io/edwards25519 v1.1.0 h1:FNf4tywRC1HmFuKW5xopWpigGjJKiJSV0Cqo0cJWDaA= -filippo.io/edwards25519 v1.1.0/go.mod h1:BxyFTGdWcka3PhytdK4V28tE5sGfRvvvRV7EaN4VDT4= +filippo.io/edwards25519 v1.2.0 h1:crnVqOiS4jqYleHd9vaKZ+HKtHfllngJIiOpNpoJsjo= +filippo.io/edwards25519 v1.2.0/go.mod h1:xzAOLCNug/yB62zG1bQ8uziwrIqIuxhctzJT18Q77mc= github.com/DATA-DOG/go-sqlmock v1.5.2 h1:OcvFkGmslmlZibjAjaHm3L//6LiuBgolP7OputlJIzU= github.com/DATA-DOG/go-sqlmock v1.5.2/go.mod h1:88MAG/4G7SMwSE3CeA0ZKzrT5CiOU3OJ+JlNzwDqpNU= github.com/chzyer/logex v1.2.1 h1:XHDu3E6q+gdHgsdTPH6ImJMIp436vR6MPtH8gP05QzM= @@ -10,13 +10,14 @@ github.com/chzyer/test v1.0.0 h1:p3BQDXSxOhOG0P9z6/hGnII4LGiEPOYBhs8asl/fC04= github.com/chzyer/test v1.0.0/go.mod h1:2JlltgoNkt4TW/z9V/IzDdFaMTM2JPIi26O1pF38GC8= github.com/coder/websocket v1.8.14 h1:9L0p0iKiNOibykf283eHkKUHHrpG7f65OE3BhhO7v9g= github.com/coder/websocket v1.8.14/go.mod h1:NX3SzP+inril6yawo5CQXx8+fk145lPDC6pumgx0mVg= -github.com/coreos/go-systemd/v22 v22.5.0 h1:RrqgGjYQKalulkV8NGVIfkXQf6YYmOyiJKk8iXXhfZs= github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc= +github.com/coreos/go-systemd/v22 v22.6.0 h1:aGVa/v8B7hpb0TKl0MWoAavPDmHvobFe5R5zn0bCJWo= +github.com/coreos/go-systemd/v22 v22.6.0/go.mod h1:iG+pp635Fo7ZmV/j14KUcmEyWF+0X7Lua8rrTWzYgWU= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA= -github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw= -github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= +github.com/lib/pq v1.11.2 h1:x6gxUeu39V0BHZiugWe8LXZYZ+Utk7hSJGThs8sdzfs= +github.com/lib/pq v1.11.2/go.mod h1:/p+8NSbOcwzAEI7wiMXFlgydTwcgTr3OSKMsD2BitpA= github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg= github.com/mattn/go-colorable v0.1.14 h1:9A9LHSqF/7dyVVX6g0U9cwm9pG3kP9gSzcuIPHPsaIE= github.com/mattn/go-colorable v0.1.14/go.mod h1:6LmQG8QLFO4G5z1gPvYEzlUgJ2wF+stgPZH1UqBm1s8= @@ -24,10 +25,10 @@ github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/ github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= -github.com/mattn/go-sqlite3 v1.14.32 h1:JD12Ag3oLy1zQA+BNn74xRgaBbdhbNIDYvQUEuuErjs= -github.com/mattn/go-sqlite3 v1.14.32/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y= -github.com/petermattis/goid v0.0.0-20250904145737-900bdf8bb490 h1:QTvNkZ5ylY0PGgA+Lih+GdboMLY/G9SEGLMEGVjTVA4= -github.com/petermattis/goid v0.0.0-20250904145737-900bdf8bb490/go.mod h1:pxMtw7cyUw6B2bRH0ZBANSPg+AoSud1I1iyJHI69jH4= +github.com/mattn/go-sqlite3 v1.14.34 h1:3NtcvcUnFBPsuRcno8pUtupspG/GM+9nZ88zgJcp6Zk= +github.com/mattn/go-sqlite3 v1.14.34/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y= +github.com/petermattis/goid v0.0.0-20260113132338-7c7de50cc741 h1:KPpdlQLZcHfTMQRi6bFQ7ogNO0ltFT4PmtwTLW4W+14= +github.com/petermattis/goid v0.0.0-20260113132338-7c7de50cc741/go.mod h1:pxMtw7cyUw6B2bRH0ZBANSPg+AoSud1I1iyJHI69jH4= github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= @@ -49,28 +50,28 @@ github.com/tidwall/pretty v1.2.1 h1:qjsOFOWWQl+N3RsoF5/ssm1pHmJJwhjlSbZ51I6wMl4= github.com/tidwall/pretty v1.2.1/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU= github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY= github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28= -github.com/yuin/goldmark v1.7.13 h1:GPddIs617DnBLFFVJFgpo1aBfe/4xcvMc3SB5t/D0pA= -github.com/yuin/goldmark v1.7.13/go.mod h1:ip/1k0VRfGynBgxOz0yCqHrbZXhcjxyuS66Brc7iBKg= -go.mau.fi/util v0.9.2 h1:+S4Z03iCsGqU2WY8X2gySFsFjaLlUHFRDVCYvVwynKM= -go.mau.fi/util v0.9.2/go.mod h1:055elBBCJSdhRsmub7ci9hXZPgGr1U6dYg44cSgRgoU= +github.com/yuin/goldmark v1.7.16 h1:n+CJdUxaFMiDUNnWC3dMWCIQJSkxH4uz3ZwQBkAlVNE= +github.com/yuin/goldmark v1.7.16/go.mod h1:ip/1k0VRfGynBgxOz0yCqHrbZXhcjxyuS66Brc7iBKg= +go.mau.fi/util v0.9.6 h1:2nsvxm49KhI3wrFltr0+wSUBlnQ4CMtykuELjpIU+ts= +go.mau.fi/util v0.9.6/go.mod h1:sIJpRH7Iy5Ad1SBuxQoatxtIeErgzxCtjd/2hCMkYMI= go.mau.fi/zeroconfig v0.2.0 h1:e/OGEERqVRRKlgaro7E6bh8xXiKFSXB3eNNIud7FUjU= go.mau.fi/zeroconfig v0.2.0/go.mod h1:J0Vn0prHNOm493oZoQ84kq83ZaNCYZnq+noI1b1eN8w= -golang.org/x/crypto v0.43.0 h1:dduJYIi3A3KOfdGOHX8AVZ/jGiyPa3IbBozJ5kNuE04= -golang.org/x/crypto v0.43.0/go.mod h1:BFbav4mRNlXJL4wNeejLpWxB7wMbc79PdRGhWKncxR0= -golang.org/x/exp v0.0.0-20251009144603-d2f985daa21b h1:18qgiDvlvH7kk8Ioa8Ov+K6xCi0GMvmGfGW0sgd/SYA= -golang.org/x/exp v0.0.0-20251009144603-d2f985daa21b/go.mod h1:j/pmGrbnkbPtQfxEe5D0VQhZC6qKbfKifgD0oM7sR70= -golang.org/x/net v0.46.0 h1:giFlY12I07fugqwPuWJi68oOnpfqFnJIJzaIIm2JVV4= -golang.org/x/net v0.46.0/go.mod h1:Q9BGdFy1y4nkUwiLvT5qtyhAnEHgnQ/zd8PfU6nc210= -golang.org/x/sync v0.17.0 h1:l60nONMj9l5drqw6jlhIELNv9I0A4OFgRsG9k2oT9Ug= -golang.org/x/sync v0.17.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI= +golang.org/x/crypto v0.48.0 h1:/VRzVqiRSggnhY7gNRxPauEQ5Drw9haKdM0jqfcCFts= +golang.org/x/crypto v0.48.0/go.mod h1:r0kV5h3qnFPlQnBSrULhlsRfryS2pmewsg+XfMgkVos= +golang.org/x/exp v0.0.0-20260218203240-3dfff04db8fa h1:Zt3DZoOFFYkKhDT3v7Lm9FDMEV06GpzjG2jrqW+QTE0= +golang.org/x/exp v0.0.0-20260218203240-3dfff04db8fa/go.mod h1:K79w1Vqn7PoiZn+TkNpx3BUWUQksGO3JcVX6qIjytmA= +golang.org/x/net v0.50.0 h1:ucWh9eiCGyDR3vtzso0WMQinm2Dnt8cFMuQa9K33J60= +golang.org/x/net v0.50.0/go.mod h1:UgoSli3F/pBgdJBHCTc+tp3gmrU4XswgGRgtnwWTfyM= +golang.org/x/sync v0.19.0 h1:vV+1eWNmZ5geRlYjzm2adRgW2/mcpevXNg50YZtPCE4= +golang.org/x/sync v0.19.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI= golang.org/x/sys v0.0.0-20220310020820-b874c991c1a5/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.37.0 h1:fdNQudmxPjkdUTPnLn5mdQv7Zwvbvpaxqs831goi9kQ= -golang.org/x/sys v0.37.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= -golang.org/x/text v0.30.0 h1:yznKA/E9zq54KzlzBEAWn1NXSQ8DIp/NYMy88xJjl4k= -golang.org/x/text v0.30.0/go.mod h1:yDdHFIX9t+tORqspjENWgzaCVXgk0yYnYuSZ8UzzBVM= +golang.org/x/sys v0.41.0 h1:Ivj+2Cp/ylzLiEU89QhWblYnOE9zerudt9Ftecq2C6k= +golang.org/x/sys v0.41.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= +golang.org/x/text v0.34.0 h1:oL/Qq0Kdaqxa1KbNeMKwQq0reLCCaFtqu2eNuSeNHbk= +golang.org/x/text v0.34.0/go.mod h1:homfLqTYRFyVYemLBFl5GgL/DWEiH5wcsQ5gSh1yziA= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/natefinch/lumberjack.v2 v2.2.1 h1:bBRl1b0OH9s/DuPhuXpNl+VtCaJXFZ5/uEFST95x9zc= diff --git a/id/contenturi.go b/id/contenturi.go index e6a313f5..67127b6c 100644 --- a/id/contenturi.go +++ b/id/contenturi.go @@ -17,8 +17,14 @@ import ( ) var ( - InvalidContentURI = errors.New("invalid Matrix content URI") - InputNotJSONString = errors.New("input doesn't look like a JSON string") + ErrInvalidContentURI = errors.New("invalid Matrix content URI") + ErrInputNotJSONString = errors.New("input doesn't look like a JSON string") +) + +// Deprecated: use variables prefixed with Err +var ( + InvalidContentURI = ErrInvalidContentURI + InputNotJSONString = ErrInputNotJSONString ) // ContentURIString is a string that's expected to be a Matrix content URI. @@ -55,9 +61,9 @@ func ParseContentURI(uri string) (parsed ContentURI, err error) { if len(uri) == 0 { return } else if !strings.HasPrefix(uri, "mxc://") { - err = InvalidContentURI + err = ErrInvalidContentURI } else if index := strings.IndexRune(uri[6:], '/'); index == -1 || index == len(uri)-7 { - err = InvalidContentURI + err = ErrInvalidContentURI } else { parsed.Homeserver = uri[6 : 6+index] parsed.FileID = uri[6+index+1:] @@ -71,9 +77,9 @@ func ParseContentURIBytes(uri []byte) (parsed ContentURI, err error) { if len(uri) == 0 { return } else if !bytes.HasPrefix(uri, mxcBytes) { - err = InvalidContentURI + err = ErrInvalidContentURI } else if index := bytes.IndexRune(uri[6:], '/'); index == -1 || index == len(uri)-7 { - err = InvalidContentURI + err = ErrInvalidContentURI } else { parsed.Homeserver = string(uri[6 : 6+index]) parsed.FileID = string(uri[6+index+1:]) @@ -86,7 +92,7 @@ func (uri *ContentURI) UnmarshalJSON(raw []byte) (err error) { *uri = ContentURI{} return nil } else if len(raw) < 2 || raw[0] != '"' || raw[len(raw)-1] != '"' { - return InputNotJSONString + return fmt.Errorf("ContentURI: %w", ErrInputNotJSONString) } parsed, err := ParseContentURIBytes(raw[1 : len(raw)-1]) if err != nil { diff --git a/id/crypto.go b/id/crypto.go index 355a84a8..ee857f78 100644 --- a/id/crypto.go +++ b/id/crypto.go @@ -53,6 +53,34 @@ const ( KeyBackupAlgorithmMegolmBackupV1 KeyBackupAlgorithm = "m.megolm_backup.v1.curve25519-aes-sha2" ) +type KeySource string + +func (source KeySource) String() string { + return string(source) +} + +func (source KeySource) Int() int { + switch source { + case KeySourceDirect: + return 100 + case KeySourceBackup: + return 90 + case KeySourceImport: + return 80 + case KeySourceForward: + return 50 + default: + return 0 + } +} + +const ( + KeySourceDirect KeySource = "direct" + KeySourceBackup KeySource = "backup" + KeySourceImport KeySource = "import" + KeySourceForward KeySource = "forward" +) + // BackupVersion is an arbitrary string that identifies a server side key backup. type KeyBackupVersion string diff --git a/id/matrixuri.go b/id/matrixuri.go index 8f5ec849..d5c78bc7 100644 --- a/id/matrixuri.go +++ b/id/matrixuri.go @@ -54,7 +54,7 @@ var SigilToPathSegment = map[rune]string{ func (uri *MatrixURI) getQuery() url.Values { q := make(url.Values) - if uri.Via != nil && len(uri.Via) > 0 { + if len(uri.Via) > 0 { q["via"] = uri.Via } if len(uri.Action) > 0 { diff --git a/id/trust.go b/id/trust.go index 04f6e36b..6255093e 100644 --- a/id/trust.go +++ b/id/trust.go @@ -16,6 +16,7 @@ type TrustState int const ( TrustStateBlacklisted TrustState = -100 + TrustStateDeviceKeyMismatch TrustState = -5 TrustStateUnset TrustState = 0 TrustStateUnknownDevice TrustState = 10 TrustStateForwarded TrustState = 20 @@ -23,7 +24,7 @@ const ( TrustStateCrossSignedTOFU TrustState = 100 TrustStateCrossSignedVerified TrustState = 200 TrustStateVerified TrustState = 300 - TrustStateInvalid TrustState = (1 << 31) - 1 + TrustStateInvalid TrustState = -2147483647 ) func (ts *TrustState) UnmarshalText(data []byte) error { @@ -44,6 +45,8 @@ func ParseTrustState(val string) TrustState { switch strings.ToLower(val) { case "blacklisted": return TrustStateBlacklisted + case "device-key-mismatch": + return TrustStateDeviceKeyMismatch case "unverified": return TrustStateUnset case "cross-signed-untrusted": @@ -67,6 +70,8 @@ func (ts TrustState) String() string { switch ts { case TrustStateBlacklisted: return "blacklisted" + case TrustStateDeviceKeyMismatch: + return "device-key-mismatch" case TrustStateUnset: return "unverified" case TrustStateCrossSignedUntrusted: diff --git a/id/userid.go b/id/userid.go index 859d2358..726a0d58 100644 --- a/id/userid.go +++ b/id/userid.go @@ -219,15 +219,15 @@ func DecodeUserLocalpart(str string) (string, error) { for i := 0; i < len(strBytes); i++ { b := strBytes[i] if !isValidByte(b) { - return "", fmt.Errorf("Byte pos %d: Invalid byte", i) + return "", fmt.Errorf("invalid encoded byte at position %d: %c", i, b) } if b == '_' { // next byte is a-z and should be upper-case or is another _ and should be a literal _ if i+1 >= len(strBytes) { - return "", fmt.Errorf("Byte pos %d: expected _[a-z_] encoding but ran out of string", i) + return "", fmt.Errorf("unexpected end of string after underscore at %d", i) } if !isValidEscapedChar(strBytes[i+1]) { // invalid escaping - return "", fmt.Errorf("Byte pos %d: expected _[a-z_] encoding", i) + return "", fmt.Errorf("unexpected byte %c after underscore at %d", strBytes[i+1], i) } if strBytes[i+1] == '_' { outputBuffer.WriteByte('_') @@ -237,7 +237,7 @@ func DecodeUserLocalpart(str string) (string, error) { i++ // skip next byte since we just handled it } else if b == '=' { // next 2 bytes are hex and should be buffered ready to be read as utf8 if i+2 >= len(strBytes) { - return "", fmt.Errorf("Byte pos: %d: expected quote-printable encoding but ran out of string", i) + return "", fmt.Errorf("unexpected end of string after equals sign at %d", i) } dst := make([]byte, 1) _, err := hex.Decode(dst, strBytes[i+1:i+3]) diff --git a/mediaproxy/mediaproxy.go b/mediaproxy/mediaproxy.go index 07e30810..4d2bc7cf 100644 --- a/mediaproxy/mediaproxy.go +++ b/mediaproxy/mediaproxy.go @@ -95,9 +95,13 @@ func (d *GetMediaResponseCallback) GetContentType() string { return d.ContentType } +type FileMeta struct { + ContentType string + ReplacementFile string +} + type GetMediaResponseFile struct { - Callback func(w *os.File) error - ContentType string + Callback func(w *os.File) (*FileMeta, error) } type GetMediaFunc = func(ctx context.Context, mediaID string, params map[string]string) (response GetMediaResponse, err error) @@ -139,6 +143,7 @@ func New(serverName string, serverKey string, getMedia GetMediaFunc) (*MediaProx } mp.FederationRouter = http.NewServeMux() mp.FederationRouter.HandleFunc("GET /v1/media/download/{mediaID}", mp.DownloadMediaFederation) + mp.FederationRouter.HandleFunc("GET /v1/media/thumbnail/{mediaID}", mp.DownloadMediaFederation) mp.FederationRouter.HandleFunc("GET /v1/version", mp.KeyServer.GetServerVersion) mp.ClientMediaRouter = http.NewServeMux() mp.ClientMediaRouter.HandleFunc("GET /download/{serverName}/{mediaID}", mp.DownloadMedia) @@ -453,23 +458,35 @@ func doTempFileDownload( if err != nil { return false, fmt.Errorf("failed to create temp file: %w", err) } + origTempFile := tempFile defer func() { - _ = tempFile.Close() - _ = os.Remove(tempFile.Name()) + _ = origTempFile.Close() + _ = os.Remove(origTempFile.Name()) }() - err = data.Callback(tempFile) + meta, err := data.Callback(tempFile) if err != nil { return false, err } - _, err = tempFile.Seek(0, io.SeekStart) - if err != nil { - return false, fmt.Errorf("failed to seek to start of temp file: %w", err) + if meta.ReplacementFile != "" { + tempFile, err = os.Open(meta.ReplacementFile) + if err != nil { + return false, fmt.Errorf("failed to open replacement file: %w", err) + } + defer func() { + _ = tempFile.Close() + _ = os.Remove(origTempFile.Name()) + }() + } else { + _, err = tempFile.Seek(0, io.SeekStart) + if err != nil { + return false, fmt.Errorf("failed to seek to start of temp file: %w", err) + } } fileInfo, err := tempFile.Stat() if err != nil { return false, fmt.Errorf("failed to stat temp file: %w", err) } - mimeType := data.ContentType + mimeType := meta.ContentType if mimeType == "" { buf := make([]byte, 512) n, err := tempFile.Read(buf) diff --git a/mockserver/mockserver.go b/mockserver/mockserver.go index e52c387a..507c24a5 100644 --- a/mockserver/mockserver.go +++ b/mockserver/mockserver.go @@ -231,7 +231,7 @@ func (ms *MockServer) postKeysUpload(w http.ResponseWriter, r *http.Request) { } func (ms *MockServer) postDeviceSigningUpload(w http.ResponseWriter, r *http.Request) { - var req mautrix.UploadCrossSigningKeysReq + var req mautrix.UploadCrossSigningKeysReq[any] mustDecode(r, &req) userID := ms.getUserID(r).UserID diff --git a/pushrules/action.go b/pushrules/action.go index 9838e88b..b5a884b2 100644 --- a/pushrules/action.go +++ b/pushrules/action.go @@ -105,7 +105,7 @@ func (action *PushAction) UnmarshalJSON(raw []byte) error { if ok { action.Action = ActionSetTweak action.Tweak = PushActionTweak(tweak) - action.Value, _ = val["value"] + action.Value = val["value"] } } return nil diff --git a/pushrules/condition_test.go b/pushrules/condition_test.go index 0d3eaf7a..37af3e34 100644 --- a/pushrules/condition_test.go +++ b/pushrules/condition_test.go @@ -102,14 +102,6 @@ func newEventPropertyIsPushCondition(key string, value any) *pushrules.PushCondi } } -func newEventPropertyContainsPushCondition(key string, value any) *pushrules.PushCondition { - return &pushrules.PushCondition{ - Kind: pushrules.KindEventPropertyContains, - Key: key, - Value: value, - } -} - func TestPushCondition_Match_InvalidKind(t *testing.T) { condition := &pushrules.PushCondition{ Kind: pushrules.PushCondKind("invalid"), diff --git a/requests.go b/requests.go index f0287b3c..cc8b7266 100644 --- a/requests.go +++ b/requests.go @@ -66,14 +66,14 @@ const ( ) // ReqRegister is the JSON request for https://spec.matrix.org/v1.2/client-server-api/#post_matrixclientv3register -type ReqRegister struct { +type ReqRegister[UIAType any] struct { Username string `json:"username,omitempty"` Password string `json:"password,omitempty"` DeviceID id.DeviceID `json:"device_id,omitempty"` InitialDeviceDisplayName string `json:"initial_device_display_name,omitempty"` InhibitLogin bool `json:"inhibit_login,omitempty"` RefreshToken bool `json:"refresh_token,omitempty"` - Auth interface{} `json:"auth,omitempty"` + Auth UIAType `json:"auth,omitempty"` // Type for registration, only used for appservice user registrations // https://spec.matrix.org/v1.2/application-service-api/#server-admin-style-permissions @@ -320,11 +320,11 @@ func (csk *CrossSigningKeys) FirstKey() id.Ed25519 { return "" } -type UploadCrossSigningKeysReq struct { +type UploadCrossSigningKeysReq[UIAType any] struct { Master CrossSigningKeys `json:"master_key"` SelfSigning CrossSigningKeys `json:"self_signing_key"` UserSigning CrossSigningKeys `json:"user_signing_key"` - Auth interface{} `json:"auth,omitempty"` + Auth UIAType `json:"auth,omitempty"` } type KeyMap map[id.DeviceKeyID]string @@ -367,13 +367,12 @@ type ReqSendToDevice struct { } type ReqSendEvent struct { - Timestamp int64 - TransactionID string - UnstableDelay time.Duration - - DontEncrypt bool - - MeowEventID id.EventID + Timestamp int64 + TransactionID string + UnstableDelay time.Duration + UnstableStickyDuration time.Duration + DontEncrypt bool + MeowEventID id.EventID } type ReqDelayedEvents struct { @@ -393,14 +392,14 @@ type ReqDeviceInfo struct { } // ReqDeleteDevice is the JSON request for https://spec.matrix.org/v1.2/client-server-api/#delete_matrixclientv3devicesdeviceid -type ReqDeleteDevice struct { - Auth interface{} `json:"auth,omitempty"` +type ReqDeleteDevice[UIAType any] struct { + Auth UIAType `json:"auth,omitempty"` } // ReqDeleteDevices is the JSON request for https://spec.matrix.org/v1.2/client-server-api/#post_matrixclientv3delete_devices -type ReqDeleteDevices struct { +type ReqDeleteDevices[UIAType any] struct { Devices []id.DeviceID `json:"devices"` - Auth interface{} `json:"auth,omitempty"` + Auth UIAType `json:"auth,omitempty"` } type ReqPutPushRule struct { diff --git a/responses.go b/responses.go index 3484c134..4fbe1fbc 100644 --- a/responses.go +++ b/responses.go @@ -6,12 +6,14 @@ import ( "fmt" "maps" "reflect" + "slices" "strconv" "strings" "github.com/tidwall/gjson" "github.com/tidwall/sjson" "go.mau.fi/util/jsontime" + "go.mau.fi/util/ptr" "maunium.net/go/mautrix/event" "maunium.net/go/mautrix/id" @@ -256,15 +258,13 @@ func (r *UserDirectoryEntry) MarshalJSON() ([]byte, error) { type RespMutualRooms struct { Joined []id.RoomID `json:"joined"` NextBatch string `json:"next_batch,omitempty"` + Count int `json:"count,omitempty"` } type RespRoomSummary struct { PublicRoomInfo - Membership event.Membership `json:"membership,omitempty"` - RoomVersion id.RoomVersion `json:"room_version,omitempty"` - Encryption id.Algorithm `json:"encryption,omitempty"` - AllowedRoomIDs []id.RoomID `json:"allowed_room_ids,omitempty"` + Membership event.Membership `json:"membership,omitempty"` UnstableRoomVersion id.RoomVersion `json:"im.nheko.summary.room_version,omitempty"` UnstableRoomVersionOld id.RoomVersion `json:"im.nheko.summary.version,omitempty"` @@ -342,6 +342,24 @@ type LazyLoadSummary struct { InvitedMemberCount *int `json:"m.invited_member_count,omitempty"` } +func (lls *LazyLoadSummary) MemberCount() int { + if lls == nil { + return 0 + } + return ptr.Val(lls.JoinedMemberCount) + ptr.Val(lls.InvitedMemberCount) +} + +func (lls *LazyLoadSummary) Equal(other *LazyLoadSummary) bool { + if lls == other { + return true + } else if lls == nil || other == nil { + return false + } + return ptr.Val(lls.JoinedMemberCount) == ptr.Val(other.JoinedMemberCount) && + ptr.Val(lls.InvitedMemberCount) == ptr.Val(other.InvitedMemberCount) && + slices.Equal(lls.Heroes, other.Heroes) +} + type SyncEventsList struct { Events []*event.Event `json:"events,omitempty"` } @@ -672,6 +690,10 @@ type PublicRoomInfo struct { RoomType event.RoomType `json:"room_type"` Topic string `json:"topic,omitempty"` WorldReadable bool `json:"world_readable"` + + RoomVersion id.RoomVersion `json:"room_version,omitempty"` + Encryption id.Algorithm `json:"encryption,omitempty"` + AllowedRoomIDs []id.RoomID `json:"allowed_room_ids,omitempty"` } // RespHierarchy is the JSON response for https://spec.matrix.org/v1.4/client-server-api/#get_matrixclientv1roomsroomidhierarchy @@ -755,3 +777,23 @@ type RespSuspended struct { type RespLocked struct { Locked bool `json:"locked"` } + +type ConnectionInfo struct { + IP string `json:"ip,omitempty"` + LastSeen jsontime.UnixMilli `json:"last_seen,omitempty"` + UserAgent string `json:"user_agent,omitempty"` +} + +type SessionInfo struct { + Connections []ConnectionInfo `json:"connections,omitempty"` +} + +type DeviceInfo struct { + Sessions []SessionInfo `json:"sessions,omitempty"` +} + +// RespWhoIs is the response body for https://spec.matrix.org/v1.15/client-server-api/#get_matrixclientv3adminwhoisuserid +type RespWhoIs struct { + UserID id.UserID `json:"user_id,omitempty"` + Devices map[id.DeviceID]DeviceInfo `json:"devices,omitempty"` +} diff --git a/room.go b/room.go index c3ddb7e6..4292bff5 100644 --- a/room.go +++ b/room.go @@ -5,8 +5,6 @@ import ( "maunium.net/go/mautrix/id" ) -type RoomStateMap = map[event.Type]map[string]*event.Event - // Room represents a single Matrix room. type Room struct { ID id.RoomID @@ -25,8 +23,8 @@ func (room Room) UpdateState(evt *event.Event) { // GetStateEvent returns the state event for the given type/state_key combo, or nil. func (room Room) GetStateEvent(eventType event.Type, stateKey string) *event.Event { - stateEventMap, _ := room.State[eventType] - evt, _ := stateEventMap[stateKey] + stateEventMap := room.State[eventType] + evt := stateEventMap[stateKey] return evt } diff --git a/sqlstatestore/statestore.go b/sqlstatestore/statestore.go index c4126802..11957dfa 100644 --- a/sqlstatestore/statestore.go +++ b/sqlstatestore/statestore.go @@ -470,3 +470,26 @@ func (store *SQLStateStore) GetCreate(ctx context.Context, roomID id.RoomID) (ev } return } + +func (store *SQLStateStore) SetJoinRules(ctx context.Context, roomID id.RoomID, rules *event.JoinRulesEventContent) error { + if roomID == "" { + return fmt.Errorf("room ID is empty") + } + _, err := store.Exec(ctx, ` + INSERT INTO mx_room_state (room_id, join_rules) VALUES ($1, $2) + ON CONFLICT (room_id) DO UPDATE SET join_rules=excluded.join_rules + `, roomID, dbutil.JSON{Data: rules}) + return err +} + +func (store *SQLStateStore) GetJoinRules(ctx context.Context, roomID id.RoomID) (levels *event.JoinRulesEventContent, err error) { + levels = &event.JoinRulesEventContent{} + err = store. + QueryRow(ctx, "SELECT join_rules FROM mx_room_state WHERE room_id=$1 AND join_rules IS NOT NULL", roomID). + Scan(&dbutil.JSON{Data: &levels}) + if errors.Is(err, sql.ErrNoRows) { + levels = nil + err = nil + } + return +} diff --git a/sqlstatestore/v00-latest-revision.sql b/sqlstatestore/v00-latest-revision.sql index b5a858ec..4679f1c6 100644 --- a/sqlstatestore/v00-latest-revision.sql +++ b/sqlstatestore/v00-latest-revision.sql @@ -1,4 +1,4 @@ --- v0 -> v9 (compatible with v3+): Latest revision +-- v0 -> v10 (compatible with v3+): Latest revision CREATE TABLE mx_registrations ( user_id TEXT PRIMARY KEY @@ -27,5 +27,6 @@ CREATE TABLE mx_room_state ( power_levels jsonb, encryption jsonb, create_event jsonb, + join_rules jsonb, members_fetched BOOLEAN NOT NULL DEFAULT false ); diff --git a/sqlstatestore/v10-join-rules.sql b/sqlstatestore/v10-join-rules.sql new file mode 100644 index 00000000..3074c46a --- /dev/null +++ b/sqlstatestore/v10-join-rules.sql @@ -0,0 +1,2 @@ +-- v10 (compatible with v3+): Add join rules to room state table +ALTER TABLE mx_room_state ADD COLUMN join_rules jsonb; diff --git a/statestore.go b/statestore.go index 1933ab95..2bd498dd 100644 --- a/statestore.go +++ b/statestore.go @@ -37,6 +37,9 @@ type StateStore interface { SetCreate(ctx context.Context, evt *event.Event) error GetCreate(ctx context.Context, roomID id.RoomID) (*event.Event, error) + GetJoinRules(ctx context.Context, roomID id.RoomID) (*event.JoinRulesEventContent, error) + SetJoinRules(ctx context.Context, roomID id.RoomID, content *event.JoinRulesEventContent) error + HasFetchedMembers(ctx context.Context, roomID id.RoomID) (bool, error) MarkMembersFetched(ctx context.Context, roomID id.RoomID) error GetAllMembers(ctx context.Context, roomID id.RoomID) (map[id.UserID]*event.MemberEventContent, error) @@ -73,6 +76,8 @@ func UpdateStateStore(ctx context.Context, store StateStore, evt *event.Event) { err = store.SetEncryptionEvent(ctx, evt.RoomID, content) case *event.CreateEventContent: err = store.SetCreate(ctx, evt) + case *event.JoinRulesEventContent: + err = store.SetJoinRules(ctx, evt.RoomID, content) default: switch evt.Type { case event.StateMember, event.StatePowerLevels, event.StateEncryption, event.StateCreate: @@ -107,11 +112,13 @@ type MemoryStateStore struct { PowerLevels map[id.RoomID]*event.PowerLevelsEventContent `json:"power_levels"` Encryption map[id.RoomID]*event.EncryptionEventContent `json:"encryption"` Create map[id.RoomID]*event.Event `json:"create"` + JoinRules map[id.RoomID]*event.JoinRulesEventContent `json:"join_rules"` registrationsLock sync.RWMutex membersLock sync.RWMutex powerLevelsLock sync.RWMutex encryptionLock sync.RWMutex + joinRulesLock sync.RWMutex } func NewMemoryStateStore() StateStore { @@ -122,6 +129,7 @@ func NewMemoryStateStore() StateStore { PowerLevels: make(map[id.RoomID]*event.PowerLevelsEventContent), Encryption: make(map[id.RoomID]*event.EncryptionEventContent), Create: make(map[id.RoomID]*event.Event), + JoinRules: make(map[id.RoomID]*event.JoinRulesEventContent), } } @@ -354,6 +362,19 @@ func (store *MemoryStateStore) GetEncryptionEvent(_ context.Context, roomID id.R return store.Encryption[roomID], nil } +func (store *MemoryStateStore) SetJoinRules(ctx context.Context, roomID id.RoomID, content *event.JoinRulesEventContent) error { + store.joinRulesLock.Lock() + store.JoinRules[roomID] = content + store.joinRulesLock.Unlock() + return nil +} + +func (store *MemoryStateStore) GetJoinRules(ctx context.Context, roomID id.RoomID) (*event.JoinRulesEventContent, error) { + store.joinRulesLock.RLock() + defer store.joinRulesLock.RUnlock() + return store.JoinRules[roomID], nil +} + func (store *MemoryStateStore) IsEncrypted(ctx context.Context, roomID id.RoomID) (bool, error) { cfg, err := store.GetEncryptionEvent(ctx, roomID) return cfg != nil && cfg.Algorithm == id.AlgorithmMegolmV1, err diff --git a/synapseadmin/roomapi.go b/synapseadmin/roomapi.go index a09ba174..0925b748 100644 --- a/synapseadmin/roomapi.go +++ b/synapseadmin/roomapi.go @@ -75,8 +75,7 @@ type RespListRooms struct { // https://matrix-org.github.io/synapse/latest/admin_api/rooms.html#list-room-api func (cli *Client) ListRooms(ctx context.Context, req ReqListRoom) (RespListRooms, error) { var resp RespListRooms - var reqURL string - reqURL = cli.Client.BuildURLWithQuery(mautrix.SynapseAdminURLPath{"v1", "rooms"}, req.BuildQuery()) + reqURL := cli.Client.BuildURLWithQuery(mautrix.SynapseAdminURLPath{"v1", "rooms"}, req.BuildQuery()) _, err := cli.Client.MakeRequest(ctx, http.MethodGet, reqURL, nil, &resp) return resp, err } @@ -117,6 +116,7 @@ func (cli *Client) RoomMessages(ctx context.Context, roomID id.RoomID, from, to type ReqDeleteRoom struct { Purge bool `json:"purge,omitempty"` + ForcePurge bool `json:"force_purge,omitempty"` Block bool `json:"block,omitempty"` Message string `json:"message,omitempty"` RoomName string `json:"room_name,omitempty"` diff --git a/url.go b/url.go index d888956a..91b3d49d 100644 --- a/url.go +++ b/url.go @@ -98,10 +98,8 @@ func (saup SynapseAdminURLPath) FullPath() []any { // and appservice user ID set already. func (cli *Client) BuildURLWithQuery(urlPath PrefixableURLPath, urlQuery map[string]string) string { return cli.BuildURLWithFullQuery(urlPath, func(q url.Values) { - if urlQuery != nil { - for k, v := range urlQuery { - q.Set(k, v) - } + for k, v := range urlQuery { + q.Set(k, v) } }) } diff --git a/version.go b/version.go index 7b4eea41..f00bbf39 100644 --- a/version.go +++ b/version.go @@ -8,7 +8,7 @@ import ( "strings" ) -const Version = "v0.25.2" +const Version = "v0.26.3" var GoModVersion = "" var Commit = "" diff --git a/versions.go b/versions.go index 0392532e..61b2e4ea 100644 --- a/versions.go +++ b/versions.go @@ -60,23 +60,28 @@ type UnstableFeature struct { } var ( - FeatureAsyncUploads = UnstableFeature{UnstableFlag: "fi.mau.msc2246.stable", SpecVersion: SpecV17} - FeatureAppservicePing = UnstableFeature{UnstableFlag: "fi.mau.msc2659.stable", SpecVersion: SpecV17} - FeatureAuthenticatedMedia = UnstableFeature{UnstableFlag: "org.matrix.msc3916.stable", SpecVersion: SpecV111} - FeatureMutualRooms = UnstableFeature{UnstableFlag: "uk.half-shot.msc2666.query_mutual_rooms"} - FeatureUserRedaction = UnstableFeature{UnstableFlag: "org.matrix.msc4194"} - FeatureViewRedactedContent = UnstableFeature{UnstableFlag: "fi.mau.msc2815"} - FeatureAccountModeration = UnstableFeature{UnstableFlag: "uk.timedout.msc4323"} - FeatureUnstableProfileFields = UnstableFeature{UnstableFlag: "uk.tcpip.msc4133"} - FeatureArbitraryProfileFields = UnstableFeature{UnstableFlag: "uk.tcpip.msc4133.stable", SpecVersion: SpecV116} + FeatureAsyncUploads = UnstableFeature{UnstableFlag: "fi.mau.msc2246.stable", SpecVersion: SpecV17} + FeatureAppservicePing = UnstableFeature{UnstableFlag: "fi.mau.msc2659.stable", SpecVersion: SpecV17} + FeatureAuthenticatedMedia = UnstableFeature{UnstableFlag: "org.matrix.msc3916.stable", SpecVersion: SpecV111} + FeatureUnstableMutualRooms = UnstableFeature{UnstableFlag: "uk.half-shot.msc2666.query_mutual_rooms"} + FeatureStableMutualRooms = UnstableFeature{UnstableFlag: "uk.half-shot.msc2666.query_mutual_rooms.stable" /*, SpecVersion: SpecV118*/} + FeatureUserRedaction = UnstableFeature{UnstableFlag: "org.matrix.msc4194"} + FeatureViewRedactedContent = UnstableFeature{UnstableFlag: "fi.mau.msc2815"} + FeatureUnstableAccountModeration = UnstableFeature{UnstableFlag: "uk.timedout.msc4323"} + FeatureStableAccountModeration = UnstableFeature{UnstableFlag: "uk.timedout.msc4323.stable" /*, SpecVersion: SpecV118*/} + FeatureUnstableProfileFields = UnstableFeature{UnstableFlag: "uk.tcpip.msc4133"} + FeatureArbitraryProfileFields = UnstableFeature{UnstableFlag: "uk.tcpip.msc4133.stable", SpecVersion: SpecV116} + FeatureRedactSendAsEvent = UnstableFeature{UnstableFlag: "com.beeper.msc4169"} - BeeperFeatureHungry = UnstableFeature{UnstableFlag: "com.beeper.hungry"} - BeeperFeatureBatchSending = UnstableFeature{UnstableFlag: "com.beeper.batch_sending"} - BeeperFeatureRoomYeeting = UnstableFeature{UnstableFlag: "com.beeper.room_yeeting"} - BeeperFeatureAutojoinInvites = UnstableFeature{UnstableFlag: "com.beeper.room_create_autojoin_invites"} - BeeperFeatureArbitraryProfileMeta = UnstableFeature{UnstableFlag: "com.beeper.arbitrary_profile_meta"} - BeeperFeatureAccountDataMute = UnstableFeature{UnstableFlag: "com.beeper.account_data_mute"} - BeeperFeatureInboxState = UnstableFeature{UnstableFlag: "com.beeper.inbox_state"} + BeeperFeatureHungry = UnstableFeature{UnstableFlag: "com.beeper.hungry"} + BeeperFeatureBatchSending = UnstableFeature{UnstableFlag: "com.beeper.batch_sending"} + BeeperFeatureRoomYeeting = UnstableFeature{UnstableFlag: "com.beeper.room_yeeting"} + BeeperFeatureAutojoinInvites = UnstableFeature{UnstableFlag: "com.beeper.room_create_autojoin_invites"} + BeeperFeatureArbitraryProfileMeta = UnstableFeature{UnstableFlag: "com.beeper.arbitrary_profile_meta"} + BeeperFeatureAccountDataMute = UnstableFeature{UnstableFlag: "com.beeper.account_data_mute"} + BeeperFeatureInboxState = UnstableFeature{UnstableFlag: "com.beeper.inbox_state"} + BeeperFeatureArbitraryMemberChange = UnstableFeature{UnstableFlag: "com.beeper.arbitrary_member_change"} + BeeperFeatureEphemeralEvents = UnstableFeature{UnstableFlag: "com.beeper.ephemeral"} ) func (versions *RespVersions) Supports(feature UnstableFeature) bool { @@ -121,6 +126,7 @@ var ( SpecV114 = MustParseSpecVersion("v1.14") SpecV115 = MustParseSpecVersion("v1.15") SpecV116 = MustParseSpecVersion("v1.16") + SpecV117 = MustParseSpecVersion("v1.17") ) func (svf SpecVersionFormat) String() string {