diff --git a/.github/workflows/go.yml b/.github/workflows/go.yml index c0add220..3cf412b4 100644 --- a/.github/workflows/go.yml +++ b/.github/workflows/go.yml @@ -10,12 +10,12 @@ jobs: runs-on: ubuntu-latest name: Lint (latest) steps: - - uses: actions/checkout@v6 + - uses: actions/checkout@v4 - name: Set up Go - uses: actions/setup-go@v6 + uses: actions/setup-go@v5 with: - go-version: "1.26" + go-version: "1.24" cache: true - name: Install libolm @@ -24,7 +24,6 @@ jobs: - name: Install goimports run: | go install golang.org/x/tools/cmd/goimports@latest - go install honnef.co/go/tools/cmd/staticcheck@latest export PATH="$HOME/go/bin:$PATH" - name: Run pre-commit @@ -35,14 +34,14 @@ jobs: strategy: fail-fast: false matrix: - go-version: ["1.25", "1.26"] - name: Build (${{ matrix.go-version == '1.26' && 'latest' || 'old' }}, libolm) + go-version: ["1.24", "1.25"] + name: Build (${{ matrix.go-version == '1.25' && 'latest' || 'old' }}, libolm) steps: - - uses: actions/checkout@v6 + - uses: actions/checkout@v4 - name: Set up Go ${{ matrix.go-version }} - uses: actions/setup-go@v6 + uses: actions/setup-go@v5 with: go-version: ${{ matrix.go-version }} cache: true @@ -61,28 +60,28 @@ jobs: - name: Test run: go test -json -v ./... 2>&1 | gotestfmt - - name: Test (jsonv2) - env: - GOEXPERIMENT: jsonv2 - run: go test -json -v ./... 2>&1 | gotestfmt - build-goolm: runs-on: ubuntu-latest strategy: fail-fast: false matrix: - go-version: ["1.25", "1.26"] - name: Build (${{ matrix.go-version == '1.26' && 'latest' || 'old' }}, goolm) + go-version: ["1.24", "1.25"] + name: Build (${{ matrix.go-version == '1.25' && 'latest' || 'old' }}, goolm) steps: - - uses: actions/checkout@v6 + - uses: actions/checkout@v4 - name: Set up Go ${{ matrix.go-version }} - uses: actions/setup-go@v6 + uses: actions/setup-go@v5 with: go-version: ${{ matrix.go-version }} cache: true + - name: Set up gotestfmt + uses: GoTestTools/gotestfmt-action@v2 + with: + token: ${{ secrets.GITHUB_TOKEN }} + - name: Build run: | rm -rf crypto/libolm diff --git a/.github/workflows/stale.yml b/.github/workflows/stale.yml index 9a9e7375..578349c9 100644 --- a/.github/workflows/stale.yml +++ b/.github/workflows/stale.yml @@ -17,7 +17,7 @@ jobs: lock-stale: runs-on: ubuntu-latest steps: - - uses: dessant/lock-threads@v6 + - uses: dessant/lock-threads@v5 id: lock with: issue-inactive-days: 90 diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 616fccb2..81701203 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,6 +1,6 @@ repos: - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v6.0.0 + rev: v5.0.0 hooks: - id: trailing-whitespace exclude_types: [markdown] @@ -9,7 +9,7 @@ repos: - id: check-added-large-files - repo: https://github.com/tekwizely/pre-commit-golang - rev: v1.0.0-rc.4 + rev: v1.0.0-rc.1 hooks: - id: go-imports-repo args: @@ -18,7 +18,8 @@ repos: - "-w" - id: go-vet-repo-mod - id: go-mod-tidy - - id: go-staticcheck-repo-mod + # TODO enable this + #- id: go-staticcheck-repo-mod - repo: https://github.com/beeper/pre-commit-go rev: v0.4.2 @@ -26,4 +27,3 @@ repos: - id: prevent-literal-http-methods - id: zerolog-ban-global-log - id: zerolog-ban-msgf - - id: zerolog-use-stringer diff --git a/CHANGELOG.md b/CHANGELOG.md index f2829199..22ff47f7 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,202 +1,4 @@ -## v0.26.3 (2026-02-16) - -* Bumped minimum Go version to 1.25. -* *(client)* Added fields for sending [MSC4354] sticky events. -* *(bridgev2)* Added automatic message request accepting when sending message. -* *(mediaproxy)* Added support for federation thumbnail endpoint. -* *(crypto/ssss)* Improved support for recovery keys with slightly broken - metadata. -* *(crypto)* Changed key import to call session received callback even for - sessions that already exist in the database. -* *(appservice)* Fixed building websocket URL accidentally using file path - separators instead of always `/`. -* *(crypto)* Fixed key exports not including the `sender_claimed_keys` field. -* *(client)* Fixed incorrect context usage in async uploads. -* *(crypto)* Fixed panic when passing invalid input to megolm message index - parser used for debugging. -* *(bridgev2/provisioning)* Fixed completed or failed logins not being cleaned - up properly. - -[MSC4354]: https://github.com/matrix-org/matrix-spec-proposals/pull/4354 - -## v0.26.2 (2026-01-16) - -* *(bridgev2)* Added chunked portal deletion to avoid database locks when - deleting large portals. -* *(crypto,bridgev2)* Added option to encrypt reaction and reply metadata - as per [MSC4392]. -* *(bridgev2/login)* Added `default_value` for user input fields. -* *(bridgev2)* Added interfaces to let the Matrix connector provide suggested - HTTP client settings and to reset active connections of the network connector. -* *(bridgev2)* Added interface to let network connectors get the provisioning - API HTTP router and add new endpoints. -* *(event)* Added blurhash field to Beeper link preview objects. -* *(event)* Added [MSC4391] support for bot commands. -* *(event)* Dropped [MSC4332] support for bot commands. -* *(client)* Changed media download methods to return an error if the provided - MXC URI is empty. -* *(client)* Stabilized support for [MSC4323]. -* *(bridgev2/matrix)* Fixed `GetEvent` panicking when trying to decrypt events. -* *(bridgev2)* Fixed some deadlocks when room creation happens in parallel with - a portal re-ID call. - -[MSC4391]: https://github.com/matrix-org/matrix-spec-proposals/pull/4391 -[MSC4392]: https://github.com/matrix-org/matrix-spec-proposals/pull/4392 - -## v0.26.1 (2025-12-16) - -* **Breaking change *(mediaproxy)*** Changed `GetMediaResponseFile` to return - the mime type from the callback rather than in the return get media return - value. The callback can now also redirect the caller to a different file. -* *(federation)* Added join/knock/leave functions - (thanks to [@nexy7574] in [#422]). -* *(federation/eventauth)* Fixed various incorrect checks. -* *(client)* Added backoff for retrying media uploads to external URLs - (with MSC3870). -* *(bridgev2/config)* Added support for overriding config fields using - environment variables. -* *(bridgev2/commands)* Added command to mute chat on remote network. -* *(bridgev2)* Added interface for network connectors to redirect to a different - user ID when handling an invite from Matrix. -* *(bridgev2)* Added interface for signaling message request status of portals. -* *(bridgev2)* Changed portal creation to not backfill unless `CanBackfill` flag - is set in chat info. -* *(bridgev2)* Changed Matrix reaction handling to only delete old reaction if - bridging the new one is successful. -* *(bridgev2/mxmain)* Improved error message when trying to run bridge with - pre-megabridge database when no database migration exists. -* *(bridgev2)* Improved reliability of database migration when enabling split - portals. -* *(bridgev2)* Improved detection of orphaned DM rooms when starting new chats. -* *(bridgev2)* Stopped sending redundant invites when joining ghosts to public - portal rooms. -* *(bridgev2)* Stopped hardcoding room versions in favor of checking - server capabilities to determine appropriate `/createRoom` parameters. - -[#422]: https://github.com/mautrix/go/pull/422 - -## v0.26.0 (2025-11-16) - -* *(client,appservice)* Deprecated `SendMassagedStateEvent` as `SendStateEvent` - has been able to do the same for a while now. -* *(client,federation)* Added size limits for responses to make it safer to send - requests to untrusted servers. -* *(client)* Added wrapper for `/admin/whois` client API - (thanks to [@nexy7574] in [#411]). -* *(synapseadmin)* Added `force_purge` option to DeleteRoom - (thanks to [@nexy7574] in [#420]). -* *(statestore)* Added saving join rules for rooms. -* *(bridgev2)* Added optional automatic rollback of room state if bridging the - change to the remote network fails. -* *(bridgev2)* Added management room notices if transient disconnect state - doesn't resolve within 3 minutes. -* *(bridgev2)* Added interface to signal that certain participants couldn't be - invited when creating a group. -* *(bridgev2)* Added `select` type for user input fields in login. -* *(bridgev2)* Added interface to let network connector customize personal - filtering space. -* *(bridgev2/matrix)* Added checks to avoid sending error messages in reply to - other bots. -* *(bridgev2/matrix)* Switched to using [MSC4169] to send redactions whenever - possible. -* *(bridgev2/publicmedia)* Added support for custom path prefixes, file names, - and encrypted files. -* *(bridgev2/commands)* Added command to resync a single portal. -* *(bridgev2/commands)* Added create group command. -* *(bridgev2/config)* Added option to limit maximum number of logins. -* *(bridgev2)* Changed ghost joining to skip unnecessary invite if portal room - is public. -* *(bridgev2/disappear)* Changed read receipt handling to only start - disappearing timers for messages up to the read message (note: may not work in - all cases if the read receipt points at an unknown event). -* *(event/reply)* Changed plaintext reply fallback removal to only happen when - an HTML reply fallback is removed successfully. -* *(bridgev2/matrix)* Fixed unnecessary sleep after registering bot on first run. -* *(crypto/goolm)* Fixed panic when processing certain malformed Olm messages. -* *(federation)* Fixed HTTP method for sending transactions - (thanks to [@nexy7574] in [#426]). -* *(federation)* Fixed response body being closed even when using `DontReadBody` - parameter. -* *(federation)* Fixed validating auth for requests with query params. -* *(federation/eventauth)* Fixed typo causing restricted joins to not work. - -[MSC4169]: https://github.com/matrix-org/matrix-spec-proposals/pull/4169 -[#411]: github.com/mautrix/go/pull/411 -[#420]: github.com/mautrix/go/pull/420 -[#426]: github.com/mautrix/go/pull/426 - -## v0.25.2 (2025-10-16) - -* **Breaking change *(id)*** Split `UserID.ParseAndValidate` into - `ParseAndValidateRelaxed` and `ParseAndValidateStrict`. Strict is the old - behavior, but most users likely want the relaxed version, as there are real - users whose user IDs aren't valid under the strict rules. -* *(crypto)* Added helper methods for generating and verifying with recovery - keys. -* *(bridgev2/matrix)* Added config option to automatically generate a recovery - key for the bridge bot and self-sign the bridge's device. -* *(bridgev2/matrix)* Added initial support for using appservice/MSC3202 mode - for encryption with standard servers like Synapse. -* *(bridgev2)* Added optional support for implicit read receipts. -* *(bridgev2)* Added interface for deleting chats on remote network. -* *(bridgev2)* Added local enforcement of media duration and size limits. -* *(bridgev2)* Extended event duration logging to log any event taking too long. -* *(bridgev2)* Improved validation in group creation provisioning API. -* *(event)* Added event type constant for poll end events. -* *(client)* Added wrapper for searching user directory. -* *(client)* Improved support for managing [MSC4140] delayed events. -* *(crypto/helper)* Changed default sync handling to not block on waiting for - decryption keys. On initial sync, keys won't be requested at all by default. -* *(crypto)* Fixed olm unwedging not working (regressed in v0.25.1). -* *(bridgev2)* Fixed various bugs with migrating to split portals. -* *(event)* Fixed poll start events having incorrect null `m.relates_to`. -* *(client)* Fixed `RespUserProfile` losing standard fields when re-marshaling. -* *(federation)* Fixed various bugs in event auth. - -## v0.25.1 (2025-09-16) - -* *(client)* Fixed HTTP method of delete devices API call - (thanks to [@fmseals] in [#393]). -* *(client)* Added wrappers for [MSC4323]: User suspension & locking endpoints - (thanks to [@nexy7574] in [#407]). -* *(client)* Stabilized support for extensible profiles. -* *(client)* Stabilized support for `state_after` in sync. -* *(client)* Removed deprecated MSC2716 requests. -* *(crypto)* Added fallback to ensure `m.relates_to` is always copied even if - the content struct doesn't implement `Relatable`. -* *(crypto)* Changed olm unwedging to ignore newly created sessions if they - haven't been used successfully in either direction. -* *(federation)* Added utilities for generating, parsing, validating and - authorizing PDUs. - * Note: the new PDU code depends on `GOEXPERIMENT=jsonv2` -* *(event)* Added `is_animated` flag from [MSC4230] to file info. -* *(event)* Added types for [MSC4332]: In-room bot commands. -* *(event)* Added missing poll end event type for [MSC3381]. -* *(appservice)* Fixed URLs not being escaped properly when using unix socket - for homeserver connections. -* *(format)* Added more helpers for forming markdown links. -* *(event,bridgev2)* Added support for Beeper's disappearing message state event. -* *(bridgev2)* Redesigned group creation interface and added support in commands - and provisioning API. -* *(bridgev2)* Added GetEvent to Matrix interface to allow network connectors to - get an old event. The method is best effort only, as some configurations don't - allow fetching old events. -* *(bridgev2)* Added shared logic for provisioning that can be reused by the - API, commands and other sources. -* *(bridgev2)* Fixed mentions and URL previews not being copied over when - caption and media are merged. -* *(bridgev2)* Removed config option to change provisioning API prefix, which - had already broken in the previous release. - -[@fmseals]: https://github.com/fmseals -[#393]: https://github.com/mautrix/go/pull/393 -[#407]: https://github.com/mautrix/go/pull/407 -[MSC3381]: https://github.com/matrix-org/matrix-spec-proposals/pull/3381 -[MSC4230]: https://github.com/matrix-org/matrix-spec-proposals/pull/4230 -[MSC4323]: https://github.com/matrix-org/matrix-spec-proposals/pull/4323 -[MSC4332]: https://github.com/matrix-org/matrix-spec-proposals/pull/4332 - -## v0.25.0 (2025-08-16) +## v0.25.0 (unreleased) * Bumped minimum Go version to 1.24. * **Breaking change *(appservice,bridgev2,federation)*** Replaced gorilla/mux @@ -437,7 +239,6 @@ [MSC4156]: https://github.com/matrix-org/matrix-spec-proposals/pull/4156 [MSC4190]: https://github.com/matrix-org/matrix-spec-proposals/pull/4190 [#288]: https://github.com/mautrix/go/pull/288 -[@onestacked]: https://github.com/onestacked ## v0.22.0 (2024-11-16) diff --git a/README.md b/README.md index b1a2edf8..ac41ca78 100644 --- a/README.md +++ b/README.md @@ -1,9 +1,8 @@ # mautrix-go [![GoDoc](https://pkg.go.dev/badge/maunium.net/go/mautrix)](https://pkg.go.dev/maunium.net/go/mautrix) -A Golang Matrix framework. Used by [gomuks](https://gomuks.app), -[go-neb](https://github.com/matrix-org/go-neb), -[mautrix-whatsapp](https://github.com/mautrix/whatsapp) +A Golang Matrix framework. Used by [gomuks](https://matrix.org/docs/projects/client/gomuks), +[go-neb](https://github.com/matrix-org/go-neb), [mautrix-whatsapp](https://github.com/mautrix/whatsapp) and others. Matrix room: [`#go:maunium.net`](https://matrix.to/#/#go:maunium.net) @@ -14,10 +13,9 @@ The original project is licensed under [Apache 2.0](https://github.com/matrix-or In addition to the basic client API features the original project has, this framework also has: * Appservice support (Intent API like mautrix-python, room state storage, etc) -* End-to-end encryption support (incl. key backup, cross-signing, interactive verification, etc) +* End-to-end encryption support (incl. interactive SAS verification) * High-level module for building puppeting bridges -* Partial federation module (making requests, PDU processing and event authorization) -* A media proxy server which can be used to expose anything as a Matrix media repo +* High-level module for building chat clients * Wrapper functions for the Synapse admin API * Structs for parsing event content * Helpers for parsing and generating Matrix HTML diff --git a/appservice/appservice.go b/appservice/appservice.go index d7037ef6..b0af02cd 100644 --- a/appservice/appservice.go +++ b/appservice/appservice.go @@ -334,7 +334,7 @@ func (as *AppService) SetHomeserverURL(homeserverURL string) error { } else if as.hsURLForClient.Scheme == "" { as.hsURLForClient.Scheme = "https" } - as.hsURLForClient.RawPath = as.hsURLForClient.EscapedPath() + as.hsURLForClient.RawPath = parsedURL.EscapedPath() jar, _ := cookiejar.New(&cookiejar.Options{PublicSuffixList: publicsuffix.List}) as.HTTPClient = &http.Client{Timeout: 180 * time.Second, Jar: jar} @@ -360,7 +360,7 @@ func (as *AppService) NewMautrixClient(userID id.UserID) *mautrix.Client { AccessToken: as.Registration.AppToken, UserAgent: as.UserAgent, StateStore: as.StateStore, - Log: as.Log.With().Stringer("as_user_id", userID).Logger(), + Log: as.Log.With().Str("as_user_id", userID.String()).Logger(), Client: as.HTTPClient, DefaultHTTPRetries: as.DefaultHTTPRetries, SpecVersions: as.SpecVersions, diff --git a/appservice/http.go b/appservice/http.go index 27ce6288..862de7fd 100644 --- a/appservice/http.go +++ b/appservice/http.go @@ -201,7 +201,7 @@ func (as *AppService) handleEvents(ctx context.Context, evts []*event.Event, def } err := evt.Content.ParseRaw(evt.Type) if errors.Is(err, event.ErrUnsupportedContentType) { - log.Debug().Stringer("event_id", evt.ID).Msg("Not parsing content of unsupported event") + log.Debug().Str("event_id", evt.ID.String()).Msg("Not parsing content of unsupported event") } else if err != nil { log.Warn().Err(err). Str("event_id", evt.ID.String()). diff --git a/appservice/intent.go b/appservice/intent.go index 5d43f190..fa9d9e7a 100644 --- a/appservice/intent.go +++ b/appservice/intent.go @@ -51,7 +51,7 @@ func (as *AppService) NewIntentAPI(localpart string) *IntentAPI { } func (intent *IntentAPI) Register(ctx context.Context) error { - _, err := intent.Client.MakeRequest(ctx, http.MethodPost, intent.BuildClientURL("v3", "register"), &mautrix.ReqRegister[any]{ + _, err := intent.Client.MakeRequest(ctx, http.MethodPost, intent.BuildClientURL("v3", "register"), &mautrix.ReqRegister{ Username: intent.Localpart, Type: mautrix.AuthTypeAppservice, InhibitLogin: true, @@ -214,31 +214,23 @@ func (intent *IntentAPI) AddDoublePuppetValueWithTS(into any, ts int64) any { } } -func (intent *IntentAPI) SendMessageEvent(ctx context.Context, roomID id.RoomID, eventType event.Type, contentJSON any, extra ...mautrix.ReqSendEvent) (*mautrix.RespSendEvent, error) { +func (intent *IntentAPI) SendMessageEvent(ctx context.Context, roomID id.RoomID, eventType event.Type, contentJSON interface{}) (*mautrix.RespSendEvent, error) { if err := intent.EnsureJoined(ctx, roomID); err != nil { return nil, err } contentJSON = intent.AddDoublePuppetValue(contentJSON) - return intent.Client.SendMessageEvent(ctx, roomID, eventType, contentJSON, extra...) + return intent.Client.SendMessageEvent(ctx, roomID, eventType, contentJSON) } -func (intent *IntentAPI) BeeperSendEphemeralEvent(ctx context.Context, roomID id.RoomID, eventType event.Type, contentJSON any, extra ...mautrix.ReqSendEvent) (*mautrix.RespSendEvent, error) { - if err := intent.EnsureJoined(ctx, roomID); err != nil { - return nil, err - } - if !intent.SpecVersions.Supports(mautrix.BeeperFeatureEphemeralEvents) { - return nil, mautrix.MUnrecognized.WithMessage("Homeserver does not advertise com.beeper.ephemeral support") - } - contentJSON = intent.AddDoublePuppetValue(contentJSON) - return intent.Client.BeeperSendEphemeralEvent(ctx, roomID, eventType, contentJSON, extra...) -} - -// Deprecated: use SendMessageEvent with mautrix.ReqSendEvent.Timestamp instead func (intent *IntentAPI) SendMassagedMessageEvent(ctx context.Context, roomID id.RoomID, eventType event.Type, contentJSON interface{}, ts int64) (*mautrix.RespSendEvent, error) { - return intent.SendMessageEvent(ctx, roomID, eventType, contentJSON, mautrix.ReqSendEvent{Timestamp: ts}) + if err := intent.EnsureJoined(ctx, roomID); err != nil { + return nil, err + } + contentJSON = intent.AddDoublePuppetValueWithTS(contentJSON, ts) + return intent.Client.SendMessageEvent(ctx, roomID, eventType, contentJSON, mautrix.ReqSendEvent{Timestamp: ts}) } -func (intent *IntentAPI) SendStateEvent(ctx context.Context, roomID id.RoomID, eventType event.Type, stateKey string, contentJSON any, extra ...mautrix.ReqSendEvent) (*mautrix.RespSendEvent, error) { +func (intent *IntentAPI) SendStateEvent(ctx context.Context, roomID id.RoomID, eventType event.Type, stateKey string, contentJSON interface{}) (*mautrix.RespSendEvent, error) { if eventType != event.StateMember || stateKey != string(intent.UserID) { if err := intent.EnsureJoined(ctx, roomID); err != nil { return nil, err @@ -247,12 +239,15 @@ func (intent *IntentAPI) SendStateEvent(ctx context.Context, roomID id.RoomID, e return nil, err } contentJSON = intent.AddDoublePuppetValue(contentJSON) - return intent.Client.SendStateEvent(ctx, roomID, eventType, stateKey, contentJSON, extra...) + return intent.Client.SendStateEvent(ctx, roomID, eventType, stateKey, contentJSON) } -// Deprecated: use SendStateEvent with mautrix.ReqSendEvent.Timestamp instead func (intent *IntentAPI) SendMassagedStateEvent(ctx context.Context, roomID id.RoomID, eventType event.Type, stateKey string, contentJSON interface{}, ts int64) (*mautrix.RespSendEvent, error) { - return intent.SendStateEvent(ctx, roomID, eventType, stateKey, contentJSON, mautrix.ReqSendEvent{Timestamp: ts}) + if err := intent.EnsureJoined(ctx, roomID); err != nil { + return nil, err + } + contentJSON = intent.AddDoublePuppetValueWithTS(contentJSON, ts) + return intent.Client.SendMassagedStateEvent(ctx, roomID, eventType, stateKey, contentJSON, ts) } func (intent *IntentAPI) StateEvent(ctx context.Context, roomID id.RoomID, eventType event.Type, stateKey string, outContent interface{}) error { @@ -311,7 +306,7 @@ func (intent *IntentAPI) SendCustomMembershipEvent(ctx context.Context, roomID i func (intent *IntentAPI) JoinRoomByID(ctx context.Context, roomID id.RoomID, extraContent ...map[string]interface{}) (resp *mautrix.RespJoinRoom, err error) { if intent.IsCustomPuppet || len(extraContent) > 0 { _, err = intent.SendCustomMembershipEvent(ctx, roomID, intent.UserID, event.MembershipJoin, "", extraContent...) - return &mautrix.RespJoinRoom{RoomID: roomID}, err + return &mautrix.RespJoinRoom{}, err } return intent.Client.JoinRoomByID(ctx, roomID) } diff --git a/appservice/websocket.go b/appservice/websocket.go index ef65e65a..18768098 100644 --- a/appservice/websocket.go +++ b/appservice/websocket.go @@ -11,10 +11,9 @@ import ( "encoding/json" "errors" "fmt" - "io" "net/http" "net/url" - "path" + "path/filepath" "strings" "sync" "sync/atomic" @@ -56,7 +55,7 @@ func (wsc *WebsocketCommand) MakeResponse(ok bool, data any) *WebsocketRequest { var prefixMessage string for unwrappedErr != nil { errorData, jsonErr = json.Marshal(unwrappedErr) - if len(errorData) > 2 && jsonErr == nil { + if errorData != nil && len(errorData) > 2 && jsonErr == nil { prefixMessage = strings.Replace(err.Error(), unwrappedErr.Error(), "", 1) prefixMessage = strings.TrimRight(prefixMessage, ": ") break @@ -293,16 +292,10 @@ func (as *AppService) consumeWebsocket(ctx context.Context, stopFunc func(error) as.Log.Debug().Msg("Ignoring non-text message from websocket") continue } - data, err := io.ReadAll(reader) - if err != nil { - as.Log.Debug().Err(err).Msg("Error reading data from websocket") - stopFunc(parseCloseError(err)) - return - } var msg WebsocketMessage - err = json.Unmarshal(data, &msg) + err = json.NewDecoder(reader).Decode(&msg) if err != nil { - as.Log.Debug().Err(err).Msg("Error parsing JSON received from websocket") + as.Log.Debug().Err(err).Msg("Error reading JSON from websocket") stopFunc(parseCloseError(err)) return } @@ -374,7 +367,7 @@ func (as *AppService) StartWebsocket(ctx context.Context, baseURL string, onConn copiedURL := *as.hsURLForClient parsed = &copiedURL } - parsed.Path = path.Join(parsed.Path, "_matrix/client/unstable/fi.mau.as_sync") + parsed.Path = filepath.Join(parsed.Path, "_matrix/client/unstable/fi.mau.as_sync") if parsed.Scheme == "http" { parsed.Scheme = "ws" } else if parsed.Scheme == "https" { @@ -419,7 +412,6 @@ func (as *AppService) StartWebsocket(ctx context.Context, baseURL string, onConn } }) } - ws.SetReadLimit(50 * 1024 * 1024) as.ws = ws as.StopWebsocket = stopFunc as.PrepareWebsocket() diff --git a/bridgev2/bridge.go b/bridgev2/bridge.go index 226adc90..24619c79 100644 --- a/bridgev2/bridge.go +++ b/bridgev2/bridge.go @@ -9,14 +9,11 @@ package bridgev2 import ( "context" "fmt" - "os" "sync" - "sync/atomic" "time" "github.com/rs/zerolog" "go.mau.fi/util/dbutil" - "go.mau.fi/util/exhttp" "go.mau.fi/util/exsync" "maunium.net/go/mautrix/bridgev2/bridgeconfig" @@ -54,7 +51,6 @@ type Bridge struct { Background bool ExternallyManagedDB bool - stopping atomic.Bool wakeupBackfillQueue chan struct{} stopBackfillQueue *exsync.Event @@ -130,7 +126,6 @@ func (br *Bridge) Start(ctx context.Context) error { func (br *Bridge) RunOnce(ctx context.Context, loginID networkid.UserLoginID, params *ConnectBackgroundParams) error { br.Background = true - br.stopping.Store(false) err := br.StartConnectors(ctx) if err != nil { return err @@ -166,7 +161,6 @@ func (br *Bridge) RunOnce(ctx context.Context, loginID networkid.UserLoginID, pa case <-time.After(20 * time.Second): case <-ctx.Done(): } - br.stopping.Store(true) return nil } else { br.Log.Info().Str("user_login_id", string(login.ID)).Msg("Starting individual user login in background mode") @@ -176,7 +170,6 @@ func (br *Bridge) RunOnce(ctx context.Context, loginID networkid.UserLoginID, pa func (br *Bridge) StartConnectors(ctx context.Context) error { br.Log.Info().Msg("Starting bridge") - br.stopping.Store(false) if br.BackgroundCtx == nil || br.BackgroundCtx.Err() != nil { br.BackgroundCtx, br.cancelBackgroundCtx = context.WithCancel(context.Background()) br.BackgroundCtx = br.Log.WithContext(br.BackgroundCtx) @@ -189,11 +182,7 @@ func (br *Bridge) StartConnectors(ctx context.Context) error { } } if !br.Background { - var postMigrate func() - br.didSplitPortals, postMigrate = br.MigrateToSplitPortals(ctx) - if postMigrate != nil { - defer postMigrate() - } + br.didSplitPortals = br.MigrateToSplitPortals(ctx) } br.Log.Info().Msg("Starting Matrix connector") err := br.Matrix.Start(ctx) @@ -282,64 +271,20 @@ func (br *Bridge) ResendBridgeInfo(ctx context.Context, resendInfo, resendCaps b Msg("Resent bridge info to all portals") } -func (br *Bridge) MigrateToSplitPortals(ctx context.Context) (bool, func()) { +func (br *Bridge) MigrateToSplitPortals(ctx context.Context) bool { log := zerolog.Ctx(ctx).With().Str("action", "migrate to split portals").Logger() ctx = log.WithContext(ctx) if !br.Config.SplitPortals || br.DB.KV.Get(ctx, database.KeySplitPortalsEnabled) == "true" { - return false, nil + return false } affected, err := br.DB.Portal.MigrateToSplitPortals(ctx) if err != nil { - log.WithLevel(zerolog.FatalLevel).Err(err).Msg("Failed to migrate portals") - os.Exit(31) - return false, nil + log.Err(err).Msg("Failed to migrate portals") + return false } log.Info().Int64("rows_affected", affected).Msg("Migrated to split portals") - affected2, err := br.DB.Portal.FixParentsAfterSplitPortalMigration(ctx) - if err != nil { - log.Err(err).Msg("Failed to fix parent portals after split portal migration") - os.Exit(31) - return false, nil - } - log.Info().Int64("rows_affected", affected2).Msg("Updated parent receivers after split portal migration") - withoutReceiver, err := br.DB.Portal.GetAllWithoutReceiver(ctx) - if err != nil { - log.WithLevel(zerolog.FatalLevel).Err(err).Msg("Failed to get portals that failed to migrate") - os.Exit(31) - return false, nil - } - var roomsToDelete []id.RoomID - log.Info().Int("remaining_portals", len(withoutReceiver)).Msg("Deleting remaining portals without receiver") - for _, portal := range withoutReceiver { - if err = br.DB.Portal.Delete(ctx, portal.PortalKey); err != nil { - log.Err(err). - Str("portal_id", string(portal.ID)). - Stringer("mxid", portal.MXID). - Msg("Failed to delete portal database row that failed to migrate") - } else if portal.MXID != "" { - log.Debug(). - Str("portal_id", string(portal.ID)). - Stringer("mxid", portal.MXID). - Msg("Marked portal room for deletion from homeserver") - roomsToDelete = append(roomsToDelete, portal.MXID) - } else { - log.Debug(). - Str("portal_id", string(portal.ID)). - Msg("Deleted portal row with no Matrix room") - } - } br.DB.KV.Set(ctx, database.KeySplitPortalsEnabled, "true") - log.Info().Msg("Finished split portal migration successfully") - return affected > 0, func() { - for _, roomID := range roomsToDelete { - if err = br.Bot.DeleteRoom(ctx, roomID, true); err != nil { - log.Err(err). - Stringer("mxid", roomID). - Msg("Failed to delete portal room that failed to migrate") - } - } - log.Info().Int("room_count", len(roomsToDelete)).Msg("Finished deleting rooms that failed to migrate") - } + return affected > 0 } func (br *Bridge) StartLogins(ctx context.Context) error { @@ -374,46 +319,6 @@ func (br *Bridge) StartLogins(ctx context.Context) error { return nil } -func (br *Bridge) ResetNetworkConnections() { - nrn, ok := br.Network.(NetworkResettingNetwork) - if ok { - br.Log.Info().Msg("Resetting network connections with NetworkConnector.ResetNetworkConnections") - nrn.ResetNetworkConnections() - return - } - - br.Log.Info().Msg("Network connector doesn't support ResetNetworkConnections, recreating clients manually") - for _, login := range br.GetAllCachedUserLogins() { - login.Log.Debug().Msg("Disconnecting and recreating client for network reset") - ctx := login.Log.WithContext(br.BackgroundCtx) - login.Client.Disconnect() - err := login.recreateClient(ctx) - if err != nil { - login.Log.Err(err).Msg("Failed to recreate client during network reset") - login.BridgeState.Send(status.BridgeState{ - StateEvent: status.StateUnknownError, - Error: "bridgev2-network-reset-fail", - Info: map[string]any{"go_error": err.Error()}, - }) - } else { - login.Client.Connect(ctx) - } - } - br.Log.Info().Msg("Finished resetting all user logins") -} - -func (br *Bridge) GetHTTPClientSettings() exhttp.ClientSettings { - mchs, ok := br.Matrix.(MatrixConnectorWithHTTPSettings) - if ok { - return mchs.GetHTTPClientSettings() - } - return exhttp.SensibleClientSettings -} - -func (br *Bridge) IsStopping() bool { - return br.stopping.Load() -} - func (br *Bridge) Stop() { br.stop(false, 0) } @@ -424,7 +329,6 @@ func (br *Bridge) StopWithTimeout(timeout time.Duration) { func (br *Bridge) stop(isRunOnce bool, timeout time.Duration) { br.Log.Info().Msg("Shutting down bridge") - br.stopping.Store(true) br.DisappearLoop.Stop() br.stopBackfillQueue.Set() br.Matrix.PreStop() diff --git a/bridgev2/bridgeconfig/backfill.go b/bridgev2/bridgeconfig/backfill.go index eedae1e8..53282e41 100644 --- a/bridgev2/bridgeconfig/backfill.go +++ b/bridgev2/bridgeconfig/backfill.go @@ -34,12 +34,10 @@ type BackfillQueueConfig struct { MaxBatchesOverride map[string]int `yaml:"max_batches_override"` } -func (bqc *BackfillQueueConfig) GetOverride(names ...string) int { - for _, name := range names { - override, ok := bqc.MaxBatchesOverride[name] - if ok { - return override - } +func (bqc *BackfillQueueConfig) GetOverride(name string) int { + override, ok := bqc.MaxBatchesOverride[name] + if !ok { + return bqc.MaxBatches } - return bqc.MaxBatches + return override } diff --git a/bridgev2/bridgeconfig/config.go b/bridgev2/bridgeconfig/config.go index bd6b9c06..9bdee5fe 100644 --- a/bridgev2/bridgeconfig/config.go +++ b/bridgev2/bridgeconfig/config.go @@ -33,8 +33,6 @@ type Config struct { Encryption EncryptionConfig `yaml:"encryption"` Logging zeroconfig.Config `yaml:"logging"` - EnvConfigPrefix string `yaml:"env_config_prefix"` - ManagementRoomTexts ManagementRoomTexts `yaml:"management_room_texts"` } @@ -62,40 +60,36 @@ type CleanupOnLogouts struct { } type BridgeConfig struct { - CommandPrefix string `yaml:"command_prefix"` - PersonalFilteringSpaces bool `yaml:"personal_filtering_spaces"` - PrivateChatPortalMeta bool `yaml:"private_chat_portal_meta"` - AsyncEvents bool `yaml:"async_events"` - SplitPortals bool `yaml:"split_portals"` - ResendBridgeInfo bool `yaml:"resend_bridge_info"` - NoBridgeInfoStateKey bool `yaml:"no_bridge_info_state_key"` - BridgeStatusNotices string `yaml:"bridge_status_notices"` - UnknownErrorAutoReconnect time.Duration `yaml:"unknown_error_auto_reconnect"` - UnknownErrorMaxAutoReconnects int `yaml:"unknown_error_max_auto_reconnects"` - BridgeMatrixLeave bool `yaml:"bridge_matrix_leave"` - BridgeNotices bool `yaml:"bridge_notices"` - TagOnlyOnCreate bool `yaml:"tag_only_on_create"` - OnlyBridgeTags []event.RoomTag `yaml:"only_bridge_tags"` - MuteOnlyOnCreate bool `yaml:"mute_only_on_create"` - DeduplicateMatrixMessages bool `yaml:"deduplicate_matrix_messages"` - CrossRoomReplies bool `yaml:"cross_room_replies"` - OutgoingMessageReID bool `yaml:"outgoing_message_re_id"` - RevertFailedStateChanges bool `yaml:"revert_failed_state_changes"` - KickMatrixUsers bool `yaml:"kick_matrix_users"` - CleanupOnLogout CleanupOnLogouts `yaml:"cleanup_on_logout"` - Relay RelayConfig `yaml:"relay"` - Permissions PermissionConfig `yaml:"permissions"` - Backfill BackfillConfig `yaml:"backfill"` + CommandPrefix string `yaml:"command_prefix"` + PersonalFilteringSpaces bool `yaml:"personal_filtering_spaces"` + PrivateChatPortalMeta bool `yaml:"private_chat_portal_meta"` + AsyncEvents bool `yaml:"async_events"` + SplitPortals bool `yaml:"split_portals"` + ResendBridgeInfo bool `yaml:"resend_bridge_info"` + NoBridgeInfoStateKey bool `yaml:"no_bridge_info_state_key"` + BridgeStatusNotices string `yaml:"bridge_status_notices"` + UnknownErrorAutoReconnect time.Duration `yaml:"unknown_error_auto_reconnect"` + BridgeMatrixLeave bool `yaml:"bridge_matrix_leave"` + BridgeNotices bool `yaml:"bridge_notices"` + TagOnlyOnCreate bool `yaml:"tag_only_on_create"` + OnlyBridgeTags []event.RoomTag `yaml:"only_bridge_tags"` + MuteOnlyOnCreate bool `yaml:"mute_only_on_create"` + DeduplicateMatrixMessages bool `yaml:"deduplicate_matrix_messages"` + CrossRoomReplies bool `yaml:"cross_room_replies"` + OutgoingMessageReID bool `yaml:"outgoing_message_re_id"` + CleanupOnLogout CleanupOnLogouts `yaml:"cleanup_on_logout"` + Relay RelayConfig `yaml:"relay"` + Permissions PermissionConfig `yaml:"permissions"` + Backfill BackfillConfig `yaml:"backfill"` } type MatrixConfig struct { - MessageStatusEvents bool `yaml:"message_status_events"` - DeliveryReceipts bool `yaml:"delivery_receipts"` - MessageErrorNotices bool `yaml:"message_error_notices"` - SyncDirectChatList bool `yaml:"sync_direct_chat_list"` - FederateRooms bool `yaml:"federate_rooms"` - UploadFileThreshold int64 `yaml:"upload_file_threshold"` - GhostExtraProfileInfo bool `yaml:"ghost_extra_profile_info"` + MessageStatusEvents bool `yaml:"message_status_events"` + DeliveryReceipts bool `yaml:"delivery_receipts"` + MessageErrorNotices bool `yaml:"message_error_notices"` + SyncDirectChatList bool `yaml:"sync_direct_chat_list"` + FederateRooms bool `yaml:"federate_rooms"` + UploadFileThreshold int64 `yaml:"upload_file_threshold"` } type AnalyticsConfig struct { @@ -105,6 +99,7 @@ type AnalyticsConfig struct { } type ProvisioningConfig struct { + Prefix string `yaml:"prefix"` SharedSecret string `yaml:"shared_secret"` DebugEndpoints bool `yaml:"debug_endpoints"` EnableSessionTransfers bool `yaml:"enable_session_transfers"` @@ -117,12 +112,10 @@ type DirectMediaConfig struct { } type PublicMediaConfig struct { - Enabled bool `yaml:"enabled"` - SigningKey string `yaml:"signing_key"` - Expiry int `yaml:"expiry"` - HashLength int `yaml:"hash_length"` - PathPrefix string `yaml:"path_prefix"` - UseDatabase bool `yaml:"use_database"` + Enabled bool `yaml:"enabled"` + SigningKey string `yaml:"signing_key"` + HashLength int `yaml:"hash_length"` + Expiry int `yaml:"expiry"` } type DoublePuppetConfig struct { diff --git a/bridgev2/bridgeconfig/encryption.go b/bridgev2/bridgeconfig/encryption.go index 934613ca..1ef7e18f 100644 --- a/bridgev2/bridgeconfig/encryption.go +++ b/bridgev2/bridgeconfig/encryption.go @@ -16,8 +16,6 @@ type EncryptionConfig struct { Require bool `yaml:"require"` Appservice bool `yaml:"appservice"` MSC4190 bool `yaml:"msc4190"` - MSC4392 bool `yaml:"msc4392"` - SelfSign bool `yaml:"self_sign"` PlaintextMentions bool `yaml:"plaintext_mentions"` diff --git a/bridgev2/bridgeconfig/legacymigrate.go b/bridgev2/bridgeconfig/legacymigrate.go index 954a37c3..fb2a86d6 100644 --- a/bridgev2/bridgeconfig/legacymigrate.go +++ b/bridgev2/bridgeconfig/legacymigrate.go @@ -133,7 +133,9 @@ func doMigrateLegacy(helper up.Helper, python bool) { CopyToOtherLocation(helper, up.Bool, []string{"bridge", "sync_direct_chat_list"}, []string{"matrix", "sync_direct_chat_list"}) CopyToOtherLocation(helper, up.Bool, []string{"bridge", "federate_rooms"}, []string{"matrix", "federate_rooms"}) + CopyToOtherLocation(helper, up.Str, []string{"bridge", "provisioning", "prefix"}, []string{"provisioning", "prefix"}) CopyToOtherLocation(helper, up.Str, []string{"bridge", "provisioning", "shared_secret"}, []string{"provisioning", "shared_secret"}) + CopyToOtherLocation(helper, up.Str, []string{"appservice", "provisioning", "prefix"}, []string{"provisioning", "prefix"}) CopyToOtherLocation(helper, up.Str, []string{"appservice", "provisioning", "shared_secret"}, []string{"provisioning", "shared_secret"}) CopyToOtherLocation(helper, up.Bool, []string{"bridge", "provisioning", "debug_endpoints"}, []string{"provisioning", "debug_endpoints"}) diff --git a/bridgev2/bridgeconfig/permissions.go b/bridgev2/bridgeconfig/permissions.go index 9efe068e..610051e0 100644 --- a/bridgev2/bridgeconfig/permissions.go +++ b/bridgev2/bridgeconfig/permissions.go @@ -24,7 +24,6 @@ type Permissions struct { DoublePuppet bool `yaml:"double_puppet"` Admin bool `yaml:"admin"` ManageRelay bool `yaml:"manage_relay"` - MaxLogins int `yaml:"max_logins"` } type PermissionConfig map[string]*Permissions @@ -41,7 +40,10 @@ func (pc PermissionConfig) IsConfigured() bool { _, hasExampleDomain := pc["example.com"] _, hasExampleUser := pc["@admin:example.com"] exampleLen := boolToInt(hasWildcard) + boolToInt(hasExampleUser) + boolToInt(hasExampleDomain) - return len(pc) > exampleLen + if len(pc) <= exampleLen { + return false + } + return true } func (pc PermissionConfig) Get(userID id.UserID) Permissions { diff --git a/bridgev2/bridgeconfig/upgrade.go b/bridgev2/bridgeconfig/upgrade.go index 92515ea0..b69a1fdb 100644 --- a/bridgev2/bridgeconfig/upgrade.go +++ b/bridgev2/bridgeconfig/upgrade.go @@ -33,7 +33,6 @@ func doUpgrade(helper up.Helper) { helper.Copy(up.Bool, "bridge", "no_bridge_info_state_key") helper.Copy(up.Str|up.Null, "bridge", "bridge_status_notices") helper.Copy(up.Str|up.Int|up.Null, "bridge", "unknown_error_auto_reconnect") - helper.Copy(up.Int, "bridge", "unknown_error_max_auto_reconnects") helper.Copy(up.Bool, "bridge", "bridge_matrix_leave") helper.Copy(up.Bool, "bridge", "bridge_notices") helper.Copy(up.Bool, "bridge", "tag_only_on_create") @@ -41,8 +40,6 @@ func doUpgrade(helper up.Helper) { helper.Copy(up.Bool, "bridge", "mute_only_on_create") helper.Copy(up.Bool, "bridge", "deduplicate_matrix_messages") helper.Copy(up.Bool, "bridge", "cross_room_replies") - helper.Copy(up.Bool, "bridge", "revert_failed_state_changes") - helper.Copy(up.Bool, "bridge", "kick_matrix_users") helper.Copy(up.Bool, "bridge", "cleanup_on_logout", "enabled") helper.Copy(up.Str, "bridge", "cleanup_on_logout", "manual", "private") helper.Copy(up.Str, "bridge", "cleanup_on_logout", "manual", "relayed") @@ -101,12 +98,12 @@ func doUpgrade(helper up.Helper) { helper.Copy(up.Bool, "matrix", "sync_direct_chat_list") helper.Copy(up.Bool, "matrix", "federate_rooms") helper.Copy(up.Int, "matrix", "upload_file_threshold") - helper.Copy(up.Bool, "matrix", "ghost_extra_profile_info") helper.Copy(up.Str|up.Null, "analytics", "token") helper.Copy(up.Str|up.Null, "analytics", "url") helper.Copy(up.Str|up.Null, "analytics", "user_id") + helper.Copy(up.Str, "provisioning", "prefix") if secret, ok := helper.Get(up.Str, "provisioning", "shared_secret"); !ok || secret == "generate" { sharedSecret := random.String(64) helper.Set(up.Str, sharedSecret, "provisioning", "shared_secret") @@ -136,8 +133,6 @@ func doUpgrade(helper up.Helper) { } helper.Copy(up.Int, "public_media", "expiry") helper.Copy(up.Int, "public_media", "hash_length") - helper.Copy(up.Str|up.Null, "public_media", "path_prefix") - helper.Copy(up.Bool, "public_media", "use_database") helper.Copy(up.Bool, "backfill", "enabled") helper.Copy(up.Int, "backfill", "max_initial_messages") @@ -163,8 +158,6 @@ func doUpgrade(helper up.Helper) { } else { helper.Copy(up.Bool, "encryption", "msc4190") } - helper.Copy(up.Bool, "encryption", "msc4392") - helper.Copy(up.Bool, "encryption", "self_sign") helper.Copy(up.Bool, "encryption", "allow_key_sharing") if secret, ok := helper.Get(up.Str, "encryption", "pickle_key"); !ok || secret == "generate" { helper.Set(up.Str, random.String(64), "encryption", "pickle_key") @@ -187,8 +180,6 @@ func doUpgrade(helper up.Helper) { helper.Copy(up.Int, "encryption", "rotation", "messages") helper.Copy(up.Bool, "encryption", "rotation", "disable_device_change_key_rotation") - helper.Copy(up.Str|up.Null, "env_config_prefix") - helper.Copy(up.Map, "logging") } @@ -216,7 +207,6 @@ var SpacedBlocks = [][]string{ {"backfill"}, {"double_puppet"}, {"encryption"}, - {"env_config_prefix"}, {"logging"}, } diff --git a/bridgev2/bridgestate.go b/bridgev2/bridgestate.go index 96d9fd5c..f31d4e92 100644 --- a/bridgev2/bridgestate.go +++ b/bridgev2/bridgestate.go @@ -15,15 +15,12 @@ import ( "time" "github.com/rs/zerolog" - "go.mau.fi/util/exfmt" "maunium.net/go/mautrix/bridgev2/status" "maunium.net/go/mautrix/event" "maunium.net/go/mautrix/format" ) -var CatchBridgeStateQueuePanics = true - type BridgeStateQueue struct { prevUnsent *status.BridgeState prevSent *status.BridgeState @@ -32,13 +29,8 @@ type BridgeStateQueue struct { bridge *Bridge login *UserLogin - firstTransientDisconnect time.Time - cancelScheduledNotice atomic.Pointer[context.CancelFunc] - stopChan chan struct{} stopReconnect atomic.Pointer[context.CancelFunc] - - unknownErrorReconnects int } func (br *Bridge) SendGlobalBridgeState(state status.BridgeState) { @@ -82,63 +74,31 @@ func (bsq *BridgeStateQueue) StopUnknownErrorReconnect() { if cancelFn := bsq.stopReconnect.Swap(nil); cancelFn != nil { (*cancelFn)() } - if cancelFn := bsq.cancelScheduledNotice.Swap(nil); cancelFn != nil { - (*cancelFn)() - } } func (bsq *BridgeStateQueue) loop() { - if CatchBridgeStateQueuePanics { - defer func() { - err := recover() - if err != nil { - bsq.login.Log.Error(). - Bytes(zerolog.ErrorStackFieldName, debug.Stack()). - Any(zerolog.ErrorFieldName, err). - Msg("Panic in bridge state loop") - } - }() - } + defer func() { + err := recover() + if err != nil { + bsq.login.Log.Error(). + Bytes(zerolog.ErrorStackFieldName, debug.Stack()). + Any(zerolog.ErrorFieldName, err). + Msg("Panic in bridge state loop") + } + }() for state := range bsq.ch { bsq.immediateSendBridgeState(state) } } -func (bsq *BridgeStateQueue) scheduleNotice(triggeredBy status.BridgeState) { - log := bsq.login.Log.With().Str("action", "transient disconnect notice").Logger() - ctx := log.WithContext(bsq.bridge.BackgroundCtx) - if !bsq.waitForTransientDisconnectReconnect(ctx) { - return - } - prevUnsent := bsq.GetPrevUnsent() - prev := bsq.GetPrev() - if triggeredBy.Timestamp != prev.Timestamp || len(bsq.ch) > 0 || bsq.errorSent || - prevUnsent.StateEvent != status.StateTransientDisconnect || prev.StateEvent != status.StateTransientDisconnect { - log.Trace().Any("triggered_by", triggeredBy).Msg("Not sending delayed transient disconnect notice") - return - } - log.Debug().Any("triggered_by", triggeredBy).Msg("Sending delayed transient disconnect notice") - bsq.sendNotice(ctx, triggeredBy, true) -} - -func (bsq *BridgeStateQueue) sendNotice(ctx context.Context, state status.BridgeState, isDelayed bool) { +func (bsq *BridgeStateQueue) sendNotice(ctx context.Context, state status.BridgeState) { noticeConfig := bsq.bridge.Config.BridgeStatusNotices isError := state.StateEvent == status.StateBadCredentials || state.StateEvent == status.StateUnknownError || - state.UserAction == status.UserActionOpenNative || - (isDelayed && state.StateEvent == status.StateTransientDisconnect) + state.UserAction == status.UserActionOpenNative sendNotice := noticeConfig == "all" || (noticeConfig == "errors" && (isError || (bsq.errorSent && state.StateEvent == status.StateConnected))) - if state.StateEvent != status.StateTransientDisconnect && state.StateEvent != status.StateUnknownError { - bsq.firstTransientDisconnect = time.Time{} - } if !sendNotice { - if !bsq.errorSent && !isDelayed && noticeConfig == "errors" && state.StateEvent == status.StateTransientDisconnect { - if bsq.firstTransientDisconnect.IsZero() { - bsq.firstTransientDisconnect = time.Now() - } - go bsq.scheduleNotice(state) - } return } managementRoom, err := bsq.login.User.GetManagementRoom(ctx) @@ -154,9 +114,6 @@ func (bsq *BridgeStateQueue) sendNotice(ctx context.Context, state status.Bridge if state.Error != "" { message += fmt.Sprintf(" (`%s`)", state.Error) } - if isDelayed { - message += fmt.Sprintf(" not resolved after waiting %s", exfmt.Duration(TransientDisconnectNoticeDelay)) - } if state.Message != "" { message += fmt.Sprintf(": %s", state.Message) } @@ -194,14 +151,8 @@ func (bsq *BridgeStateQueue) unknownErrorReconnect(triggeredBy status.BridgeStat } else if prevUnsent.StateEvent != status.StateUnknownError || prev.StateEvent != status.StateUnknownError { log.Debug().Msg("Not reconnecting as the previous state was not an unknown error") return - } else if bsq.unknownErrorReconnects > bsq.bridge.Config.UnknownErrorMaxAutoReconnects { - log.Warn().Msg("Not reconnecting as the maximum number of unknown error reconnects has been reached") - return } - bsq.unknownErrorReconnects++ - log.Info(). - Int("reconnect_num", bsq.unknownErrorReconnects). - Msg("Disconnecting and reconnecting login due to unknown error") + log.Info().Msg("Disconnecting and reconnecting login due to unknown error") bsq.login.Disconnect() log.Debug().Msg("Disconnection finished, recreating client and reconnecting") err := bsq.login.recreateClient(ctx) @@ -220,30 +171,14 @@ func (bsq *BridgeStateQueue) waitForUnknownErrorReconnect(ctx context.Context) b return false } reconnectIn += time.Duration(rand.Int64N(int64(float64(reconnectIn)*0.4)) - int64(float64(reconnectIn)*0.2)) - return bsq.waitForReconnect(ctx, reconnectIn, &bsq.stopReconnect) -} - -const TransientDisconnectNoticeDelay = 3 * time.Minute - -func (bsq *BridgeStateQueue) waitForTransientDisconnectReconnect(ctx context.Context) bool { - timeUntilSchedule := time.Until(bsq.firstTransientDisconnect.Add(TransientDisconnectNoticeDelay)) - zerolog.Ctx(ctx).Trace(). - Stringer("duration", timeUntilSchedule). - Msg("Waiting before sending notice about transient disconnect") - return bsq.waitForReconnect(ctx, timeUntilSchedule, &bsq.cancelScheduledNotice) -} - -func (bsq *BridgeStateQueue) waitForReconnect( - ctx context.Context, reconnectIn time.Duration, ptr *atomic.Pointer[context.CancelFunc], -) bool { cancelCtx, cancel := context.WithCancel(ctx) defer cancel() - if oldCancel := ptr.Swap(&cancel); oldCancel != nil { + if oldCancel := bsq.stopReconnect.Swap(&cancel); oldCancel != nil { (*oldCancel)() } select { case <-time.After(reconnectIn): - return ptr.CompareAndSwap(&cancel, nil) + return bsq.stopReconnect.CompareAndSwap(&cancel, nil) case <-cancelCtx.Done(): return false case <-bsq.stopChan: @@ -263,7 +198,7 @@ func (bsq *BridgeStateQueue) immediateSendBridgeState(state status.BridgeState) } ctx := bsq.login.Log.WithContext(context.Background()) - bsq.sendNotice(ctx, state, false) + bsq.sendNotice(ctx, state) retryIn := 2 for { diff --git a/bridgev2/commands/debug.go b/bridgev2/commands/debug.go index 1cae98fe..4c93dbd4 100644 --- a/bridgev2/commands/debug.go +++ b/bridgev2/commands/debug.go @@ -7,13 +7,10 @@ package commands import ( - "encoding/json" "strings" - "time" "maunium.net/go/mautrix/bridgev2" "maunium.net/go/mautrix/bridgev2/networkid" - "maunium.net/go/mautrix/event" ) var CommandRegisterPush = &FullHandler{ @@ -62,64 +59,3 @@ var CommandRegisterPush = &FullHandler{ RequiresLogin: true, NetworkAPI: NetworkAPIImplements[bridgev2.PushableNetworkAPI], } - -var CommandSendAccountData = &FullHandler{ - Func: func(ce *Event) { - if len(ce.Args) < 2 { - ce.Reply("Usage: `$cmdprefix debug-account-data ") - return - } - var content event.Content - evtType := event.Type{Type: ce.Args[0], Class: event.AccountDataEventType} - ce.RawArgs = strings.TrimSpace(strings.Trim(ce.RawArgs, ce.Args[0])) - err := json.Unmarshal([]byte(ce.RawArgs), &content) - if err != nil { - ce.Reply("Failed to parse JSON: %v", err) - return - } - err = content.ParseRaw(evtType) - if err != nil { - ce.Reply("Failed to deserialize content: %v", err) - return - } - res := ce.Bridge.QueueMatrixEvent(ce.Ctx, &event.Event{ - Sender: ce.User.MXID, - Type: evtType, - Timestamp: time.Now().UnixMilli(), - RoomID: ce.RoomID, - Content: content, - }) - ce.Reply("Result: %+v", res) - }, - Name: "debug-account-data", - Help: HelpMeta{ - Section: HelpSectionAdmin, - Description: "Send a room account data event to the bridge", - Args: "<_type_> <_content_>", - }, - RequiresAdmin: true, - RequiresPortal: true, - RequiresLogin: true, -} - -var CommandResetNetwork = &FullHandler{ - Func: func(ce *Event) { - if strings.Contains(strings.ToLower(ce.RawArgs), "--reset-transport") { - nrn, ok := ce.Bridge.Network.(bridgev2.NetworkResettingNetwork) - if ok { - nrn.ResetHTTPTransport() - } else { - ce.Reply("Network connector does not support resetting HTTP transport") - } - } - ce.Bridge.ResetNetworkConnections() - ce.React("✅️") - }, - Name: "debug-reset-network", - Help: HelpMeta{ - Section: HelpSectionAdmin, - Description: "Reset network connections to the remote network", - Args: "[--reset-transport]", - }, - RequiresAdmin: true, -} diff --git a/bridgev2/commands/login.go b/bridgev2/commands/login.go index 96d62d3e..a18564c2 100644 --- a/bridgev2/commands/login.go +++ b/bridgev2/commands/login.go @@ -70,15 +70,6 @@ func fnLogin(ce *Event) { } ce.Args = ce.Args[1:] } - if reauth == nil && ce.User.HasTooManyLogins() { - ce.Reply( - "You have reached the maximum number of logins (%d). "+ - "Please logout from an existing login before creating a new one. "+ - "If you want to re-authenticate an existing login, use the `$cmdprefix relogin` command.", - ce.User.Permissions.MaxLogins, - ) - return - } flows := ce.Bridge.Network.GetLoginFlows() var chosenFlowID string if len(ce.Args) > 0 { @@ -121,7 +112,6 @@ func fnLogin(ce *Event) { ce.Reply("Failed to start login: %v", err) return } - ce.Log.Debug().Any("first_step", nextStep).Msg("Created login process") nextStep = checkLoginCommandDirectParams(ce, login, nextStep) if nextStep != nil { @@ -200,14 +190,11 @@ type userInputLoginCommandState struct { func (uilcs *userInputLoginCommandState) promptNext(ce *Event) { field := uilcs.RemainingFields[0] - parts := []string{fmt.Sprintf("Please enter your %s", field.Name)} if field.Description != "" { - parts = append(parts, field.Description) + ce.Reply("Please enter your %s\n%s", field.Name, field.Description) + } else { + ce.Reply("Please enter your %s", field.Name) } - if len(field.Options) > 0 { - parts = append(parts, fmt.Sprintf("Options: `%s`", strings.Join(field.Options, "`, `"))) - } - ce.Reply(strings.Join(parts, "\n")) StoreCommandState(ce.User, &CommandState{ Next: MinimalCommandHandlerFunc(uilcs.submitNext), Action: "Login", @@ -252,19 +239,14 @@ func sendQR(ce *Event, qr string, prevEventID *id.EventID) error { return fmt.Errorf("failed to upload image: %w", err) } content := &event.MessageEventContent{ - MsgType: event.MsgImage, - FileName: "qr.png", - URL: qrMXC, - File: qrFile, + MsgType: event.MsgImage, + FileName: "qr.png", + URL: qrMXC, + File: qrFile, + Body: qr, Format: event.FormatHTML, FormattedBody: fmt.Sprintf("
%s
", html.EscapeString(qr)), - Info: &event.FileInfo{ - MimeType: "image/png", - Width: qrSizePx, - Height: qrSizePx, - Size: len(qrData), - }, } if *prevEventID != "" { content.SetEdit(*prevEventID) @@ -279,36 +261,6 @@ func sendQR(ce *Event, qr string, prevEventID *id.EventID) error { return nil } -func sendUserInputAttachments(ce *Event, atts []*bridgev2.LoginUserInputAttachment) error { - for _, att := range atts { - if att.FileName == "" { - return fmt.Errorf("missing attachment filename") - } - mxc, file, err := ce.Bot.UploadMedia(ce.Ctx, ce.RoomID, att.Content, att.FileName, att.Info.MimeType) - if err != nil { - return fmt.Errorf("failed to upload attachment %q: %w", att.FileName, err) - } - content := &event.MessageEventContent{ - MsgType: att.Type, - FileName: att.FileName, - URL: mxc, - File: file, - Info: &event.FileInfo{ - MimeType: att.Info.MimeType, - Width: att.Info.Width, - Height: att.Info.Height, - Size: att.Info.Size, - }, - Body: att.FileName, - } - _, err = ce.Bot.SendMessage(ce.Ctx, ce.RoomID, event.EventMessage, &event.Content{Parsed: content}, nil) - if err != nil { - return nil - } - } - return nil -} - type contextKey int const ( @@ -500,7 +452,6 @@ func maybeURLDecodeCookie(val string, field *bridgev2.LoginCookieField) string { } func doLoginStep(ce *Event, login bridgev2.LoginProcess, step *bridgev2.LoginStep, override *bridgev2.UserLogin) { - ce.Log.Debug().Any("next_step", step).Msg("Got next login step") if step.Instructions != "" { ce.Reply(step.Instructions) } @@ -515,10 +466,6 @@ func doLoginStep(ce *Event, login bridgev2.LoginProcess, step *bridgev2.LoginSte Override: override, }).prompt(ce) case bridgev2.LoginStepTypeUserInput: - err := sendUserInputAttachments(ce, step.UserInputParams.Attachments) - if err != nil { - ce.Reply("Failed to send attachments: %v", err) - } (&userInputLoginCommandState{ Login: login.(bridgev2.LoginProcessUserInput), RemainingFields: step.UserInputParams.Fields, diff --git a/bridgev2/commands/processor.go b/bridgev2/commands/processor.go index 391c3685..c28e3a32 100644 --- a/bridgev2/commands/processor.go +++ b/bridgev2/commands/processor.go @@ -41,11 +41,10 @@ func NewProcessor(bridge *bridgev2.Bridge) bridgev2.CommandProcessor { } proc.AddHandlers( CommandHelp, CommandCancel, - CommandRegisterPush, CommandSendAccountData, CommandResetNetwork, - CommandDeletePortal, CommandDeleteAllPortals, CommandSetManagementRoom, + CommandRegisterPush, CommandDeletePortal, CommandDeleteAllPortals, CommandSetManagementRoom, CommandLogin, CommandRelogin, CommandListLogins, CommandLogout, CommandSetPreferredLogin, CommandSetRelay, CommandUnsetRelay, - CommandResolveIdentifier, CommandStartChat, CommandCreateGroup, CommandSearch, CommandSyncChat, CommandMute, + CommandResolveIdentifier, CommandStartChat, CommandSearch, CommandSudo, CommandDoIn, ) return proc diff --git a/bridgev2/commands/relay.go b/bridgev2/commands/relay.go index 94c19739..af756c87 100644 --- a/bridgev2/commands/relay.go +++ b/bridgev2/commands/relay.go @@ -37,7 +37,7 @@ func fnSetRelay(ce *Event) { } onlySetDefaultRelays := !ce.User.Permissions.Admin && ce.Bridge.Config.Relay.AdminOnly var relay *bridgev2.UserLogin - if len(ce.Args) == 0 && ce.Portal.Receiver == "" { + if len(ce.Args) == 0 { relay = ce.User.GetDefaultLogin() isLoggedIn := relay != nil if onlySetDefaultRelays { @@ -73,19 +73,9 @@ func fnSetRelay(ce *Event) { } } } else { - var targetID networkid.UserLoginID - if ce.Portal.Receiver != "" { - targetID = ce.Portal.Receiver - if len(ce.Args) > 0 && ce.Args[0] != string(targetID) { - ce.Reply("In split portals, only the receiver (%s) can be set as relay", targetID) - return - } - } else { - targetID = networkid.UserLoginID(ce.Args[0]) - } - relay = ce.Bridge.GetCachedUserLoginByID(targetID) + relay = ce.Bridge.GetCachedUserLoginByID(networkid.UserLoginID(ce.Args[0])) if relay == nil { - ce.Reply("User login with ID `%s` not found", targetID) + ce.Reply("User login with ID `%s` not found", ce.Args[0]) return } else if slices.Contains(ce.Bridge.Config.Relay.DefaultRelays, relay.ID) { // All good diff --git a/bridgev2/commands/startchat.go b/bridgev2/commands/startchat.go index c7b05a6e..719d3dd5 100644 --- a/bridgev2/commands/startchat.go +++ b/bridgev2/commands/startchat.go @@ -1,4 +1,4 @@ -// Copyright (c) 2025 Tulir Asokan +// Copyright (c) 2024 Tulir Asokan // // This Source Code Form is subject to the terms of the Mozilla Public // License, v. 2.0. If a copy of the MPL was not distributed with this @@ -8,21 +8,13 @@ package commands import ( "context" - "errors" "fmt" "html" - "maps" - "slices" "strings" "time" - "github.com/rs/zerolog" - "maunium.net/go/mautrix/bridgev2" "maunium.net/go/mautrix/bridgev2/networkid" - "maunium.net/go/mautrix/bridgev2/provisionutil" - "maunium.net/go/mautrix/event" - "maunium.net/go/mautrix/format" "maunium.net/go/mautrix/id" ) @@ -38,35 +30,6 @@ var CommandResolveIdentifier = &FullHandler{ NetworkAPI: NetworkAPIImplements[bridgev2.IdentifierResolvingNetworkAPI], } -var CommandSyncChat = &FullHandler{ - Func: func(ce *Event) { - login, _, err := ce.Portal.FindPreferredLogin(ce.Ctx, ce.User, false) - if err != nil { - ce.Log.Err(err).Msg("Failed to find login for sync") - ce.Reply("Failed to find login: %v", err) - return - } else if login == nil { - ce.Reply("No login found for sync") - return - } - info, err := login.Client.GetChatInfo(ce.Ctx, ce.Portal) - if err != nil { - ce.Log.Err(err).Msg("Failed to get chat info for sync") - ce.Reply("Failed to get chat info: %v", err) - return - } - ce.Portal.UpdateInfo(ce.Ctx, info, login, nil, time.Time{}) - ce.React("✅️") - }, - Name: "sync-portal", - Help: HelpMeta{ - Section: HelpSectionChats, - Description: "Sync the current portal room", - }, - RequiresPortal: true, - RequiresLogin: true, -} - var CommandStartChat = &FullHandler{ Func: fnResolveIdentifier, Name: "start-chat", @@ -80,15 +43,9 @@ var CommandStartChat = &FullHandler{ NetworkAPI: NetworkAPIImplements[bridgev2.IdentifierResolvingNetworkAPI], } -func getClientForStartingChat[T bridgev2.NetworkAPI](ce *Event, thing string) (*bridgev2.UserLogin, T, []string) { - var remainingArgs []string - if len(ce.Args) > 1 { - remainingArgs = ce.Args[1:] - } - var login *bridgev2.UserLogin - if len(ce.Args) > 0 { - login = ce.Bridge.GetCachedUserLoginByID(networkid.UserLoginID(ce.Args[0])) - } +func getClientForStartingChat[T bridgev2.IdentifierResolvingNetworkAPI](ce *Event, thing string) (*bridgev2.UserLogin, T, []string) { + remainingArgs := ce.Args[1:] + login := ce.Bridge.GetCachedUserLoginByID(networkid.UserLoginID(ce.Args[0])) if login == nil || login.UserMXID != ce.User.MXID { remainingArgs = ce.Args login = ce.User.GetDefaultLogin() @@ -100,13 +57,24 @@ func getClientForStartingChat[T bridgev2.NetworkAPI](ce *Event, thing string) (* return login, api, remainingArgs } -func formatResolveIdentifierResult(resp *provisionutil.RespResolveIdentifier) string { - if resp.MXID != "" { - return fmt.Sprintf("`%s` / [%s](%s)", resp.ID, resp.Name, resp.MXID.URI().MatrixToURL()) - } else if resp.Name != "" { - return fmt.Sprintf("`%s` / %s", resp.ID, resp.Name) +func formatResolveIdentifierResult(ctx context.Context, resp *bridgev2.ResolveIdentifierResponse) string { + var targetName string + var targetMXID id.UserID + if resp.Ghost != nil { + if resp.UserInfo != nil { + resp.Ghost.UpdateInfo(ctx, resp.UserInfo) + } + targetName = resp.Ghost.Name + targetMXID = resp.Ghost.Intent.GetMXID() + } else if resp.UserInfo != nil && resp.UserInfo.Name != nil { + targetName = *resp.UserInfo.Name + } + if targetMXID != "" { + return fmt.Sprintf("`%s` / [%s](%s)", resp.UserID, targetName, targetMXID.URI().MatrixToURL()) + } else if targetName != "" { + return fmt.Sprintf("`%s` / %s", resp.UserID, targetName) } else { - return fmt.Sprintf("`%s`", resp.ID) + return fmt.Sprintf("`%s`", resp.UserID) } } @@ -119,137 +87,65 @@ func fnResolveIdentifier(ce *Event) { if api == nil { return } - allLogins := ce.User.GetUserLogins() createChat := ce.Command == "start-chat" || ce.Command == "pm" identifier := strings.Join(identifierParts, " ") - resp, err := provisionutil.ResolveIdentifier(ce.Ctx, login, identifier, createChat) - for i := 0; i < len(allLogins) && errors.Is(err, bridgev2.ErrResolveIdentifierTryNext); i++ { - resp, err = provisionutil.ResolveIdentifier(ce.Ctx, allLogins[i], identifier, createChat) - } + resp, err := api.ResolveIdentifier(ce.Ctx, identifier, createChat) if err != nil { + ce.Log.Err(err).Msg("Failed to resolve identifier") ce.Reply("Failed to resolve identifier: %v", err) return } else if resp == nil { ce.ReplyAdvanced(fmt.Sprintf("Identifier %s not found", html.EscapeString(identifier)), false, true) return } - formattedName := formatResolveIdentifierResult(resp) + formattedName := formatResolveIdentifierResult(ce.Ctx, resp) if createChat { - name := resp.Portal.Name - if name == "" { - name = resp.Portal.MXID.String() + if resp.Chat == nil { + ce.Reply("Interface error: network connector did not return chat for create chat request") + return } - if !resp.JustCreated { - ce.Reply("You already have a direct chat with %s at [%s](%s)", formattedName, name, resp.Portal.MXID.URI().MatrixToURL()) + portal := resp.Chat.Portal + if portal == nil { + portal, err = ce.Bridge.GetPortalByKey(ce.Ctx, resp.Chat.PortalKey) + if err != nil { + ce.Log.Err(err).Msg("Failed to get portal") + ce.Reply("Failed to get portal: %v", err) + return + } + } + if resp.Chat.PortalInfo == nil { + resp.Chat.PortalInfo, err = api.GetChatInfo(ce.Ctx, portal) + if err != nil { + ce.Log.Err(err).Msg("Failed to get portal info") + ce.Reply("Failed to get portal info: %v", err) + return + } + } + if portal.MXID != "" { + name := portal.Name + if name == "" { + name = portal.MXID.String() + } + portal.UpdateInfo(ce.Ctx, resp.Chat.PortalInfo, login, nil, time.Time{}) + ce.Reply("You already have a direct chat with %s at [%s](%s)", formattedName, name, portal.MXID.URI().MatrixToURL()) } else { - ce.Reply("Created chat with %s: [%s](%s)", formattedName, name, resp.Portal.MXID.URI().MatrixToURL()) + err = portal.CreateMatrixRoom(ce.Ctx, login, resp.Chat.PortalInfo) + if err != nil { + ce.Log.Err(err).Msg("Failed to create room") + ce.Reply("Failed to create room: %v", err) + return + } + name := portal.Name + if name == "" { + name = portal.MXID.String() + } + ce.Reply("Created chat with %s: [%s](%s)", formattedName, name, portal.MXID.URI().MatrixToURL()) } } else { ce.Reply("Found %s", formattedName) } } -var CommandCreateGroup = &FullHandler{ - Func: fnCreateGroup, - Name: "create-group", - Aliases: []string{"create"}, - Help: HelpMeta{ - Section: HelpSectionChats, - Description: "Create a new group chat for the current Matrix room", - Args: "[_group type_]", - }, - RequiresLogin: true, - NetworkAPI: NetworkAPIImplements[bridgev2.GroupCreatingNetworkAPI], -} - -func getState[T any](ctx context.Context, roomID id.RoomID, evtType event.Type, provider bridgev2.MatrixConnectorWithArbitraryRoomState) (content T) { - evt, err := provider.GetStateEvent(ctx, roomID, evtType, "") - if err != nil { - zerolog.Ctx(ctx).Err(err).Stringer("event_type", evtType).Msg("Failed to get state event for group creation") - } else if evt != nil { - content, _ = evt.Content.Parsed.(T) - } - return -} - -func fnCreateGroup(ce *Event) { - ce.Bridge.Matrix.GetCapabilities() - login, api, remainingArgs := getClientForStartingChat[bridgev2.GroupCreatingNetworkAPI](ce, "creating group") - if api == nil { - return - } - stateProvider, ok := ce.Bridge.Matrix.(bridgev2.MatrixConnectorWithArbitraryRoomState) - if !ok { - ce.Reply("Matrix connector doesn't support fetching room state") - return - } - members, err := ce.Bridge.Matrix.GetMembers(ce.Ctx, ce.RoomID) - if err != nil { - ce.Log.Err(err).Msg("Failed to get room members for group creation") - ce.Reply("Failed to get room members: %v", err) - return - } - caps := ce.Bridge.Network.GetCapabilities() - params := &bridgev2.GroupCreateParams{ - Username: "", - Participants: make([]networkid.UserID, 0, len(members)-2), - Parent: nil, // TODO check space parent event - Name: getState[*event.RoomNameEventContent](ce.Ctx, ce.RoomID, event.StateRoomName, stateProvider), - Avatar: getState[*event.RoomAvatarEventContent](ce.Ctx, ce.RoomID, event.StateRoomAvatar, stateProvider), - Topic: getState[*event.TopicEventContent](ce.Ctx, ce.RoomID, event.StateTopic, stateProvider), - Disappear: getState[*event.BeeperDisappearingTimer](ce.Ctx, ce.RoomID, event.StateBeeperDisappearingTimer, stateProvider), - RoomID: ce.RoomID, - } - for userID, member := range members { - if userID == ce.User.MXID || userID == ce.Bot.GetMXID() || !member.Membership.IsInviteOrJoin() { - continue - } - if parsedUserID, ok := ce.Bridge.Matrix.ParseGhostMXID(userID); ok { - params.Participants = append(params.Participants, parsedUserID) - } else if !ce.Bridge.Config.SplitPortals { - if user, err := ce.Bridge.GetExistingUserByMXID(ce.Ctx, userID); err != nil { - ce.Log.Err(err).Stringer("user_id", userID).Msg("Failed to get user for room member") - } else if user != nil { - // TODO add user logins to participants - //for _, login := range user.GetUserLogins() { - // params.Participants = append(params.Participants, login.GetUserID()) - //} - } - } - } - - if len(caps.Provisioning.GroupCreation) == 0 { - ce.Reply("No group creation types defined in network capabilities") - return - } else if len(remainingArgs) > 0 { - params.Type = remainingArgs[0] - } else if len(caps.Provisioning.GroupCreation) == 1 { - for params.Type = range caps.Provisioning.GroupCreation { - // The loop assigns the variable we want - } - } else { - types := strings.Join(slices.Collect(maps.Keys(caps.Provisioning.GroupCreation)), "`, `") - ce.Reply("Please specify type of group to create: `%s`", types) - return - } - resp, err := provisionutil.CreateGroup(ce.Ctx, login, params) - if err != nil { - ce.Reply("Failed to create group: %v", err) - return - } - var postfix string - if len(resp.FailedParticipants) > 0 { - failedParticipantsStrings := make([]string, len(resp.FailedParticipants)) - i := 0 - for participantID, meta := range resp.FailedParticipants { - failedParticipantsStrings[i] = fmt.Sprintf("* %s: %s", format.SafeMarkdownCode(participantID), meta.Reason) - i++ - } - postfix += "\n\nFailed to add some participants:\n" + strings.Join(failedParticipantsStrings, "\n") - } - ce.Reply("Successfully created group `%s`%s", resp.ID, postfix) -} - var CommandSearch = &FullHandler{ Func: fnSearch, Name: "search", @@ -267,67 +163,35 @@ func fnSearch(ce *Event) { ce.Reply("Usage: `$cmdprefix search `") return } - login, api, queryParts := getClientForStartingChat[bridgev2.UserSearchingNetworkAPI](ce, "searching users") + _, api, queryParts := getClientForStartingChat[bridgev2.UserSearchingNetworkAPI](ce, "searching users") if api == nil { return } - resp, err := provisionutil.SearchUsers(ce.Ctx, login, strings.Join(queryParts, " ")) + results, err := api.SearchUsers(ce.Ctx, strings.Join(queryParts, " ")) if err != nil { + ce.Log.Err(err).Msg("Failed to search for users") ce.Reply("Failed to search for users: %v", err) return } - resultsString := make([]string, len(resp.Results)) - for i, res := range resp.Results { - formattedName := formatResolveIdentifierResult(res) + resultsString := make([]string, len(results)) + for i, res := range results { + formattedName := formatResolveIdentifierResult(ce.Ctx, res) resultsString[i] = fmt.Sprintf("* %s", formattedName) - if res.Portal != nil && res.Portal.MXID != "" { - portalName := res.Portal.Name - if portalName == "" { - portalName = res.Portal.MXID.String() + if res.Chat != nil { + if res.Chat.Portal == nil { + res.Chat.Portal, err = ce.Bridge.GetExistingPortalByKey(ce.Ctx, res.Chat.PortalKey) + if err != nil { + ce.Log.Err(err).Object("portal_key", res.Chat.PortalKey).Msg("Failed to get DM portal") + } + } + if res.Chat.Portal != nil && res.Chat.Portal.MXID != "" { + portalName := res.Chat.Portal.Name + if portalName == "" { + portalName = res.Chat.Portal.MXID.String() + } + resultsString[i] = fmt.Sprintf("%s - DM portal: [%s](%s)", resultsString[i], portalName, res.Chat.Portal.MXID.URI().MatrixToURL()) } - resultsString[i] = fmt.Sprintf("%s - DM portal: [%s](%s)", resultsString[i], portalName, res.Portal.MXID.URI().MatrixToURL()) } } ce.Reply("Search results:\n\n%s", strings.Join(resultsString, "\n")) } - -var CommandMute = &FullHandler{ - Func: fnMute, - Name: "mute", - Aliases: []string{"unmute"}, - Help: HelpMeta{ - Section: HelpSectionChats, - Description: "Mute or unmute a chat on the remote network", - Args: "[duration]", - }, - RequiresPortal: true, - RequiresLogin: true, - NetworkAPI: NetworkAPIImplements[bridgev2.MuteHandlingNetworkAPI], -} - -func fnMute(ce *Event) { - _, api, _ := getClientForStartingChat[bridgev2.MuteHandlingNetworkAPI](ce, "muting chats") - var mutedUntil int64 - if ce.Command == "mute" { - mutedUntil = -1 - if len(ce.Args) > 0 { - duration, err := time.ParseDuration(ce.Args[0]) - if err != nil { - ce.Reply("Invalid duration: %v", err) - return - } - mutedUntil = time.Now().Add(duration).UnixMilli() - } - } - err := api.HandleMute(ce.Ctx, &bridgev2.MatrixMute{ - MatrixEventBase: bridgev2.MatrixEventBase[*event.BeeperMuteEventContent]{ - Content: &event.BeeperMuteEventContent{MutedUntil: mutedUntil}, - Portal: ce.Portal, - }, - }) - if err != nil { - ce.Reply("Failed to %s chat: %v", ce.Command, err) - } else { - ce.React("✅️") - } -} diff --git a/bridgev2/database/database.go b/bridgev2/database/database.go index 05abddf0..f1789441 100644 --- a/bridgev2/database/database.go +++ b/bridgev2/database/database.go @@ -7,7 +7,13 @@ package database import ( + "encoding/json" + "reflect" + "strings" + "go.mau.fi/util/dbutil" + "golang.org/x/exp/constraints" + "golang.org/x/exp/maps" "maunium.net/go/mautrix/bridgev2/networkid" @@ -28,7 +34,6 @@ type Database struct { UserPortal *UserPortalQuery BackfillTask *BackfillTaskQuery KV *KVQuery - PublicMedia *PublicMediaQuery } type MetaMerger interface { @@ -136,12 +141,6 @@ func New(bridgeID networkid.BridgeID, mt MetaTypes, db *dbutil.Database) *Databa BridgeID: bridgeID, Database: db, }, - PublicMedia: &PublicMediaQuery{ - BridgeID: bridgeID, - QueryHelper: dbutil.MakeQueryHelper(db, func(_ *dbutil.QueryHelper[*PublicMedia]) *PublicMedia { - return &PublicMedia{} - }), - }, } } @@ -152,3 +151,55 @@ func ensureBridgeIDMatches(ptr *networkid.BridgeID, expected networkid.BridgeID) panic("bridge ID mismatch") } } + +func GetNumberFromMap[T constraints.Integer | constraints.Float](m map[string]any, key string) (T, bool) { + if val, found := m[key]; found { + floatVal, ok := val.(float64) + if ok { + return T(floatVal), true + } + tVal, ok := val.(T) + if ok { + return tVal, true + } + } + return 0, false +} + +func unmarshalMerge(input []byte, data any, extra *map[string]any) error { + err := json.Unmarshal(input, data) + if err != nil { + return err + } + err = json.Unmarshal(input, extra) + if err != nil { + return err + } + if *extra == nil { + *extra = make(map[string]any) + } + return nil +} + +func marshalMerge(data any, extra map[string]any) ([]byte, error) { + if extra == nil { + return json.Marshal(data) + } + merged := make(map[string]any) + maps.Copy(merged, extra) + dataRef := reflect.ValueOf(data).Elem() + dataType := dataRef.Type() + for _, field := range reflect.VisibleFields(dataType) { + parts := strings.Split(field.Tag.Get("json"), ",") + if len(parts) == 0 || len(parts[0]) == 0 || parts[0] == "-" { + continue + } + fieldVal := dataRef.FieldByIndex(field.Index) + if fieldVal.IsZero() { + delete(merged, parts[0]) + } else { + merged[parts[0]] = fieldVal.Interface() + } + } + return json.Marshal(merged) +} diff --git a/bridgev2/database/disappear.go b/bridgev2/database/disappear.go index df36b205..4e6f5e0a 100644 --- a/bridgev2/database/disappear.go +++ b/bridgev2/database/disappear.go @@ -12,92 +12,54 @@ import ( "time" "go.mau.fi/util/dbutil" - "go.mau.fi/util/jsontime" "maunium.net/go/mautrix/bridgev2/networkid" - "maunium.net/go/mautrix/event" "maunium.net/go/mautrix/id" ) -// Deprecated: use [event.DisappearingType] -type DisappearingType = event.DisappearingType +// DisappearingType represents the type of a disappearing message timer. +type DisappearingType string -// Deprecated: use constants in event package const ( - DisappearingTypeNone = event.DisappearingTypeNone - DisappearingTypeAfterRead = event.DisappearingTypeAfterRead - DisappearingTypeAfterSend = event.DisappearingTypeAfterSend + DisappearingTypeNone DisappearingType = "" + DisappearingTypeAfterRead DisappearingType = "after_read" + DisappearingTypeAfterSend DisappearingType = "after_send" ) // DisappearingSetting represents a disappearing message timer setting // by combining a type with a timer and an optional start timestamp. type DisappearingSetting struct { - Type event.DisappearingType + Type DisappearingType Timer time.Duration DisappearAt time.Time } -func DisappearingSettingFromEvent(evt *event.BeeperDisappearingTimer) DisappearingSetting { - if evt == nil || evt.Type == event.DisappearingTypeNone { - return DisappearingSetting{} - } - return DisappearingSetting{ - Type: evt.Type, - Timer: evt.Timer.Duration, - } -} - -func (ds DisappearingSetting) Normalize() DisappearingSetting { - if ds.Type == event.DisappearingTypeNone { - ds.Timer = 0 - } else if ds.Timer == 0 { - ds.Type = event.DisappearingTypeNone - } - return ds -} - -func (ds DisappearingSetting) StartingAt(start time.Time) DisappearingSetting { - ds.DisappearAt = start.Add(ds.Timer) - return ds -} - -func (ds DisappearingSetting) ToEventContent() *event.BeeperDisappearingTimer { - if ds.Type == event.DisappearingTypeNone || ds.Timer == 0 { - return &event.BeeperDisappearingTimer{} - } - return &event.BeeperDisappearingTimer{ - Type: ds.Type, - Timer: jsontime.MS(ds.Timer), - } -} - type DisappearingMessageQuery struct { BridgeID networkid.BridgeID *dbutil.QueryHelper[*DisappearingMessage] } type DisappearingMessage struct { - BridgeID networkid.BridgeID - RoomID id.RoomID - EventID id.EventID - Timestamp time.Time + BridgeID networkid.BridgeID + RoomID id.RoomID + EventID id.EventID DisappearingSetting } const ( upsertDisappearingMessageQuery = ` - INSERT INTO disappearing_message (bridge_id, mx_room, mxid, timestamp, type, timer, disappear_at) - VALUES ($1, $2, $3, $4, $5, $6, $7) + INSERT INTO disappearing_message (bridge_id, mx_room, mxid, type, timer, disappear_at) + VALUES ($1, $2, $3, $4, $5, $6) ON CONFLICT (bridge_id, mxid) DO UPDATE SET timer=excluded.timer, disappear_at=excluded.disappear_at ` startDisappearingMessagesQuery = ` UPDATE disappearing_message SET disappear_at=$1 + timer - WHERE bridge_id=$2 AND mx_room=$3 AND disappear_at IS NULL AND type='after_read' AND timestamp<=$4 - RETURNING bridge_id, mx_room, mxid, timestamp, type, timer, disappear_at + WHERE bridge_id=$2 AND mx_room=$3 AND disappear_at IS NULL AND type='after_read' + RETURNING bridge_id, mx_room, mxid, type, timer, disappear_at ` getUpcomingDisappearingMessagesQuery = ` - SELECT bridge_id, mx_room, mxid, timestamp, type, timer, disappear_at + SELECT bridge_id, mx_room, mxid, type, timer, disappear_at FROM disappearing_message WHERE bridge_id = $1 AND disappear_at IS NOT NULL AND disappear_at < $2 ORDER BY disappear_at LIMIT $3 ` @@ -111,8 +73,8 @@ func (dmq *DisappearingMessageQuery) Put(ctx context.Context, dm *DisappearingMe return dmq.Exec(ctx, upsertDisappearingMessageQuery, dm.sqlVariables()...) } -func (dmq *DisappearingMessageQuery) StartAllBefore(ctx context.Context, roomID id.RoomID, beforeTS time.Time) ([]*DisappearingMessage, error) { - return dmq.QueryMany(ctx, startDisappearingMessagesQuery, time.Now().UnixNano(), dmq.BridgeID, roomID, beforeTS.UnixNano()) +func (dmq *DisappearingMessageQuery) StartAll(ctx context.Context, roomID id.RoomID) ([]*DisappearingMessage, error) { + return dmq.QueryMany(ctx, startDisappearingMessagesQuery, time.Now().UnixNano(), dmq.BridgeID, roomID) } func (dmq *DisappearingMessageQuery) GetUpcoming(ctx context.Context, duration time.Duration, limit int) ([]*DisappearingMessage, error) { @@ -124,19 +86,17 @@ func (dmq *DisappearingMessageQuery) Delete(ctx context.Context, eventID id.Even } func (d *DisappearingMessage) Scan(row dbutil.Scannable) (*DisappearingMessage, error) { - var timestamp int64 var disappearAt sql.NullInt64 - err := row.Scan(&d.BridgeID, &d.RoomID, &d.EventID, ×tamp, &d.Type, &d.Timer, &disappearAt) + err := row.Scan(&d.BridgeID, &d.RoomID, &d.EventID, &d.Type, &d.Timer, &disappearAt) if err != nil { return nil, err } if disappearAt.Valid { d.DisappearAt = time.Unix(0, disappearAt.Int64) } - d.Timestamp = time.Unix(0, timestamp) return d, nil } func (d *DisappearingMessage) sqlVariables() []any { - return []any{d.BridgeID, d.RoomID, d.EventID, d.Timestamp.UnixNano(), d.Type, d.Timer, dbutil.ConvertedPtr(d.DisappearAt, time.Time.UnixNano)} + return []any{d.BridgeID, d.RoomID, d.EventID, d.Type, d.Timer, dbutil.ConvertedPtr(d.DisappearAt, time.Time.UnixNano)} } diff --git a/bridgev2/database/ghost.go b/bridgev2/database/ghost.go index 16af35ca..c32929ad 100644 --- a/bridgev2/database/ghost.go +++ b/bridgev2/database/ghost.go @@ -7,17 +7,12 @@ package database import ( - "bytes" "context" "encoding/hex" - "encoding/json" - "fmt" "go.mau.fi/util/dbutil" - "go.mau.fi/util/exerrors" "maunium.net/go/mautrix/bridgev2/networkid" - "maunium.net/go/mautrix/crypto/canonicaljson" "maunium.net/go/mautrix/id" ) @@ -27,55 +22,6 @@ type GhostQuery struct { *dbutil.QueryHelper[*Ghost] } -type ExtraProfile map[string]json.RawMessage - -func (ep *ExtraProfile) Set(key string, value any) error { - if key == "displayname" || key == "avatar_url" { - return fmt.Errorf("cannot set reserved profile key %q", key) - } - marshaled, err := json.Marshal(value) - if err != nil { - return err - } - if *ep == nil { - *ep = make(ExtraProfile) - } - (*ep)[key] = canonicaljson.CanonicalJSONAssumeValid(marshaled) - return nil -} - -func (ep *ExtraProfile) With(key string, value any) *ExtraProfile { - exerrors.PanicIfNotNil(ep.Set(key, value)) - return ep -} - -func canonicalizeIfObject(data json.RawMessage) json.RawMessage { - if len(data) > 0 && (data[0] == '{' || data[0] == '[') { - return canonicaljson.CanonicalJSONAssumeValid(data) - } - return data -} - -func (ep *ExtraProfile) CopyTo(dest *ExtraProfile) (changed bool) { - if len(*ep) == 0 { - return - } - if *dest == nil { - *dest = make(ExtraProfile) - } - for key, val := range *ep { - if key == "displayname" || key == "avatar_url" { - continue - } - existing, exists := (*dest)[key] - if !exists || !bytes.Equal(canonicalizeIfObject(existing), val) { - (*dest)[key] = val - changed = true - } - } - return -} - type Ghost struct { BridgeID networkid.BridgeID ID networkid.UserID @@ -89,14 +35,13 @@ type Ghost struct { ContactInfoSet bool IsBot bool Identifiers []string - ExtraProfile ExtraProfile Metadata any } const ( getGhostBaseQuery = ` SELECT bridge_id, id, name, avatar_id, avatar_hash, avatar_mxc, - name_set, avatar_set, contact_info_set, is_bot, identifiers, extra_profile, metadata + name_set, avatar_set, contact_info_set, is_bot, identifiers, metadata FROM ghost ` getGhostByIDQuery = getGhostBaseQuery + `WHERE bridge_id=$1 AND id=$2` @@ -104,14 +49,13 @@ const ( insertGhostQuery = ` INSERT INTO ghost ( bridge_id, id, name, avatar_id, avatar_hash, avatar_mxc, - name_set, avatar_set, contact_info_set, is_bot, identifiers, extra_profile, metadata + name_set, avatar_set, contact_info_set, is_bot, identifiers, metadata ) - VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13) + VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12) ` updateGhostQuery = ` UPDATE ghost SET name=$3, avatar_id=$4, avatar_hash=$5, avatar_mxc=$6, - name_set=$7, avatar_set=$8, contact_info_set=$9, is_bot=$10, - identifiers=$11, extra_profile=$12, metadata=$13 + name_set=$7, avatar_set=$8, contact_info_set=$9, is_bot=$10, identifiers=$11, metadata=$12 WHERE bridge_id=$1 AND id=$2 ` ) @@ -142,7 +86,7 @@ func (g *Ghost) Scan(row dbutil.Scannable) (*Ghost, error) { &g.BridgeID, &g.ID, &g.Name, &g.AvatarID, &avatarHash, &g.AvatarMXC, &g.NameSet, &g.AvatarSet, &g.ContactInfoSet, &g.IsBot, - dbutil.JSON{Data: &g.Identifiers}, dbutil.JSON{Data: &g.ExtraProfile}, dbutil.JSON{Data: g.Metadata}, + dbutil.JSON{Data: &g.Identifiers}, dbutil.JSON{Data: g.Metadata}, ) if err != nil { return nil, err @@ -172,6 +116,6 @@ func (g *Ghost) sqlVariables() []any { g.BridgeID, g.ID, g.Name, g.AvatarID, avatarHash, g.AvatarMXC, g.NameSet, g.AvatarSet, g.ContactInfoSet, g.IsBot, - dbutil.JSON{Data: &g.Identifiers}, dbutil.JSON{Data: g.ExtraProfile}, dbutil.JSON{Data: g.Metadata}, + dbutil.JSON{Data: &g.Identifiers}, dbutil.JSON{Data: g.Metadata}, } } diff --git a/bridgev2/database/kvstore.go b/bridgev2/database/kvstore.go index bca26ed5..52b4984e 100644 --- a/bridgev2/database/kvstore.go +++ b/bridgev2/database/kvstore.go @@ -23,7 +23,6 @@ const ( KeySplitPortalsEnabled Key = "split_portals_enabled" KeyBridgeInfoVersion Key = "bridge_info_version" KeyEncryptionStateResynced Key = "encryption_state_resynced" - KeyRecoveryKey Key = "recovery_key" ) type KVQuery struct { diff --git a/bridgev2/database/message.go b/bridgev2/database/message.go index 4fd599a8..9b3b1493 100644 --- a/bridgev2/database/message.go +++ b/bridgev2/database/message.go @@ -11,12 +11,9 @@ import ( "crypto/sha256" "database/sql" "encoding/base64" - "fmt" "strings" - "sync" "time" - "github.com/rs/zerolog" "go.mau.fi/util/dbutil" "maunium.net/go/mautrix/bridgev2/networkid" @@ -27,7 +24,6 @@ type MessageQuery struct { BridgeID networkid.BridgeID MetaType MetaTypeCreator *dbutil.QueryHelper[*Message] - chunkDeleteLock sync.Mutex } type Message struct { @@ -68,8 +64,8 @@ const ( getFirstMessagePartByIDQuery = getMessageBaseQuery + `WHERE bridge_id=$1 AND (room_receiver=$2 OR room_receiver='') AND id=$3 ORDER BY part_id ASC LIMIT 1` getMessagesBetweenTimeQuery = getMessageBaseQuery + `WHERE bridge_id=$1 AND room_id=$2 AND room_receiver=$3 AND timestamp>$4 AND timestamp<=$5` getOldestMessageInPortal = getMessageBaseQuery + `WHERE bridge_id=$1 AND room_id=$2 AND room_receiver=$3 ORDER BY timestamp ASC, part_id ASC LIMIT 1` - getFirstMessageInThread = getMessageBaseQuery + `WHERE bridge_id=$1 AND room_id=$2 AND room_receiver=$3 AND (id=$4 OR thread_root_id=$4) ORDER BY thread_root_id NULLS FIRST, timestamp ASC, part_id ASC LIMIT 1` - getLastMessageInThread = getMessageBaseQuery + `WHERE bridge_id=$1 AND room_id=$2 AND room_receiver=$3 AND (id=$4 OR thread_root_id=$4) ORDER BY thread_root_id NULLS LAST, timestamp DESC, part_id DESC LIMIT 1` + getFirstMessageInThread = getMessageBaseQuery + `WHERE bridge_id=$1 AND room_id=$2 AND room_receiver=$3 AND (id=$4 OR thread_root_id=$4) ORDER BY timestamp ASC, part_id ASC LIMIT 1` + getLastMessageInThread = getMessageBaseQuery + `WHERE bridge_id=$1 AND room_id=$2 AND room_receiver=$3 AND (id=$4 OR thread_root_id=$4) ORDER BY timestamp DESC, part_id DESC LIMIT 1` getLastNInPortal = getMessageBaseQuery + `WHERE bridge_id=$1 AND room_id=$2 AND room_receiver=$3 ORDER BY timestamp DESC, part_id DESC LIMIT $4` getLastMessagePartAtOrBeforeTimeQuery = getMessageBaseQuery + `WHERE bridge_id = $1 AND room_id=$2 AND room_receiver=$3 AND timestamp<=$4 ORDER BY timestamp DESC, part_id DESC LIMIT 1` @@ -100,10 +96,6 @@ const ( deleteMessagePartByRowIDQuery = ` DELETE FROM message WHERE bridge_id=$1 AND rowid=$2 ` - deleteMessageChunkQuery = ` - DELETE FROM message WHERE bridge_id=$1 AND room_id=$2 AND room_receiver=$3 AND rowid > $4 AND rowid <= $5 - ` - getMaxMessageRowIDQuery = `SELECT MAX(rowid) FROM message WHERE bridge_id=$1` ) func (mq *MessageQuery) GetAllPartsByID(ctx context.Context, receiver networkid.UserLoginID, id networkid.MessageID) ([]*Message, error) { @@ -188,85 +180,6 @@ func (mq *MessageQuery) Delete(ctx context.Context, rowID int64) error { return mq.Exec(ctx, deleteMessagePartByRowIDQuery, mq.BridgeID, rowID) } -func (mq *MessageQuery) deleteChunk(ctx context.Context, portal networkid.PortalKey, minRowID, maxRowID int64) (int64, error) { - res, err := mq.GetDB().Exec(ctx, deleteMessageChunkQuery, mq.BridgeID, portal.ID, portal.Receiver, minRowID, maxRowID) - if err != nil { - return 0, err - } - return res.RowsAffected() -} - -func (mq *MessageQuery) getMaxRowID(ctx context.Context) (maxRowID int64, err error) { - err = mq.GetDB().QueryRow(ctx, getMaxMessageRowIDQuery, mq.BridgeID).Scan(&maxRowID) - return -} - -const deleteChunkSize = 100_000 - -func (mq *MessageQuery) DeleteInChunks(ctx context.Context, portal networkid.PortalKey) error { - if mq.GetDB().Dialect != dbutil.SQLite { - return nil - } - log := zerolog.Ctx(ctx).With(). - Str("action", "delete messages in chunks"). - Stringer("portal_key", portal). - Logger() - if !mq.chunkDeleteLock.TryLock() { - log.Warn().Msg("Portal deletion lock is being held, waiting...") - mq.chunkDeleteLock.Lock() - log.Debug().Msg("Acquired portal deletion lock after waiting") - } - defer mq.chunkDeleteLock.Unlock() - total, err := mq.CountMessagesInPortal(ctx, portal) - if err != nil { - return fmt.Errorf("failed to count messages in portal: %w", err) - } else if total < deleteChunkSize/3 { - return nil - } - globalMaxRowID, err := mq.getMaxRowID(ctx) - if err != nil { - return fmt.Errorf("failed to get max row ID: %w", err) - } - log.Debug(). - Int("total_count", total). - Int64("global_max_row_id", globalMaxRowID). - Msg("Portal has lots of messages, deleting in chunks to avoid database locks") - maxRowID := int64(deleteChunkSize) - globalMaxRowID += deleteChunkSize * 1.2 - var dbTimeUsed time.Duration - globalStart := time.Now() - for total > 500 && maxRowID < globalMaxRowID { - start := time.Now() - count, err := mq.deleteChunk(ctx, portal, maxRowID-deleteChunkSize, maxRowID) - duration := time.Since(start) - dbTimeUsed += duration - if err != nil { - return fmt.Errorf("failed to delete chunk of messages before %d: %w", maxRowID, err) - } - total -= int(count) - maxRowID += deleteChunkSize - sleepTime := max(10*time.Millisecond, min(250*time.Millisecond, time.Duration(count/100)*time.Millisecond)) - log.Debug(). - Int64("max_row_id", maxRowID). - Int64("deleted_count", count). - Int("remaining_count", total). - Dur("duration", duration). - Dur("sleep_time", sleepTime). - Msg("Deleted chunk of messages") - select { - case <-time.After(sleepTime): - case <-ctx.Done(): - return ctx.Err() - } - } - log.Debug(). - Int("remaining_count", total). - Dur("db_time_used", dbTimeUsed). - Dur("total_duration", time.Since(globalStart)). - Msg("Finished chunked delete of messages in portal") - return nil -} - func (mq *MessageQuery) CountMessagesInPortal(ctx context.Context, key networkid.PortalKey) (count int, err error) { err = mq.GetDB().QueryRow(ctx, countMessagesInPortalQuery, mq.BridgeID, key.ID, key.Receiver).Scan(&count) return diff --git a/bridgev2/database/portal.go b/bridgev2/database/portal.go index 0e6be286..17e44b09 100644 --- a/bridgev2/database/portal.go +++ b/bridgev2/database/portal.go @@ -16,7 +16,6 @@ import ( "go.mau.fi/util/dbutil" "maunium.net/go/mautrix/bridgev2/networkid" - "maunium.net/go/mautrix/event" "maunium.net/go/mautrix/id" ) @@ -35,20 +34,9 @@ type PortalQuery struct { *dbutil.QueryHelper[*Portal] } -type CapStateFlags uint32 - -func (csf CapStateFlags) Has(flag CapStateFlags) bool { - return csf&flag != 0 -} - -const ( - CapStateFlagDisappearingTimerSet CapStateFlags = 1 << iota -) - type CapabilityState struct { Source networkid.UserLoginID `json:"source"` ID string `json:"id"` - Flags CapStateFlags `json:"flags"` } type Portal struct { @@ -56,31 +44,30 @@ type Portal struct { networkid.PortalKey MXID id.RoomID - ParentKey networkid.PortalKey - RelayLoginID networkid.UserLoginID - OtherUserID networkid.UserID - Name string - Topic string - AvatarID networkid.AvatarID - AvatarHash [32]byte - AvatarMXC id.ContentURIString - NameSet bool - TopicSet bool - AvatarSet bool - NameIsCustom bool - InSpace bool - MessageRequest bool - RoomType RoomType - Disappear DisappearingSetting - CapState CapabilityState - Metadata any + ParentKey networkid.PortalKey + RelayLoginID networkid.UserLoginID + OtherUserID networkid.UserID + Name string + Topic string + AvatarID networkid.AvatarID + AvatarHash [32]byte + AvatarMXC id.ContentURIString + NameSet bool + TopicSet bool + AvatarSet bool + NameIsCustom bool + InSpace bool + RoomType RoomType + Disappear DisappearingSetting + CapState CapabilityState + Metadata any } const ( getPortalBaseQuery = ` SELECT bridge_id, id, receiver, mxid, parent_id, parent_receiver, relay_login_id, other_user_id, name, topic, avatar_id, avatar_hash, avatar_mxc, - name_set, topic_set, avatar_set, name_is_custom, in_space, message_request, + name_set, topic_set, avatar_set, name_is_custom, in_space, room_type, disappear_type, disappear_timer, cap_state, metadata FROM portal @@ -89,9 +76,7 @@ const ( getPortalByIDWithUncertainReceiverQuery = getPortalBaseQuery + `WHERE bridge_id=$1 AND id=$2 AND (receiver=$3 OR receiver='')` getPortalByMXIDQuery = getPortalBaseQuery + `WHERE bridge_id=$1 AND mxid=$2` getAllPortalsWithMXIDQuery = getPortalBaseQuery + `WHERE bridge_id=$1 AND mxid IS NOT NULL` - getAllPortalsWithoutReceiver = getPortalBaseQuery + `WHERE bridge_id=$1 AND (receiver='' OR (parent_id<>'' AND parent_receiver='')) ORDER BY parent_id DESC` getAllDMPortalsQuery = getPortalBaseQuery + `WHERE bridge_id=$1 AND room_type='dm' AND other_user_id=$2` - getDMPortalQuery = getPortalBaseQuery + `WHERE bridge_id=$1 AND room_type='dm' AND receiver=$2 AND other_user_id=$3` getAllPortalsQuery = getPortalBaseQuery + `WHERE bridge_id=$1` getChildPortalsQuery = getPortalBaseQuery + `WHERE bridge_id=$1 AND parent_id=$2 AND parent_receiver=$3` @@ -102,11 +87,11 @@ const ( bridge_id, id, receiver, mxid, parent_id, parent_receiver, relay_login_id, other_user_id, name, topic, avatar_id, avatar_hash, avatar_mxc, - name_set, avatar_set, topic_set, name_is_custom, in_space, message_request, + name_set, avatar_set, topic_set, name_is_custom, in_space, room_type, disappear_type, disappear_timer, cap_state, metadata, relay_bridge_id ) VALUES ( - $1, $2, $3, $4, $5, $6, cast($7 AS TEXT), $8, $9, $10, $11, $12, $13, $14, $15, $16, $17, $18, $19, $20, $21, $22, $23, $24, + $1, $2, $3, $4, $5, $6, cast($7 AS TEXT), $8, $9, $10, $11, $12, $13, $14, $15, $16, $17, $18, $19, $20, $21, $22, $23, CASE WHEN cast($7 AS TEXT) IS NULL THEN NULL ELSE $1 END ) ` @@ -115,8 +100,8 @@ const ( SET mxid=$4, parent_id=$5, parent_receiver=$6, relay_login_id=cast($7 AS TEXT), relay_bridge_id=CASE WHEN cast($7 AS TEXT) IS NULL THEN NULL ELSE bridge_id END, other_user_id=$8, name=$9, topic=$10, avatar_id=$11, avatar_hash=$12, avatar_mxc=$13, - name_set=$14, avatar_set=$15, topic_set=$16, name_is_custom=$17, in_space=$18, message_request=$19, - room_type=$20, disappear_type=$21, disappear_timer=$22, cap_state=$23, metadata=$24 + name_set=$14, avatar_set=$15, topic_set=$16, name_is_custom=$17, in_space=$18, + room_type=$19, disappear_type=$20, disappear_timer=$21, cap_state=$22, metadata=$23 WHERE bridge_id=$1 AND id=$2 AND receiver=$3 ` deletePortalQuery = ` @@ -126,33 +111,15 @@ const ( reIDPortalQuery = `UPDATE portal SET id=$4, receiver=$5 WHERE bridge_id=$1 AND id=$2 AND receiver=$3` migrateToSplitPortalsQuery = ` UPDATE portal - SET receiver=new_receiver - FROM ( - SELECT bridge_id, id, COALESCE(( - SELECT login_id - FROM user_portal - WHERE bridge_id=portal.bridge_id AND portal_id=portal.id AND portal_receiver='' - LIMIT 1 - ), ( - SELECT login_id - FROM user_portal - WHERE portal.parent_id<>'' AND bridge_id=portal.bridge_id AND portal_id=portal.parent_id - LIMIT 1 - ), ( - SELECT id FROM user_login WHERE bridge_id=portal.bridge_id LIMIT 1 - ), '') AS new_receiver - FROM portal - WHERE receiver='' AND bridge_id=$1 - ) updates - WHERE portal.bridge_id=updates.bridge_id AND portal.id=updates.id AND portal.receiver='' AND NOT EXISTS ( - SELECT 1 FROM portal p2 WHERE p2.bridge_id=updates.bridge_id AND p2.id=updates.id AND p2.receiver=updates.new_receiver - ) - ` - fixParentsAfterSplitPortalMigrationQuery = ` - UPDATE portal - SET parent_receiver=receiver - WHERE bridge_id=$1 AND parent_receiver='' AND receiver<>'' AND parent_id<>'' - AND EXISTS(SELECT 1 FROM portal pp WHERE pp.bridge_id=$1 AND pp.id=portal.parent_id AND pp.receiver=portal.receiver); + SET receiver=COALESCE(( + SELECT login_id + FROM user_portal + WHERE bridge_id=portal.bridge_id AND portal_id=portal.id AND portal_receiver='' + LIMIT 1 + ), ( + SELECT id FROM user_login WHERE bridge_id=portal.bridge_id LIMIT 1 + ), '') + WHERE receiver='' AND bridge_id=$1 ` ) @@ -180,10 +147,6 @@ func (pq *PortalQuery) GetAllWithMXID(ctx context.Context) ([]*Portal, error) { return pq.QueryMany(ctx, getAllPortalsWithMXIDQuery, pq.BridgeID) } -func (pq *PortalQuery) GetAllWithoutReceiver(ctx context.Context) ([]*Portal, error) { - return pq.QueryMany(ctx, getAllPortalsWithoutReceiver, pq.BridgeID) -} - func (pq *PortalQuery) GetAll(ctx context.Context) ([]*Portal, error) { return pq.QueryMany(ctx, getAllPortalsQuery, pq.BridgeID) } @@ -192,10 +155,6 @@ func (pq *PortalQuery) GetAllDMsWith(ctx context.Context, otherUserID networkid. return pq.QueryMany(ctx, getAllDMPortalsQuery, pq.BridgeID, otherUserID) } -func (pq *PortalQuery) GetDM(ctx context.Context, receiver networkid.UserLoginID, otherUserID networkid.UserID) (*Portal, error) { - return pq.QueryOne(ctx, getDMPortalQuery, pq.BridgeID, receiver, otherUserID) -} - func (pq *PortalQuery) GetChildren(ctx context.Context, parentKey networkid.PortalKey) ([]*Portal, error) { return pq.QueryMany(ctx, getChildPortalsQuery, pq.BridgeID, parentKey.ID, parentKey.Receiver) } @@ -226,14 +185,6 @@ func (pq *PortalQuery) MigrateToSplitPortals(ctx context.Context) (int64, error) return res.RowsAffected() } -func (pq *PortalQuery) FixParentsAfterSplitPortalMigration(ctx context.Context) (int64, error) { - res, err := pq.GetDB().Exec(ctx, fixParentsAfterSplitPortalMigrationQuery, pq.BridgeID) - if err != nil { - return 0, err - } - return res.RowsAffected() -} - func (p *Portal) Scan(row dbutil.Scannable) (*Portal, error) { var mxid, parentID, parentReceiver, relayLoginID, otherUserID, disappearType sql.NullString var disappearTimer sql.NullInt64 @@ -242,7 +193,7 @@ func (p *Portal) Scan(row dbutil.Scannable) (*Portal, error) { &p.BridgeID, &p.ID, &p.Receiver, &mxid, &parentID, &parentReceiver, &relayLoginID, &otherUserID, &p.Name, &p.Topic, &p.AvatarID, &avatarHash, &p.AvatarMXC, - &p.NameSet, &p.TopicSet, &p.AvatarSet, &p.NameIsCustom, &p.InSpace, &p.MessageRequest, + &p.NameSet, &p.TopicSet, &p.AvatarSet, &p.NameIsCustom, &p.InSpace, &p.RoomType, &disappearType, &disappearTimer, dbutil.JSON{Data: &p.CapState}, dbutil.JSON{Data: p.Metadata}, ) @@ -257,7 +208,7 @@ func (p *Portal) Scan(row dbutil.Scannable) (*Portal, error) { } if disappearType.Valid { p.Disappear = DisappearingSetting{ - Type: event.DisappearingType(disappearType.String), + Type: DisappearingType(disappearType.String), Timer: time.Duration(disappearTimer.Int64), } } @@ -289,7 +240,7 @@ func (p *Portal) sqlVariables() []any { p.BridgeID, p.ID, p.Receiver, dbutil.StrPtr(p.MXID), dbutil.StrPtr(p.ParentKey.ID), p.ParentKey.Receiver, dbutil.StrPtr(p.RelayLoginID), dbutil.StrPtr(p.OtherUserID), p.Name, p.Topic, p.AvatarID, avatarHash, p.AvatarMXC, - p.NameSet, p.TopicSet, p.AvatarSet, p.NameIsCustom, p.InSpace, p.MessageRequest, + p.NameSet, p.TopicSet, p.AvatarSet, p.NameIsCustom, p.InSpace, p.RoomType, dbutil.StrPtr(p.Disappear.Type), dbutil.NumPtr(p.Disappear.Timer), dbutil.JSON{Data: p.CapState}, dbutil.JSON{Data: p.Metadata}, } diff --git a/bridgev2/database/publicmedia.go b/bridgev2/database/publicmedia.go deleted file mode 100644 index b667399c..00000000 --- a/bridgev2/database/publicmedia.go +++ /dev/null @@ -1,72 +0,0 @@ -// Copyright (c) 2025 Tulir Asokan -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. - -package database - -import ( - "context" - "database/sql" - "time" - - "go.mau.fi/util/dbutil" - - "maunium.net/go/mautrix/bridgev2/networkid" - "maunium.net/go/mautrix/crypto/attachment" - "maunium.net/go/mautrix/id" -) - -type PublicMediaQuery struct { - BridgeID networkid.BridgeID - *dbutil.QueryHelper[*PublicMedia] -} - -type PublicMedia struct { - BridgeID networkid.BridgeID - PublicID string - MXC id.ContentURI - Keys *attachment.EncryptedFile - MimeType string - Expiry time.Time -} - -const ( - upsertPublicMediaQuery = ` - INSERT INTO public_media (bridge_id, public_id, mxc, keys, mimetype, expiry) - VALUES ($1, $2, $3, $4, $5, $6) - ON CONFLICT (bridge_id, public_id) DO UPDATE SET expiry=EXCLUDED.expiry - ` - getPublicMediaQuery = ` - SELECT bridge_id, public_id, mxc, keys, mimetype, expiry - FROM public_media WHERE bridge_id=$1 AND public_id=$2 - ` -) - -func (pmq *PublicMediaQuery) Put(ctx context.Context, pm *PublicMedia) error { - ensureBridgeIDMatches(&pm.BridgeID, pmq.BridgeID) - return pmq.Exec(ctx, upsertPublicMediaQuery, pm.sqlVariables()...) -} - -func (pmq *PublicMediaQuery) Get(ctx context.Context, publicID string) (*PublicMedia, error) { - return pmq.QueryOne(ctx, getPublicMediaQuery, pmq.BridgeID, publicID) -} - -func (pm *PublicMedia) Scan(row dbutil.Scannable) (*PublicMedia, error) { - var expiry sql.NullInt64 - var mimetype sql.NullString - err := row.Scan(&pm.BridgeID, &pm.PublicID, &pm.MXC, dbutil.JSON{Data: &pm.Keys}, &mimetype, &expiry) - if err != nil { - return nil, err - } - if expiry.Valid { - pm.Expiry = time.Unix(0, expiry.Int64) - } - pm.MimeType = mimetype.String - return pm, nil -} - -func (pm *PublicMedia) sqlVariables() []any { - return []any{pm.BridgeID, pm.PublicID, &pm.MXC, dbutil.JSONPtr(pm.Keys), dbutil.StrPtr(pm.MimeType), dbutil.ConvertedPtr(pm.Expiry, time.Time.UnixNano)} -} diff --git a/bridgev2/database/upgrades/00-latest.sql b/bridgev2/database/upgrades/00-latest.sql index 6092dc24..4eea05bb 100644 --- a/bridgev2/database/upgrades/00-latest.sql +++ b/bridgev2/database/upgrades/00-latest.sql @@ -1,4 +1,4 @@ --- v0 -> v27 (compatible with v9+): Latest revision +-- v0 -> v22 (compatible with v9+): Latest revision CREATE TABLE "user" ( bridge_id TEXT NOT NULL, mxid TEXT NOT NULL, @@ -48,7 +48,6 @@ CREATE TABLE portal ( topic_set BOOLEAN NOT NULL, name_is_custom BOOLEAN NOT NULL DEFAULT false, in_space BOOLEAN NOT NULL, - message_request BOOLEAN NOT NULL DEFAULT false, room_type TEXT NOT NULL, disappear_type TEXT, disappear_timer BIGINT, @@ -65,7 +64,6 @@ CREATE TABLE portal ( ON DELETE SET NULL ON UPDATE CASCADE ); CREATE UNIQUE INDEX portal_bridge_mxid_idx ON portal (bridge_id, mxid); -CREATE INDEX portal_parent_idx ON portal (bridge_id, parent_id, parent_receiver); CREATE TABLE ghost ( bridge_id TEXT NOT NULL, @@ -80,7 +78,6 @@ CREATE TABLE ghost ( contact_info_set BOOLEAN NOT NULL, is_bot BOOLEAN NOT NULL, identifiers jsonb NOT NULL, - extra_profile jsonb, metadata jsonb NOT NULL, PRIMARY KEY (bridge_id, id) @@ -130,7 +127,6 @@ CREATE TABLE disappearing_message ( bridge_id TEXT NOT NULL, mx_room TEXT NOT NULL, mxid TEXT NOT NULL, - timestamp BIGINT NOT NULL DEFAULT 0, type TEXT NOT NULL, timer BIGINT NOT NULL, disappear_at BIGINT, @@ -141,7 +137,6 @@ CREATE TABLE disappearing_message ( REFERENCES portal (bridge_id, mxid) ON DELETE CASCADE ); -CREATE INDEX disappearing_message_portal_idx ON disappearing_message (bridge_id, mx_room); CREATE TABLE reaction ( bridge_id TEXT NOT NULL, @@ -220,14 +215,3 @@ CREATE TABLE kv_store ( PRIMARY KEY (bridge_id, key) ); - -CREATE TABLE public_media ( - bridge_id TEXT NOT NULL, - public_id TEXT NOT NULL, - mxc TEXT NOT NULL, - keys jsonb, - mimetype TEXT, - expiry BIGINT, - - PRIMARY KEY (bridge_id, public_id) -); diff --git a/bridgev2/database/upgrades/23-disappearing-timer-ts.sql b/bridgev2/database/upgrades/23-disappearing-timer-ts.sql deleted file mode 100644 index ecd00b8d..00000000 --- a/bridgev2/database/upgrades/23-disappearing-timer-ts.sql +++ /dev/null @@ -1,2 +0,0 @@ --- v23 (compatible with v9+): Add event timestamp for disappearing messages -ALTER TABLE disappearing_message ADD COLUMN timestamp BIGINT NOT NULL DEFAULT 0; diff --git a/bridgev2/database/upgrades/24-public-media.sql b/bridgev2/database/upgrades/24-public-media.sql deleted file mode 100644 index c4290090..00000000 --- a/bridgev2/database/upgrades/24-public-media.sql +++ /dev/null @@ -1,11 +0,0 @@ --- v24 (compatible with v9+): Custom URLs for public media -CREATE TABLE public_media ( - bridge_id TEXT NOT NULL, - public_id TEXT NOT NULL, - mxc TEXT NOT NULL, - keys jsonb, - mimetype TEXT, - expiry BIGINT, - - PRIMARY KEY (bridge_id, public_id) -); diff --git a/bridgev2/database/upgrades/25-message-requests.sql b/bridgev2/database/upgrades/25-message-requests.sql deleted file mode 100644 index b9d82a7a..00000000 --- a/bridgev2/database/upgrades/25-message-requests.sql +++ /dev/null @@ -1,2 +0,0 @@ --- v25 (compatible with v9+): Flag for message request portals -ALTER TABLE portal ADD COLUMN message_request BOOLEAN NOT NULL DEFAULT false; diff --git a/bridgev2/database/upgrades/26-disappearing-message-portal-index.sql b/bridgev2/database/upgrades/26-disappearing-message-portal-index.sql deleted file mode 100644 index ae5d8cad..00000000 --- a/bridgev2/database/upgrades/26-disappearing-message-portal-index.sql +++ /dev/null @@ -1,3 +0,0 @@ --- v26 (compatible with v9+): Add room index for disappearing message table and portal parents -CREATE INDEX disappearing_message_portal_idx ON disappearing_message (bridge_id, mx_room); -CREATE INDEX portal_parent_idx ON portal (bridge_id, parent_id, parent_receiver); diff --git a/bridgev2/database/upgrades/27-ghost-extra-profile.sql b/bridgev2/database/upgrades/27-ghost-extra-profile.sql deleted file mode 100644 index e8e0549a..00000000 --- a/bridgev2/database/upgrades/27-ghost-extra-profile.sql +++ /dev/null @@ -1,2 +0,0 @@ --- v27 (compatible with v9+): Add column for extra ghost profile metadata -ALTER TABLE ghost ADD COLUMN extra_profile jsonb; diff --git a/bridgev2/database/userlogin.go b/bridgev2/database/userlogin.go index 00ff01c9..9fa6569a 100644 --- a/bridgev2/database/userlogin.go +++ b/bridgev2/database/userlogin.go @@ -116,7 +116,7 @@ func (u *UserLogin) ensureHasMetadata(metaType MetaTypeCreator) *UserLogin { func (u *UserLogin) sqlVariables() []any { var remoteProfile dbutil.JSON - if !u.RemoteProfile.IsZero() { + if !u.RemoteProfile.IsEmpty() { remoteProfile.Data = &u.RemoteProfile } return []any{u.BridgeID, u.UserMXID, u.ID, u.RemoteName, remoteProfile, dbutil.StrPtr(u.SpaceRoom), dbutil.JSON{Data: u.Metadata}} diff --git a/bridgev2/disappear.go b/bridgev2/disappear.go index b5c37e8f..f072c01f 100644 --- a/bridgev2/disappear.go +++ b/bridgev2/disappear.go @@ -86,8 +86,8 @@ func (dl *DisappearLoop) Stop() { } } -func (dl *DisappearLoop) StartAllBefore(ctx context.Context, roomID id.RoomID, beforeTS time.Time) { - startedMessages, err := dl.br.DB.DisappearingMessage.StartAllBefore(ctx, roomID, beforeTS) +func (dl *DisappearLoop) StartAll(ctx context.Context, roomID id.RoomID) { + startedMessages, err := dl.br.DB.DisappearingMessage.StartAll(ctx, roomID) if err != nil { zerolog.Ctx(ctx).Err(err).Msg("Failed to start disappearing messages") return diff --git a/bridgev2/errors.go b/bridgev2/errors.go index f6677d2e..c023dcdf 100644 --- a/bridgev2/errors.go +++ b/bridgev2/errors.go @@ -38,53 +38,35 @@ var ErrNotLoggedIn = errors.New("not logged in") // but direct media is not enabled. var ErrDirectMediaNotEnabled = errors.New("direct media is not enabled") -var ErrPortalIsDeleted = errors.New("portal is deleted") -var ErrPortalNotFoundInEventHandler = errors.New("portal not found to handle remote event") - // Common message status errors var ( - ErrPanicInEventHandler error = WrapErrorInStatus(errors.New("panic in event handler")).WithSendNotice(true).WithErrorAsMessage() - ErrNoPortal error = WrapErrorInStatus(errors.New("room is not a portal")).WithIsCertain(true).WithSendNotice(false) - ErrIgnoringReactionFromRelayedUser error = WrapErrorInStatus(errors.New("ignoring reaction event from relayed user")).WithIsCertain(true).WithSendNotice(false) - ErrIgnoringPollFromRelayedUser error = WrapErrorInStatus(errors.New("ignoring poll event from relayed user")).WithIsCertain(true).WithSendNotice(false) - ErrIgnoringDeleteChatRelayedUser error = WrapErrorInStatus(errors.New("ignoring delete chat event from relayed user")).WithIsCertain(true).WithSendNotice(false) - ErrIgnoringAcceptRequestRelayedUser error = WrapErrorInStatus(errors.New("ignoring accept message request event from relayed user")).WithIsCertain(true).WithSendNotice(false) - ErrEditsNotSupported error = WrapErrorInStatus(errors.New("this bridge does not support edits")).WithIsCertain(true).WithErrorAsMessage().WithErrorReason(event.MessageStatusUnsupported) - ErrEditsNotSupportedInPortal error = WrapErrorInStatus(errors.New("edits are not allowed in this chat")).WithIsCertain(true).WithErrorAsMessage().WithErrorReason(event.MessageStatusUnsupported) - ErrCaptionsNotAllowed error = WrapErrorInStatus(errors.New("captions are not supported here")).WithIsCertain(true).WithErrorAsMessage().WithErrorReason(event.MessageStatusUnsupported) - ErrLocationMessagesNotAllowed error = WrapErrorInStatus(errors.New("location messages are not supported here")).WithIsCertain(true).WithErrorAsMessage().WithErrorReason(event.MessageStatusUnsupported) - ErrEditTargetTooOld error = WrapErrorInStatus(errors.New("the message is too old to be edited")).WithIsCertain(true).WithErrorAsMessage().WithErrorReason(event.MessageStatusUnsupported) - ErrEditTargetTooManyEdits error = WrapErrorInStatus(errors.New("the message has been edited too many times")).WithIsCertain(true).WithErrorAsMessage().WithErrorReason(event.MessageStatusUnsupported) - ErrReactionsNotSupported error = WrapErrorInStatus(errors.New("this bridge does not support reactions")).WithIsCertain(true).WithErrorAsMessage().WithErrorReason(event.MessageStatusUnsupported) - ErrPollsNotSupported error = WrapErrorInStatus(errors.New("this bridge does not support polls")).WithIsCertain(true).WithErrorAsMessage().WithErrorReason(event.MessageStatusUnsupported) - ErrRoomMetadataNotSupported error = WrapErrorInStatus(errors.New("this bridge does not support changing room metadata")).WithIsCertain(true).WithErrorAsMessage().WithSendNotice(false).WithErrorReason(event.MessageStatusUnsupported) - ErrRoomMetadataNotAllowed error = WrapErrorInStatus(errors.New("changes are not allowed here")).WithIsCertain(true).WithErrorAsMessage().WithSendNotice(false).WithErrorReason(event.MessageStatusUnsupported) - ErrRedactionsNotSupported error = WrapErrorInStatus(errors.New("this bridge does not support deleting messages")).WithIsCertain(true).WithErrorAsMessage().WithErrorReason(event.MessageStatusUnsupported) - ErrUnexpectedParsedContentType error = WrapErrorInStatus(errors.New("unexpected parsed content type")).WithErrorAsMessage().WithIsCertain(true).WithSendNotice(true) - ErrInvalidStateKey error = WrapErrorInStatus(errors.New("room metadata state key is unset or non-empty")).WithErrorAsMessage().WithIsCertain(true).WithSendNotice(false) - ErrDatabaseError error = WrapErrorInStatus(errors.New("database error")).WithMessage("internal database error").WithIsCertain(true).WithSendNotice(true) - ErrTargetMessageNotFound error = WrapErrorInStatus(errors.New("target message not found")).WithErrorAsMessage().WithIsCertain(true).WithSendNotice(false) - ErrUnsupportedMessageType error = WrapErrorInStatus(errors.New("unsupported message type")).WithErrorAsMessage().WithIsCertain(true).WithSendNotice(true).WithErrorReason(event.MessageStatusUnsupported) - ErrUnsupportedMediaType error = WrapErrorInStatus(errors.New("unsupported media type")).WithErrorAsMessage().WithIsCertain(true).WithSendNotice(true).WithErrorReason(event.MessageStatusUnsupported) - ErrMediaDurationTooLong error = WrapErrorInStatus(errors.New("media duration too long")).WithErrorAsMessage().WithSendNotice(true).WithErrorReason(event.MessageStatusUnsupported) - ErrVoiceMessageDurationTooLong error = WrapErrorInStatus(errors.New("voice message too long")).WithErrorAsMessage().WithSendNotice(true).WithErrorReason(event.MessageStatusUnsupported) - ErrMediaTooLarge error = WrapErrorInStatus(errors.New("media too large")).WithErrorAsMessage().WithIsCertain(true).WithSendNotice(true).WithErrorReason(event.MessageStatusUnsupported) - ErrIgnoringMNotice error = WrapErrorInStatus(errors.New("ignoring m.notice message")).WithIsCertain(true).WithErrorAsMessage().WithSendNotice(false) - ErrMediaDownloadFailed error = WrapErrorInStatus(errors.New("failed to download media")).WithMessage("failed to download media").WithIsCertain(true).WithSendNotice(true) - ErrMediaReuploadFailed error = WrapErrorInStatus(errors.New("failed to reupload media")).WithMessage("failed to reupload media").WithIsCertain(true).WithSendNotice(true) - ErrMediaConvertFailed error = WrapErrorInStatus(errors.New("failed to convert media")).WithMessage("failed to convert media").WithIsCertain(true).WithSendNotice(true) - ErrMembershipNotSupported error = WrapErrorInStatus(errors.New("this bridge does not support changing group membership")).WithIsCertain(true).WithErrorAsMessage().WithSendNotice(false).WithErrorReason(event.MessageStatusUnsupported) - ErrDeleteChatNotSupported error = WrapErrorInStatus(errors.New("this bridge does not support deleting chats")).WithIsCertain(true).WithErrorAsMessage().WithSendNotice(false).WithErrorReason(event.MessageStatusUnsupported) - ErrBeeperAIStreamNotSupported error = WrapErrorInStatus(errors.New("this bridge does not support Beeper AI stream events")).WithIsCertain(true).WithErrorAsMessage().WithSendNotice(false).WithErrorReason(event.MessageStatusUnsupported) - ErrPowerLevelsNotSupported error = WrapErrorInStatus(errors.New("this bridge does not support changing group power levels")).WithIsCertain(true).WithErrorAsMessage().WithSendNotice(false).WithErrorReason(event.MessageStatusUnsupported) - ErrRemoteEchoTimeout = WrapErrorInStatus(errors.New("remote echo timed out")).WithIsCertain(false).WithSendNotice(true).WithErrorReason(event.MessageStatusTooOld) - ErrRemoteAckTimeout = WrapErrorInStatus(errors.New("remote ack timed out")).WithIsCertain(false).WithSendNotice(true).WithErrorReason(event.MessageStatusTooOld) - - ErrPublicMediaDisabled = WrapErrorInStatus(errors.New("public media is not enabled in the bridge config")).WithIsCertain(true).WithErrorAsMessage().WithErrorReason(event.MessageStatusUnsupported).WithSendNotice(true) - ErrPublicMediaDatabaseDisabled = WrapErrorInStatus(errors.New("public media database storage is disabled")).WithIsCertain(true).WithErrorAsMessage().WithErrorReason(event.MessageStatusUnsupported).WithSendNotice(true) - ErrPublicMediaGenerateFailed = WrapErrorInStatus(errors.New("failed to generate public media URL")).WithIsCertain(true).WithMessage("failed to generate public media URL").WithErrorReason(event.MessageStatusUnsupported).WithSendNotice(true) - - ErrDisappearingTimerUnsupported error = WrapErrorInStatus(errors.New("invalid disappearing timer")).WithIsCertain(true) + ErrPanicInEventHandler error = WrapErrorInStatus(errors.New("panic in event handler")).WithSendNotice(true).WithErrorAsMessage() + ErrNoPortal error = WrapErrorInStatus(errors.New("room is not a portal")).WithIsCertain(true).WithSendNotice(false) + ErrIgnoringReactionFromRelayedUser error = WrapErrorInStatus(errors.New("ignoring reaction event from relayed user")).WithIsCertain(true).WithSendNotice(false) + ErrIgnoringPollFromRelayedUser error = WrapErrorInStatus(errors.New("ignoring poll event from relayed user")).WithIsCertain(true).WithSendNotice(false) + ErrEditsNotSupported error = WrapErrorInStatus(errors.New("this bridge does not support edits")).WithIsCertain(true).WithErrorAsMessage() + ErrEditsNotSupportedInPortal error = WrapErrorInStatus(errors.New("edits are not allowed in this chat")).WithIsCertain(true).WithErrorAsMessage() + ErrCaptionsNotAllowed error = WrapErrorInStatus(errors.New("captions are not supported here")).WithIsCertain(true).WithErrorAsMessage() + ErrLocationMessagesNotAllowed error = WrapErrorInStatus(errors.New("location messages are not supported here")).WithIsCertain(true).WithErrorAsMessage() + ErrEditTargetTooOld error = WrapErrorInStatus(errors.New("the message is too old to be edited")).WithIsCertain(true).WithErrorAsMessage() + ErrEditTargetTooManyEdits error = WrapErrorInStatus(errors.New("the message has been edited too many times")).WithIsCertain(true).WithErrorAsMessage() + ErrReactionsNotSupported error = WrapErrorInStatus(errors.New("this bridge does not support reactions")).WithIsCertain(true).WithErrorAsMessage() + ErrPollsNotSupported error = WrapErrorInStatus(errors.New("this bridge does not support polls")).WithIsCertain(true).WithErrorAsMessage() + ErrRoomMetadataNotSupported error = WrapErrorInStatus(errors.New("this bridge does not support changing room metadata")).WithIsCertain(true).WithErrorAsMessage().WithSendNotice(false) + ErrRedactionsNotSupported error = WrapErrorInStatus(errors.New("this bridge does not support deleting messages")).WithIsCertain(true).WithErrorAsMessage() + ErrUnexpectedParsedContentType error = WrapErrorInStatus(errors.New("unexpected parsed content type")).WithErrorAsMessage().WithIsCertain(true).WithSendNotice(true) + ErrDatabaseError error = WrapErrorInStatus(errors.New("database error")).WithMessage("internal database error").WithIsCertain(true).WithSendNotice(true) + ErrTargetMessageNotFound error = WrapErrorInStatus(errors.New("target message not found")).WithErrorAsMessage().WithIsCertain(true).WithSendNotice(false) + ErrUnsupportedMessageType error = WrapErrorInStatus(errors.New("unsupported message type")).WithErrorAsMessage().WithIsCertain(true).WithSendNotice(true) + ErrUnsupportedMediaType error = WrapErrorInStatus(errors.New("unsupported media type")).WithErrorAsMessage().WithIsCertain(true).WithSendNotice(true) + ErrIgnoringMNotice error = WrapErrorInStatus(errors.New("ignoring m.notice message")).WithIsCertain(true).WithErrorAsMessage().WithSendNotice(false) + ErrMediaDownloadFailed error = WrapErrorInStatus(errors.New("failed to download media")).WithMessage("failed to download media").WithIsCertain(true).WithSendNotice(true) + ErrMediaReuploadFailed error = WrapErrorInStatus(errors.New("failed to reupload media")).WithMessage("failed to reupload media").WithIsCertain(true).WithSendNotice(true) + ErrMediaConvertFailed error = WrapErrorInStatus(errors.New("failed to convert media")).WithMessage("failed to convert media").WithIsCertain(true).WithSendNotice(true) + ErrMembershipNotSupported error = WrapErrorInStatus(errors.New("this bridge does not support changing group membership")).WithIsCertain(true).WithErrorAsMessage().WithSendNotice(false) + ErrPowerLevelsNotSupported error = WrapErrorInStatus(errors.New("this bridge does not support changing group power levels")).WithIsCertain(true).WithErrorAsMessage().WithSendNotice(false) + ErrRemoteEchoTimeout = WrapErrorInStatus(errors.New("remote echo timed out")).WithIsCertain(false).WithSendNotice(true).WithErrorReason(event.MessageStatusTooOld) + ErrRemoteAckTimeout = WrapErrorInStatus(errors.New("remote ack timed out")).WithIsCertain(false).WithSendNotice(true).WithErrorReason(event.MessageStatusTooOld) ) // Common login interface errors diff --git a/bridgev2/ghost.go b/bridgev2/ghost.go index 590dd1dc..6cef6f06 100644 --- a/bridgev2/ghost.go +++ b/bridgev2/ghost.go @@ -9,15 +9,12 @@ package bridgev2 import ( "context" "crypto/sha256" - "encoding/json" "fmt" - "maps" "net/http" - "slices" "github.com/rs/zerolog" - "go.mau.fi/util/exerrors" "go.mau.fi/util/exmime" + "golang.org/x/exp/slices" "maunium.net/go/mautrix/bridgev2/database" "maunium.net/go/mautrix/bridgev2/networkid" @@ -137,11 +134,10 @@ func (a *Avatar) Reupload(ctx context.Context, intent MatrixAPI, currentHash [32 } type UserInfo struct { - Identifiers []string - Name *string - Avatar *Avatar - IsBot *bool - ExtraProfile database.ExtraProfile + Identifiers []string + Name *string + Avatar *Avatar + IsBot *bool ExtraUpdates ExtraUpdater[*Ghost] } @@ -189,9 +185,9 @@ func (ghost *Ghost) UpdateAvatar(ctx context.Context, avatar *Avatar) bool { return true } -func (ghost *Ghost) getExtraProfileMeta() any { +func (ghost *Ghost) getExtraProfileMeta() *event.BeeperProfileExtra { bridgeName := ghost.Bridge.Network.GetName() - baseExtra := &event.BeeperProfileExtra{ + return &event.BeeperProfileExtra{ RemoteID: string(ghost.ID), Identifiers: ghost.Identifiers, Service: bridgeName.BeeperBridgeType, @@ -199,35 +195,23 @@ func (ghost *Ghost) getExtraProfileMeta() any { IsBridgeBot: false, IsNetworkBot: ghost.IsBot, } - if len(ghost.ExtraProfile) == 0 { - return baseExtra - } - mergedExtra := maps.Clone(ghost.ExtraProfile) - baseExtraMarshaled := exerrors.Must(json.Marshal(baseExtra)) - exerrors.PanicIfNotNil(json.Unmarshal(baseExtraMarshaled, &mergedExtra)) - return mergedExtra } -func (ghost *Ghost) UpdateContactInfo(ctx context.Context, identifiers []string, isBot *bool, extraProfile database.ExtraProfile) bool { - if !ghost.Bridge.Matrix.GetCapabilities().ExtraProfileMeta { - ghost.ContactInfoSet = false - return false - } +func (ghost *Ghost) UpdateContactInfo(ctx context.Context, identifiers []string, isBot *bool) bool { if identifiers != nil { slices.Sort(identifiers) } - changed := extraProfile.CopyTo(&ghost.ExtraProfile) + if ghost.ContactInfoSet && + (identifiers == nil || slices.Equal(identifiers, ghost.Identifiers)) && + (isBot == nil || *isBot == ghost.IsBot) { + return false + } if identifiers != nil { - changed = changed || !slices.Equal(identifiers, ghost.Identifiers) ghost.Identifiers = identifiers } if isBot != nil { - changed = changed || *isBot != ghost.IsBot ghost.IsBot = *isBot } - if ghost.ContactInfoSet && !changed { - return false - } err := ghost.Intent.SetExtraProfileMeta(ctx, ghost.getExtraProfileMeta()) if err != nil { zerolog.Ctx(ctx).Err(err).Msg("Failed to set extra profile metadata") @@ -250,7 +234,7 @@ func (br *Bridge) allowAggressiveUpdateForType(evtType RemoteEventType) bool { } func (ghost *Ghost) UpdateInfoIfNecessary(ctx context.Context, source *UserLogin, evtType RemoteEventType) { - if ghost.Name != "" && ghost.NameSet && ghost.AvatarSet && !ghost.Bridge.allowAggressiveUpdateForType(evtType) { + if ghost.Name != "" && ghost.NameSet && !ghost.Bridge.allowAggressiveUpdateForType(evtType) { return } info, err := source.Client.GetUserInfo(ctx, ghost) @@ -260,16 +244,12 @@ func (ghost *Ghost) UpdateInfoIfNecessary(ctx context.Context, source *UserLogin zerolog.Ctx(ctx).Debug(). Bool("has_name", ghost.Name != ""). Bool("name_set", ghost.NameSet). - Bool("has_avatar", ghost.AvatarMXC != ""). - Bool("avatar_set", ghost.AvatarSet). Msg("Updating ghost info in IfNecessary call") ghost.UpdateInfo(ctx, info) } else { zerolog.Ctx(ctx).Trace(). Bool("has_name", ghost.Name != ""). Bool("name_set", ghost.NameSet). - Bool("has_avatar", ghost.AvatarMXC != ""). - Bool("avatar_set", ghost.AvatarSet). Msg("No ghost info received in IfNecessary call") } } @@ -297,14 +277,9 @@ func (ghost *Ghost) UpdateInfo(ctx context.Context, info *UserInfo) { } if info.Avatar != nil { update = ghost.UpdateAvatar(ctx, info.Avatar) || update - } else if oldAvatar == "" && !ghost.AvatarSet { - // Special case: nil avatar means we're not expecting one ever, if we don't currently have - // one we flag it as set to avoid constantly refetching in UpdateInfoIfNecessary. - ghost.AvatarSet = true - update = true } - if info.Identifiers != nil || info.IsBot != nil || info.ExtraProfile != nil { - update = ghost.UpdateContactInfo(ctx, info.Identifiers, info.IsBot, info.ExtraProfile) || update + if info.Identifiers != nil || info.IsBot != nil { + update = ghost.UpdateContactInfo(ctx, info.Identifiers, info.IsBot) || update } if info.ExtraUpdates != nil { update = info.ExtraUpdates(ctx, ghost) || update diff --git a/bridgev2/login.go b/bridgev2/login.go index b8321719..1fa3afbc 100644 --- a/bridgev2/login.go +++ b/bridgev2/login.go @@ -13,7 +13,6 @@ import ( "strings" "maunium.net/go/mautrix/bridgev2/networkid" - "maunium.net/go/mautrix/event" ) // LoginProcess represents a single occurrence of a user logging into the remote network. @@ -179,8 +178,6 @@ const ( LoginInputFieldTypeToken LoginInputFieldType = "token" LoginInputFieldTypeURL LoginInputFieldType = "url" LoginInputFieldTypeDomain LoginInputFieldType = "domain" - LoginInputFieldTypeSelect LoginInputFieldType = "select" - LoginInputFieldTypeCaptchaCode LoginInputFieldType = "captcha_code" ) type LoginInputDataField struct { @@ -192,13 +189,8 @@ type LoginInputDataField struct { Name string `json:"name"` // The description of the field shown to the user. Description string `json:"description"` - // A default value that the client can pre-fill the field with. - DefaultValue string `json:"default_value,omitempty"` // A regex pattern that the client can use to validate input client-side. Pattern string `json:"pattern,omitempty"` - // For fields of type select, the valid options. - // Pattern may also be filled with a regex that matches the same options. - Options []string `json:"options,omitempty"` // A function that validates the input and optionally cleans it up before it's submitted to the connector. Validate func(string) (string, error) `json:"-"` } @@ -273,23 +265,6 @@ func (f *LoginInputDataField) FillDefaultValidate() { type LoginUserInputParams struct { // The fields that the user needs to fill in. Fields []LoginInputDataField `json:"fields"` - - // Attachments to display alongside the input fields. - Attachments []*LoginUserInputAttachment `json:"attachments"` -} - -type LoginUserInputAttachment struct { - Type event.MessageType `json:"type,omitempty"` - FileName string `json:"filename,omitempty"` - Content []byte `json:"content,omitempty"` - Info LoginUserInputAttachmentInfo `json:"info,omitempty"` -} - -type LoginUserInputAttachmentInfo struct { - MimeType string `json:"mimetype,omitempty"` - Width int `json:"w,omitempty"` - Height int `json:"h,omitempty"` - Size int `json:"size,omitempty"` } type LoginCompleteParams struct { diff --git a/bridgev2/matrix/connector.go b/bridgev2/matrix/connector.go index 5a2df953..19eb399b 100644 --- a/bridgev2/matrix/connector.go +++ b/bridgev2/matrix/connector.go @@ -26,7 +26,6 @@ import ( _ "go.mau.fi/util/dbutil/litestream" "go.mau.fi/util/exbytes" "go.mau.fi/util/exsync" - "go.mau.fi/util/ptr" "go.mau.fi/util/random" "golang.org/x/sync/semaphore" @@ -81,8 +80,6 @@ type Connector struct { MediaConfig mautrix.RespMediaConfig SpecVersions *mautrix.RespVersions - SpecCaps *mautrix.RespCapabilities - specCapsLock sync.Mutex Capabilities *bridgev2.MatrixCapabilities IgnoreUnsupportedServer bool @@ -144,20 +141,14 @@ func (br *Connector) Init(bridge *bridgev2.Bridge) { br.EventProcessor.On(event.EventReaction, br.handleRoomEvent) br.EventProcessor.On(event.EventRedaction, br.handleRoomEvent) br.EventProcessor.On(event.EventEncrypted, br.handleEncryptedEvent) - br.EventProcessor.On(event.EphemeralEventEncrypted, br.handleEncryptedEvent) br.EventProcessor.On(event.StateMember, br.handleRoomEvent) br.EventProcessor.On(event.StatePowerLevels, br.handleRoomEvent) br.EventProcessor.On(event.StateRoomName, br.handleRoomEvent) - br.EventProcessor.On(event.BeeperSendState, br.handleRoomEvent) br.EventProcessor.On(event.StateRoomAvatar, br.handleRoomEvent) br.EventProcessor.On(event.StateTopic, br.handleRoomEvent) br.EventProcessor.On(event.StateTombstone, br.handleRoomEvent) - br.EventProcessor.On(event.StateBeeperDisappearingTimer, br.handleRoomEvent) - br.EventProcessor.On(event.BeeperDeleteChat, br.handleRoomEvent) - br.EventProcessor.On(event.BeeperAcceptMessageRequest, br.handleRoomEvent) br.EventProcessor.On(event.EphemeralEventReceipt, br.handleEphemeralEvent) br.EventProcessor.On(event.EphemeralEventTyping, br.handleEphemeralEvent) - br.EventProcessor.On(event.BeeperEphemeralEventAIStream, br.handleEphemeralEvent) br.Bot = br.AS.BotIntent() br.Crypto = NewCryptoHelper(br) br.Bridge.Commands.(*commands.Processor).AddHandlers( @@ -282,7 +273,7 @@ func (br *Connector) GetPublicAddress() string { if br.Config.AppService.PublicAddress == "https://bridge.example.com" { return "" } - return strings.TrimRight(br.Config.AppService.PublicAddress, "/") + return br.Config.AppService.PublicAddress } func (br *Connector) GetRouter() *http.ServeMux { @@ -344,18 +335,16 @@ func (br *Connector) logInitialRequestError(err error, defaultMessage string) { } func (br *Connector) ensureConnection(ctx context.Context) { - triedToRegister := false for { versions, err := br.Bot.Versions(ctx) if err != nil { - if errors.Is(err, mautrix.MForbidden) && !triedToRegister { + if errors.Is(err, mautrix.MForbidden) { br.Log.Debug().Msg("M_FORBIDDEN in /versions, trying to register before retrying") err = br.Bot.EnsureRegistered(ctx) if err != nil { br.logInitialRequestError(err, "Failed to register after /versions failed with M_FORBIDDEN") os.Exit(16) } - triedToRegister = true } else if errors.Is(err, mautrix.MUnknownToken) || errors.Is(err, mautrix.MExclusive) { br.logInitialRequestError(err, "/versions request failed with auth error") os.Exit(16) @@ -368,9 +357,6 @@ func (br *Connector) ensureConnection(ctx context.Context) { *br.AS.SpecVersions = *versions br.Capabilities.AutoJoinInvites = br.SpecVersions.Supports(mautrix.BeeperFeatureAutojoinInvites) br.Capabilities.BatchSending = br.SpecVersions.Supports(mautrix.BeeperFeatureBatchSending) - br.Capabilities.ArbitraryMemberChange = br.SpecVersions.Supports(mautrix.BeeperFeatureArbitraryMemberChange) - br.Capabilities.ExtraProfileMeta = br.SpecVersions.Supports(mautrix.BeeperFeatureArbitraryProfileMeta) || - (br.SpecVersions.Supports(mautrix.FeatureArbitraryProfileFields) && br.Config.Matrix.GhostExtraProfileInfo) break } } @@ -415,21 +401,6 @@ func (br *Connector) ensureConnection(ctx context.Context) { br.Bot.EnsureAppserviceConnection(ctx) } -func (br *Connector) fetchCapabilities(ctx context.Context) *mautrix.RespCapabilities { - br.specCapsLock.Lock() - defer br.specCapsLock.Unlock() - if br.SpecCaps != nil { - return br.SpecCaps - } - caps, err := br.Bot.Capabilities(ctx) - if err != nil { - br.Log.Err(err).Msg("Failed to fetch capabilities from homeserver") - return nil - } - br.SpecCaps = caps - return caps -} - func (br *Connector) fetchMediaConfig(ctx context.Context) { cfg, err := br.Bot.GetMediaConfig(ctx) if err != nil { @@ -538,8 +509,7 @@ func (br *Connector) internalSendMessageStatus(ctx context.Context, ms *bridgev2 Msg("Failed to send MSS event") } } - if ms.SendNotice && br.Config.Matrix.MessageErrorNotices && evt.MessageType != event.MsgNotice && - (ms.Status == event.MessageStatusFail || ms.Status == event.MessageStatusRetriable || ms.Step == status.MsgStepDecrypted) { + if ms.SendNotice && br.Config.Matrix.MessageErrorNotices && (ms.Status == event.MessageStatusFail || ms.Status == event.MessageStatusRetriable || ms.Step == status.MsgStepDecrypted) { content := ms.ToNoticeEvent(evt) if editEvent != "" { content.SetEdit(editEvent) @@ -623,28 +593,13 @@ func (br *Connector) GetPowerLevels(ctx context.Context, roomID id.RoomID) (*eve } func (br *Connector) GetStateEvent(ctx context.Context, roomID id.RoomID, eventType event.Type, stateKey string) (*event.Event, error) { - if stateKey == "" { - switch eventType { - case event.StateCreate: - createEvt, err := br.Bot.StateStore.GetCreate(ctx, roomID) - if err != nil || createEvt != nil { - return createEvt, err - } - case event.StateJoinRules: - joinRulesContent, err := br.Bot.StateStore.GetJoinRules(ctx, roomID) - if err != nil { - return nil, err - } else if joinRulesContent != nil { - return &event.Event{ - Type: event.StateJoinRules, - RoomID: roomID, - StateKey: ptr.Ptr(""), - Content: event.Content{Parsed: joinRulesContent}, - }, nil - } + if eventType == event.StateCreate && stateKey == "" { + createEvt, err := br.Bot.StateStore.GetCreate(ctx, roomID) + if err != nil || createEvt != nil { + return createEvt, err } } - return br.Bot.FullStateEvent(ctx, roomID, eventType, stateKey) + return br.Bot.FullStateEvent(ctx, roomID, eventType, "") } func (br *Connector) GetMembers(ctx context.Context, roomID id.RoomID) (map[id.UserID]*event.MemberEventContent, error) { @@ -687,7 +642,7 @@ func (br *Connector) BatchSend(ctx context.Context, roomID id.RoomID, req *mautr if intent != nil { intent.AddDoublePuppetValueWithTS(&evt.Content, evt.Timestamp) } - if evt.Type != event.EventEncrypted && evt.Type != event.EventReaction { + if evt.Type != event.EventEncrypted { err = br.Crypto.Encrypt(ctx, roomID, evt.Type, &evt.Content) if err != nil { return nil, err diff --git a/bridgev2/matrix/crypto.go b/bridgev2/matrix/crypto.go index 7f18f1f5..47226625 100644 --- a/bridgev2/matrix/crypto.go +++ b/bridgev2/matrix/crypto.go @@ -24,7 +24,6 @@ import ( "maunium.net/go/mautrix" "maunium.net/go/mautrix/bridgev2" - "maunium.net/go/mautrix/bridgev2/database" "maunium.net/go/mautrix/crypto" "maunium.net/go/mautrix/crypto/olm" "maunium.net/go/mautrix/event" @@ -38,9 +37,9 @@ func init() { var _ crypto.StateStore = (*sqlstatestore.SQLStateStore)(nil) -var NoSessionFound = crypto.ErrNoSessionFound -var DuplicateMessageIndex = crypto.ErrDuplicateMessageIndex -var UnknownMessageIndex = olm.ErrUnknownMessageIndex +var NoSessionFound = crypto.NoSessionFound +var DuplicateMessageIndex = crypto.DuplicateMessageIndex +var UnknownMessageIndex = olm.UnknownMessageIndex type CryptoHelper struct { bridge *Connector @@ -136,19 +135,7 @@ func (helper *CryptoHelper) Init(ctx context.Context) error { return err } if isExistingDevice { - if !helper.verifyKeysAreOnServer(ctx) { - return nil - } - } else { - err = helper.ShareKeys(ctx) - if err != nil { - return fmt.Errorf("failed to share device keys: %w", err) - } - } - if helper.bridge.Config.Encryption.SelfSign { - if !helper.doSelfSign(ctx) { - os.Exit(34) - } + helper.verifyKeysAreOnServer(ctx) } go helper.resyncEncryptionInfo(context.TODO()) @@ -156,46 +143,6 @@ func (helper *CryptoHelper) Init(ctx context.Context) error { return nil } -func (helper *CryptoHelper) doSelfSign(ctx context.Context) bool { - log := zerolog.Ctx(ctx) - hasKeys, isVerified, err := helper.mach.GetOwnVerificationStatus(ctx) - if err != nil { - log.WithLevel(zerolog.FatalLevel).Err(err).Msg("Failed to check verification status") - return false - } - log.Debug().Bool("has_keys", hasKeys).Bool("is_verified", isVerified).Msg("Checked verification status") - keyInDB := helper.bridge.Bridge.DB.KV.Get(ctx, database.KeyRecoveryKey) - if !hasKeys || keyInDB == "overwrite" { - if keyInDB != "" && keyInDB != "overwrite" { - log.WithLevel(zerolog.FatalLevel). - Msg("No keys on server, but database already has recovery key. Delete `recovery_key` from `kv_store` manually to continue.") - return false - } - recoveryKey, err := helper.mach.GenerateAndVerifyWithRecoveryKey(ctx) - if recoveryKey != "" { - helper.bridge.Bridge.DB.KV.Set(ctx, database.KeyRecoveryKey, recoveryKey) - } - if err != nil { - log.WithLevel(zerolog.FatalLevel).Err(err).Msg("Failed to generate recovery key and self-sign") - return false - } - log.Info().Msg("Generated new recovery key and self-signed bot device") - } else if !isVerified { - if keyInDB == "" { - log.WithLevel(zerolog.FatalLevel). - Msg("Server already has cross-signing keys, but no key in database. Add `recovery_key` to `kv_store`, or set it to `overwrite` to generate new keys.") - return false - } - err = helper.mach.VerifyWithRecoveryKey(ctx, keyInDB) - if err != nil { - log.WithLevel(zerolog.FatalLevel).Err(err).Msg("Failed to verify with recovery key") - return false - } - log.Info().Msg("Verified bot device with existing recovery key") - } - return true -} - func (helper *CryptoHelper) resyncEncryptionInfo(ctx context.Context) { log := helper.log.With().Str("action", "resync encryption event").Logger() rows, err := helper.store.DB.Query(ctx, `SELECT room_id FROM mx_room_state WHERE encryption='{"resync":true}'`) @@ -210,12 +157,12 @@ func (helper *CryptoHelper) resyncEncryptionInfo(ctx context.Context) { var evt event.EncryptionEventContent err = helper.client.StateEvent(ctx, roomID, event.StateEncryption, "", &evt) if err != nil { - log.Err(err).Stringer("room_id", roomID).Msg("Failed to get encryption event") + log.Err(err).Str("room_id", roomID.String()).Msg("Failed to get encryption event") _, err = helper.store.DB.Exec(ctx, ` UPDATE mx_room_state SET encryption=NULL WHERE room_id=$1 AND encryption='{"resync":true}' `, roomID) if err != nil { - log.Err(err).Stringer("room_id", roomID).Msg("Failed to unmark room for resync after failed sync") + log.Err(err).Str("room_id", roomID.String()).Msg("Failed to unmark room for resync after failed sync") } } else { maxAge := evt.RotationPeriodMillis @@ -238,9 +185,9 @@ func (helper *CryptoHelper) resyncEncryptionInfo(ctx context.Context) { WHERE room_id=$3 AND max_age IS NULL AND max_messages IS NULL `, maxAge, maxMessages, roomID) if err != nil { - log.Err(err).Stringer("room_id", roomID).Msg("Failed to update megolm session table") + log.Err(err).Str("room_id", roomID.String()).Msg("Failed to update megolm session table") } else { - log.Debug().Stringer("room_id", roomID).Msg("Updated megolm session table") + log.Debug().Str("room_id", roomID.String()).Msg("Updated megolm session table") } } } @@ -286,7 +233,7 @@ func (helper *CryptoHelper) loginBot(ctx context.Context) (*mautrix.Client, bool if err != nil { return nil, false, fmt.Errorf("failed to find existing device ID: %w", err) } else if len(deviceID) > 0 { - helper.log.Debug().Stringer("device_id", deviceID).Msg("Found existing device ID for bot in database") + helper.log.Debug().Str("device_id", deviceID.String()).Msg("Found existing device ID for bot in database") } // Create a new client instance with the default AS settings (including as_token), // the Login call will then override the access token in the client. @@ -327,7 +274,7 @@ func (helper *CryptoHelper) loginBot(ctx context.Context) (*mautrix.Client, bool return client, deviceID != "", nil } -func (helper *CryptoHelper) verifyKeysAreOnServer(ctx context.Context) bool { +func (helper *CryptoHelper) verifyKeysAreOnServer(ctx context.Context) { helper.log.Debug().Msg("Making sure keys are still on server") resp, err := helper.client.QueryKeys(ctx, &mautrix.ReqQueryKeys{ DeviceKeys: map[id.UserID]mautrix.DeviceIDList{ @@ -340,11 +287,10 @@ func (helper *CryptoHelper) verifyKeysAreOnServer(ctx context.Context) bool { } device, ok := resp.DeviceKeys[helper.client.UserID][helper.client.DeviceID] if ok && len(device.Keys) > 0 { - return true + return } helper.log.Warn().Msg("Existing device doesn't have keys on server, resetting crypto") helper.Reset(ctx, false) - return false } func (helper *CryptoHelper) Start() { @@ -439,7 +385,7 @@ func (helper *CryptoHelper) Encrypt(ctx context.Context, roomID id.RoomID, evtTy var encrypted *event.EncryptedEventContent encrypted, err = helper.mach.EncryptMegolmEvent(ctx, roomID, evtType, content) if err != nil { - if !errors.Is(err, crypto.ErrSessionExpired) && !errors.Is(err, crypto.ErrSessionNotShared) && !errors.Is(err, crypto.ErrNoGroupSession) { + if !errors.Is(err, crypto.SessionExpired) && !errors.Is(err, crypto.SessionNotShared) && !errors.Is(err, crypto.NoGroupSession) { return } helper.log.Debug().Err(err). diff --git a/bridgev2/matrix/intent.go b/bridgev2/matrix/intent.go index f7254bd4..7d78b5a2 100644 --- a/bridgev2/matrix/intent.go +++ b/bridgev2/matrix/intent.go @@ -9,7 +9,6 @@ package matrix import ( "bytes" "context" - "encoding/json" "errors" "fmt" "io" @@ -28,7 +27,6 @@ import ( "maunium.net/go/mautrix/bridgev2" "maunium.net/go/mautrix/bridgev2/bridgeconfig" "maunium.net/go/mautrix/crypto/attachment" - "maunium.net/go/mautrix/crypto/canonicaljson" "maunium.net/go/mautrix/event" "maunium.net/go/mautrix/id" "maunium.net/go/mautrix/pushrules" @@ -45,13 +43,13 @@ type ASIntent struct { var _ bridgev2.MatrixAPI = (*ASIntent)(nil) var _ bridgev2.MarkAsDMMatrixAPI = (*ASIntent)(nil) -var _ bridgev2.EphemeralSendingMatrixAPI = (*ASIntent)(nil) func (as *ASIntent) SendMessage(ctx context.Context, roomID id.RoomID, eventType event.Type, content *event.Content, extra *bridgev2.MatrixSendExtra) (*mautrix.RespSendEvent, error) { if extra == nil { extra = &bridgev2.MatrixSendExtra{} } - if eventType == event.EventRedaction && !as.Connector.SpecVersions.Supports(mautrix.FeatureRedactSendAsEvent) { + // TODO remove this once hungryserv and synapse support sending m.room.redactions directly in all room versions + if eventType == event.EventRedaction { parsedContent := content.Parsed.(*event.RedactionEventContent) as.Matrix.AddDoublePuppetValue(content) return as.Matrix.RedactEvent(ctx, roomID, parsedContent.Redacts, mautrix.ReqRedact{ @@ -59,7 +57,7 @@ func (as *ASIntent) SendMessage(ctx context.Context, roomID id.RoomID, eventType Extra: content.Raw, }) } - if (eventType != event.EventReaction || as.Connector.Config.Encryption.MSC4392) && eventType != event.EventRedaction { + if eventType != event.EventReaction && eventType != event.EventRedaction { msgContent, ok := content.Parsed.(*event.MessageEventContent) if ok { msgContent.AddPerMessageProfileFallback() @@ -84,27 +82,16 @@ func (as *ASIntent) SendMessage(ctx context.Context, roomID id.RoomID, eventType eventType = event.EventEncrypted } } - return as.Matrix.SendMessageEvent(ctx, roomID, eventType, content, mautrix.ReqSendEvent{Timestamp: extra.Timestamp.UnixMilli()}) -} - -func (as *ASIntent) BeeperSendEphemeralEvent(ctx context.Context, roomID id.RoomID, eventType event.Type, content *event.Content, txnID string) (*mautrix.RespSendEvent, error) { - if !as.Connector.SpecVersions.Supports(mautrix.BeeperFeatureEphemeralEvents) { - return nil, mautrix.MUnrecognized.WithMessage("Homeserver does not advertise com.beeper.ephemeral support") + if extra.Timestamp.IsZero() { + return as.Matrix.SendMessageEvent(ctx, roomID, eventType, content) + } else { + return as.Matrix.SendMassagedMessageEvent(ctx, roomID, eventType, content, extra.Timestamp.UnixMilli()) } - if encrypted, err := as.Matrix.StateStore.IsEncrypted(ctx, roomID); err != nil { - return nil, fmt.Errorf("failed to check if room is encrypted: %w", err) - } else if encrypted && as.Connector.Crypto != nil { - if err = as.Connector.Crypto.Encrypt(ctx, roomID, eventType, content); err != nil { - return nil, err - } - eventType = event.EventEncrypted - } - return as.Matrix.BeeperSendEphemeralEvent(ctx, roomID, eventType, content, mautrix.ReqSendEvent{TransactionID: txnID}) } func (as *ASIntent) fillMemberEvent(ctx context.Context, roomID id.RoomID, userID id.UserID, content *event.Content) { - targetContent, ok := content.Parsed.(*event.MemberEventContent) - if !ok || targetContent.Displayname != "" || targetContent.AvatarURL != "" { + targetContent := content.Parsed.(*event.MemberEventContent) + if targetContent.Displayname != "" || targetContent.AvatarURL != "" { return } memberContent, err := as.Matrix.StateStore.TryGetMember(ctx, roomID, userID) @@ -139,7 +126,11 @@ func (as *ASIntent) SendState(ctx context.Context, roomID id.RoomID, eventType e if eventType == event.StateMember { as.fillMemberEvent(ctx, roomID, id.UserID(stateKey), content) } - resp, err = as.Matrix.SendStateEvent(ctx, roomID, eventType, stateKey, content, mautrix.ReqSendEvent{Timestamp: ts.UnixMilli()}) + if ts.IsZero() { + resp, err = as.Matrix.SendStateEvent(ctx, roomID, eventType, stateKey, content) + } else { + resp, err = as.Matrix.SendMassagedStateEvent(ctx, roomID, eventType, stateKey, content, ts.UnixMilli()) + } if err != nil && eventType == event.StateMember { var httpErr mautrix.HTTPError if errors.As(err, &httpErr) && httpErr.RespError != nil && @@ -421,7 +412,6 @@ func (as *ASIntent) UploadMediaStream( removeAndClose(replFile) removeAndClose(tempFile) } - req.AsyncContext = zerolog.Ctx(ctx).WithContext(as.Connector.Bridge.BackgroundCtx) startedAsyncUpload = true var resp *mautrix.RespCreateMXC resp, err = as.Matrix.UploadAsync(ctx, req) @@ -454,7 +444,6 @@ func (as *ASIntent) doUploadReq(ctx context.Context, file *event.EncryptedFileIn as.Connector.uploadSema.Release(int64(len(req.ContentBytes))) } } - req.AsyncContext = zerolog.Ctx(ctx).WithContext(as.Connector.Bridge.BackgroundCtx) var resp *mautrix.RespCreateMXC resp, err = as.Matrix.UploadAsync(ctx, req) if resp != nil { @@ -486,62 +475,11 @@ func (as *ASIntent) SetAvatarURL(ctx context.Context, avatarURL id.ContentURIStr return as.Matrix.SetAvatarURL(ctx, parsedAvatarURL) } -func dataToFields(data any) (map[string]json.RawMessage, error) { - fields, ok := data.(map[string]json.RawMessage) - if ok { - return fields, nil - } - d, err := json.Marshal(data) - if err != nil { - return nil, err - } - d = canonicaljson.CanonicalJSONAssumeValid(d) - err = json.Unmarshal(d, &fields) - return fields, err -} - -func marshalField(val any) json.RawMessage { - data, _ := json.Marshal(val) - if len(data) > 0 && (data[0] == '{' || data[0] == '[') { - return canonicaljson.CanonicalJSONAssumeValid(data) - } - return data -} - -var nullJSON = json.RawMessage("null") - func (as *ASIntent) SetExtraProfileMeta(ctx context.Context, data any) error { - if as.Connector.SpecVersions.Supports(mautrix.BeeperFeatureArbitraryProfileMeta) { - return as.Matrix.BeeperUpdateProfile(ctx, data) - } else if as.Connector.SpecVersions.Supports(mautrix.FeatureArbitraryProfileFields) && as.Connector.Config.Matrix.GhostExtraProfileInfo { - fields, err := dataToFields(data) - if err != nil { - return fmt.Errorf("failed to marshal fields: %w", err) - } - currentProfile, err := as.Matrix.GetProfile(ctx, as.Matrix.UserID) - if err != nil { - return fmt.Errorf("failed to get current profile: %w", err) - } - for key, val := range fields { - existing, ok := currentProfile.Extra[key] - if !ok { - if bytes.Equal(val, nullJSON) { - continue - } - err = as.Matrix.SetProfileField(ctx, key, val) - } else if !bytes.Equal(marshalField(existing), val) { - if bytes.Equal(val, nullJSON) { - err = as.Matrix.DeleteProfileField(ctx, key) - } else { - err = as.Matrix.SetProfileField(ctx, key, val) - } - } - if err != nil { - return fmt.Errorf("failed to set profile field %q: %w", key, err) - } - } + if !as.Connector.SpecVersions.Supports(mautrix.BeeperFeatureArbitraryProfileMeta) { + return nil } - return nil + return as.Matrix.BeeperUpdateProfile(ctx, data) } func (as *ASIntent) GetMXID() id.UserID { @@ -583,39 +521,6 @@ func (br *Connector) getDefaultEncryptionEvent() *event.EncryptionEventContent { return content } -func (as *ASIntent) filterCreateRequestForV12(ctx context.Context, req *mautrix.ReqCreateRoom) { - if as.Connector.Config.Homeserver.Software == bridgeconfig.SoftwareHungry { - // Hungryserv doesn't override the capabilities endpoint nor do room versions - return - } - caps := as.Connector.fetchCapabilities(ctx) - roomVer := req.RoomVersion - if roomVer == "" && caps != nil && caps.RoomVersions != nil { - roomVer = id.RoomVersion(caps.RoomVersions.Default) - } - if roomVer != "" && !roomVer.PrivilegedRoomCreators() { - return - } - creators, _ := req.CreationContent["additional_creators"].([]id.UserID) - creators = append(slices.Clone(creators), as.GetMXID()) - if req.PowerLevelOverride != nil { - for _, creator := range creators { - delete(req.PowerLevelOverride.Users, creator) - } - } - for _, evt := range req.InitialState { - if evt.Type != event.StatePowerLevels { - continue - } - content, ok := evt.Content.Parsed.(*event.PowerLevelsEventContent) - if ok { - for _, creator := range creators { - delete(content.Users, creator) - } - } - } -} - func (as *ASIntent) CreateRoom(ctx context.Context, req *mautrix.ReqCreateRoom) (id.RoomID, error) { if as.Connector.Config.Encryption.Default { req.InitialState = append(req.InitialState, &event.Event{ @@ -631,7 +536,6 @@ func (as *ASIntent) CreateRoom(ctx context.Context, req *mautrix.ReqCreateRoom) } req.CreationContent["m.federate"] = false } - as.filterCreateRequestForV12(ctx, req) resp, err := as.Matrix.CreateRoom(ctx, req) if err != nil { return "", err @@ -673,9 +577,6 @@ func (as *ASIntent) MarkAsDM(ctx context.Context, roomID id.RoomID, withUser id. } func (as *ASIntent) DeleteRoom(ctx context.Context, roomID id.RoomID, puppetsOnly bool) error { - if roomID == "" { - return nil - } if as.Connector.SpecVersions.Supports(mautrix.BeeperFeatureRoomYeeting) { err := as.Matrix.BeeperDeleteRoom(ctx, roomID) if err != nil { @@ -773,23 +674,3 @@ func (as *ASIntent) MuteRoom(ctx context.Context, roomID id.RoomID, until time.T }) } } - -func (as *ASIntent) GetEvent(ctx context.Context, roomID id.RoomID, eventID id.EventID) (*event.Event, error) { - evt, err := as.Matrix.Client.GetEvent(ctx, roomID, eventID) - if err != nil { - return nil, err - } - err = evt.Content.ParseRaw(evt.Type) - if err != nil { - zerolog.Ctx(ctx).Err(err).Stringer("room_id", roomID).Stringer("event_id", eventID).Msg("failed to parse event content") - } - - if evt.Type == event.EventEncrypted { - if as.Connector.Crypto == nil || as.Connector.Config.Encryption.DeleteKeys.RatchetOnDecrypt { - return nil, errors.New("can't decrypt the event") - } - return as.Connector.Crypto.Decrypt(ctx, evt) - } - - return evt, nil -} diff --git a/bridgev2/matrix/matrix.go b/bridgev2/matrix/matrix.go index 954d0ad9..49c377db 100644 --- a/bridgev2/matrix/matrix.go +++ b/bridgev2/matrix/matrix.go @@ -27,11 +27,6 @@ func (br *Connector) handleRoomEvent(ctx context.Context, evt *event.Event) { if br.shouldIgnoreEvent(evt) { return } - if !br.Config.Bridge.Permissions.Get(evt.Sender).SendEvents && evt.Type != event.StateMember { - zerolog.Ctx(ctx).Debug().Msg("Dropping event from user with no permission to send events") - br.SendMessageStatus(ctx, &bridgev2.ErrNoPermissionToInteract, bridgev2.StatusEventInfoFromEvent(evt)) - return - } if (evt.Type == event.EventMessage || evt.Type == event.EventSticker) && !evt.Mautrix.WasEncrypted && br.Config.Encryption.Require { zerolog.Ctx(ctx).Warn().Msg("Dropping unencrypted event as encryption is configured to be required") br.sendCryptoStatusError(ctx, evt, errMessageNotEncrypted, nil, 0, true) @@ -68,10 +63,6 @@ func (br *Connector) handleEphemeralEvent(ctx context.Context, evt *event.Event) case event.EphemeralEventTyping: typingContent := evt.Content.AsTyping() typingContent.UserIDs = slices.DeleteFunc(typingContent.UserIDs, br.shouldIgnoreEventFromUser) - case event.BeeperEphemeralEventAIStream: - if br.shouldIgnoreEvent(evt) { - return - } } br.Bridge.QueueMatrixEvent(ctx, evt) } @@ -85,11 +76,6 @@ func (br *Connector) handleEncryptedEvent(ctx context.Context, evt *event.Event) Str("event_id", evt.ID.String()). Str("session_id", content.SessionID.String()). Logger() - if !br.Config.Bridge.Permissions.Get(evt.Sender).SendEvents { - log.Debug().Msg("Dropping event from user with no permission to send events") - br.SendMessageStatus(ctx, &bridgev2.ErrNoPermissionToInteract, bridgev2.StatusEventInfoFromEvent(evt)) - return - } ctx = log.WithContext(ctx) if br.Crypto == nil { br.sendCryptoStatusError(ctx, evt, errNoCrypto, nil, 0, true) @@ -101,18 +87,17 @@ func (br *Connector) handleEncryptedEvent(ctx context.Context, evt *event.Event) decryptionStart := time.Now() decrypted, err := br.Crypto.Decrypt(ctx, evt) decryptionRetryCount := 0 - var errorEventID id.EventID if errors.Is(err, NoSessionFound) { decryptionRetryCount = 1 log.Debug(). Int("wait_seconds", int(initialSessionWaitTimeout.Seconds())). Msg("Couldn't find session, waiting for keys to arrive...") - go br.sendCryptoStatusError(ctx, evt, err, &errorEventID, 0, false) + go br.sendCryptoStatusError(ctx, evt, err, nil, 0, false) if br.Crypto.WaitForSession(ctx, evt.RoomID, content.SenderKey, content.SessionID, initialSessionWaitTimeout) { log.Debug().Msg("Got keys after waiting, trying to decrypt event again") decrypted, err = br.Crypto.Decrypt(ctx, evt) } else { - go br.waitLongerForSession(ctx, evt, decryptionStart, &errorEventID) + go br.waitLongerForSession(ctx, evt, decryptionStart) return } } @@ -121,18 +106,18 @@ func (br *Connector) handleEncryptedEvent(ctx context.Context, evt *event.Event) go br.sendCryptoStatusError(ctx, evt, err, nil, decryptionRetryCount, true) return } - br.postDecrypt(ctx, evt, decrypted, decryptionRetryCount, &errorEventID, time.Since(decryptionStart)) + br.postDecrypt(ctx, evt, decrypted, decryptionRetryCount, nil, time.Since(decryptionStart)) } -func (br *Connector) waitLongerForSession(ctx context.Context, evt *event.Event, decryptionStart time.Time, errorEventID *id.EventID) { +func (br *Connector) waitLongerForSession(ctx context.Context, evt *event.Event, decryptionStart time.Time) { log := zerolog.Ctx(ctx) content := evt.Content.AsEncrypted() log.Debug(). Int("wait_seconds", int(extendedSessionWaitTimeout.Seconds())). Msg("Couldn't find session, requesting keys and waiting longer...") - //lint:ignore SA1019 RequestSession will gracefully request from all devices if DeviceID is blank go br.Crypto.RequestSession(ctx, evt.RoomID, content.SenderKey, content.SessionID, evt.Sender, content.DeviceID) + var errorEventID *id.EventID go br.sendCryptoStatusError(ctx, evt, fmt.Errorf("%w. The bridge will retry for %d seconds", errNoDecryptionKeys, int(extendedSessionWaitTimeout.Seconds())), errorEventID, 1, false) if !br.Crypto.WaitForSession(ctx, evt.RoomID, content.SenderKey, content.SessionID, extendedSessionWaitTimeout) { @@ -235,6 +220,7 @@ func (br *Connector) postDecrypt(ctx context.Context, original, decrypted *event go br.sendSuccessCheckpoint(ctx, decrypted, status.MsgStepDecrypted, retryCount) decrypted.Mautrix.CheckpointSent = true decrypted.Mautrix.DecryptionDuration = duration + decrypted.Mautrix.EventSource |= event.SourceDecrypted br.EventProcessor.Dispatch(ctx, decrypted) if errorEventID != nil && *errorEventID != "" { _, _ = br.Bot.RedactEvent(ctx, decrypted.RoomID, *errorEventID) diff --git a/bridgev2/matrix/mxmain/dberror.go b/bridgev2/matrix/mxmain/dberror.go index f5e438de..0f6aa68c 100644 --- a/bridgev2/matrix/mxmain/dberror.go +++ b/bridgev2/matrix/mxmain/dberror.go @@ -66,12 +66,7 @@ func (br *BridgeMain) LogDBUpgradeErrorAndExit(name string, err error, message s } else if errors.Is(err, dbutil.ErrForeignTables) { br.Log.Info().Msg("See https://docs.mau.fi/faq/foreign-tables for more info") } else if errors.Is(err, dbutil.ErrNotOwned) { - var noe dbutil.NotOwnedError - if errors.As(err, &noe) && noe.Owner == br.Name { - br.Log.Info().Msg("The database appears to be on a very old pre-megabridge schema. Perhaps you need to run an older version of the bridge with migration support first?") - } else { - br.Log.Info().Msg("Sharing the same database with different programs is not supported") - } + br.Log.Info().Msg("Sharing the same database with different programs is not supported") } else if errors.Is(err, dbutil.ErrUnsupportedDatabaseVersion) { br.Log.Info().Msg("Downgrading the bridge is not supported") } diff --git a/bridgev2/matrix/mxmain/envconfig.go b/bridgev2/matrix/mxmain/envconfig.go deleted file mode 100644 index 1b4f1467..00000000 --- a/bridgev2/matrix/mxmain/envconfig.go +++ /dev/null @@ -1,161 +0,0 @@ -// Copyright (c) 2025 Tulir Asokan -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. - -package mxmain - -import ( - "fmt" - "iter" - "os" - "reflect" - "strconv" - "strings" - - "go.mau.fi/util/random" -) - -var randomParseFilePrefix = random.String(16) + "READFILE:" - -func parseEnv(prefix string) iter.Seq2[[]string, string] { - return func(yield func([]string, string) bool) { - for _, s := range os.Environ() { - if !strings.HasPrefix(s, prefix) { - continue - } - kv := strings.SplitN(s, "=", 2) - key := strings.TrimPrefix(kv[0], prefix) - value := kv[1] - if strings.HasSuffix(key, "_FILE") { - key = strings.TrimSuffix(key, "_FILE") - value = randomParseFilePrefix + value - } - key = strings.ToLower(key) - if !strings.ContainsRune(key, '.') { - key = strings.ReplaceAll(key, "__", ".") - } - if !yield(strings.Split(key, "."), value) { - return - } - } - } -} - -func reflectYAMLFieldName(f *reflect.StructField) string { - parts := strings.SplitN(f.Tag.Get("yaml"), ",", 2) - fieldName := parts[0] - if fieldName == "-" && len(parts) == 1 { - return "" - } - if fieldName == "" { - return strings.ToLower(f.Name) - } - return fieldName -} - -type reflectGetResult struct { - val reflect.Value - valKind reflect.Kind - remainingPath []string -} - -func reflectGetYAML(rv reflect.Value, path []string) (*reflectGetResult, bool) { - if len(path) == 0 { - return &reflectGetResult{val: rv, valKind: rv.Kind()}, true - } - if rv.Kind() == reflect.Ptr { - rv = rv.Elem() - } - switch rv.Kind() { - case reflect.Map: - return &reflectGetResult{val: rv, remainingPath: path, valKind: rv.Type().Elem().Kind()}, true - case reflect.Struct: - fields := reflect.VisibleFields(rv.Type()) - for _, field := range fields { - fieldName := reflectYAMLFieldName(&field) - if fieldName != "" && fieldName == path[0] { - return reflectGetYAML(rv.FieldByIndex(field.Index), path[1:]) - } - } - default: - } - return nil, false -} - -func reflectGetFromMainOrNetwork(main, network reflect.Value, path []string) (*reflectGetResult, bool) { - if len(path) > 0 && path[0] == "network" { - return reflectGetYAML(network, path[1:]) - } - return reflectGetYAML(main, path) -} - -func formatKeyString(key []string) string { - return strings.Join(key, "->") -} - -func UpdateConfigFromEnv(cfg, networkData any, prefix string) error { - cfgVal := reflect.ValueOf(cfg) - networkVal := reflect.ValueOf(networkData) - for key, value := range parseEnv(prefix) { - field, ok := reflectGetFromMainOrNetwork(cfgVal, networkVal, key) - if !ok { - return fmt.Errorf("%s not found", formatKeyString(key)) - } - if strings.HasPrefix(value, randomParseFilePrefix) { - filepath := strings.TrimPrefix(value, randomParseFilePrefix) - fileData, err := os.ReadFile(filepath) - if err != nil { - return fmt.Errorf("failed to read file %s for %s: %w", filepath, formatKeyString(key), err) - } - value = strings.TrimSpace(string(fileData)) - } - var parsedVal any - var err error - switch field.valKind { - case reflect.String: - parsedVal = value - case reflect.Bool: - parsedVal, err = strconv.ParseBool(value) - if err != nil { - return fmt.Errorf("invalid value for %s: %w", formatKeyString(key), err) - } - case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: - parsedVal, err = strconv.ParseInt(value, 10, 64) - if err != nil { - return fmt.Errorf("invalid value for %s: %w", formatKeyString(key), err) - } - case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: - parsedVal, err = strconv.ParseUint(value, 10, 64) - if err != nil { - return fmt.Errorf("invalid value for %s: %w", formatKeyString(key), err) - } - case reflect.Float32, reflect.Float64: - parsedVal, err = strconv.ParseFloat(value, 64) - if err != nil { - return fmt.Errorf("invalid value for %s: %w", formatKeyString(key), err) - } - default: - return fmt.Errorf("unsupported type %s in %s", field.valKind, formatKeyString(key)) - } - if field.val.Kind() == reflect.Ptr { - if field.val.IsNil() { - field.val.Set(reflect.New(field.val.Type().Elem())) - } - field.val = field.val.Elem() - } - if field.val.Kind() == reflect.Map { - key = key[:len(key)-len(field.remainingPath)] - mapKeyStr := strings.Join(field.remainingPath, ".") - key = append(key, mapKeyStr) - if field.val.Type().Key().Kind() != reflect.String { - return fmt.Errorf("unsupported map key type %s in %s", field.val.Type().Key().Kind(), formatKeyString(key)) - } - field.val.SetMapIndex(reflect.ValueOf(mapKeyStr), reflect.ValueOf(parsedVal)) - } else { - field.val.Set(reflect.ValueOf(parsedVal)) - } - } - return nil -} diff --git a/bridgev2/matrix/mxmain/example-config.yaml b/bridgev2/matrix/mxmain/example-config.yaml index ccc81c4b..48e0d528 100644 --- a/bridgev2/matrix/mxmain/example-config.yaml +++ b/bridgev2/matrix/mxmain/example-config.yaml @@ -15,7 +15,6 @@ bridge: # By default, users who are in the same group on the remote network will be # in the same Matrix room bridged to that group. If this is set to true, # every user will get their own Matrix room instead. - # SETTING THIS IS IRREVERSIBLE AND POTENTIALLY DESTRUCTIVE IF PORTALS ALREADY EXIST. split_portals: false # Should the bridge resend `m.bridge` events to all portals on startup? resend_bridge_info: false @@ -29,9 +28,6 @@ bridge: # How long after an unknown error should the bridge attempt a full reconnect? # Must be at least 1 minute. The bridge will add an extra ±20% jitter to this value. unknown_error_auto_reconnect: null - # Maximum number of times to do the auto-reconnect above. - # The counter is per login, but is never reset except on logout and restart. - unknown_error_max_auto_reconnects: 10 # Should leaving Matrix rooms be bridged as leaving groups on the remote network? bridge_matrix_leave: false @@ -50,11 +46,6 @@ bridge: # Should cross-room reply metadata be bridged? # Most Matrix clients don't support this and servers may reject such messages too. cross_room_replies: false - # If a state event fails to bridge, should the bridge revert any state changes made by that event? - revert_failed_state_changes: false - # In portals with no relay set, should Matrix users be kicked if they're - # not logged into an account that's in the remote chat? - kick_matrix_users: true # What should be done to portal rooms when a user logs out or is logged out? # Permitted values: @@ -244,9 +235,6 @@ matrix: # The threshold as bytes after which the bridge should roundtrip uploads via the disk # rather than keeping the whole file in memory. upload_file_threshold: 5242880 - # Should the bridge set additional custom profile info for ghosts? - # This can make a lot of requests, as there's no batch profile update endpoint. - ghost_extra_profile_info: false # Segment-compatible analytics endpoint for tracking some events, like provisioning API login and encryption errors. analytics: @@ -259,8 +247,10 @@ analytics: # Settings for provisioning API provisioning: + # Prefix for the provisioning API paths. + prefix: /_matrix/provision # Shared secret for authentication. If set to "generate" or null, a random secret will be generated, - # or if set to "disable", the provisioning API will be disabled. Must be at least 16 characters. + # or if set to "disable", the provisioning API will be disabled. shared_secret: generate # Whether to allow provisioning API requests to be authed using Matrix access tokens. # This follows the same rules as double puppeting to determine which server to contact to check the token, @@ -286,14 +276,6 @@ public_media: expiry: 0 # Length of hash to use for public media URLs. Must be between 0 and 32. hash_length: 32 - # The path prefix for generated URLs. Note that this will NOT change the path where media is actually served. - # If you change this, you must configure your reverse proxy to rewrite the path accordingly. - path_prefix: /_mautrix/publicmedia - # Should the bridge store media metadata in the database in order to support encrypted media and generate shorter URLs? - # If false, the generated URLs will just have the MXC URI and a HMAC signature. - # The hash_length field will be used to decide the length of the generated URL. - # This also allows invalidating URLs by deleting the database entry. - use_database: false # Settings for converting remote media to custom mxc:// URIs instead of reuploading. # More details can be found at https://docs.mau.fi/bridges/go/discord/direct-media.html @@ -384,12 +366,6 @@ encryption: # Only relevant when using end-to-bridge encryption, required when using encryption with next-gen auth (MSC3861). # Changing this option requires updating the appservice registration file. msc4190: false - # Whether to encrypt reactions and reply metadata as per MSC4392. - msc4392: false - # Should the bridge bot generate a recovery key and cross-signing keys and verify itself? - # Note that without the latest version of MSC4190, this will fail if you reset the bridge database. - # The generated recovery key will be saved in the kv_store table under `recovery_key`. - self_sign: false # Enable key sharing? If enabled, key requests for rooms where users are in will be fulfilled. # You must use a client that supports requesting keys from other users to use this feature. allow_key_sharing: true @@ -452,16 +428,6 @@ encryption: # You should not enable this option unless you understand all the implications. disable_device_change_key_rotation: false -# Prefix for environment variables. All variables with this prefix must map to valid config fields. -# Nesting in variable names is represented with a dot (.). -# If there are no dots in the name, two underscores (__) are replaced with a dot. -# -# e.g. if the prefix is set to `BRIDGE_`, then `BRIDGE_APPSERVICE__AS_TOKEN` will set appservice.as_token. -# `BRIDGE_appservice.as_token` would work as well, but can't be set in a shell as easily. -# -# If this is null, reading config fields from environment will be disabled. -env_config_prefix: null - # Logging config. See https://github.com/tulir/zeroconfig for details. logging: min_level: debug diff --git a/bridgev2/matrix/mxmain/legacymigrate.go b/bridgev2/matrix/mxmain/legacymigrate.go index 97cdeddf..c8eb820b 100644 --- a/bridgev2/matrix/mxmain/legacymigrate.go +++ b/bridgev2/matrix/mxmain/legacymigrate.go @@ -135,10 +135,7 @@ func (br *BridgeMain) CheckLegacyDB( } var dbVersion int err = br.DB.QueryRow(ctx, "SELECT version FROM version").Scan(&dbVersion) - if err != nil { - log.Fatal().Err(err).Msg("Failed to get database version") - return - } else if dbVersion < expectedVersion { + if dbVersion < expectedVersion { log.Fatal(). Int("expected_version", expectedVersion). Int("version", dbVersion). diff --git a/bridgev2/matrix/mxmain/main.go b/bridgev2/matrix/mxmain/main.go index 1e8b51d1..e6219c50 100644 --- a/bridgev2/matrix/mxmain/main.go +++ b/bridgev2/matrix/mxmain/main.go @@ -26,7 +26,6 @@ import ( "go.mau.fi/util/dbutil" "go.mau.fi/util/exerrors" "go.mau.fi/util/exzerolog" - "go.mau.fi/util/progver" "gopkg.in/yaml.v3" flag "maunium.net/go/mauflag" @@ -63,9 +62,6 @@ type BridgeMain struct { // git tag to see if the built version is the release or a dev build. // You can either bump this right after a release or right before, as long as it matches on the release commit. Version string - // SemCalVer defines whether this bridge uses a mix of semantic and calendar versioning, - // such that the Version field is YY.0M.patch, while git tags are major.YY0M.patch. - SemCalVer bool // PostInit is a function that will be called after the bridge has been initialized but before it is started. PostInit func() @@ -90,7 +86,11 @@ type BridgeMain struct { RegistrationPath string SaveConfig bool - ver progver.ProgramVersion + baseVersion string + commit string + LinkifiedVersion string + VersionDesc string + BuildTime time.Time AdditionalShortFlags string AdditionalLongFlags string @@ -99,7 +99,14 @@ type BridgeMain struct { } type VersionJSONOutput struct { - progver.ProgramVersion + Name string + URL string + + Version string + IsRelease bool + Commit string + FormattedVersion string + BuildTime time.Time OS string Arch string @@ -140,11 +147,18 @@ func (br *BridgeMain) PreInit() { flag.PrintHelp() os.Exit(0) } else if *version { - fmt.Println(br.ver.VersionDescription) + fmt.Println(br.VersionDesc) os.Exit(0) } else if *versionJSON { output := VersionJSONOutput{ - ProgramVersion: br.ver, + URL: br.URL, + Name: br.Name, + + Version: br.baseVersion, + IsRelease: br.Version == br.baseVersion, + Commit: br.commit, + FormattedVersion: br.Version, + BuildTime: br.BuildTime, OS: runtime.GOOS, Arch: runtime.GOARCH, @@ -226,8 +240,8 @@ func (br *BridgeMain) Init() { br.Log.Info(). Str("name", br.Name). - Str("version", br.ver.FormattedVersion). - Time("built_at", br.ver.BuildTime). + Str("version", br.Version). + Time("built_at", br.BuildTime). Str("go_version", runtime.Version()). Msg("Initializing bridge") @@ -241,7 +255,7 @@ func (br *BridgeMain) Init() { br.Matrix.AS.DoublePuppetValue = br.Name br.Bridge.Commands.(*commands.Processor).AddHandler(&commands.FullHandler{ Func: func(ce *commands.Event) { - ce.Reply(br.ver.MarkdownDescription()) + ce.Reply("[%s](%s) %s (%s)", br.Name, br.URL, br.LinkifiedVersion, br.BuildTime.Format(time.RFC1123)) }, Name: "version", Help: commands.HelpMeta{ @@ -354,13 +368,6 @@ func (br *BridgeMain) LoadConfig() { } } cfg.Bridge.Backfill = cfg.Backfill - if cfg.EnvConfigPrefix != "" { - err = UpdateConfigFromEnv(&cfg, networkData, cfg.EnvConfigPrefix) - if err != nil { - _, _ = fmt.Fprintln(os.Stderr, "Failed to parse environment variables:", err) - os.Exit(10) - } - } br.Config = &cfg } @@ -439,12 +446,42 @@ func (br *BridgeMain) Stop() { // // (to use both at the same time, simply merge the ldflags into one, `-ldflags "-X '...' -X ..."`) func (br *BridgeMain) InitVersion(tag, commit, rawBuildTime string) { - br.ver = progver.ProgramVersion{ - Name: br.Name, - URL: br.URL, - BaseVersion: br.Version, - SemCalVer: br.SemCalVer, - }.Init(tag, commit, rawBuildTime) - mautrix.DefaultUserAgent = fmt.Sprintf("%s/%s %s", br.Name, br.ver.FormattedVersion, mautrix.DefaultUserAgent) - br.Version = br.ver.FormattedVersion + br.baseVersion = br.Version + if len(tag) > 0 && tag[0] == 'v' { + tag = tag[1:] + } + if tag != br.Version { + suffix := "" + if !strings.HasSuffix(br.Version, "+dev") { + suffix = "+dev" + } + if len(commit) > 8 { + br.Version = fmt.Sprintf("%s%s.%s", br.Version, suffix, commit[:8]) + } else { + br.Version = fmt.Sprintf("%s%s.unknown", br.Version, suffix) + } + } + + br.LinkifiedVersion = fmt.Sprintf("v%s", br.Version) + if tag == br.Version { + br.LinkifiedVersion = fmt.Sprintf("[v%s](%s/releases/v%s)", br.Version, br.URL, tag) + } else if len(commit) > 8 { + br.LinkifiedVersion = strings.Replace(br.LinkifiedVersion, commit[:8], fmt.Sprintf("[%s](%s/commit/%s)", commit[:8], br.URL, commit), 1) + } + var buildTime time.Time + if rawBuildTime != "unknown" { + buildTime, _ = time.Parse(time.RFC3339, rawBuildTime) + } + var builtWith string + if buildTime.IsZero() { + rawBuildTime = "unknown" + builtWith = runtime.Version() + } else { + rawBuildTime = buildTime.Format(time.RFC1123) + builtWith = fmt.Sprintf("built at %s with %s", rawBuildTime, runtime.Version()) + } + mautrix.DefaultUserAgent = fmt.Sprintf("%s/%s %s", br.Name, br.Version, mautrix.DefaultUserAgent) + br.VersionDesc = fmt.Sprintf("%s %s (%s)", br.Name, br.Version, builtWith) + br.commit = commit + br.BuildTime = buildTime } diff --git a/bridgev2/matrix/provisioning.go b/bridgev2/matrix/provisioning.go index 243b91da..df3e1bdf 100644 --- a/bridgev2/matrix/provisioning.go +++ b/bridgev2/matrix/provisioning.go @@ -30,7 +30,6 @@ import ( "maunium.net/go/mautrix" "maunium.net/go/mautrix/bridgev2" "maunium.net/go/mautrix/bridgev2/networkid" - "maunium.net/go/mautrix/bridgev2/provisionutil" "maunium.net/go/mautrix/bridgev2/status" "maunium.net/go/mautrix/federation" "maunium.net/go/mautrix/id" @@ -85,9 +84,10 @@ const ( provisioningUserKey provisioningContextKey = iota provisioningUserLoginKey provisioningLoginProcessKey - ProvisioningKeyRequest ) +const ProvisioningKeyRequest = "fi.mau.provision.request" + func (prov *ProvisioningAPI) GetUser(r *http.Request) *bridgev2.User { return r.Context().Value(provisioningUserKey).(*bridgev2.User) } @@ -96,7 +96,12 @@ func (prov *ProvisioningAPI) GetRouter() *http.ServeMux { return prov.Router } -func (br *Connector) GetProvisioning() bridgev2.IProvisioningAPI { +type IProvisioningAPI interface { + GetRouter() *http.ServeMux + GetUser(r *http.Request) *bridgev2.User +} + +func (br *Connector) GetProvisioning() IProvisioningAPI { return br.Provisioning } @@ -114,7 +119,6 @@ func (prov *ProvisioningAPI) Init() { tp.Transport.TLSHandshakeTimeout = 10 * time.Second prov.Router = http.NewServeMux() prov.Router.HandleFunc("GET /v3/whoami", prov.GetWhoami) - prov.Router.HandleFunc("GET /v3/capabilities", prov.GetCapabilities) prov.Router.HandleFunc("GET /v3/login/flows", prov.GetLoginFlows) prov.Router.HandleFunc("POST /v3/login/start/{flowID}", prov.PostLoginStart) prov.Router.HandleFunc("POST /v3/login/step/{loginProcessID}/{stepID}/{stepType}", prov.PostLoginStep) @@ -124,7 +128,7 @@ func (prov *ProvisioningAPI) Init() { prov.Router.HandleFunc("POST /v3/search_users", prov.PostSearchUsers) prov.Router.HandleFunc("GET /v3/resolve_identifier/{identifier}", prov.GetResolveIdentifier) prov.Router.HandleFunc("POST /v3/create_dm/{identifier}", prov.PostCreateDM) - prov.Router.HandleFunc("POST /v3/create_group/{type}", prov.PostCreateGroup) + prov.Router.HandleFunc("POST /v3/create_group", prov.PostCreateGroup) if prov.br.Config.Provisioning.EnableSessionTransfers { prov.log.Debug().Msg("Enabling session transfer API") @@ -206,20 +210,12 @@ func (prov *ProvisioningAPI) checkFederatedMatrixAuth(ctx context.Context, userI } } -func disabledAuth(w http.ResponseWriter, r *http.Request) { - mautrix.MForbidden.WithMessage("Provisioning API is disabled").Write(w) -} - func (prov *ProvisioningAPI) DebugAuthMiddleware(h http.Handler) http.Handler { - secret := prov.br.Config.Provisioning.SharedSecret - if len(secret) < 16 { - return http.HandlerFunc(disabledAuth) - } return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { auth := strings.TrimPrefix(r.Header.Get("Authorization"), "Bearer ") if auth == "" { mautrix.MMissingToken.WithMessage("Missing auth token").Write(w) - } else if !exstrings.ConstantTimeEqual(auth, secret) { + } else if !exstrings.ConstantTimeEqual(auth, prov.br.Config.Provisioning.SharedSecret) { mautrix.MUnknownToken.WithMessage("Invalid auth token").Write(w) } else { h.ServeHTTP(w, r) @@ -228,10 +224,6 @@ func (prov *ProvisioningAPI) DebugAuthMiddleware(h http.Handler) http.Handler { } func (prov *ProvisioningAPI) AuthMiddleware(h http.Handler) http.Handler { - secret := prov.br.Config.Provisioning.SharedSecret - if len(secret) < 16 { - return http.HandlerFunc(disabledAuth) - } return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { auth := strings.TrimPrefix(r.Header.Get("Authorization"), "Bearer ") if auth == "" && prov.GetAuthFromRequest != nil { @@ -245,7 +237,7 @@ func (prov *ProvisioningAPI) AuthMiddleware(h http.Handler) http.Handler { if userID == "" && prov.GetUserIDFromRequest != nil { userID = prov.GetUserIDFromRequest(r) } - if !exstrings.ConstantTimeEqual(auth, secret) { + if !exstrings.ConstantTimeEqual(auth, prov.br.Config.Provisioning.SharedSecret) { var err error if strings.HasPrefix(auth, "openid:") { err = prov.checkFederatedMatrixAuth(r.Context(), userID, strings.TrimPrefix(auth, "openid:")) @@ -324,7 +316,7 @@ func (prov *ProvisioningAPI) GetWhoami(w http.ResponseWriter, r *http.Request) { prevState.UserID = "" prevState.RemoteID = "" prevState.RemoteName = "" - prevState.RemoteProfile = status.RemoteProfile{} + prevState.RemoteProfile = nil resp.Logins[i] = RespWhoamiLogin{ StateEvent: prevState.StateEvent, StateTS: prevState.Timestamp, @@ -356,24 +348,18 @@ func (prov *ProvisioningAPI) GetLoginFlows(w http.ResponseWriter, r *http.Reques }) } -func (prov *ProvisioningAPI) GetCapabilities(w http.ResponseWriter, r *http.Request) { - exhttp.WriteJSONResponse(w, http.StatusOK, &prov.net.GetCapabilities().Provisioning) -} - var ErrNilStep = errors.New("bridge returned nil step with no error") -var ErrTooManyLogins = bridgev2.RespError{ErrCode: "FI.MAU.BRIDGE.TOO_MANY_LOGINS", Err: "Maximum number of logins exceeded"} func (prov *ProvisioningAPI) PostLoginStart(w http.ResponseWriter, r *http.Request) { overrideLogin, failed := prov.GetExplicitLoginForRequest(w, r) if failed { return } - user := prov.GetUser(r) - if overrideLogin == nil && user.HasTooManyLogins() { - ErrTooManyLogins.AppendMessage(" (%d)", user.Permissions.MaxLogins).Write(w) - return - } - login, err := prov.net.CreateLogin(r.Context(), user, r.PathValue("flowID")) + login, err := prov.net.CreateLogin( + r.Context(), + prov.GetUser(r), + r.PathValue("flowID"), + ) if err != nil { zerolog.Ctx(r.Context()).Err(err).Msg("Failed to create login process") RespondWithError(w, err, "Internal error creating login process") @@ -403,18 +389,10 @@ func (prov *ProvisioningAPI) PostLoginStart(w http.ResponseWriter, r *http.Reque Override: overrideLogin, } prov.loginsLock.Unlock() - zerolog.Ctx(r.Context()).Info(). - Any("first_step", firstStep). - Msg("Created login process") exhttp.WriteJSONResponse(w, http.StatusOK, &RespSubmitLogin{LoginID: loginID, LoginStep: firstStep}) } func (prov *ProvisioningAPI) handleCompleteStep(ctx context.Context, login *ProvLogin, step *bridgev2.LoginStep) { - zerolog.Ctx(ctx).Info(). - Str("step_id", step.StepID). - Str("user_login_id", string(step.CompleteParams.UserLoginID)). - Msg("Login completed successfully") - prov.deleteLogin(login, false) if login.Override == nil || login.Override.ID == step.CompleteParams.UserLoginID { return } @@ -428,15 +406,6 @@ func (prov *ProvisioningAPI) handleCompleteStep(ctx context.Context, login *Prov }, bridgev2.DeleteOpts{LogoutRemote: true}) } -func (prov *ProvisioningAPI) deleteLogin(login *ProvLogin, cancel bool) { - if cancel { - login.Process.Cancel() - } - prov.loginsLock.Lock() - delete(prov.logins, login.ID) - prov.loginsLock.Unlock() -} - func (prov *ProvisioningAPI) PostLoginStep(w http.ResponseWriter, r *http.Request) { loginID := r.PathValue("loginProcessID") prov.loginsLock.RLock() @@ -507,14 +476,11 @@ func (prov *ProvisioningAPI) PostLoginSubmitInput(w http.ResponseWriter, r *http if err != nil { zerolog.Ctx(r.Context()).Err(err).Msg("Failed to submit input") RespondWithError(w, err, "Internal error submitting input") - prov.deleteLogin(login, true) return } login.NextStep = nextStep if nextStep.Type == bridgev2.LoginStepTypeComplete { prov.handleCompleteStep(r.Context(), login, nextStep) - } else { - zerolog.Ctx(r.Context()).Debug().Any("next_step", nextStep).Msg("Returning next login step") } exhttp.WriteJSONResponse(w, http.StatusOK, &RespSubmitLogin{LoginID: login.ID, LoginStep: nextStep}) } @@ -528,14 +494,11 @@ func (prov *ProvisioningAPI) PostLoginWait(w http.ResponseWriter, r *http.Reques if err != nil { zerolog.Ctx(r.Context()).Err(err).Msg("Failed to wait") RespondWithError(w, err, "Internal error waiting for login") - prov.deleteLogin(login, true) return } login.NextStep = nextStep if nextStep.Type == bridgev2.LoginStepTypeComplete { prov.handleCompleteStep(r.Context(), login, nextStep) - } else { - zerolog.Ctx(r.Context()).Debug().Any("next_step", nextStep).Msg("Returning next login step") } exhttp.WriteJSONResponse(w, http.StatusOK, &RespSubmitLogin{LoginID: login.ID, LoginStep: nextStep}) } @@ -619,23 +582,115 @@ func RespondWithError(w http.ResponseWriter, err error, message string) { } } +type RespResolveIdentifier struct { + ID networkid.UserID `json:"id"` + Name string `json:"name,omitempty"` + AvatarURL id.ContentURIString `json:"avatar_url,omitempty"` + Identifiers []string `json:"identifiers,omitempty"` + MXID id.UserID `json:"mxid,omitempty"` + DMRoomID id.RoomID `json:"dm_room_mxid,omitempty"` +} + func (prov *ProvisioningAPI) doResolveIdentifier(w http.ResponseWriter, r *http.Request, createChat bool) { login := prov.GetLoginForRequest(w, r) if login == nil { return } - resp, err := provisionutil.ResolveIdentifier(r.Context(), login, r.PathValue("identifier"), createChat) + api, ok := login.Client.(bridgev2.IdentifierResolvingNetworkAPI) + if !ok { + mautrix.MUnrecognized.WithMessage("This bridge does not support resolving identifiers").Write(w) + return + } + resp, err := api.ResolveIdentifier(r.Context(), r.PathValue("identifier"), createChat) if err != nil { + zerolog.Ctx(r.Context()).Err(err).Msg("Failed to resolve identifier") RespondWithError(w, err, "Internal error resolving identifier") + return } else if resp == nil { mautrix.MNotFound.WithMessage("Identifier not found").Write(w) - } else { - status := http.StatusOK - if resp.JustCreated { - status = http.StatusCreated - } - exhttp.WriteJSONResponse(w, status, resp) + return } + apiResp := &RespResolveIdentifier{ + ID: resp.UserID, + } + status := http.StatusOK + if resp.Ghost != nil { + if resp.UserInfo != nil { + resp.Ghost.UpdateInfo(r.Context(), resp.UserInfo) + } + apiResp.Name = resp.Ghost.Name + apiResp.AvatarURL = resp.Ghost.AvatarMXC + apiResp.Identifiers = resp.Ghost.Identifiers + apiResp.MXID = resp.Ghost.Intent.GetMXID() + } else if resp.UserInfo != nil && resp.UserInfo.Name != nil { + apiResp.Name = *resp.UserInfo.Name + } + if resp.Chat != nil { + if resp.Chat.Portal == nil { + resp.Chat.Portal, err = prov.br.Bridge.GetPortalByKey(r.Context(), resp.Chat.PortalKey) + if err != nil { + zerolog.Ctx(r.Context()).Err(err).Msg("Failed to get portal") + mautrix.MUnknown.WithMessage("Failed to get portal").Write(w) + return + } + } + if createChat && resp.Chat.Portal.MXID == "" { + status = http.StatusCreated + err = resp.Chat.Portal.CreateMatrixRoom(r.Context(), login, resp.Chat.PortalInfo) + if err != nil { + zerolog.Ctx(r.Context()).Err(err).Msg("Failed to create portal room") + mautrix.MUnknown.WithMessage("Failed to create portal room").Write(w) + return + } + } + apiResp.DMRoomID = resp.Chat.Portal.MXID + } + exhttp.WriteJSONResponse(w, status, apiResp) +} + +type RespGetContactList struct { + Contacts []*RespResolveIdentifier `json:"contacts"` +} + +func (prov *ProvisioningAPI) processResolveIdentifiers(ctx context.Context, resp []*bridgev2.ResolveIdentifierResponse) (apiResp []*RespResolveIdentifier) { + apiResp = make([]*RespResolveIdentifier, len(resp)) + for i, contact := range resp { + apiContact := &RespResolveIdentifier{ + ID: contact.UserID, + } + apiResp[i] = apiContact + if contact.UserInfo != nil { + if contact.UserInfo.Name != nil { + apiContact.Name = *contact.UserInfo.Name + } + if contact.UserInfo.Identifiers != nil { + apiContact.Identifiers = contact.UserInfo.Identifiers + } + } + if contact.Ghost != nil { + if contact.Ghost.Name != "" { + apiContact.Name = contact.Ghost.Name + } + if len(contact.Ghost.Identifiers) >= len(apiContact.Identifiers) { + apiContact.Identifiers = contact.Ghost.Identifiers + } + apiContact.AvatarURL = contact.Ghost.AvatarMXC + apiContact.MXID = contact.Ghost.Intent.GetMXID() + } + if contact.Chat != nil { + if contact.Chat.Portal == nil { + var err error + contact.Chat.Portal, err = prov.br.Bridge.GetPortalByKey(ctx, contact.Chat.PortalKey) + if err != nil { + zerolog.Ctx(ctx).Err(err).Msg("Failed to get portal") + } + } + if contact.Chat.Portal != nil { + apiContact.DMRoomID = contact.Chat.Portal.MXID + } + } + } + return } func (prov *ProvisioningAPI) GetContactList(w http.ResponseWriter, r *http.Request) { @@ -643,18 +698,30 @@ func (prov *ProvisioningAPI) GetContactList(w http.ResponseWriter, r *http.Reque if login == nil { return } - resp, err := provisionutil.GetContactList(r.Context(), login) - if err != nil { - RespondWithError(w, err, "Internal error getting contact list") + api, ok := login.Client.(bridgev2.ContactListingNetworkAPI) + if !ok { + mautrix.MUnrecognized.WithMessage("This bridge does not support listing contacts").Write(w) return } - exhttp.WriteJSONResponse(w, http.StatusOK, resp) + resp, err := api.GetContactList(r.Context()) + if err != nil { + zerolog.Ctx(r.Context()).Err(err).Msg("Failed to get contact list") + RespondWithError(w, err, "Internal error fetching contact list") + return + } + exhttp.WriteJSONResponse(w, http.StatusOK, &RespGetContactList{ + Contacts: prov.processResolveIdentifiers(r.Context(), resp), + }) } type ReqSearchUsers struct { Query string `json:"query"` } +type RespSearchUsers struct { + Results []*RespResolveIdentifier `json:"results"` +} + func (prov *ProvisioningAPI) PostSearchUsers(w http.ResponseWriter, r *http.Request) { var req ReqSearchUsers err := json.NewDecoder(r.Body).Decode(&req) @@ -667,12 +734,20 @@ func (prov *ProvisioningAPI) PostSearchUsers(w http.ResponseWriter, r *http.Requ if login == nil { return } - resp, err := provisionutil.SearchUsers(r.Context(), login, req.Query) - if err != nil { - RespondWithError(w, err, "Internal error searching users") + api, ok := login.Client.(bridgev2.UserSearchingNetworkAPI) + if !ok { + mautrix.MUnrecognized.WithMessage("This bridge does not support searching for users").Write(w) return } - exhttp.WriteJSONResponse(w, http.StatusOK, resp) + resp, err := api.SearchUsers(r.Context(), req.Query) + if err != nil { + zerolog.Ctx(r.Context()).Err(err).Msg("Failed to get contact list") + RespondWithError(w, err, "Internal error fetching contact list") + return + } + exhttp.WriteJSONResponse(w, http.StatusOK, &RespSearchUsers{ + Results: prov.processResolveIdentifiers(r.Context(), resp), + }) } func (prov *ProvisioningAPI) GetResolveIdentifier(w http.ResponseWriter, r *http.Request) { @@ -684,24 +759,11 @@ func (prov *ProvisioningAPI) PostCreateDM(w http.ResponseWriter, r *http.Request } func (prov *ProvisioningAPI) PostCreateGroup(w http.ResponseWriter, r *http.Request) { - var req bridgev2.GroupCreateParams - err := json.NewDecoder(r.Body).Decode(&req) - if err != nil { - zerolog.Ctx(r.Context()).Err(err).Msg("Failed to decode request body") - mautrix.MNotJSON.WithMessage("Failed to decode request body").Write(w) - return - } - req.Type = r.PathValue("type") login := prov.GetLoginForRequest(w, r) if login == nil { return } - resp, err := provisionutil.CreateGroup(r.Context(), login, &req) - if err != nil { - RespondWithError(w, err, "Internal error creating group") - return - } - exhttp.WriteJSONResponse(w, http.StatusOK, resp) + mautrix.MUnrecognized.WithMessage("Creating groups is not yet implemented").Write(w) } type ReqExportCredentials struct { diff --git a/bridgev2/matrix/provisioning.yaml b/bridgev2/matrix/provisioning.yaml index 26068db4..b9879ea5 100644 --- a/bridgev2/matrix/provisioning.yaml +++ b/bridgev2/matrix/provisioning.yaml @@ -361,25 +361,14 @@ paths: $ref: '#/components/responses/InternalError' 501: $ref: '#/components/responses/NotSupported' - /v3/create_group/{type}: + /v3/create_group: post: tags: [ snc ] summary: Create a group chat on the remote network. operationId: createGroup parameters: - $ref: "#/components/parameters/loginID" - requestBody: - content: - application/json: - schema: - $ref: '#/components/schemas/GroupCreateParams' responses: - 200: - description: Identifier resolved successfully - content: - application/json: - schema: - $ref: '#/components/schemas/CreatedGroup' 401: $ref: '#/components/responses/Unauthorized' 404: @@ -400,7 +389,7 @@ components: - username - meow@example.com loginID: - name: login_id + name: loginID in: query description: An optional explicit login ID to do the action through. required: false @@ -583,74 +572,6 @@ components: description: The Matrix room ID of the direct chat with the user. examples: - '!OKhS0I5q2fCzdnl2qgeozDQw:t2bot.io' - GroupCreateParams: - type: object - description: | - Parameters for creating a group chat. - The /capabilities endpoint response must be checked to see which fields are actually allowed. - properties: - type: - type: string - description: The type of group to create. - examples: - - channel - username: - type: string - description: The public username for the created group. - participants: - type: array - description: The users to add to the group initially. - items: - type: string - parent: - type: object - name: - type: object - description: The `m.room.name` event content for the room. - properties: - name: - type: string - avatar: - type: object - description: The `m.room.avatar` event content for the room. - properties: - url: - type: string - format: mxc - topic: - type: object - description: The `m.room.topic` event content for the room. - properties: - topic: - type: string - disappear: - type: object - description: The `com.beeper.disappearing_timer` event content for the room. - properties: - type: - type: string - timer: - type: number - room_id: - type: string - format: matrix_room_id - description: | - An existing Matrix room ID to bridge to. - The other parameters must be already in sync with the room state when using this parameter. - CreatedGroup: - type: object - description: A successfully created group chat. - required: [id, mxid] - properties: - id: - type: string - description: The internal chat ID of the created group. - mxid: - type: string - format: matrix_room_id - description: The Matrix room ID of the portal. - examples: - - '!OKhS0I5q2fCzdnl2qgeozDQw:t2bot.io' LoginStep: type: object description: A step in a login process. @@ -714,7 +635,7 @@ components: type: type: string description: The type of field. - enum: [ username, phone_number, email, password, 2fa_code, token, url, domain, select ] + enum: [ username, phone_number, email, password, 2fa_code, token, url, domain ] id: type: string description: The internal ID of the field. This must be used as the key in the object when submitting the data back to the bridge. @@ -728,53 +649,10 @@ components: description: A more detailed description of the field shown to the user. examples: - Include the country code with a + - default_value: - type: string - description: A default value that the client can pre-fill the field with. pattern: type: string format: regex description: A regular expression that the field value must match. - options: - type: array - description: For fields of type select, the valid options. - items: - type: string - attachments: - type: array - description: A list of media attachments to show the user alongside the form fields. - items: - type: object - description: A media attachment to show the user. - required: [ type, filename, content ] - properties: - type: - type: string - description: The type of media attachment, using the same media type identifiers as Matrix attachments. Only some are supported. - enum: [ m.image, m.audio ] - filename: - type: string - description: The filename for the media attachment. - content: - type: string - description: The raw file content for the attachment encoded in base64. - info: - type: object - description: Optional but recommended metadata for the attachment. Can generally be derived from the raw content if omitted. - properties: - mimetype: - type: string - description: The MIME type for the media content. - examples: [ image/png, audio/mpeg ] - w: - type: number - description: The width of the media in pixels. Only applicable for images and videos. - h: - type: number - description: The height of the media in pixels. Only applicable for images and videos. - size: - type: number - description: The size of the media content in number of bytes. Strongly recommended to include. - description: Cookie login step required: [ type, cookies ] properties: diff --git a/bridgev2/matrix/publicmedia.go b/bridgev2/matrix/publicmedia.go index 82ea8c2b..95e37262 100644 --- a/bridgev2/matrix/publicmedia.go +++ b/bridgev2/matrix/publicmedia.go @@ -7,26 +7,16 @@ package matrix import ( - "context" "crypto/hmac" "crypto/sha256" "encoding/base64" "encoding/binary" "fmt" "io" - "mime" "net/http" - "net/url" - "slices" - "strings" "time" - "github.com/rs/zerolog" - "maunium.net/go/mautrix/bridgev2" - "maunium.net/go/mautrix/bridgev2/database" - "maunium.net/go/mautrix/crypto/attachment" - "maunium.net/go/mautrix/event" "maunium.net/go/mautrix/id" ) @@ -43,10 +33,7 @@ func (br *Connector) initPublicMedia() error { return fmt.Errorf("public media hash length is negative") } br.pubMediaSigKey = []byte(br.Config.PublicMedia.SigningKey) - br.AS.Router.HandleFunc("GET /_mautrix/publicmedia/{customID}", br.serveDatabasePublicMedia) - br.AS.Router.HandleFunc("GET /_mautrix/publicmedia/{customID}/{filename}", br.serveDatabasePublicMedia) br.AS.Router.HandleFunc("GET /_mautrix/publicmedia/{server}/{mediaID}/{checksum}", br.servePublicMedia) - br.AS.Router.HandleFunc("GET /_mautrix/publicmedia/{server}/{mediaID}/{checksum}/{filename}", br.servePublicMedia) return nil } @@ -57,20 +44,6 @@ func (br *Connector) hashContentURI(uri id.ContentURI, expiry []byte) []byte { return hasher.Sum(expiry)[:br.Config.PublicMedia.HashLength+len(expiry)] } -func (br *Connector) hashDBPublicMedia(pm *database.PublicMedia) []byte { - hasher := hmac.New(sha256.New, br.pubMediaSigKey) - hasher.Write([]byte(pm.MXC.String())) - hasher.Write([]byte(pm.MimeType)) - if pm.Keys != nil { - hasher.Write([]byte(pm.Keys.Version)) - hasher.Write([]byte(pm.Keys.Key.Algorithm)) - hasher.Write([]byte(pm.Keys.Key.Key)) - hasher.Write([]byte(pm.Keys.InitVector)) - hasher.Write([]byte(pm.Keys.Hashes.SHA256)) - } - return hasher.Sum(nil)[:br.Config.PublicMedia.HashLength] -} - func (br *Connector) makePublicMediaChecksum(uri id.ContentURI) []byte { var expiresAt []byte if br.Config.PublicMedia.Expiry > 0 { @@ -120,47 +93,9 @@ func (br *Connector) servePublicMedia(w http.ResponseWriter, r *http.Request) { http.Error(w, "checksum expired", http.StatusGone) return } - br.doProxyMedia(w, r, contentURI, nil, "") -} - -func (br *Connector) serveDatabasePublicMedia(w http.ResponseWriter, r *http.Request) { - if !br.Config.PublicMedia.UseDatabase { - http.Error(w, "public media short links are disabled", http.StatusNotFound) - return - } - log := zerolog.Ctx(r.Context()) - media, err := br.Bridge.DB.PublicMedia.Get(r.Context(), r.PathValue("customID")) - if err != nil { - log.Err(err).Msg("Failed to get public media from database") - http.Error(w, "failed to get media metadata", http.StatusInternalServerError) - return - } else if media == nil { - http.Error(w, "media ID not found", http.StatusNotFound) - return - } else if !media.Expiry.IsZero() && media.Expiry.Before(time.Now()) { - // This is not gone as it can still be refreshed in the DB - http.Error(w, "media expired", http.StatusNotFound) - return - } else if media.Keys != nil && media.Keys.PrepareForDecryption() != nil { - http.Error(w, "media keys are malformed", http.StatusInternalServerError) - return - } - br.doProxyMedia(w, r, media.MXC, media.Keys, media.MimeType) -} - -var safeMimes = []string{ - "text/css", "text/plain", "text/csv", - "application/json", "application/ld+json", - "image/jpeg", "image/gif", "image/png", "image/apng", "image/webp", "image/avif", - "video/mp4", "video/webm", "video/ogg", "video/quicktime", - "audio/mp4", "audio/webm", "audio/aac", "audio/mpeg", "audio/ogg", "audio/wave", - "audio/wav", "audio/x-wav", "audio/x-pn-wav", "audio/flac", "audio/x-flac", -} - -func (br *Connector) doProxyMedia(w http.ResponseWriter, r *http.Request, contentURI id.ContentURI, encInfo *attachment.EncryptedFile, mimeType string) { resp, err := br.Bot.Download(r.Context(), contentURI) if err != nil { - zerolog.Ctx(r.Context()).Warn().Stringer("uri", contentURI).Err(err).Msg("Failed to download media to proxy") + br.Log.Warn().Stringer("uri", contentURI).Err(err).Msg("Failed to download media to proxy") http.Error(w, "failed to download media", http.StatusInternalServerError) return } @@ -168,41 +103,11 @@ func (br *Connector) doProxyMedia(w http.ResponseWriter, r *http.Request, conten for _, hdr := range proxyHeadersToCopy { w.Header()[hdr] = resp.Header[hdr] } - stream := resp.Body - if encInfo != nil { - if mimeType == "" { - mimeType = "application/octet-stream" - } - contentDisposition := "attachment" - if slices.Contains(safeMimes, mimeType) { - contentDisposition = "inline" - } - dispositionArgs := map[string]string{} - if filename := r.PathValue("filename"); filename != "" { - dispositionArgs["filename"] = filename - } - w.Header().Set("Content-Type", mimeType) - w.Header().Set("Content-Disposition", mime.FormatMediaType(contentDisposition, dispositionArgs)) - // Note: this won't check the Close result like it should, but it's probably not a big deal here - stream = encInfo.DecryptStream(stream) - } else if filename := r.PathValue("filename"); filename != "" { - contentDisposition, _, _ := mime.ParseMediaType(resp.Header.Get("Content-Disposition")) - if contentDisposition == "" { - contentDisposition = "attachment" - } - w.Header().Set("Content-Disposition", mime.FormatMediaType(contentDisposition, map[string]string{ - "filename": filename, - })) - } w.WriteHeader(http.StatusOK) - _, _ = io.Copy(w, stream) + _, _ = io.Copy(w, resp.Body) } func (br *Connector) GetPublicMediaAddress(contentURI id.ContentURIString) string { - return br.getPublicMediaAddressWithFileName(contentURI, "") -} - -func (br *Connector) getPublicMediaAddressWithFileName(contentURI id.ContentURIString, fileName string) string { if br.pubMediaSigKey == nil { return "" } @@ -210,69 +115,11 @@ func (br *Connector) getPublicMediaAddressWithFileName(contentURI id.ContentURIS if err != nil || !parsed.IsValid() { return "" } - fileName = url.PathEscape(strings.ReplaceAll(fileName, "/", "_")) - if fileName == ".." { - fileName = "" - } - parts := []string{ + return fmt.Sprintf( + "%s/_mautrix/publicmedia/%s/%s/%s", br.GetPublicAddress(), - strings.Trim(br.Config.PublicMedia.PathPrefix, "/"), parsed.Homeserver, parsed.FileID, base64.RawURLEncoding.EncodeToString(br.makePublicMediaChecksum(parsed)), - fileName, - } - if fileName == "" { - parts = parts[:len(parts)-1] - } - return strings.Join(parts, "/") -} - -func (br *Connector) GetPublicMediaAddressForEvent(ctx context.Context, evt *event.MessageEventContent) (string, error) { - if br.pubMediaSigKey == nil { - return "", bridgev2.ErrPublicMediaDisabled - } - if !br.Config.PublicMedia.UseDatabase { - if evt.File != nil { - return "", fmt.Errorf("can't generate address for encrypted file: %w", bridgev2.ErrPublicMediaDatabaseDisabled) - } - return br.getPublicMediaAddressWithFileName(evt.URL, evt.GetFileName()), nil - } - mxc := evt.URL - var keys *attachment.EncryptedFile - if evt.File != nil { - mxc = evt.File.URL - keys = &evt.File.EncryptedFile - } - parsedMXC, err := mxc.Parse() - if err != nil { - return "", fmt.Errorf("%w: failed to parse MXC: %w", bridgev2.ErrPublicMediaGenerateFailed, err) - } - pm := &database.PublicMedia{ - MXC: parsedMXC, - Keys: keys, - MimeType: evt.GetInfo().MimeType, - } - if br.Config.PublicMedia.Expiry > 0 { - pm.Expiry = time.Now().Add(time.Duration(br.Config.PublicMedia.Expiry) * time.Second) - } - pm.PublicID = base64.RawURLEncoding.EncodeToString(br.hashDBPublicMedia(pm)) - err = br.Bridge.DB.PublicMedia.Put(ctx, pm) - if err != nil { - return "", fmt.Errorf("%w: failed to store public media in database: %w", bridgev2.ErrPublicMediaGenerateFailed, err) - } - fileName := url.PathEscape(strings.ReplaceAll(evt.GetFileName(), "/", "_")) - if fileName == ".." { - fileName = "" - } - parts := []string{ - br.GetPublicAddress(), - strings.Trim(br.Config.PublicMedia.PathPrefix, "/"), - pm.PublicID, - fileName, - } - if fileName == "" { - parts = parts[:len(parts)-1] - } - return strings.Join(parts, "/"), nil + ) } diff --git a/bridgev2/matrixinterface.go b/bridgev2/matrixinterface.go index be26db49..b30e274a 100644 --- a/bridgev2/matrixinterface.go +++ b/bridgev2/matrixinterface.go @@ -14,8 +14,6 @@ import ( "os" "time" - "go.mau.fi/util/exhttp" - "maunium.net/go/mautrix" "maunium.net/go/mautrix/bridgev2/database" "maunium.net/go/mautrix/bridgev2/networkid" @@ -25,10 +23,8 @@ import ( ) type MatrixCapabilities struct { - AutoJoinInvites bool - BatchSending bool - ArbitraryMemberChange bool - ExtraProfileMeta bool + AutoJoinInvites bool + BatchSending bool } type MatrixConnector interface { @@ -62,54 +58,35 @@ type MatrixConnector interface { } type MatrixConnectorWithArbitraryRoomState interface { - MatrixConnector GetStateEvent(ctx context.Context, roomID id.RoomID, eventType event.Type, stateKey string) (*event.Event, error) } type MatrixConnectorWithServer interface { - MatrixConnector GetPublicAddress() string GetRouter() *http.ServeMux } -type IProvisioningAPI interface { - GetRouter() *http.ServeMux - GetUser(r *http.Request) *User -} - -type MatrixConnectorWithProvisioning interface { - MatrixConnector - GetProvisioning() IProvisioningAPI -} - type MatrixConnectorWithPublicMedia interface { - MatrixConnector GetPublicMediaAddress(contentURI id.ContentURIString) string - GetPublicMediaAddressForEvent(ctx context.Context, evt *event.MessageEventContent) (string, error) } type MatrixConnectorWithNameDisambiguation interface { - MatrixConnector IsConfusableName(ctx context.Context, roomID id.RoomID, userID id.UserID, name string) ([]id.UserID, error) } type MatrixConnectorWithBridgeIdentifier interface { - MatrixConnector GetUniqueBridgeID() string } type MatrixConnectorWithURLPreviews interface { - MatrixConnector GetURLPreview(ctx context.Context, url string) (*event.LinkPreview, error) } type MatrixConnectorWithPostRoomBridgeHandling interface { - MatrixConnector HandleNewlyBridgedRoom(ctx context.Context, roomID id.RoomID) error } type MatrixConnectorWithAnalytics interface { - MatrixConnector TrackAnalytics(userID id.UserID, event string, properties map[string]any) } @@ -124,15 +101,9 @@ type DirectNotificationData struct { } type MatrixConnectorWithNotifications interface { - MatrixConnector DisplayNotification(ctx context.Context, data *DirectNotificationData) } -type MatrixConnectorWithHTTPSettings interface { - MatrixConnector - GetHTTPClientSettings() exhttp.ClientSettings -} - type MatrixSendExtra struct { Timestamp time.Time MessageMeta *database.Message @@ -205,21 +176,12 @@ type MatrixAPI interface { TagRoom(ctx context.Context, roomID id.RoomID, tag event.RoomTag, isTagged bool) error MuteRoom(ctx context.Context, roomID id.RoomID, until time.Time) error - - GetEvent(ctx context.Context, roomID id.RoomID, eventID id.EventID) (*event.Event, error) } type StreamOrderReadingMatrixAPI interface { - MatrixAPI MarkStreamOrderRead(ctx context.Context, roomID id.RoomID, streamOrder int64, ts time.Time) error } type MarkAsDMMatrixAPI interface { - MatrixAPI MarkAsDM(ctx context.Context, roomID id.RoomID, otherUser id.UserID) error } - -type EphemeralSendingMatrixAPI interface { - MatrixAPI - BeeperSendEphemeralEvent(ctx context.Context, roomID id.RoomID, eventType event.Type, content *event.Content, txnID string) (*mautrix.RespSendEvent, error) -} diff --git a/bridgev2/matrixinvite.go b/bridgev2/matrixinvite.go index 75c00cb0..bfbabd26 100644 --- a/bridgev2/matrixinvite.go +++ b/bridgev2/matrixinvite.go @@ -88,36 +88,6 @@ func sendErrorAndLeave(ctx context.Context, evt *event.Event, intent MatrixAPI, rejectInvite(ctx, evt, intent, "") } -func (portal *Portal) CleanupOrphanedDM(ctx context.Context, userMXID id.UserID) { - if portal.MXID == "" { - return - } - log := zerolog.Ctx(ctx) - existingPortalMembers, err := portal.Bridge.Matrix.GetMembers(ctx, portal.MXID) - if err != nil { - log.Err(err). - Stringer("old_portal_mxid", portal.MXID). - Msg("Failed to check existing portal members, deleting room") - } else if targetUserMember, ok := existingPortalMembers[userMXID]; !ok { - log.Debug(). - Stringer("old_portal_mxid", portal.MXID). - Msg("Inviter has no member event in old portal, deleting room") - } else if targetUserMember.Membership.IsInviteOrJoin() { - return - } else { - log.Debug(). - Stringer("old_portal_mxid", portal.MXID). - Str("membership", string(targetUserMember.Membership)). - Msg("Inviter is not in old portal, deleting room") - } - - if err = portal.RemoveMXID(ctx); err != nil { - log.Err(err).Msg("Failed to delete old portal mxid") - } else if err = portal.Bridge.Bot.DeleteRoom(ctx, portal.MXID, true); err != nil { - log.Err(err).Msg("Failed to clean up old portal room") - } -} - func (br *Bridge) handleGhostDMInvite(ctx context.Context, evt *event.Event, sender *User) EventHandlingResult { ghostID, _ := br.Matrix.ParseGhostMXID(id.UserID(evt.GetStateKey())) validator, ok := br.Network.(IdentifierValidatingNetwork) @@ -195,7 +165,34 @@ func (br *Bridge) handleGhostDMInvite(ctx context.Context, evt *event.Event, sen return EventHandlingResultFailed } } - portal.CleanupOrphanedDM(ctx, sender.MXID) + if portal.MXID != "" { + doCleanup := true + existingPortalMembers, err := br.Matrix.GetMembers(ctx, portal.MXID) + if err != nil { + log.Err(err). + Stringer("old_portal_mxid", portal.MXID). + Msg("Failed to check existing portal members, deleting room") + } else if targetUserMember, ok := existingPortalMembers[sender.MXID]; !ok { + log.Debug(). + Stringer("old_portal_mxid", portal.MXID). + Msg("Inviter has no member event in old portal, deleting room") + } else if targetUserMember.Membership.IsInviteOrJoin() { + doCleanup = false + } else { + log.Debug(). + Stringer("old_portal_mxid", portal.MXID). + Str("membership", string(targetUserMember.Membership)). + Msg("Inviter is not in old portal, deleting room") + } + + if doCleanup { + if err = portal.RemoveMXID(ctx); err != nil { + log.Err(err).Msg("Failed to delete old portal mxid") + } else if err = br.Bot.DeleteRoom(ctx, portal.MXID, true); err != nil { + log.Err(err).Msg("Failed to clean up old portal room") + } + } + } err = invitedGhost.Intent.EnsureInvited(ctx, evt.RoomID, br.Bot.GetMXID()) if err != nil { log.Err(err).Msg("Failed to ensure bot is invited to room") @@ -209,67 +206,72 @@ func (br *Bridge) handleGhostDMInvite(ctx context.Context, evt *event.Event, sen return EventHandlingResultFailed } - portal.roomCreateLock.Lock() - defer portal.roomCreateLock.Unlock() - portalMXID := portal.MXID - if portalMXID != "" { - sendErrorAndLeave(ctx, evt, invitedGhost.Intent, "You already have a direct chat with me at [%s](%s)", portalMXID, portalMXID.URI(br.Matrix.ServerName()).MatrixToURL()) - rejectInvite(ctx, evt, br.Bot, "") - return EventHandlingResultSuccess - } - err = br.givePowerToBot(ctx, evt.RoomID, invitedGhost.Intent) - if err != nil { - log.Err(err).Msg("Failed to give permissions to bridge bot") - sendErrorAndLeave(ctx, evt, invitedGhost.Intent, "Failed to give permissions to bridge bot") - rejectInvite(ctx, evt, br.Bot, "") - return EventHandlingResultSuccess - } - overrideIntent := invitedGhost.Intent - if resp.DMRedirectedTo != "" && resp.DMRedirectedTo != invitedGhost.ID { - log.Debug(). - Str("dm_redirected_to_id", string(resp.DMRedirectedTo)). - Msg("Created DM was redirected to another user ID") - _, err = invitedGhost.Intent.SendState(ctx, evt.RoomID, event.StateMember, invitedGhost.Intent.GetMXID().String(), &event.Content{ - Parsed: &event.MemberEventContent{ - Membership: event.MembershipLeave, - Reason: "Direct chat redirected to another internal user ID", + didSetPortal := portal.setMXIDToExistingRoom(ctx, evt.RoomID) + if didSetPortal { + message := "Private chat portal created" + err = br.givePowerToBot(ctx, evt.RoomID, invitedGhost.Intent) + hasWarning := false + if err != nil { + log.Warn().Err(err).Msg("Failed to give power to bot in new DM") + message += "\n\nWarning: failed to promote bot" + hasWarning = true + } + if resp.DMRedirectedTo != "" && resp.DMRedirectedTo != invitedGhost.ID { + log.Debug(). + Str("dm_redirected_to_id", string(resp.DMRedirectedTo)). + Msg("Created DM was redirected to another user ID") + _, err = invitedGhost.Intent.SendState(ctx, portal.MXID, event.StateMember, invitedGhost.Intent.GetMXID().String(), &event.Content{ + Parsed: &event.MemberEventContent{ + Membership: event.MembershipLeave, + Reason: "Direct chat redirected to another internal user ID", + }, + }, time.Time{}) + if err != nil { + log.Err(err).Msg("Failed to make incorrect ghost leave new DM room") + } + otherUserGhost, err := br.GetGhostByID(ctx, resp.DMRedirectedTo) + if err != nil { + log.Err(err).Msg("Failed to get ghost of real portal other user ID") + } else { + invitedGhost = otherUserGhost + } + } + if resp.PortalInfo != nil { + portal.UpdateInfo(ctx, resp.PortalInfo, sourceLogin, nil, time.Time{}) + } else { + portal.UpdateCapabilities(ctx, sourceLogin, true) + portal.UpdateBridgeInfo(ctx) + } + // TODO this might become unnecessary if UpdateInfo starts taking care of it + _, err = br.Bot.SendState(ctx, portal.MXID, event.StateElementFunctionalMembers, "", &event.Content{ + Parsed: &event.ElementFunctionalMembersContent{ + ServiceMembers: []id.UserID{br.Bot.GetMXID()}, }, }, time.Time{}) if err != nil { - log.Err(err).Msg("Failed to make incorrect ghost leave new DM room") + log.Warn().Err(err).Msg("Failed to set service members in room") + if !hasWarning { + message += "\n\nWarning: failed to set service members" + hasWarning = true + } } - if resp.DMRedirectedTo == SpecialValueDMRedirectedToBot { - overrideIntent = br.Bot - } else if otherUserGhost, err := br.GetGhostByID(ctx, resp.DMRedirectedTo); err != nil { - log.Err(err).Msg("Failed to get ghost of real portal other user ID") - } else { - invitedGhost = otherUserGhost - overrideIntent = otherUserGhost.Intent + mx, ok := br.Matrix.(MatrixConnectorWithPostRoomBridgeHandling) + if ok { + err = mx.HandleNewlyBridgedRoom(ctx, evt.RoomID) + if err != nil { + if hasWarning { + message += fmt.Sprintf(", %s", err.Error()) + } else { + message += fmt.Sprintf("\n\nWarning: %s", err.Error()) + } + } } + sendNotice(ctx, evt, invitedGhost.Intent, message) + } else { + // TODO ensure user is invited even if PortalInfo wasn't provided? + sendErrorAndLeave(ctx, evt, invitedGhost.Intent, "You already have a direct chat with me at [%s](%s)", portal.MXID, portal.MXID.URI(br.Matrix.ServerName()).MatrixToURL()) + rejectInvite(ctx, evt, br.Bot, "") } - err = portal.UpdateMatrixRoomID(ctx, evt.RoomID, UpdateMatrixRoomIDParams{ - // We locked it before checking the mxid - RoomCreateAlreadyLocked: true, - - FailIfMXIDSet: true, - ChatInfo: resp.PortalInfo, - ChatInfoSource: sourceLogin, - }) - if err != nil { - log.Err(err).Msg("Failed to update Matrix room ID for new DM portal") - sendNotice(ctx, evt, overrideIntent, "Failed to finish configuring portal. The chat may or may not work") - return EventHandlingResultSuccess - } - message := "Private chat portal created" - mx, ok := br.Matrix.(MatrixConnectorWithPostRoomBridgeHandling) - if ok { - err = mx.HandleNewlyBridgedRoom(ctx, evt.RoomID) - if err != nil { - log.Err(err).Msg("Error in connector newly bridged room handler") - message += fmt.Sprintf("\n\nWarning: %s", err.Error()) - } - } - sendNotice(ctx, evt, overrideIntent, message) return EventHandlingResultSuccess } @@ -292,3 +294,21 @@ func (br *Bridge) givePowerToBot(ctx context.Context, roomID id.RoomID, userWith } return nil } + +func (portal *Portal) setMXIDToExistingRoom(ctx context.Context, roomID id.RoomID) bool { + portal.roomCreateLock.Lock() + defer portal.roomCreateLock.Unlock() + if portal.MXID != "" { + return false + } + portal.MXID = roomID + portal.updateLogger() + portal.Bridge.cacheLock.Lock() + portal.Bridge.portalsByMXID[portal.MXID] = portal + portal.Bridge.cacheLock.Unlock() + err := portal.Save(ctx) + if err != nil { + zerolog.Ctx(ctx).Err(err).Msg("Failed to save portal after updating mxid") + } + return true +} diff --git a/bridgev2/messagestatus.go b/bridgev2/messagestatus.go index df0c9e4d..7118649d 100644 --- a/bridgev2/messagestatus.go +++ b/bridgev2/messagestatus.go @@ -20,7 +20,6 @@ import ( type MessageStatusEventInfo struct { RoomID id.RoomID - TransactionID string SourceEventID id.EventID NewEventID id.EventID EventType event.Type @@ -42,7 +41,6 @@ func StatusEventInfoFromEvent(evt *event.Event) *MessageStatusEventInfo { return &MessageStatusEventInfo{ RoomID: evt.RoomID, - TransactionID: evt.Unsigned.TransactionID, SourceEventID: evt.ID, EventType: evt.Type, MessageType: evt.Content.AsMessage().MsgType, @@ -184,10 +182,9 @@ func (ms *MessageStatus) ToMSSEvent(evt *MessageStatusEventInfo) *event.BeeperMe Type: event.RelReference, EventID: evt.SourceEventID, }, - TargetTxnID: evt.TransactionID, - Status: ms.Status, - Reason: ms.ErrorReason, - Message: ms.Message, + Status: ms.Status, + Reason: ms.ErrorReason, + Message: ms.Message, } if ms.InternalError != nil { content.InternalError = ms.InternalError.Error() diff --git a/bridgev2/networkid/bridgeid.go b/bridgev2/networkid/bridgeid.go index e3a6df70..443d3655 100644 --- a/bridgev2/networkid/bridgeid.go +++ b/bridgev2/networkid/bridgeid.go @@ -47,8 +47,8 @@ type PortalID string // As a special case, Receiver MUST be set if the Bridge.Config.SplitPortals flag is set to true. // The flag is intended for puppeting-only bridges which want multiple logins to create separate portals for each user. type PortalKey struct { - ID PortalID `json:"portal_id"` - Receiver UserLoginID `json:"portal_receiver,omitempty"` + ID PortalID + Receiver UserLoginID } func (pk PortalKey) IsEmpty() bool { diff --git a/bridgev2/networkinterface.go b/bridgev2/networkinterface.go index efc5f100..eb38bd2d 100644 --- a/bridgev2/networkinterface.go +++ b/bridgev2/networkinterface.go @@ -16,9 +16,7 @@ import ( "github.com/rs/zerolog" "go.mau.fi/util/configupgrade" "go.mau.fi/util/ptr" - "go.mau.fi/util/random" - "maunium.net/go/mautrix" "maunium.net/go/mautrix/bridgev2/database" "maunium.net/go/mautrix/bridgev2/networkid" "maunium.net/go/mautrix/event" @@ -119,15 +117,11 @@ func MergeCaption(textPart, mediaPart *ConvertedMessagePart) *ConvertedMessagePa mediaPart.Content.EnsureHasHTML() mediaPart.Content.Body += "\n\n" + textPart.Content.Body mediaPart.Content.FormattedBody += "

" + textPart.Content.FormattedBody - mediaPart.Content.Mentions = mediaPart.Content.Mentions.Merge(textPart.Content.Mentions) - mediaPart.Content.BeeperLinkPreviews = append(mediaPart.Content.BeeperLinkPreviews, textPart.Content.BeeperLinkPreviews...) } else { mediaPart.Content.FileName = mediaPart.Content.Body mediaPart.Content.Body = textPart.Content.Body mediaPart.Content.Format = textPart.Content.Format mediaPart.Content.FormattedBody = textPart.Content.FormattedBody - mediaPart.Content.Mentions = textPart.Content.Mentions - mediaPart.Content.BeeperLinkPreviews = textPart.Content.BeeperLinkPreviews } if metaMerger, ok := mediaPart.DBMetadata.(database.MetaMerger); ok { metaMerger.CopyFrom(textPart.DBMetadata) @@ -261,7 +255,6 @@ type NetworkConnector interface { } type StoppableNetwork interface { - NetworkConnector // Stop is called when the bridge is stopping, after all network clients have been disconnected. Stop() } @@ -318,16 +311,6 @@ type MaxFileSizeingNetwork interface { SetMaxFileSize(maxSize int64) } -type NetworkResettingNetwork interface { - NetworkConnector - // ResetHTTPTransport should recreate the HTTP client used by the bridge. - // It should refetch settings from the Matrix connector using GetHTTPClientSettings if applicable. - ResetHTTPTransport() - // ResetNetworkConnections should forcefully disconnect and restart any persistent network connections. - // ResetHTTPTransport will usually be called before this, so resetting the transport is not necessary here. - ResetNetworkConnections() -} - type RemoteEchoHandler func(RemoteMessage, *database.Message) (bool, error) type MatrixMessageResponse struct { @@ -359,16 +342,10 @@ type NetworkGeneralCapabilities struct { // Should the bridge re-request user info on incoming messages even if the ghost already has info? // By default, info is only requested for ghosts with no name, and other updating is left to events. AggressiveUpdateInfo bool - // Should the bridge call HandleMatrixReadReceipt with fake data when receiving a new message? - // This should be enabled if the network requires each message to be marked as read independently, - // and doesn't automatically do it when sending a message. - ImplicitReadReceipts bool // If the bridge uses the pending message mechanism ([MatrixMessage.AddPendingToSave]) // to handle asynchronous message responses, this field can be set to enable // automatic timeout errors in case the asynchronous response never arrives. OutgoingMessageTimeouts *OutgoingTimeoutConfig - // Capabilities related to the provisioning API. - Provisioning ProvisioningCapabilities } // NetworkAPI is an interface representing a remote network client for a single user login. @@ -702,35 +679,6 @@ type RoomTopicHandlingNetworkAPI interface { HandleMatrixRoomTopic(ctx context.Context, msg *MatrixRoomTopic) (bool, error) } -type DisappearTimerChangingNetworkAPI interface { - NetworkAPI - // HandleMatrixDisappearingTimer is called when the disappearing timer of a portal room is changed. - // This method should update the Disappear field of the Portal with the new timer and return true - // if the change was successful. If the change is not successful, then the field should not be updated. - HandleMatrixDisappearingTimer(ctx context.Context, msg *MatrixDisappearingTimer) (bool, error) -} - -// DeleteChatHandlingNetworkAPI is an optional interface that network connectors -// can implement to delete a chat from the remote network. -type DeleteChatHandlingNetworkAPI interface { - NetworkAPI - // HandleMatrixDeleteChat is called when the user explicitly deletes a chat. - HandleMatrixDeleteChat(ctx context.Context, msg *MatrixDeleteChat) error -} - -// MessageRequestAcceptingNetworkAPI is an optional interface that network connectors -// can implement to accept message requests from the remote network. -type MessageRequestAcceptingNetworkAPI interface { - NetworkAPI - // HandleMatrixAcceptMessageRequest is called when the user accepts a message request. - HandleMatrixAcceptMessageRequest(ctx context.Context, msg *MatrixAcceptMessageRequest) error -} - -type BeeperAIStreamHandlingNetworkAPI interface { - NetworkAPI - HandleMatrixBeeperAIStream(ctx context.Context, msg *MatrixBeeperAIStream) error -} - type ResolveIdentifierResponse struct { // Ghost is the ghost of the user that the identifier resolves to. // This field should be set whenever possible. However, it is not required, @@ -750,8 +698,6 @@ type ResolveIdentifierResponse struct { Chat *CreateChatResponse } -var SpecialValueDMRedirectedToBot = networkid.UserID("__fi.mau.bridgev2.dm_redirected_to_bot::" + random.String(10)) - type CreateChatResponse struct { PortalKey networkid.PortalKey // Portal and PortalInfo are not required, the caller will fetch them automatically based on PortalKey if necessary. @@ -760,17 +706,6 @@ type CreateChatResponse struct { // If a start DM request (CreateChatWithGhost or ResolveIdentifier) returns the DM to a different user, // this field should have the user ID of said different user. DMRedirectedTo networkid.UserID - - FailedParticipants map[networkid.UserID]*CreateChatFailedParticipant -} - -type CreateChatFailedParticipant struct { - Reason string `json:"reason"` - InviteEventType string `json:"invite_event_type,omitempty"` - InviteContent *event.Content `json:"invite_content,omitempty"` - - UserMXID id.UserID `json:"user_mxid,omitempty"` - DMRoomMXID id.RoomID `json:"dm_room_mxid,omitempty"` } // IdentifierResolvingNetworkAPI is an optional interface that network connectors can implement to support starting new direct chats. @@ -805,83 +740,7 @@ type UserSearchingNetworkAPI interface { type GroupCreatingNetworkAPI interface { IdentifierResolvingNetworkAPI - CreateGroup(ctx context.Context, params *GroupCreateParams) (*CreateChatResponse, error) -} - -type PersonalFilteringCustomizingNetworkAPI interface { - NetworkAPI - CustomizePersonalFilteringSpace(req *mautrix.ReqCreateRoom) -} - -type ProvisioningCapabilities struct { - ResolveIdentifier ResolveIdentifierCapabilities `json:"resolve_identifier"` - GroupCreation map[string]GroupTypeCapabilities `json:"group_creation"` -} - -type ResolveIdentifierCapabilities struct { - // Can DMs be created after resolving an identifier? - CreateDM bool `json:"create_dm"` - // Can users be looked up by phone number? - LookupPhone bool `json:"lookup_phone"` - // Can users be looked up by email address? - LookupEmail bool `json:"lookup_email"` - // Can users be looked up by network-specific username? - LookupUsername bool `json:"lookup_username"` - // Can any phone number be contacted without having to validate it via lookup first? - AnyPhone bool `json:"any_phone"` - // Can a contact list be retrieved from the bridge? - ContactList bool `json:"contact_list"` - // Can users be searched by name on the remote network? - Search bool `json:"search"` -} - -type GroupTypeCapabilities struct { - TypeDescription string `json:"type_description"` - - Name GroupFieldCapability `json:"name"` - Username GroupFieldCapability `json:"username"` - Avatar GroupFieldCapability `json:"avatar"` - Topic GroupFieldCapability `json:"topic"` - Disappear GroupFieldCapability `json:"disappear"` - Participants GroupFieldCapability `json:"participants"` - Parent GroupFieldCapability `json:"parent"` -} - -type GroupFieldCapability struct { - // Is setting this field allowed at all in the create request? - // Even if false, the network connector should attempt to set the metadata after group creation, - // as the allowed flag can't be enforced properly when creating a group for an existing Matrix room. - Allowed bool `json:"allowed"` - // Is setting this field mandatory for the creation to succeed? - Required bool `json:"required,omitempty"` - // The minimum/maximum length of the field, if applicable. - // For members, length means the number of members excluding the creator. - MinLength int `json:"min_length,omitempty"` - MaxLength int `json:"max_length,omitempty"` - - // Only for the disappear field: allowed disappearing settings - DisappearSettings *event.DisappearingTimerCapability `json:"settings,omitempty"` - - // This can be used to tell provisionutil not to call ValidateUserID on each participant. - // It only meant to allow hacks where ResolveIdentifier returns a fake ID that isn't actually valid for MXIDs. - SkipIdentifierValidation bool `json:"-"` -} - -type GroupCreateParams struct { - Type string `json:"type,omitempty"` - - Username string `json:"username,omitempty"` - // Clients may also provide MXIDs here, but provisionutil will normalize them, so bridges only need to handle network IDs - Participants []networkid.UserID `json:"participants,omitempty"` - Parent *networkid.PortalKey `json:"parent,omitempty"` - - Name *event.RoomNameEventContent `json:"name,omitempty"` - Avatar *event.RoomAvatarEventContent `json:"avatar,omitempty"` - Topic *event.TopicEventContent `json:"topic,omitempty"` - Disappear *event.BeeperDisappearingTimer `json:"disappear,omitempty"` - - // An existing room ID to bridge to. If unset, a new room will be created. - RoomID id.RoomID `json:"room_id,omitempty"` + CreateGroup(ctx context.Context, name string, users ...networkid.UserID) (*CreateChatResponse, error) } type MembershipChangeType struct { @@ -921,15 +780,16 @@ type MatrixMembershipChange struct { MatrixRoomMeta[*event.MemberEventContent] Target GhostOrUserLogin Type MembershipChangeType -} -type MatrixMembershipResult struct { - RedirectTo networkid.UserID + // Deprecated: Use Target instead + TargetGhost *Ghost + // Deprecated: Use Target instead + TargetUserLogin *UserLogin } type MembershipHandlingNetworkAPI interface { NetworkAPI - HandleMatrixMembership(ctx context.Context, msg *MatrixMembershipChange) (*MatrixMembershipResult, error) + HandleMatrixMembership(ctx context.Context, msg *MatrixMembershipChange) (bool, error) } type SinglePowerLevelChange struct { @@ -1168,11 +1028,6 @@ type RemoteChatDelete interface { RemoteDeleteOnlyForMe } -type RemoteChatDeleteWithChildren interface { - RemoteChatDelete - DeleteChildren() bool -} - type RemoteEventThatMayCreatePortal interface { RemoteEvent ShouldCreatePortal() bool @@ -1405,14 +1260,12 @@ type MatrixMessageRemove struct { type MatrixRoomMeta[ContentType any] struct { MatrixEventBase[ContentType] - PrevContent ContentType - IsStateRequest bool + PrevContent ContentType } type MatrixRoomName = MatrixRoomMeta[*event.RoomNameEventContent] type MatrixRoomAvatar = MatrixRoomMeta[*event.RoomAvatarEventContent] type MatrixRoomTopic = MatrixRoomMeta[*event.TopicEventContent] -type MatrixDisappearingTimer = MatrixRoomMeta[*event.BeeperDisappearingTimer] type MatrixReadReceipt struct { Portal *Portal @@ -1427,8 +1280,6 @@ type MatrixReadReceipt struct { LastRead time.Time // The receipt metadata. Receipt event.ReadReceipt - // Whether the receipt is implicit, i.e. triggered by an incoming timeline event rather than an explicit receipt. - Implicit bool } type MatrixTyping struct { @@ -1442,9 +1293,6 @@ type MatrixViewingChat struct { Portal *Portal } -type MatrixDeleteChat = MatrixEventBase[*event.BeeperChatDeleteEventContent] -type MatrixAcceptMessageRequest = MatrixEventBase[*event.BeeperAcceptMessageRequestEventContent] -type MatrixBeeperAIStream = MatrixEventBase[*event.BeeperAIStreamEventContent] type MatrixMarkedUnread = MatrixRoomMeta[*event.MarkedUnreadEventContent] type MatrixMute = MatrixRoomMeta[*event.BeeperMuteEventContent] type MatrixRoomTag = MatrixRoomMeta[*event.TagEventContent] diff --git a/bridgev2/portal.go b/bridgev2/portal.go index 16aa703b..d343a651 100644 --- a/bridgev2/portal.go +++ b/bridgev2/portal.go @@ -19,7 +19,6 @@ import ( "github.com/rs/zerolog" "go.mau.fi/util/exfmt" - "go.mau.fi/util/exmaps" "go.mau.fi/util/exslices" "go.mau.fi/util/exsync" "go.mau.fi/util/ptr" @@ -86,15 +85,12 @@ type Portal struct { lastCapUpdate time.Time - roomCreateLock sync.Mutex - cancelRoomCreate atomic.Pointer[context.CancelFunc] - RoomCreated *exsync.Event + roomCreateLock sync.Mutex functionalMembersLock sync.Mutex functionalMembersCache *event.ElementFunctionalMembersContent - events chan portalEvent - deleted *exsync.Event + events chan portalEvent eventsLock sync.Mutex eventIdx int @@ -126,15 +122,7 @@ func (br *Bridge) loadPortal(ctx context.Context, dbPortal *database.Portal, que currentlyTypingLogins: make(map[id.UserID]*UserLogin), currentlyTypingGhosts: exsync.NewSet[id.UserID](), outgoingMessages: make(map[networkid.TransactionID]*outgoingMessage), - - RoomCreated: exsync.NewEvent(), - deleted: exsync.NewEvent(), } - if portal.MXID != "" { - portal.RoomCreated.Set() - } - // Putting the portal in the cache before it's fully initialized is mildly dangerous, - // but loading the relay user login may depend on it. br.portalsByKey[portal.PortalKey] = portal if portal.MXID != "" { br.portalsByMXID[portal.MXID] = portal @@ -143,20 +131,12 @@ func (br *Bridge) loadPortal(ctx context.Context, dbPortal *database.Portal, que if portal.ParentKey.ID != "" { portal.Parent, err = br.UnlockedGetPortalByKey(ctx, portal.ParentKey, false) if err != nil { - delete(br.portalsByKey, portal.PortalKey) - if portal.MXID != "" { - delete(br.portalsByMXID, portal.MXID) - } return nil, fmt.Errorf("failed to load parent portal (%s): %w", portal.ParentKey, err) } } if portal.RelayLoginID != "" { portal.Relay, err = br.unlockedGetExistingUserLoginByID(ctx, portal.RelayLoginID) if err != nil { - delete(br.portalsByKey, portal.PortalKey) - if portal.MXID != "" { - delete(br.portalsByMXID, portal.MXID) - } return nil, fmt.Errorf("failed to load relay login (%s): %w", portal.RelayLoginID, err) } } @@ -169,9 +149,7 @@ func (br *Bridge) loadPortal(ctx context.Context, dbPortal *database.Portal, que } func (portal *Portal) updateLogger() { - logWith := portal.Bridge.Log.With(). - Str("portal_id", string(portal.ID)). - Str("portal_receiver", string(portal.Receiver)) + logWith := portal.Bridge.Log.With().Str("portal_id", string(portal.ID)) if portal.MXID != "" { logWith = logWith.Stringer("portal_mxid", portal.MXID) } @@ -195,16 +173,6 @@ func (br *Bridge) loadManyPortals(ctx context.Context, portals []*database.Porta return output, nil } -func (br *Bridge) loadPortalWithCacheCheck(ctx context.Context, dbPortal *database.Portal) (*Portal, error) { - if dbPortal == nil { - return nil, nil - } else if cached, ok := br.portalsByKey[dbPortal.PortalKey]; ok { - return cached, nil - } else { - return br.loadPortal(ctx, dbPortal, nil, nil) - } -} - func (br *Bridge) UnlockedGetPortalByKey(ctx context.Context, key networkid.PortalKey, onlyIfExists bool) (*Portal, error) { if br.Config.SplitPortals && key.Receiver == "" { return nil, fmt.Errorf("receiver must always be set when split portals is enabled") @@ -294,26 +262,6 @@ func (br *Bridge) GetDMPortalsWith(ctx context.Context, otherUserID networkid.Us return br.loadManyPortals(ctx, rows) } -func (br *Bridge) GetChildPortals(ctx context.Context, parent networkid.PortalKey) ([]*Portal, error) { - br.cacheLock.Lock() - defer br.cacheLock.Unlock() - rows, err := br.DB.Portal.GetChildren(ctx, parent) - if err != nil { - return nil, err - } - return br.loadManyPortals(ctx, rows) -} - -func (br *Bridge) GetDMPortal(ctx context.Context, receiver networkid.UserLoginID, otherUserID networkid.UserID) (*Portal, error) { - br.cacheLock.Lock() - defer br.cacheLock.Unlock() - dbPortal, err := br.DB.Portal.GetDM(ctx, receiver, otherUserID) - if err != nil { - return nil, err - } - return br.loadPortalWithCacheCheck(ctx, dbPortal) -} - func (br *Bridge) GetPortalByKey(ctx context.Context, key networkid.PortalKey) (*Portal, error) { br.cacheLock.Lock() defer br.cacheLock.Unlock() @@ -339,23 +287,15 @@ func (br *Bridge) GetExistingPortalByKey(ctx context.Context, key networkid.Port } func (portal *Portal) queueEvent(ctx context.Context, evt portalEvent) EventHandlingResult { - if portal.deleted.IsSet() { - return EventHandlingResultIgnored - } if PortalEventBuffer == 0 { portal.eventsLock.Lock() defer portal.eventsLock.Unlock() portal.eventIdx++ - return portal.handleSingleEventWithDelayLogging(portal.eventIdx, evt) + return portal.handleSingleEventAsync(portal.eventIdx, evt) } else { - if portal.events == nil { - panic(fmt.Errorf("queueEvent into uninitialized portal %s", portal.PortalKey)) - } select { case portal.events <- evt: return EventHandlingResultQueued - case <-portal.deleted.GetChan(): - return EventHandlingResultIgnored default: zerolog.Ctx(ctx).Error(). Str("portal_id", string(portal.ID)). @@ -380,71 +320,64 @@ func (portal *Portal) eventLoop() { go portal.pendingMessageTimeoutLoop(ctx, cfg) defer cancel() } - deleteCh := portal.deleted.GetChan() - for i := 0; ; i++ { - select { - case rawEvt := <-portal.events: - if rawEvt == nil { - return - } - if portal.Bridge.Config.AsyncEvents { - go portal.handleSingleEventWithDelayLogging(i, rawEvt) - } else { - portal.handleSingleEventWithDelayLogging(i, rawEvt) - } - case <-deleteCh: - return - } + i := 0 + for rawEvt := range portal.events { + i++ + portal.handleSingleEventAsync(i, rawEvt) } } -func (portal *Portal) handleSingleEventWithDelayLogging(idx int, rawEvt any) (outerRes EventHandlingResult) { +func (portal *Portal) handleSingleEventAsync(idx int, rawEvt any) (outerRes EventHandlingResult) { ctx := portal.getEventCtxWithLog(rawEvt, idx) - log := zerolog.Ctx(ctx) - doneCh := make(chan struct{}) - var backgrounded atomic.Bool - start := time.Now() - var handleDuration time.Duration - // Note: this will not set the success flag if the handler times out - outerRes = EventHandlingResult{Queued: true} - go portal.handleSingleEvent(ctx, rawEvt, func(res EventHandlingResult) { - outerRes = res - handleDuration = time.Since(start) - close(doneCh) - if backgrounded.Load() { - log.Debug(). - Time("started_at", start). - Stringer("duration", handleDuration). - Msg("Event that took too long finally finished handling") - } - }) - tick := time.NewTicker(30 * time.Second) - _, isCreate := rawEvt.(*portalCreateEvent) - defer tick.Stop() - for i := 0; i < 10; i++ { - select { - case <-doneCh: - if i > 0 { + if _, isCreate := rawEvt.(*portalCreateEvent); isCreate { + portal.handleSingleEvent(ctx, rawEvt, func(res EventHandlingResult) { + outerRes = res + }) + } else if portal.Bridge.Config.AsyncEvents { + outerRes = EventHandlingResultQueued + go portal.handleSingleEvent(ctx, rawEvt, func(res EventHandlingResult) {}) + } else { + log := zerolog.Ctx(ctx) + doneCh := make(chan struct{}) + var backgrounded atomic.Bool + start := time.Now() + var handleDuration time.Duration + // Note: this will not set the success flag if the handler times out + outerRes = EventHandlingResult{Queued: true} + go portal.handleSingleEvent(ctx, rawEvt, func(res EventHandlingResult) { + outerRes = res + handleDuration = time.Since(start) + close(doneCh) + if backgrounded.Load() { log.Debug(). Time("started_at", start). Stringer("duration", handleDuration). - Msg("Event that took long finished handling") + Msg("Event that took too long finally finished handling") } - return - case <-tick.C: - log.Warn(). - Time("started_at", start). - Msg("Event handling is taking long") - if isCreate { - // Never background portal creation events - i = 1 + }) + tick := time.NewTicker(30 * time.Second) + defer tick.Stop() + for i := 0; i < 10; i++ { + select { + case <-doneCh: + if i > 0 { + log.Debug(). + Time("started_at", start). + Stringer("duration", handleDuration). + Msg("Event that took long finished handling") + } + return + case <-tick.C: + log.Warn(). + Time("started_at", start). + Msg("Event handling is taking long") } } + log.Warn(). + Time("started_at", start). + Msg("Event handling is taking too long, continuing in background") + backgrounded.Store(true) } - log.Warn(). - Time("started_at", start). - Msg("Event handling is taking too long, continuing in background") - backgrounded.Store(true) return } @@ -486,11 +419,6 @@ func (portal *Portal) getEventCtxWithLog(rawEvt any, idx int) context.Context { logWith = logWith.Int64("remote_stream_order", remoteStreamOrder) } } - if remoteMsg, ok := evt.evt.(RemoteEventWithTimestamp); ok { - if remoteTimestamp := remoteMsg.GetTimestamp(); !remoteTimestamp.IsZero() { - logWith = logWith.Time("remote_timestamp", remoteTimestamp) - } - } case *portalCreateEvent: return evt.ctx } @@ -530,14 +458,7 @@ func (portal *Portal) handleSingleEvent(ctx context.Context, rawEvt any, doneCal }() switch evt := rawEvt.(type) { case *portalMatrixEvent: - isStateRequest := evt.evt.Type == event.BeeperSendState - if isStateRequest { - if err := portal.unwrapBeeperSendState(ctx, evt.evt); err != nil { - portal.sendErrorStatus(ctx, evt.evt, err) - return - } - } - res = portal.handleMatrixEvent(ctx, evt.sender, evt.evt, isStateRequest) + res = portal.handleMatrixEvent(ctx, evt.sender, evt.evt) if res.SendMSS { if res.Error != nil { portal.sendErrorStatus(ctx, evt.evt, res.Error) @@ -545,21 +466,6 @@ func (portal *Portal) handleSingleEvent(ctx context.Context, rawEvt any, doneCal portal.sendSuccessStatus(ctx, evt.evt, 0, "") } } - if !isStateRequest && res.Error != nil && evt.evt.StateKey != nil { - portal.revertRoomMeta(ctx, evt.evt) - } - if isStateRequest && res.Success && !res.SkipStateEcho { - portal.sendRoomMeta( - ctx, - evt.sender.DoublePuppet(ctx), - time.UnixMilli(evt.evt.Timestamp), - evt.evt.Type, - evt.evt.GetStateKey(), - evt.evt.Content.Parsed, - false, - evt.evt.Content.Raw, - ) - } case *portalRemoteEvent: res = portal.handleRemoteEvent(ctx, evt.source, evt.evtType, evt.evt) case *portalCreateEvent: @@ -571,44 +477,18 @@ func (portal *Portal) handleSingleEvent(ctx context.Context, rawEvt any, doneCal } } -func (portal *Portal) unwrapBeeperSendState(ctx context.Context, evt *event.Event) error { - content, ok := evt.Content.Parsed.(*event.BeeperSendStateEventContent) - if !ok { - return fmt.Errorf("%w: %T", ErrUnexpectedParsedContentType, evt.Content.Parsed) - } - evt.Content = content.Content - evt.StateKey = &content.StateKey - evt.Type = event.Type{Type: content.Type, Class: event.StateEventType} - _ = evt.Content.ParseRaw(evt.Type) - mx, ok := portal.Bridge.Matrix.(MatrixConnectorWithArbitraryRoomState) - if !ok { - return fmt.Errorf("matrix connector doesn't support fetching state") - } - prevEvt, err := mx.GetStateEvent(ctx, portal.MXID, evt.Type, evt.GetStateKey()) - if err != nil && !errors.Is(err, mautrix.MNotFound) { - return fmt.Errorf("failed to get prev event: %w", err) - } else if prevEvt != nil { - evt.Unsigned.PrevContent = &prevEvt.Content - evt.Unsigned.PrevSender = prevEvt.Sender - } - return nil -} - func (portal *Portal) FindPreferredLogin(ctx context.Context, user *User, allowRelay bool) (*UserLogin, *database.UserPortal, error) { if portal.Receiver != "" { login, err := portal.Bridge.GetExistingUserLoginByID(ctx, portal.Receiver) if err != nil { return nil, nil, err } - if login == nil { - return nil, nil, fmt.Errorf("%w (receiver login is nil)", ErrNotLoggedIn) - } else if !login.Client.IsLoggedIn() { - return nil, nil, fmt.Errorf("%w (receiver login is not logged in)", ErrNotLoggedIn) - } else if login.UserMXID != user.MXID { + if login == nil || login.UserMXID != user.MXID || !login.Client.IsLoggedIn() { if allowRelay && portal.Relay != nil { return nil, nil, nil } - return nil, nil, fmt.Errorf("%w (relay not set and receiver login is owned by %s, not %s)", ErrNotLoggedIn, login.UserMXID, user.MXID) + // TODO different error for this case? + return nil, nil, ErrNotLoggedIn } up, err := portal.Bridge.DB.UserPortal.Get(ctx, login.UserLogin, portal.PortalKey) return login, up, err @@ -691,7 +571,7 @@ func (portal *Portal) checkConfusableName(ctx context.Context, userID id.UserID, var fakePerMessageProfileEventType = event.Type{Class: event.StateEventType, Type: "m.per_message_profile"} -func (portal *Portal) handleMatrixEvent(ctx context.Context, sender *User, evt *event.Event, isStateRequest bool) EventHandlingResult { +func (portal *Portal) handleMatrixEvent(ctx context.Context, sender *User, evt *event.Event) EventHandlingResult { log := zerolog.Ctx(ctx) if evt.Mautrix.EventSource&event.SourceEphemeral != 0 { switch evt.Type { @@ -699,8 +579,6 @@ func (portal *Portal) handleMatrixEvent(ctx context.Context, sender *User, evt * return portal.handleMatrixReceipts(ctx, evt) case event.EphemeralEventTyping: return portal.handleMatrixTyping(ctx, evt) - case event.BeeperEphemeralEventAIStream: - return portal.handleMatrixAIStream(ctx, sender, evt) default: return EventHandlingResultIgnored } @@ -709,7 +587,7 @@ func (portal *Portal) handleMatrixEvent(ctx context.Context, sender *User, evt * // Tombstones aren't bridged so they don't need a login return portal.handleMatrixTombstone(ctx, evt) } - login, userPortal, err := portal.FindPreferredLogin(ctx, sender, true) + login, _, err := portal.FindPreferredLogin(ctx, sender, true) if err != nil { log.Err(err).Msg("Failed to get user login to handle Matrix event") if errors.Is(err, ErrNotLoggedIn) { @@ -725,9 +603,6 @@ func (portal *Portal) handleMatrixEvent(ctx context.Context, sender *User, evt * } var origSender *OrigSender if login == nil { - if isStateRequest { - return EventHandlingResultFailed.WithMSSError(ErrCantRelayStateRequest) - } login = portal.Relay origSender = &OrigSender{ User: sender, @@ -771,21 +646,6 @@ func (portal *Portal) handleMatrixEvent(ctx context.Context, sender *User, evt * } // Copy logger because many of the handlers will use UpdateContext ctx = log.With().Str("login_id", string(login.ID)).Logger().WithContext(ctx) - - if origSender == nil && portal.Bridge.Network.GetCapabilities().ImplicitReadReceipts && !evt.Type.IsAccountData() { - rrLog := log.With().Str("subaction", "implicit read receipt").Logger() - rrCtx := rrLog.WithContext(ctx) - rrLog.Debug().Msg("Sending implicit read receipt for event") - evtTS := time.UnixMilli(evt.Timestamp) - portal.callReadReceiptHandler(rrCtx, login, nil, &MatrixReadReceipt{ - Portal: portal, - EventID: evt.ID, - Implicit: true, - ReadUpTo: evtTS, - Receipt: event.ReadReceipt{Timestamp: evtTS}, - }, userPortal) - } - switch evt.Type { case event.EventMessage, event.EventSticker, event.EventUnstablePollStart, event.EventUnstablePollResponse: return portal.handleMatrixMessage(ctx, login, origSender, evt) @@ -798,13 +658,11 @@ func (portal *Portal) handleMatrixEvent(ctx context.Context, sender *User, evt * case event.EventRedaction: return portal.handleMatrixRedaction(ctx, login, origSender, evt) case event.StateRoomName: - return handleMatrixRoomMeta(portal, ctx, login, origSender, evt, isStateRequest, RoomNameHandlingNetworkAPI.HandleMatrixRoomName) + return handleMatrixRoomMeta(portal, ctx, login, origSender, evt, RoomNameHandlingNetworkAPI.HandleMatrixRoomName) case event.StateTopic: - return handleMatrixRoomMeta(portal, ctx, login, origSender, evt, isStateRequest, RoomTopicHandlingNetworkAPI.HandleMatrixRoomTopic) + return handleMatrixRoomMeta(portal, ctx, login, origSender, evt, RoomTopicHandlingNetworkAPI.HandleMatrixRoomTopic) case event.StateRoomAvatar: - return handleMatrixRoomMeta(portal, ctx, login, origSender, evt, isStateRequest, RoomAvatarHandlingNetworkAPI.HandleMatrixRoomAvatar) - case event.StateBeeperDisappearingTimer: - return handleMatrixRoomMeta(portal, ctx, login, origSender, evt, isStateRequest, DisappearTimerChangingNetworkAPI.HandleMatrixDisappearingTimer) + return handleMatrixRoomMeta(portal, ctx, login, origSender, evt, RoomAvatarHandlingNetworkAPI.HandleMatrixRoomAvatar) case event.StateEncryption: // TODO? return EventHandlingResultIgnored @@ -815,13 +673,9 @@ func (portal *Portal) handleMatrixEvent(ctx context.Context, sender *User, evt * case event.AccountDataBeeperMute: return handleMatrixAccountData(portal, ctx, login, evt, MuteHandlingNetworkAPI.HandleMute) case event.StateMember: - return portal.handleMatrixMembership(ctx, login, origSender, evt, isStateRequest) + return portal.handleMatrixMembership(ctx, login, origSender, evt) case event.StatePowerLevels: - return portal.handleMatrixPowerLevels(ctx, login, origSender, evt, isStateRequest) - case event.BeeperDeleteChat: - return portal.handleMatrixDeleteChat(ctx, login, origSender, evt) - case event.BeeperAcceptMessageRequest: - return portal.handleMatrixAcceptMessageRequest(ctx, login, origSender, evt) + return portal.handleMatrixPowerLevels(ctx, login, origSender, evt) default: return EventHandlingResultIgnored } @@ -879,10 +733,15 @@ func (portal *Portal) handleMatrixReadReceipt(ctx context.Context, user *User, e EventID: eventID, Receipt: receipt, } + if userPortal == nil { + userPortal = database.UserPortalFor(login.UserLogin, portal.PortalKey) + } else { + evt.LastRead = userPortal.LastRead + userPortal = userPortal.CopyWithoutValues() + } evt.ExactMessage, err = portal.Bridge.DB.Message.GetPartByMXID(ctx, eventID) if err != nil { log.Err(err).Msg("Failed to get exact message from database") - evt.ReadUpTo = receipt.Timestamp } else if evt.ExactMessage != nil { log.UpdateContext(func(c zerolog.Context) zerolog.Context { return c.Str("exact_message_id", string(evt.ExactMessage.ID)).Time("exact_message_ts", evt.ExactMessage.Timestamp) @@ -891,40 +750,21 @@ func (portal *Portal) handleMatrixReadReceipt(ctx context.Context, user *User, e } else { evt.ReadUpTo = receipt.Timestamp } - portal.callReadReceiptHandler(ctx, login, rrClient, evt, userPortal) -} - -func (portal *Portal) callReadReceiptHandler( - ctx context.Context, - login *UserLogin, - rrClient ReadReceiptHandlingNetworkAPI, - evt *MatrixReadReceipt, - userPortal *database.UserPortal, -) { - if rrClient == nil { - var ok bool - rrClient, ok = login.Client.(ReadReceiptHandlingNetworkAPI) - if !ok { - return - } - } - if userPortal == nil { - userPortal = database.UserPortalFor(login.UserLogin, portal.PortalKey) - } else { - evt.LastRead = userPortal.LastRead - userPortal = userPortal.CopyWithoutValues() - } - err := rrClient.HandleMatrixReadReceipt(ctx, evt) + err = rrClient.HandleMatrixReadReceipt(ctx, evt) if err != nil { - zerolog.Ctx(ctx).Err(err).Msg("Failed to handle read receipt") + log.Err(err).Msg("Failed to handle read receipt") return } - userPortal.LastRead = evt.ReadUpTo + if evt.ExactMessage != nil { + userPortal.LastRead = evt.ExactMessage.Timestamp + } else { + userPortal.LastRead = receipt.Timestamp + } err = portal.Bridge.DB.UserPortal.Put(ctx, userPortal) if err != nil { - zerolog.Ctx(ctx).Err(err).Msg("Failed to save user portal metadata") + log.Err(err).Msg("Failed to save user portal metadata") } - portal.Bridge.DisappearLoop.StartAllBefore(ctx, portal.MXID, evt.ReadUpTo) + portal.Bridge.DisappearLoop.StartAll(ctx, portal.MXID) } func (portal *Portal) handleMatrixTyping(ctx context.Context, evt *event.Event) EventHandlingResult { @@ -945,50 +785,6 @@ func (portal *Portal) handleMatrixTyping(ctx context.Context, evt *event.Event) return EventHandlingResultSuccess } -func (portal *Portal) handleMatrixAIStream(ctx context.Context, sender *User, evt *event.Event) EventHandlingResult { - log := zerolog.Ctx(ctx) - if sender == nil { - log.Error().Msg("Missing sender for Matrix AI stream event") - return EventHandlingResultIgnored - } - login, _, err := portal.FindPreferredLogin(ctx, sender, true) - if err != nil { - log.Err(err).Msg("Failed to get user login to handle Matrix AI stream event") - return EventHandlingResultFailed.WithMSSError(err) - } - var origSender *OrigSender - if login == nil { - if portal.Relay == nil { - return EventHandlingResultIgnored - } - login = portal.Relay - origSender = &OrigSender{ - User: sender, - UserID: sender.MXID, - } - } - content, ok := evt.Content.Parsed.(*event.BeeperAIStreamEventContent) - if !ok { - log.Error().Type("content_type", evt.Content.Parsed).Msg("Unexpected parsed content type") - return EventHandlingResultFailed.WithMSSError(fmt.Errorf("%w: %T", ErrUnexpectedParsedContentType, evt.Content.Parsed)) - } - api, ok := login.Client.(BeeperAIStreamHandlingNetworkAPI) - if !ok { - return EventHandlingResultIgnored.WithMSSError(ErrBeeperAIStreamNotSupported) - } - err = api.HandleMatrixBeeperAIStream(ctx, &MatrixBeeperAIStream{ - Event: evt, - Content: content, - Portal: portal, - OrigSender: origSender, - }) - if err != nil { - log.Err(err).Msg("Failed to handle Matrix AI stream event") - return EventHandlingResultFailed.WithMSSError(err) - } - return EventHandlingResultSuccess.WithMSS() -} - func (portal *Portal) sendTypings(ctx context.Context, userIDs []id.UserID, typing bool) { for _, userID := range userIDs { login, ok := portal.currentlyTypingLogins[userID] @@ -1097,18 +893,8 @@ func (portal *Portal) checkMessageContentCaps(caps *event.RoomFeatures, content feat.Caption.Reject() { return ErrCaptionsNotAllowed } - if content.Info != nil { - dur := time.Duration(content.Info.Duration) * time.Millisecond - if feat.MaxDuration != nil && dur > feat.MaxDuration.Duration { - if capMsgType == event.CapMsgVoice { - return fmt.Errorf("%w: %s supports voice messages up to %s long", ErrVoiceMessageDurationTooLong, portal.Bridge.Network.GetName().DisplayName, exfmt.Duration(feat.MaxDuration.Duration)) - } - return fmt.Errorf("%w: %s is longer than the maximum of %s", ErrMediaDurationTooLong, exfmt.Duration(dur), exfmt.Duration(feat.MaxDuration.Duration)) - } - if feat.MaxSize != 0 && int64(content.Info.Size) > feat.MaxSize { - return fmt.Errorf("%w: %.1f MiB is larger than the maximum of %.1f MiB", ErrMediaTooLarge, float64(content.Info.Size)/1024/1024, float64(feat.MaxSize)/1024/1024) - } - if content.Info.MimeType != "" && feat.GetMimeSupport(content.Info.MimeType).Reject() { + if content.Info != nil && content.Info.MimeType != "" { + if feat.GetMimeSupport(content.Info.MimeType).Reject() { return fmt.Errorf("%w (%s in %s)", ErrUnsupportedMediaType, content.Info.MimeType, capMsgType) } } @@ -1168,12 +954,10 @@ func (portal *Portal) handleMatrixMessage(ctx context.Context, sender *UserLogin log.Debug().Msg("Ignoring poll event from relayed user") return EventHandlingResultIgnored.WithMSSError(ErrIgnoringPollFromRelayedUser) } - if !caps.PerMessageProfileRelay { - msgContent, err = portal.Bridge.Config.Relay.FormatMessage(msgContent, origSender) - if err != nil { - log.Err(err).Msg("Failed to format message for relaying") - return EventHandlingResultFailed.WithMSSError(err) - } + msgContent, err = portal.Bridge.Config.Relay.FormatMessage(msgContent, origSender) + if err != nil { + log.Err(err).Msg("Failed to format message for relaying") + return EventHandlingResultFailed.WithMSSError(err) } } if msgContent != nil { @@ -1241,16 +1025,6 @@ func (portal *Portal) handleMatrixMessage(ctx context.Context, sender *UserLogin } } } - var messageTimer *event.BeeperDisappearingTimer - if msgContent != nil { - messageTimer = msgContent.BeeperDisappearingTimer - } - if messageTimer != nil && *portal.Disappear.ToEventContent() != *messageTimer { - log.Warn(). - Any("event_timer", messageTimer). - Any("portal_timer", portal.Disappear.ToEventContent()). - Msg("Mismatching disappearing timer in event") - } wrappedMsgEvt := &MatrixMessage{ MatrixEventBase: MatrixEventBase[*event.MessageEventContent]{ @@ -1276,12 +1050,6 @@ func (portal *Portal) handleMatrixMessage(ctx context.Context, sender *UserLogin } } - err = portal.autoAcceptMessageRequest(ctx, evt, sender, origSender, caps) - if err != nil { - log.Warn().Err(err).Msg("Failed to auto-accept message request on message") - // TODO stop processing? - } - var resp *MatrixMessageResponse if msgContent != nil { resp, err = sender.Client.HandleMatrixMessage(ctx, wrappedMsgEvt) @@ -1333,23 +1101,22 @@ func (portal *Portal) handleMatrixMessage(ctx context.Context, sender *UserLogin } portal.sendSuccessStatus(ctx, evt, resp.StreamOrder, message.MXID) } - ds := portal.Disappear - if messageTimer != nil { - ds = database.DisappearingSettingFromEvent(messageTimer) - } - if ds.Type != event.DisappearingTypeNone { + if portal.Disappear.Type != database.DisappearingTypeNone { go portal.Bridge.DisappearLoop.Add(ctx, &database.DisappearingMessage{ - RoomID: portal.MXID, - EventID: message.MXID, - Timestamp: message.Timestamp, - DisappearingSetting: ds.StartingAt(message.Timestamp), + RoomID: portal.MXID, + EventID: message.MXID, + DisappearingSetting: database.DisappearingSetting{ + Type: portal.Disappear.Type, + Timer: portal.Disappear.Timer, + DisappearAt: message.Timestamp.Add(portal.Disappear.Timer), + }, }) } if resp.Pending { // Not exactly queued, but not finished either return EventHandlingResultQueued } - return EventHandlingResultSuccess.WithEventID(message.MXID).WithStreamOrder(resp.StreamOrder) + return EventHandlingResultSuccess } // AddPendingToIgnore adds a transaction ID that should be ignored if encountered as a new message. @@ -1538,7 +1305,7 @@ func (portal *Portal) handleMatrixEdit( return EventHandlingResultSuccess } -func (portal *Portal) handleMatrixReaction(ctx context.Context, sender *UserLogin, evt *event.Event) (handleRes EventHandlingResult) { +func (portal *Portal) handleMatrixReaction(ctx context.Context, sender *UserLogin, evt *event.Event) EventHandlingResult { log := zerolog.Ctx(ctx) reactingAPI, ok := sender.Client.(ReactionHandlingNetworkAPI) if !ok { @@ -1561,12 +1328,6 @@ func (portal *Portal) handleMatrixReaction(ctx context.Context, sender *UserLogi log.Warn().Msg("Reaction target message not found in database") return EventHandlingResultFailed.WithMSSError(fmt.Errorf("reaction %w", ErrTargetMessageNotFound)) } - caps := sender.Client.GetCapabilities(ctx, portal) - err = portal.autoAcceptMessageRequest(ctx, evt, sender, nil, caps) - if err != nil { - log.Warn().Err(err).Msg("Failed to auto-accept message request on reaction") - // TODO stop processing? - } log.UpdateContext(func(c zerolog.Context) zerolog.Context { return c.Str("reaction_target_remote_id", string(reactionTarget.ID)) }) @@ -1589,31 +1350,6 @@ func (portal *Portal) handleMatrixReaction(ctx context.Context, sender *UserLogi if portal.Bridge.Config.OutgoingMessageReID { deterministicID = portal.Bridge.Matrix.GenerateReactionEventID(portal.MXID, reactionTarget, preResp.SenderID, preResp.EmojiID) } - defer func() { - // Do this in a defer so that it happens after any potential defer calls to removeOutdatedReaction - if handleRes.Success { - portal.sendSuccessStatus(ctx, evt, 0, deterministicID) - } - }() - removeOutdatedReaction := func(oldReact *database.Reaction, deleteDB bool) { - if !handleRes.Success { - return - } - _, err := portal.Bridge.Bot.SendMessage(ctx, portal.MXID, event.EventRedaction, &event.Content{ - Parsed: &event.RedactionEventContent{ - Redacts: oldReact.MXID, - }, - }, nil) - if err != nil { - log.Err(err).Msg("Failed to remove old reaction") - } - if deleteDB { - err = portal.Bridge.DB.Reaction.Delete(ctx, oldReact) - if err != nil { - log.Err(err).Msg("Failed to delete old reaction from database") - } - } - } existing, err := portal.Bridge.DB.Reaction.GetByID(ctx, portal.Receiver, reactionTarget.ID, reactionTarget.PartID, preResp.SenderID, preResp.EmojiID) if err != nil { log.Err(err).Msg("Failed to check if reaction is a duplicate") @@ -1622,10 +1358,17 @@ func (portal *Portal) handleMatrixReaction(ctx context.Context, sender *UserLogi if existing.EmojiID != "" || existing.Emoji == preResp.Emoji { log.Debug().Msg("Ignoring duplicate reaction") portal.sendSuccessStatus(ctx, evt, 0, deterministicID) - return EventHandlingResultIgnored.WithEventID(deterministicID) + return EventHandlingResultIgnored } react.ReactionToOverride = existing - defer removeOutdatedReaction(existing, false) + _, err = portal.Bridge.Bot.SendMessage(ctx, portal.MXID, event.EventRedaction, &event.Content{ + Parsed: &event.RedactionEventContent{ + Redacts: existing.MXID, + }, + }, nil) + if err != nil { + log.Err(err).Msg("Failed to remove old reaction") + } } react.PreHandleResp = &preResp if preResp.MaxReactions > 0 { @@ -1640,14 +1383,18 @@ func (portal *Portal) handleMatrixReaction(ctx context.Context, sender *UserLogi // Keep n-1 previous reactions and remove the rest react.ExistingReactionsToKeep = allReactions[:preResp.MaxReactions-1] for _, oldReaction := range allReactions[preResp.MaxReactions-1:] { - if existing != nil && oldReaction.EmojiID == existing.EmojiID { - // Don't double-delete on networks that only allow one emoji - continue + _, err = portal.Bridge.Bot.SendMessage(ctx, portal.MXID, event.EventRedaction, &event.Content{ + Parsed: &event.RedactionEventContent{ + Redacts: oldReaction.MXID, + }, + }, nil) + if err != nil { + log.Err(err).Msg("Failed to remove previous reaction after limit was exceeded") + } + err = portal.Bridge.DB.Reaction.Delete(ctx, oldReaction) + if err != nil { + log.Err(err).Msg("Failed to delete previous reaction from database after limit was exceeded") } - // Intentionally defer in a loop, there won't be that many items, - // and we want all of them to be done after this function completes successfully - //goland:noinspection GoDeferInLoop - defer removeOutdatedReaction(oldReaction, true) } } } @@ -1692,7 +1439,8 @@ func (portal *Portal) handleMatrixReaction(ctx context.Context, sender *UserLogi if err != nil { log.Err(err).Msg("Failed to save reaction to database") } - return EventHandlingResultSuccess.WithEventID(deterministicID) + portal.sendSuccessStatus(ctx, evt, 0, deterministicID) + return EventHandlingResultSuccess } func handleMatrixRoomMeta[APIType any, ContentType any]( @@ -1701,19 +1449,11 @@ func handleMatrixRoomMeta[APIType any, ContentType any]( sender *UserLogin, origSender *OrigSender, evt *event.Event, - isStateRequest bool, fn func(APIType, context.Context, *MatrixRoomMeta[ContentType]) (bool, error), ) EventHandlingResult { - if evt.StateKey == nil || *evt.StateKey != "" { - return EventHandlingResultFailed.WithMSSError(ErrInvalidStateKey) - } - //caps := sender.Client.GetCapabilities(ctx, portal) - //if stateCap, ok := caps.State[evt.Type.Type]; !ok || stateCap.Level <= event.CapLevelUnsupported { - // return EventHandlingResultIgnored.WithMSSError(fmt.Errorf("%s %w", evt.Type.Type, ErrRoomMetadataNotAllowed)) - //} api, ok := sender.Client.(APIType) if !ok { - return EventHandlingResultIgnored.WithMSSError(fmt.Errorf("%w of type %s", ErrRoomMetadataNotSupported, evt.Type)) + return EventHandlingResultIgnored.WithMSSError(ErrRoomMetadataNotSupported) } log := zerolog.Ctx(ctx) content, ok := evt.Content.Parsed.(ContentType) @@ -1737,18 +1477,6 @@ func handleMatrixRoomMeta[APIType any, ContentType any]( portal.sendSuccessStatus(ctx, evt, 0, "") return EventHandlingResultIgnored } - case *event.BeeperDisappearingTimer: - if typedContent.Type == event.DisappearingTypeNone || typedContent.Timer.Duration <= 0 { - typedContent.Type = event.DisappearingTypeNone - typedContent.Timer.Duration = 0 - } - if typedContent.Type == portal.Disappear.Type && typedContent.Timer.Duration == portal.Disappear.Timer { - portal.sendSuccessStatus(ctx, evt, 0, "") - return EventHandlingResultIgnored - } - if !sender.Client.GetCapabilities(ctx, portal).DisappearingTimer.Supports(typedContent) { - return EventHandlingResultFailed.WithMSSError(ErrDisappearingTimerUnsupported) - } } var prevContent ContentType if evt.Unsigned.PrevContent != nil { @@ -1765,17 +1493,14 @@ func handleMatrixRoomMeta[APIType any, ContentType any]( InputTransactionID: portal.parseInputTransactionID(origSender, evt), }, - IsStateRequest: isStateRequest, - PrevContent: prevContent, + PrevContent: prevContent, }) if err != nil { log.Err(err).Msg("Failed to handle Matrix room metadata") return EventHandlingResultFailed.WithMSSError(err) } if changed { - if evt.Type != event.StateBeeperDisappearingTimer { - portal.UpdateBridgeInfo(ctx) - } + portal.UpdateBridgeInfo(ctx) err = portal.Save(ctx) if err != nil { log.Err(err).Msg("Failed to save portal after updating room metadata") @@ -1836,139 +1561,12 @@ func (portal *Portal) getTargetUser(ctx context.Context, userID id.UserID) (Ghos } } -func (portal *Portal) handleMatrixAcceptMessageRequest( - ctx context.Context, - sender *UserLogin, - origSender *OrigSender, - evt *event.Event, -) EventHandlingResult { - if origSender != nil { - return EventHandlingResultFailed.WithMSSError(ErrIgnoringAcceptRequestRelayedUser) - } - log := zerolog.Ctx(ctx) - content, ok := evt.Content.Parsed.(*event.BeeperAcceptMessageRequestEventContent) - if !ok { - log.Error().Type("content_type", evt.Content.Parsed).Msg("Unexpected parsed content type") - return EventHandlingResultFailed.WithMSSError(fmt.Errorf("%w: %T", ErrUnexpectedParsedContentType, evt.Content.Parsed)) - } - api, ok := sender.Client.(MessageRequestAcceptingNetworkAPI) - if !ok { - return EventHandlingResultIgnored.WithMSSError(ErrDeleteChatNotSupported) - } - err := api.HandleMatrixAcceptMessageRequest(ctx, &MatrixAcceptMessageRequest{ - Event: evt, - Content: content, - Portal: portal, - }) - if err != nil { - log.Err(err).Msg("Failed to handle Matrix accept message request") - return EventHandlingResultFailed.WithMSSError(err) - } - if portal.MessageRequest { - portal.MessageRequest = false - portal.UpdateBridgeInfo(ctx) - err = portal.Save(ctx) - if err != nil { - log.Err(err).Msg("Failed to save portal after accepting message request") - } - } - return EventHandlingResultSuccess.WithMSS() -} - -func (portal *Portal) autoAcceptMessageRequest( - ctx context.Context, evt *event.Event, sender *UserLogin, origSender *OrigSender, caps *event.RoomFeatures, -) error { - if !portal.MessageRequest || caps.MessageRequest == nil || caps.MessageRequest.AcceptWithMessage == event.CapLevelFullySupported { - return nil - } - mran, ok := sender.Client.(MessageRequestAcceptingNetworkAPI) - if !ok { - return nil - } - err := mran.HandleMatrixAcceptMessageRequest(ctx, &MatrixAcceptMessageRequest{ - Event: evt, - Content: &event.BeeperAcceptMessageRequestEventContent{ - IsImplicit: true, - }, - Portal: portal, - OrigSender: origSender, - }) - if err != nil { - return err - } - if portal.MessageRequest { - portal.MessageRequest = false - portal.UpdateBridgeInfo(ctx) - err = portal.Save(ctx) - if err != nil { - zerolog.Ctx(ctx).Err(err).Msg("Failed to save portal after accepting message request") - } - } - return nil -} - -func (portal *Portal) handleMatrixDeleteChat( - ctx context.Context, - sender *UserLogin, - origSender *OrigSender, - evt *event.Event, -) EventHandlingResult { - if origSender != nil { - return EventHandlingResultFailed.WithMSSError(ErrIgnoringDeleteChatRelayedUser) - } - log := zerolog.Ctx(ctx) - content, ok := evt.Content.Parsed.(*event.BeeperChatDeleteEventContent) - if !ok { - log.Error().Type("content_type", evt.Content.Parsed).Msg("Unexpected parsed content type") - return EventHandlingResultFailed.WithMSSError(fmt.Errorf("%w: %T", ErrUnexpectedParsedContentType, evt.Content.Parsed)) - } - api, ok := sender.Client.(DeleteChatHandlingNetworkAPI) - if !ok { - return EventHandlingResultIgnored.WithMSSError(ErrDeleteChatNotSupported) - } - err := api.HandleMatrixDeleteChat(ctx, &MatrixDeleteChat{ - Event: evt, - Content: content, - Portal: portal, - }) - if err != nil { - log.Err(err).Msg("Failed to handle Matrix chat delete") - return EventHandlingResultFailed.WithMSSError(err) - } - if portal.Receiver == "" { - _, others, err := portal.findOtherLogins(ctx, sender) - if err != nil { - log.Err(err).Msg("Failed to check if portal has other logins") - return EventHandlingResultFailed.WithError(err) - } else if len(others) > 0 { - log.Debug().Msg("Not deleting portal after chat delete as other logins are present") - return EventHandlingResultSuccess - } - } - err = portal.Delete(ctx) - if err != nil { - log.Err(err).Msg("Failed to delete portal from database") - return EventHandlingResultFailed.WithMSSError(err) - } - err = portal.Bridge.Bot.DeleteRoom(ctx, portal.MXID, false) - if err != nil { - log.Err(err).Msg("Failed to delete Matrix room") - return EventHandlingResultFailed.WithMSSError(err) - } - // No MSS here as the portal was deleted - return EventHandlingResultSuccess -} - func (portal *Portal) handleMatrixMembership( ctx context.Context, sender *UserLogin, origSender *OrigSender, evt *event.Event, - isStateRequest bool, ) EventHandlingResult { - if evt.StateKey == nil { - return EventHandlingResultFailed.WithMSSError(ErrInvalidStateKey) - } log := zerolog.Ctx(ctx) content, ok := evt.Content.Parsed.(*event.MemberEventContent) if !ok { @@ -2004,6 +1602,7 @@ func (portal *Portal) handleMatrixMembership( return EventHandlingResultIgnored //.WithMSSError(ErrIgnoringLeaveEvent) } targetGhost, _ := target.(*Ghost) + targetUserLogin, _ := target.(*UserLogin) membershipChange := &MatrixMembershipChange{ MatrixRoomMeta: MatrixRoomMeta[*event.MemberEventContent]{ MatrixEventBase: MatrixEventBase[*event.MemberEventContent]{ @@ -2014,60 +1613,19 @@ func (portal *Portal) handleMatrixMembership( InputTransactionID: portal.parseInputTransactionID(origSender, evt), }, - IsStateRequest: isStateRequest, - PrevContent: prevContent, + PrevContent: prevContent, }, - Target: target, - Type: membershipChangeType, + Target: target, + TargetGhost: targetGhost, + TargetUserLogin: targetUserLogin, + Type: membershipChangeType, } - res, err := api.HandleMatrixMembership(ctx, membershipChange) + _, err = api.HandleMatrixMembership(ctx, membershipChange) if err != nil { log.Err(err).Msg("Failed to handle Matrix membership change") return EventHandlingResultFailed.WithMSSError(err) } - didRedirectInvite := membershipChangeType == Invite && - targetGhost != nil && - res != nil && - res.RedirectTo != "" && - res.RedirectTo != targetGhost.ID - if didRedirectInvite { - log.Debug(). - Str("orig_id", string(targetGhost.ID)). - Str("redirect_id", string(res.RedirectTo)). - Msg("Invite was redirected to different ghost") - var redirectGhost *Ghost - redirectGhost, err = portal.Bridge.GetGhostByID(ctx, res.RedirectTo) - if err != nil { - log.Err(err).Msg("Failed to get redirect target ghost") - return EventHandlingResultFailed.WithError(err) - } - if !isStateRequest { - portal.sendRoomMeta( - ctx, - sender.User.DoublePuppet(ctx), - time.UnixMilli(evt.Timestamp), - event.StateMember, - evt.GetStateKey(), - &event.MemberEventContent{ - Membership: event.MembershipLeave, - Reason: fmt.Sprintf("Invite redirected to %s", res.RedirectTo), - }, - true, - nil, - ) - } - portal.sendRoomMeta( - ctx, - sender.User.DoublePuppet(ctx), - time.UnixMilli(evt.Timestamp), - event.StateMember, - redirectGhost.Intent.GetMXID().String(), - content, - false, - nil, - ) - } - return EventHandlingResultSuccess.WithMSS().WithSkipStateEcho(didRedirectInvite) + return EventHandlingResultSuccess.WithMSS() } func makePLChange(old, new int, newIsSet bool) *SinglePowerLevelChange { @@ -2092,11 +1650,7 @@ func (portal *Portal) handleMatrixPowerLevels( sender *UserLogin, origSender *OrigSender, evt *event.Event, - isStateRequest bool, ) EventHandlingResult { - if evt.StateKey == nil || *evt.StateKey != "" { - return EventHandlingResultFailed.WithMSSError(ErrInvalidStateKey) - } log := zerolog.Ctx(ctx) content, ok := evt.Content.Parsed.(*event.PowerLevelsEventContent) if !ok { @@ -2134,8 +1688,7 @@ func (portal *Portal) handleMatrixPowerLevels( InputTransactionID: portal.parseInputTransactionID(origSender, evt), }, - IsStateRequest: isStateRequest, - PrevContent: prevContent, + PrevContent: prevContent, }, Users: make(map[id.UserID]*UserPowerLevelChange), Events: make(map[string]*SinglePowerLevelChange), @@ -2267,136 +1820,42 @@ func (portal *Portal) handleMatrixTombstone(ctx context.Context, evt *event.Even return EventHandlingResultIgnored } } - err = portal.UpdateMatrixRoomID(ctx, content.ReplacementRoom, UpdateMatrixRoomIDParams{ - DeleteOldRoom: true, - FetchInfoVia: senderUser, - }) - if errors.Is(err, ErrTargetRoomIsPortal) { - return EventHandlingResultIgnored - } else if err != nil { - return EventHandlingResultFailed.WithError(err) - } - return EventHandlingResultSuccess -} -var ErrTargetRoomIsPortal = errors.New("target room is already a portal") -var ErrRoomAlreadyExists = errors.New("this portal already has a room") - -type UpdateMatrixRoomIDParams struct { - SyncDBMetadata func() - FailIfMXIDSet bool - OverwriteOldPortal bool - TombstoneOldRoom bool - DeleteOldRoom bool - - RoomCreateAlreadyLocked bool - - FetchInfoVia *User - ChatInfo *ChatInfo - ChatInfoSource *UserLogin -} - -func (portal *Portal) UpdateMatrixRoomID( - ctx context.Context, - newRoomID id.RoomID, - params UpdateMatrixRoomIDParams, -) error { - if !params.RoomCreateAlreadyLocked { - portal.roomCreateLock.Lock() - defer portal.roomCreateLock.Unlock() - } - oldRoom := portal.MXID - if oldRoom == newRoomID { - return nil - } else if oldRoom != "" && params.FailIfMXIDSet { - return ErrRoomAlreadyExists - } - log := zerolog.Ctx(ctx) portal.Bridge.cacheLock.Lock() - // Wrap unlock in a sync.OnceFunc because we want to both defer it to catch early returns - // and unlock it before return if nothing goes wrong. - unlockCacheLock := sync.OnceFunc(portal.Bridge.cacheLock.Unlock) - defer unlockCacheLock() - if existingPortal, alreadyExists := portal.Bridge.portalsByMXID[newRoomID]; alreadyExists && !params.OverwriteOldPortal { - log.Warn().Msg("Replacement room is already a portal, ignoring") - return ErrTargetRoomIsPortal - } else if alreadyExists { - log.Debug().Msg("Replacement room is already a portal, overwriting") - existingPortal.MXID = "" - existingPortal.RoomCreated.Clear() - err := existingPortal.Save(ctx) - if err != nil { - return fmt.Errorf("failed to clear mxid of existing portal: %w", err) - } - delete(portal.Bridge.portalsByMXID, portal.MXID) + if _, alreadyExists := portal.Bridge.portalsByMXID[content.ReplacementRoom]; alreadyExists { + log.Warn().Msg("Replacement room is already a portal, ignoring tombstone") + portal.Bridge.cacheLock.Unlock() + return EventHandlingResultIgnored } - portal.MXID = newRoomID - portal.RoomCreated.Set() + delete(portal.Bridge.portalsByMXID, portal.MXID) + portal.MXID = content.ReplacementRoom portal.Bridge.portalsByMXID[portal.MXID] = portal portal.NameSet = false portal.AvatarSet = false portal.TopicSet = false portal.InSpace = false portal.CapState = database.CapabilityState{} - portal.lastCapUpdate = time.Time{} - if params.SyncDBMetadata != nil { - params.SyncDBMetadata() - } - unlockCacheLock() - portal.updateLogger() + portal.Bridge.cacheLock.Unlock() - err := portal.Save(ctx) + err = portal.Save(ctx) if err != nil { - log.Err(err).Msg("Failed to save portal in UpdateMatrixRoomID") - return err + log.Err(err).Msg("Failed to save portal after tombstone") + return EventHandlingResultFailed.WithError(err) } log.Info().Msg("Successfully followed tombstone and updated portal MXID") err = portal.Bridge.DB.UserPortal.MarkAllNotInSpace(ctx, portal.PortalKey) if err != nil { - log.Err(err).Msg("Failed to update in_space flag for user portals after updating portal MXID") + log.Err(err).Msg("Failed to update in_space flag for user portals after tombstone") } go portal.addToUserSpaces(ctx) - if params.FetchInfoVia != nil { - go portal.updateInfoAfterTombstone(ctx, params.FetchInfoVia) - } else if params.ChatInfo != nil { - go portal.UpdateInfo(ctx, params.ChatInfo, params.ChatInfoSource, nil, time.Time{}) - } else if params.ChatInfoSource != nil { - portal.UpdateCapabilities(ctx, params.ChatInfoSource, true) - portal.UpdateBridgeInfo(ctx) - } + go portal.updateInfoAfterTombstone(ctx, senderUser) go func() { - // TODO this might become unnecessary if UpdateInfo starts taking care of it - _, err = portal.Bridge.Bot.SendState(ctx, portal.MXID, event.StateElementFunctionalMembers, "", &event.Content{ - Parsed: &event.ElementFunctionalMembersContent{ - ServiceMembers: []id.UserID{portal.Bridge.Bot.GetMXID()}, - }, - }, time.Time{}) + err = portal.Bridge.Bot.DeleteRoom(ctx, evt.RoomID, true) if err != nil { - if err != nil { - log.Warn().Err(err).Msg("Failed to set service members in new room") - } + log.Err(err).Msg("Failed to clean up Matrix room after following tombstone") } }() - if params.TombstoneOldRoom && oldRoom != "" { - _, err = portal.Bridge.Bot.SendState(ctx, portal.MXID, event.StateTombstone, "", &event.Content{ - Parsed: &event.TombstoneEventContent{ - Body: "Room has been replaced.", - ReplacementRoom: newRoomID, - }, - }, time.Now()) - if err != nil { - log.Err(err).Msg("Failed to send tombstone event to old room") - } - } - if params.DeleteOldRoom && oldRoom != "" { - go func() { - err = portal.Bridge.Bot.DeleteRoom(ctx, oldRoom, true) - if err != nil { - log.Err(err).Msg("Failed to clean up old Matrix room after updating portal MXID") - } - }() - } - return nil + return EventHandlingResultSuccess } func (portal *Portal) updateInfoAfterTombstone(ctx context.Context, senderUser *User) { @@ -2589,7 +2048,7 @@ func (portal *Portal) handleRemoteEvent(ctx context.Context, source *UserLogin, } func (portal *Portal) ensureFunctionalMember(ctx context.Context, ghost *Ghost) { - if !ghost.IsBot || portal.RoomType != database.RoomTypeDM || portal.OtherUserID == ghost.ID || portal.MXID == "" { + if !ghost.IsBot || portal.RoomType != database.RoomTypeDM || portal.OtherUserID == ghost.ID { return } ars, ok := portal.Bridge.Matrix.(MatrixConnectorWithArbitraryRoomState) @@ -2763,7 +2222,7 @@ func (portal *Portal) getRelationMeta( log.Err(err).Msg("Failed to get last thread message from database") } if prevThreadEvent == nil { - prevThreadEvent = ptr.Clone(threadRoot) + prevThreadEvent = threadRoot } } return @@ -2822,7 +2281,6 @@ func (portal *Portal) sendConvertedMessage( allSuccess := true for i, part := range converted.Parts { portal.applyRelationMeta(ctx, part.Content, replyTo, threadRoot, prevThreadEvent) - part.Content.BeeperDisappearingTimer = converted.Disappear.ToEventContent() dbMessage := &database.Message{ ID: id, PartID: part.ID, @@ -2867,14 +2325,13 @@ func (portal *Portal) sendConvertedMessage( logContext(log.Err(err)).Str("part_id", string(part.ID)).Msg("Failed to save message part to database") allSuccess = false } - if converted.Disappear.Type != event.DisappearingTypeNone && !dbMessage.HasFakeMXID() { - if converted.Disappear.Type == event.DisappearingTypeAfterSend && converted.Disappear.DisappearAt.IsZero() { + if converted.Disappear.Type != database.DisappearingTypeNone && !dbMessage.HasFakeMXID() { + if converted.Disappear.Type == database.DisappearingTypeAfterSend && converted.Disappear.DisappearAt.IsZero() { converted.Disappear.DisappearAt = dbMessage.Timestamp.Add(converted.Disappear.Timer) } portal.Bridge.DisappearLoop.Add(ctx, &database.DisappearingMessage{ RoomID: portal.MXID, EventID: dbMessage.MXID, - Timestamp: dbMessage.Timestamp, DisappearingSetting: converted.Disappear, }) } @@ -3543,11 +3000,11 @@ func (portal *Portal) handleRemoteMessageRemove(ctx context.Context, source *Use onlyForMeProvider, ok := evt.(RemoteDeleteOnlyForMe) onlyForMe := ok && onlyForMeProvider.DeleteOnlyForMe() if onlyForMe && portal.Receiver == "" { - _, others, err := portal.findOtherLogins(ctx, source) + logins, err := portal.Bridge.DB.UserPortal.GetAllInPortal(ctx, portal.PortalKey) if err != nil { log.Err(err).Msg("Failed to check if portal has other logins") return EventHandlingResultFailed.WithError(err) - } else if len(others) > 0 { + } else if len(logins) > 1 { log.Debug().Msg("Ignoring delete for me event in portal with multiple logins") return EventHandlingResultIgnored } @@ -3661,15 +3118,11 @@ func (portal *Portal) handleRemoteReadReceipt(ctx context.Context, source *UserL return evt.Int64("target_stream_order", targetStreamOrder) } err = soIntent.MarkStreamOrderRead(ctx, portal.MXID, targetStreamOrder, getEventTS(evt)) - if readUpTo.IsZero() { - readUpTo = getEventTS(evt) - } } else { addTargetLog = func(evt *zerolog.Event) *zerolog.Event { return evt.Stringer("target_mxid", lastTarget.MXID) } err = intent.MarkRead(ctx, portal.MXID, lastTarget.MXID, getEventTS(evt)) - readUpTo = lastTarget.Timestamp } if err != nil { addTargetLog(log.Err(err)).Msg("Failed to bridge read receipt") @@ -3678,7 +3131,7 @@ func (portal *Portal) handleRemoteReadReceipt(ctx context.Context, source *UserL addTargetLog(log.Debug()).Msg("Bridged read receipt") } if sender.IsFromMe { - portal.Bridge.DisappearLoop.StartAllBefore(ctx, portal.MXID, readUpTo) + portal.Bridge.DisappearLoop.StartAll(ctx, portal.MXID) } return EventHandlingResultSuccess } @@ -3701,7 +3154,7 @@ func (portal *Portal) handleRemoteMarkUnread(ctx context.Context, source *UserLo } func (portal *Portal) handleRemoteDeliveryReceipt(ctx context.Context, source *UserLogin, evt RemoteDeliveryReceipt) EventHandlingResult { - if portal.RoomType != database.RoomTypeDM || (evt.GetSender().Sender != portal.OtherUserID && portal.OtherUserID != "") { + if portal.RoomType != database.RoomTypeDM || evt.GetSender().Sender != portal.OtherUserID { return EventHandlingResultIgnored } intent, ok := portal.GetIntentFor(ctx, evt.GetSender(), source, RemoteEventDeliveryReceipt) @@ -3800,43 +3253,22 @@ func (portal *Portal) handleRemoteChatResync(ctx context.Context, source *UserLo return EventHandlingResultSuccess } -func (portal *Portal) findOtherLogins(ctx context.Context, source *UserLogin) (ownUP *database.UserPortal, others []*database.UserPortal, err error) { - others, err = portal.Bridge.DB.UserPortal.GetAllInPortal(ctx, portal.PortalKey) - if err != nil { - return - } - others = slices.DeleteFunc(others, func(up *database.UserPortal) bool { - if up.LoginID == source.ID { - ownUP = up - return true - } - return false - }) - return -} - -type childDeleteProxy struct { - RemoteChatDeleteWithChildren - child networkid.PortalKey - done func() -} - -func (cdp *childDeleteProxy) AddLogContext(c zerolog.Context) zerolog.Context { - return cdp.RemoteChatDeleteWithChildren.AddLogContext(c).Str("subaction", "delete children") -} -func (cdp *childDeleteProxy) GetPortalKey() networkid.PortalKey { return cdp.child } -func (cdp *childDeleteProxy) ShouldCreatePortal() bool { return false } -func (cdp *childDeleteProxy) PreHandle(ctx context.Context, portal *Portal) {} -func (cdp *childDeleteProxy) PostHandle(ctx context.Context, portal *Portal) { cdp.done() } - func (portal *Portal) handleRemoteChatDelete(ctx context.Context, source *UserLogin, evt RemoteChatDelete) EventHandlingResult { log := zerolog.Ctx(ctx) if portal.Receiver == "" && evt.DeleteOnlyForMe() { - ownUP, logins, err := portal.findOtherLogins(ctx, source) + logins, err := portal.Bridge.DB.UserPortal.GetAllInPortal(ctx, portal.PortalKey) if err != nil { log.Err(err).Msg("Failed to check if portal has other logins") return EventHandlingResultFailed.WithError(err) } + var ownUP *database.UserPortal + logins = slices.DeleteFunc(logins, func(up *database.UserPortal) bool { + if up.LoginID == source.ID { + ownUP = up + return true + } + return false + }) if len(logins) > 0 { log.Debug().Msg("Not deleting portal with other logins in remote chat delete event") if ownUP != nil { @@ -3864,31 +3296,6 @@ func (portal *Portal) handleRemoteChatDelete(ctx context.Context, source *UserLo } } } - if childDeleter, ok := evt.(RemoteChatDeleteWithChildren); ok && childDeleter.DeleteChildren() && portal.RoomType == database.RoomTypeSpace { - children, err := portal.Bridge.GetChildPortals(ctx, portal.PortalKey) - if err != nil { - log.Err(err).Msg("Failed to fetch children to delete") - return EventHandlingResultFailed.WithError(err) - } - log.Debug(). - Int("portal_count", len(children)). - Msg("Deleting child portals before remote chat delete") - var wg sync.WaitGroup - wg.Add(len(children)) - for _, child := range children { - child.queueEvent(ctx, &portalRemoteEvent{ - evt: &childDeleteProxy{ - RemoteChatDeleteWithChildren: childDeleter, - child: child.PortalKey, - done: wg.Done, - }, - source: source, - evtType: RemoteEventChatDelete, - }) - } - wg.Wait() - log.Debug().Msg("Finished deleting child portals") - } err := portal.Delete(ctx) if err != nil { log.Err(err).Msg("Failed to delete portal from database") @@ -3944,43 +3351,12 @@ type PortalInfo = ChatInfo type ChatMember struct { EventSender Membership event.Membership - // Per-room nickname for the user. Not yet used. - Nickname *string - // The power level to set for the user when syncing power levels. + Nickname *string PowerLevel *int - // Optional user info to sync the ghost user while updating membership. - UserInfo *UserInfo - // The user who sent the membership change (user who invited/kicked/banned this user). - // Not yet used. Not applicable if Membership is join or knock. - MemberSender EventSender - // Extra fields to include in the member event. + UserInfo *UserInfo + MemberEventExtra map[string]any - // The expected previous membership. If this doesn't match, the change is ignored. - PrevMembership event.Membership -} - -type ChatMemberMap map[networkid.UserID]ChatMember - -// Set adds the given entry to this map, overwriting any existing entry with the same Sender field. -func (cmm ChatMemberMap) Set(member ChatMember) ChatMemberMap { - if member.Sender == "" && member.SenderLogin == "" && !member.IsFromMe { - return cmm - } - cmm[member.Sender] = member - return cmm -} - -// Add adds the given entry to this map, but will ignore it if an entry with the same Sender field already exists. -// It returns true if the entry was added, false otherwise. -func (cmm ChatMemberMap) Add(member ChatMember) bool { - if member.Sender == "" && member.SenderLogin == "" && !member.IsFromMe { - return false - } - if _, exists := cmm[member.Sender]; exists { - return false - } - cmm[member.Sender] = member - return true + PrevMembership event.Membership } type ChatMemberList struct { @@ -3990,10 +3366,6 @@ type ChatMemberList struct { // Should the bridge call IsThisUser for every member in the list? // This should be used when SenderLogin can't be filled accurately. CheckAllLogins bool - // Should any changes have the `com.beeper.exclude_from_timeline` flag set by default? - // This is recommended for syncs with non-real-time changes. - // Real-time changes (e.g. a user joining) should not set this flag set. - ExcludeChangesFromTimeline bool // The total number of members in the chat, regardless of how many of those members are included in MemberMap. TotalMemberCount int @@ -4004,7 +3376,7 @@ type ChatMemberList struct { // Deprecated: Use MemberMap instead to avoid duplicate entries Members []ChatMember - MemberMap ChatMemberMap + MemberMap map[networkid.UserID]ChatMember PowerLevels *PowerLevelOverrides } @@ -4106,11 +3478,9 @@ type ChatInfo struct { Disappear *database.DisappearingSetting ParentID *networkid.PortalID - UserLocal *UserLocalPortalInfo - MessageRequest *bool - CanBackfill bool + UserLocal *UserLocalPortalInfo - ExcludeChangesFromTimeline bool + CanBackfill bool ExtraUpdates ExtraUpdater[*Portal] } @@ -4144,35 +3514,25 @@ type UserLocalPortalInfo struct { Tag *event.RoomTag } -func (portal *Portal) updateName( - ctx context.Context, name string, sender MatrixAPI, ts time.Time, excludeFromTimeline bool, -) bool { +func (portal *Portal) updateName(ctx context.Context, name string, sender MatrixAPI, ts time.Time) bool { if portal.Name == name && (portal.NameSet || portal.MXID == "") { return false } portal.Name = name - portal.NameSet = portal.sendRoomMeta( - ctx, sender, ts, event.StateRoomName, "", &event.RoomNameEventContent{Name: name}, excludeFromTimeline, nil, - ) + portal.NameSet = portal.sendRoomMeta(ctx, sender, ts, event.StateRoomName, "", &event.RoomNameEventContent{Name: name}) return true } -func (portal *Portal) updateTopic( - ctx context.Context, topic string, sender MatrixAPI, ts time.Time, excludeFromTimeline bool, -) bool { +func (portal *Portal) updateTopic(ctx context.Context, topic string, sender MatrixAPI, ts time.Time) bool { if portal.Topic == topic && (portal.TopicSet || portal.MXID == "") { return false } portal.Topic = topic - portal.TopicSet = portal.sendRoomMeta( - ctx, sender, ts, event.StateTopic, "", &event.TopicEventContent{Topic: topic}, excludeFromTimeline, nil, - ) + portal.TopicSet = portal.sendRoomMeta(ctx, sender, ts, event.StateTopic, "", &event.TopicEventContent{Topic: topic}) return true } -func (portal *Portal) updateAvatar( - ctx context.Context, avatar *Avatar, sender MatrixAPI, ts time.Time, excludeFromTimeline bool, -) bool { +func (portal *Portal) updateAvatar(ctx context.Context, avatar *Avatar, sender MatrixAPI, ts time.Time) bool { if portal.AvatarID == avatar.ID && (avatar.Remove || portal.AvatarMXC != "") && (portal.AvatarSet || portal.MXID == "") { return false } @@ -4195,9 +3555,7 @@ func (portal *Portal) updateAvatar( portal.AvatarMXC = newMXC portal.AvatarHash = newHash } - portal.AvatarSet = portal.sendRoomMeta( - ctx, sender, ts, event.StateRoomAvatar, "", &event.RoomAvatarEventContent{URL: portal.AvatarMXC}, excludeFromTimeline, nil, - ) + portal.AvatarSet = portal.sendRoomMeta(ctx, sender, ts, event.StateRoomAvatar, "", &event.RoomAvatarEventContent{URL: portal.AvatarMXC}) return true } @@ -4228,11 +3586,10 @@ func (portal *Portal) getBridgeInfo() (string, event.BridgeEventContent) { Creator: portal.Bridge.Bot.GetMXID(), Protocol: portal.Bridge.Network.GetName().AsBridgeInfoSection(), Channel: event.BridgeInfoSection{ - ID: string(portal.ID), - DisplayName: portal.Name, - AvatarURL: portal.AvatarMXC, - Receiver: string(portal.Receiver), - MessageRequest: portal.MessageRequest, + ID: string(portal.ID), + DisplayName: portal.Name, + AvatarURL: portal.AvatarMXC, + Receiver: string(portal.Receiver), // TODO external URL? }, BeeperRoomTypeV2: string(portal.RoomType), @@ -4240,10 +3597,6 @@ func (portal *Portal) getBridgeInfo() (string, event.BridgeEventContent) { if portal.RoomType == database.RoomTypeDM || portal.RoomType == database.RoomTypeGroupDM { bridgeInfo.BeeperRoomType = "dm" } - if bridgeInfo.Protocol.ID == "slackgo" { - bridgeInfo.TempSlackRemoteIDMigratedFlag = true - bridgeInfo.TempSlackRemoteIDMigratedFlag2 = true - } parent := portal.GetTopLevelParent() if parent != nil { bridgeInfo.Network = &event.BridgeInfoSection{ @@ -4265,8 +3618,8 @@ func (portal *Portal) UpdateBridgeInfo(ctx context.Context) { return } stateKey, bridgeInfo := portal.getBridgeInfo() - portal.sendRoomMeta(ctx, nil, time.Now(), event.StateBridge, stateKey, &bridgeInfo, false, nil) - portal.sendRoomMeta(ctx, nil, time.Now(), event.StateHalfShotBridge, stateKey, &bridgeInfo, false, nil) + portal.sendRoomMeta(ctx, nil, time.Now(), event.StateBridge, stateKey, &bridgeInfo) + portal.sendRoomMeta(ctx, nil, time.Now(), event.StateHalfShotBridge, stateKey, &bridgeInfo) } func (portal *Portal) UpdateCapabilities(ctx context.Context, source *UserLogin, implicit bool) bool { @@ -4288,22 +3641,13 @@ func (portal *Portal) UpdateCapabilities(ctx context.Context, source *UserLogin, Str("old_id", portal.CapState.ID). Str("new_id", capID). Msg("Sending new room capability event") - success := portal.sendRoomMeta(ctx, nil, time.Now(), event.StateBeeperRoomFeatures, portal.getBridgeInfoStateKey(), caps, false, nil) + success := portal.sendRoomMeta(ctx, nil, time.Now(), event.StateBeeperRoomFeatures, portal.getBridgeInfoStateKey(), caps) if !success { return false } portal.CapState = database.CapabilityState{ Source: source.ID, ID: capID, - Flags: portal.CapState.Flags, - } - if caps.DisappearingTimer != nil && !portal.CapState.Flags.Has(database.CapStateFlagDisappearingTimerSet) { - zerolog.Ctx(ctx).Debug().Msg("Disappearing timer capability was added, sending disappearing timer state event") - success = portal.sendRoomMeta(ctx, nil, time.Now(), event.StateBeeperDisappearingTimer, "", portal.Disappear.ToEventContent(), true, nil) - if !success { - return false - } - portal.CapState.Flags |= database.CapStateFlagDisappearingTimerSet } portal.lastCapUpdate = time.Now() if implicit { @@ -4330,27 +3674,15 @@ func (portal *Portal) sendStateWithIntentOrBot(ctx context.Context, sender Matri return } -func (portal *Portal) sendRoomMeta( - ctx context.Context, - sender MatrixAPI, - ts time.Time, - eventType event.Type, - stateKey string, - content any, - excludeFromTimeline bool, - extra map[string]any, -) bool { +func (portal *Portal) sendRoomMeta(ctx context.Context, sender MatrixAPI, ts time.Time, eventType event.Type, stateKey string, content any) bool { if portal.MXID == "" { return false } - if extra == nil { - extra = make(map[string]any) - } - if excludeFromTimeline { - extra["com.beeper.exclude_from_timeline"] = true - } + var extra map[string]any if !portal.NameIsCustom && (eventType == event.StateRoomName || eventType == event.StateRoomAvatar) { - extra["fi.mau.implicit_name"] = true + extra = map[string]any{ + "fi.mau.implicit_name": true, + } } _, err := portal.sendStateWithIntentOrBot(ctx, sender, eventType, stateKey, &event.Content{ Parsed: content, @@ -4362,55 +3694,9 @@ func (portal *Portal) sendRoomMeta( Msg("Failed to set room metadata") return false } - if eventType == event.StateBeeperDisappearingTimer { - // TODO remove this debug log at some point - zerolog.Ctx(ctx).Debug(). - Any("content", content). - Msg("Sent new disappearing timer event") - } return true } -func (portal *Portal) revertRoomMeta(ctx context.Context, evt *event.Event) { - if !portal.Bridge.Config.RevertFailedStateChanges { - return - } - if evt.GetStateKey() != "" && evt.Type != event.StateMember { - return - } - switch evt.Type { - case event.StateRoomName: - portal.sendRoomMeta(ctx, nil, time.Time{}, event.StateRoomName, "", &event.RoomNameEventContent{Name: portal.Name}, true, nil) - case event.StateRoomAvatar: - portal.sendRoomMeta(ctx, nil, time.Time{}, event.StateRoomAvatar, "", &event.RoomAvatarEventContent{URL: portal.AvatarMXC}, true, nil) - case event.StateTopic: - portal.sendRoomMeta(ctx, nil, time.Time{}, event.StateTopic, "", &event.TopicEventContent{Topic: portal.Topic}, true, nil) - case event.StateBeeperDisappearingTimer: - portal.sendRoomMeta(ctx, nil, time.Time{}, event.StateBeeperDisappearingTimer, "", portal.Disappear.ToEventContent(), true, nil) - case event.StateMember: - var prevContent *event.MemberEventContent - var extra map[string]any - if evt.Unsigned.PrevContent != nil { - _ = evt.Unsigned.PrevContent.ParseRaw(evt.Type) - prevContent = evt.Unsigned.PrevContent.AsMember() - newContent := evt.Content.AsMember() - if prevContent.Membership == newContent.Membership { - return - } - extra = evt.Unsigned.PrevContent.Raw - } else { - prevContent = &event.MemberEventContent{Membership: event.MembershipLeave} - } - if portal.Bridge.Matrix.GetCapabilities().ArbitraryMemberChange { - if extra == nil { - extra = make(map[string]any) - } - extra["com.beeper.member_rollback"] = true - portal.sendRoomMeta(ctx, nil, time.Time{}, event.StateMember, evt.GetStateKey(), prevContent, true, extra) - } - } -} - func (portal *Portal) getInitialMemberList(ctx context.Context, members *ChatMemberList, source *UserLogin, pl *event.PowerLevelsEventContent) (invite, functional []id.UserID, err error) { if members == nil { invite = []id.UserID{source.UserMXID} @@ -4494,39 +3780,6 @@ func (portal *Portal) updateOtherUser(ctx context.Context, members *ChatMemberLi return false } -func looksDirectlyJoinable(rule *event.JoinRulesEventContent) bool { - switch rule.JoinRule { - case event.JoinRulePublic: - return true - case event.JoinRuleKnockRestricted, event.JoinRuleRestricted: - for _, allow := range rule.Allow { - if allow.Type == "fi.mau.spam_checker" { - return true - } - } - } - return false -} - -func (portal *Portal) roomIsPublic(ctx context.Context) bool { - mx, ok := portal.Bridge.Matrix.(MatrixConnectorWithArbitraryRoomState) - if !ok { - return false - } - evt, err := mx.GetStateEvent(ctx, portal.MXID, event.StateJoinRules, "") - if err != nil { - zerolog.Ctx(ctx).Warn().Err(err).Msg("Failed to get join rules to check if room is public") - return false - } else if evt == nil { - return false - } - content, ok := evt.Content.Parsed.(*event.JoinRulesEventContent) - if !ok { - return false - } - return looksDirectlyJoinable(content) -} - func (portal *Portal) syncParticipants( ctx context.Context, members *ChatMemberList, @@ -4557,12 +3810,6 @@ func (portal *Portal) syncParticipants( } delete(currentMembers, portal.Bridge.Bot.GetMXID()) powerChanged := members.PowerLevels.Apply(portal.Bridge.Bot.GetMXID(), currentPower) - addExcludeFromTimeline := func(raw map[string]any) { - _, hasKey := raw["com.beeper.exclude_from_timeline"] - if !hasKey && members.ExcludeChangesFromTimeline { - raw["com.beeper.exclude_from_timeline"] = true - } - } syncUser := func(extraUserID id.UserID, member ChatMember, intent MatrixAPI) bool { if member.Membership == "" { member.Membership = event.MembershipJoin @@ -4592,10 +3839,12 @@ func (portal *Portal) syncParticipants( Displayname: currentMember.Displayname, AvatarURL: currentMember.AvatarURL, } - wrappedContent := &event.Content{Parsed: content, Raw: exmaps.NonNilClone(member.MemberEventExtra)} - addExcludeFromTimeline(wrappedContent.Raw) + wrappedContent := &event.Content{Parsed: content, Raw: maps.Clone(member.MemberEventExtra)} + if wrappedContent.Raw == nil { + wrappedContent.Raw = make(map[string]any) + } thisEvtSender := sender - if member.Membership == event.MembershipJoin && (intent == nil || !portal.roomIsPublic(ctx)) { + if member.Membership == event.MembershipJoin { content.Membership = event.MembershipInvite if intent != nil { wrappedContent.Raw["fi.mau.will_auto_accept"] = true @@ -4625,11 +3874,7 @@ func (portal *Portal) syncParticipants( currentMember.Membership = event.MembershipLeave } } - if content.Membership == event.MembershipJoin && intent != nil && intent.GetMXID() == extraUserID { - _, err = intent.SendState(ctx, portal.MXID, event.StateMember, extraUserID.String(), wrappedContent, ts) - } else { - _, err = portal.sendStateWithIntentOrBot(ctx, thisEvtSender, event.StateMember, extraUserID.String(), wrappedContent, ts) - } + _, err = portal.sendStateWithIntentOrBot(ctx, thisEvtSender, event.StateMember, extraUserID.String(), wrappedContent, ts) if err != nil { addLogContext(log.Err(err)). Str("new_membership", string(content.Membership)). @@ -4642,8 +3887,7 @@ func (portal *Portal) syncParticipants( if intent != nil && content.Membership == event.MembershipInvite && member.Membership == event.MembershipJoin { content.Membership = event.MembershipJoin - wrappedJoinContent := &event.Content{Parsed: content, Raw: exmaps.NonNilClone(member.MemberEventExtra)} - addExcludeFromTimeline(wrappedContent.Raw) + wrappedJoinContent := &event.Content{Parsed: content, Raw: member.MemberEventExtra} _, err = intent.SendState(ctx, portal.MXID, event.StateMember, intent.GetMXID().String(), wrappedJoinContent, ts) if err != nil { addLogContext(log.Err(err)). @@ -4706,7 +3950,7 @@ func (portal *Portal) syncParticipants( if memberEvt.Membership == event.MembershipLeave || memberEvt.Membership == event.MembershipBan { continue } - if !portal.Bridge.IsGhostMXID(extraMember) && (portal.Relay != nil || !portal.Bridge.Config.KickMatrixUsers) { + if !portal.Bridge.IsGhostMXID(extraMember) && portal.Relay != nil { continue } _, err = portal.Bridge.Bot.SendState(ctx, portal.MXID, event.StateMember, extraMember.String(), &event.Content{ @@ -4716,9 +3960,6 @@ func (portal *Portal) syncParticipants( Displayname: memberEvt.Displayname, Reason: "User is not in remote chat", }, - Raw: map[string]any{ - "com.beeper.exclude_from_timeline": members.ExcludeChangesFromTimeline, - }, }, time.Now()) if err != nil { zerolog.Ctx(ctx).Err(err). @@ -4787,28 +4028,16 @@ func DisappearingMessageNotice(expiration time.Duration, implicit bool) *event.M return content } -type UpdateDisappearingSettingOpts struct { - Sender MatrixAPI - Timestamp time.Time - Implicit bool - Save bool - SendNotice bool - - ExcludeFromTimeline bool -} - -func (portal *Portal) UpdateDisappearingSetting( - ctx context.Context, - setting database.DisappearingSetting, - opts UpdateDisappearingSettingOpts, -) bool { - setting = setting.Normalize() +func (portal *Portal) UpdateDisappearingSetting(ctx context.Context, setting database.DisappearingSetting, sender MatrixAPI, ts time.Time, implicit, save bool) bool { + if setting.Timer == 0 { + setting.Type = "" + } if portal.Disappear.Timer == setting.Timer && portal.Disappear.Type == setting.Type { return false } portal.Disappear.Type = setting.Type portal.Disappear.Timer = setting.Timer - if opts.Save { + if save { err := portal.Save(ctx) if err != nil { zerolog.Ctx(ctx).Err(err).Msg("Failed to save portal to database after updating disappearing setting") @@ -4817,45 +4046,19 @@ func (portal *Portal) UpdateDisappearingSetting( if portal.MXID == "" { return true } - - if opts.Sender == nil { - opts.Sender = portal.Bridge.Bot + content := DisappearingMessageNotice(setting.Timer, implicit) + if sender == nil { + sender = portal.Bridge.Bot } - if opts.Timestamp.IsZero() { - opts.Timestamp = time.Now() - } - portal.sendRoomMeta( - ctx, - opts.Sender, - opts.Timestamp, - event.StateBeeperDisappearingTimer, - "", - setting.ToEventContent(), - opts.ExcludeFromTimeline, - nil, - ) - - if !opts.SendNotice { - return true - } - content := DisappearingMessageNotice(setting.Timer, opts.Implicit) - _, err := opts.Sender.SendMessage(ctx, portal.MXID, event.EventMessage, &event.Content{ + _, err := sender.SendMessage(ctx, portal.MXID, event.EventMessage, &event.Content{ Parsed: content, - Raw: map[string]any{ - "com.beeper.action_message": map[string]any{ - "type": "disappearing_timer", - "timer": setting.Timer.Milliseconds(), - "timer_type": setting.Type, - "implicit": opts.Implicit, - }, - }, - }, &MatrixSendExtra{Timestamp: opts.Timestamp}) + }, &MatrixSendExtra{Timestamp: ts}) if err != nil { zerolog.Ctx(ctx).Err(err).Msg("Failed to send disappearing messages notice") } else { zerolog.Ctx(ctx).Debug(). Dur("new_timer", portal.Disappear.Timer). - Bool("implicit", opts.Implicit). + Bool("implicit", implicit). Msg("Sent disappearing messages notice") } return true @@ -4917,13 +4120,13 @@ func (portal *Portal) UpdateInfoFromGhost(ctx context.Context, ghost *Ghost) (ch return } } - changed = portal.updateName(ctx, ghost.Name, nil, time.Time{}, false) || changed + changed = portal.updateName(ctx, ghost.Name, nil, time.Time{}) || changed changed = portal.updateAvatar(ctx, &Avatar{ ID: ghost.AvatarID, MXC: ghost.AvatarMXC, Hash: ghost.AvatarHash, Remove: ghost.AvatarID == "", - }, nil, time.Time{}, false) || changed + }, nil, time.Time{}) || changed return } @@ -4932,36 +4135,28 @@ func (portal *Portal) UpdateInfo(ctx context.Context, info *ChatInfo, source *Us if info.Name == DefaultChatName { if portal.NameIsCustom { portal.NameIsCustom = false - changed = portal.updateName(ctx, "", sender, ts, info.ExcludeChangesFromTimeline) || changed + changed = portal.updateName(ctx, "", sender, ts) || changed } } else if info.Name != nil { portal.NameIsCustom = true - changed = portal.updateName(ctx, *info.Name, sender, ts, info.ExcludeChangesFromTimeline) || changed + changed = portal.updateName(ctx, *info.Name, sender, ts) || changed } if info.Topic != nil { - changed = portal.updateTopic(ctx, *info.Topic, sender, ts, info.ExcludeChangesFromTimeline) || changed + changed = portal.updateTopic(ctx, *info.Topic, sender, ts) || changed } if info.Avatar != nil { portal.NameIsCustom = true - changed = portal.updateAvatar(ctx, info.Avatar, sender, ts, info.ExcludeChangesFromTimeline) || changed + changed = portal.updateAvatar(ctx, info.Avatar, sender, ts) || changed } if info.Disappear != nil { - changed = portal.UpdateDisappearingSetting(ctx, *info.Disappear, UpdateDisappearingSettingOpts{ - Sender: sender, - Timestamp: ts, - Implicit: false, - Save: false, - - SendNotice: !info.ExcludeChangesFromTimeline, - ExcludeFromTimeline: info.ExcludeChangesFromTimeline, - }) || changed + changed = portal.UpdateDisappearingSetting(ctx, *info.Disappear, sender, ts, false, false) || changed } if info.ParentID != nil { changed = portal.updateParent(ctx, *info.ParentID, source) || changed } if info.JoinRule != nil { // TODO change detection instead of spamming this every time? - portal.sendRoomMeta(ctx, sender, ts, event.StateJoinRules, "", info.JoinRule, info.ExcludeChangesFromTimeline, nil) + portal.sendRoomMeta(ctx, sender, ts, event.StateJoinRules, "", info.JoinRule) } if info.Type != nil && portal.RoomType != *info.Type { if portal.MXID != "" && (*info.Type == database.RoomTypeSpace || portal.RoomType == database.RoomTypeSpace) { @@ -4974,10 +4169,6 @@ func (portal *Portal) UpdateInfo(ctx context.Context, info *ChatInfo, source *Us portal.RoomType = *info.Type } } - if info.MessageRequest != nil && *info.MessageRequest != portal.MessageRequest { - changed = true - portal.MessageRequest = *info.MessageRequest - } if info.Members != nil && portal.MXID != "" && source != nil { err := portal.syncParticipants(ctx, info.Members, source, nil, time.Time{}) if err != nil { @@ -5019,9 +4210,6 @@ func (portal *Portal) CreateMatrixRoom(ctx context.Context, source *UserLogin, i } return nil } - if portal.deleted.IsSet() { - return ErrPortalIsDeleted - } waiter := make(chan struct{}) closed := false evt := &portalCreateEvent{ @@ -5039,11 +4227,7 @@ func (portal *Portal) CreateMatrixRoom(ctx context.Context, source *UserLogin, i if PortalEventBuffer == 0 { go portal.queueEvent(ctx, evt) } else { - select { - case portal.events <- evt: - case <-portal.deleted.GetChan(): - return ErrPortalIsDeleted - } + portal.events <- evt } select { case <-ctx.Done(): @@ -5054,11 +4238,7 @@ func (portal *Portal) CreateMatrixRoom(ctx context.Context, source *UserLogin, i } func (portal *Portal) createMatrixRoomInLoop(ctx context.Context, source *UserLogin, info *ChatInfo, backfillBundle any) error { - cancellableCtx, cancel := context.WithCancel(ctx) - defer cancel() - portal.cancelRoomCreate.CompareAndSwap(nil, &cancel) portal.roomCreateLock.Lock() - portal.cancelRoomCreate.Store(&cancel) defer portal.roomCreateLock.Unlock() if portal.MXID != "" { if source != nil { @@ -5069,7 +4249,6 @@ func (portal *Portal) createMatrixRoomInLoop(ctx context.Context, source *UserLo log := zerolog.Ctx(ctx).With(). Str("action", "create matrix room"). Logger() - cancellableCtx = log.WithContext(cancellableCtx) ctx = log.WithContext(ctx) log.Info().Msg("Creating Matrix room") @@ -5078,16 +4257,16 @@ func (portal *Portal) createMatrixRoomInLoop(ctx context.Context, source *UserLo if info != nil { log.Warn().Msg("CreateMatrixRoom got info without members. Refetching info") } - info, err = source.Client.GetChatInfo(cancellableCtx, portal) + info, err = source.Client.GetChatInfo(ctx, portal) if err != nil { log.Err(err).Msg("Failed to update portal info for creation") return err } } - portal.UpdateInfo(cancellableCtx, info, source, nil, time.Time{}) - if cancellableCtx.Err() != nil { - return cancellableCtx.Err() + portal.UpdateInfo(ctx, info, source, nil, time.Time{}) + if ctx.Err() != nil { + return ctx.Err() } powerLevels := &event.PowerLevelsEventContent{ @@ -5100,7 +4279,7 @@ func (portal *Portal) createMatrixRoomInLoop(ctx context.Context, source *UserLo portal.Bridge.Bot.GetMXID(): 9001, }, } - initialMembers, extraFunctionalMembers, err := portal.getInitialMemberList(cancellableCtx, info.Members, source, powerLevels) + initialMembers, extraFunctionalMembers, err := portal.getInitialMemberList(ctx, info.Members, source, powerLevels) if err != nil { log.Err(err).Msg("Failed to process participant list for portal creation") return err @@ -5109,12 +4288,15 @@ func (portal *Portal) createMatrixRoomInLoop(ctx context.Context, source *UserLo req := mautrix.ReqCreateRoom{ Visibility: "private", + Name: portal.Name, + Topic: portal.Topic, CreationContent: make(map[string]any), InitialState: make([]*event.Event, 0, 6), Preset: "private_chat", IsDirect: portal.RoomType == database.RoomTypeDM, PowerLevelOverride: powerLevels, BeeperLocalRoomID: portal.Bridge.Matrix.GenerateDeterministicRoomID(portal.PortalKey), + RoomVersion: id.RoomV11, } autoJoinInvites := portal.Bridge.Matrix.GetCapabilities().AutoJoinInvites if autoJoinInvites { @@ -5127,7 +4309,7 @@ func (portal *Portal) createMatrixRoomInLoop(ctx context.Context, source *UserLo req.CreationContent["type"] = event.RoomTypeSpace } bridgeInfoStateKey, bridgeInfo := portal.getBridgeInfo() - roomFeatures := source.Client.GetCapabilities(cancellableCtx, portal) + roomFeatures := source.Client.GetCapabilities(ctx, portal) portal.CapState = database.CapabilityState{ Source: source.ID, ID: roomFeatures.GetID(), @@ -5150,47 +4332,19 @@ func (portal *Portal) createMatrixRoomInLoop(ctx context.Context, source *UserLo StateKey: &bridgeInfoStateKey, Type: event.StateBeeperRoomFeatures, Content: event.Content{Parsed: roomFeatures}, - }, &event.Event{ - Type: event.StateTopic, - Content: event.Content{ - Parsed: &event.TopicEventContent{Topic: portal.Topic}, - Raw: map[string]any{ - "com.beeper.exclude_from_timeline": true, - }, - }, }) - if roomFeatures.DisappearingTimer != nil { + if req.Topic == "" { + // Add explicit topic event if topic is empty to ensure the event is set. + // This ensures that there won't be an extra event later if PUT /state/... is called. req.InitialState = append(req.InitialState, &event.Event{ - Type: event.StateBeeperDisappearingTimer, - Content: event.Content{ - Parsed: portal.Disappear.ToEventContent(), - Raw: map[string]any{ - "com.beeper.exclude_from_timeline": true, - }, - }, - }) - portal.CapState.Flags |= database.CapStateFlagDisappearingTimerSet - } - if portal.Name != "" { - req.InitialState = append(req.InitialState, &event.Event{ - Type: event.StateRoomName, - Content: event.Content{ - Parsed: &event.RoomNameEventContent{Name: portal.Name}, - Raw: map[string]any{ - "com.beeper.exclude_from_timeline": true, - }, - }, + Type: event.StateTopic, + Content: event.Content{Parsed: &event.TopicEventContent{Topic: ""}}, }) } if portal.AvatarMXC != "" { req.InitialState = append(req.InitialState, &event.Event{ - Type: event.StateRoomAvatar, - Content: event.Content{ - Parsed: &event.RoomAvatarEventContent{URL: portal.AvatarMXC}, - Raw: map[string]any{ - "com.beeper.exclude_from_timeline": true, - }, - }, + Type: event.StateRoomAvatar, + Content: event.Content{Parsed: &event.RoomAvatarEventContent{URL: portal.AvatarMXC}}, }) } if portal.Parent != nil && portal.Parent.MXID != "" { @@ -5209,9 +4363,6 @@ func (portal *Portal) createMatrixRoomInLoop(ctx context.Context, source *UserLo Content: event.Content{Parsed: info.JoinRule}, }) } - if cancellableCtx.Err() != nil { - return cancellableCtx.Err() - } roomID, err := portal.Bridge.Bot.CreateRoom(ctx, &req) if err != nil { log.Err(err).Msg("Failed to create Matrix room") @@ -5222,7 +4373,6 @@ func (portal *Portal) createMatrixRoomInLoop(ctx context.Context, source *UserLo portal.TopicSet = true portal.NameSet = true portal.MXID = roomID - portal.RoomCreated.Set() portal.Bridge.cacheLock.Lock() portal.Bridge.portalsByMXID[roomID] = portal portal.Bridge.cacheLock.Unlock() @@ -5270,10 +4420,7 @@ func (portal *Portal) createMatrixRoomInLoop(ctx context.Context, source *UserLo } } portal.addToUserSpaces(ctx) - if info.CanBackfill && - portal.Bridge.Config.Backfill.Enabled && - portal.RoomType != database.RoomTypeSpace && - !portal.Bridge.Background { + if portal.Bridge.Config.Backfill.Enabled && portal.RoomType != database.RoomTypeSpace && !portal.Bridge.Background { portal.doForwardBackfill(ctx, source, nil, backfillBundle) } return nil @@ -5288,7 +4435,7 @@ func (portal *Portal) addToUserSpaces(ctx context.Context) { if portal.Receiver != "" { login := portal.Bridge.GetCachedUserLoginByID(portal.Receiver) if login != nil { - up, err := portal.Bridge.DB.UserPortal.GetOrCreate(ctx, login.UserLogin, portal.PortalKey) + up, err := portal.Bridge.DB.UserPortal.Get(ctx, login.UserLogin, portal.PortalKey) if err != nil { log.Err(err).Msg("Failed to get user portal to add portal to spaces") } else { @@ -5313,11 +4460,8 @@ func (portal *Portal) addToUserSpaces(ctx context.Context) { } func (portal *Portal) Delete(ctx context.Context) error { - if portal.deleted.IsSet() { - return nil - } portal.removeInPortalCache(ctx) - err := portal.safeDBDelete(ctx) + err := portal.Bridge.DB.Portal.Delete(ctx, portal.PortalKey) if err != nil { return err } @@ -5327,21 +4471,11 @@ func (portal *Portal) Delete(ctx context.Context) error { return nil } -func (portal *Portal) safeDBDelete(ctx context.Context) error { - err := portal.Bridge.DB.Message.DeleteInChunks(ctx, portal.PortalKey) - if err != nil { - return fmt.Errorf("failed to delete messages in portal: %w", err) - } - // TODO delete child portals? - return portal.Bridge.DB.Portal.Delete(ctx, portal.PortalKey) -} - func (portal *Portal) RemoveMXID(ctx context.Context) error { if portal.MXID == "" { return nil } portal.MXID = "" - portal.RoomCreated.Clear() err := portal.Save(ctx) if err != nil { return err @@ -5374,10 +4508,8 @@ func (portal *Portal) removeInPortalCache(ctx context.Context) { } func (portal *Portal) unlockedDelete(ctx context.Context) error { - if portal.deleted.IsSet() { - return nil - } - err := portal.safeDBDelete(ctx) + // TODO delete child portals? + err := portal.Bridge.DB.Portal.Delete(ctx, portal.PortalKey) if err != nil { return err } @@ -5386,14 +4518,10 @@ func (portal *Portal) unlockedDelete(ctx context.Context) error { } func (portal *Portal) unlockedDeleteCache() { - if portal.deleted.IsSet() { - return - } delete(portal.Bridge.portalsByKey, portal.PortalKey) if portal.MXID != "" { delete(portal.Bridge.portalsByMXID, portal.MXID) } - portal.deleted.Set() if portal.events != nil { // TODO there's a small risk of this racing with a queueEvent call close(portal.events) @@ -5405,9 +4533,6 @@ func (portal *Portal) Save(ctx context.Context) error { } func (portal *Portal) SetRelay(ctx context.Context, relay *UserLogin) error { - if portal.Receiver != "" && relay.ID != portal.Receiver { - return fmt.Errorf("can't set non-receiver login as relay") - } portal.Relay = relay if relay == nil { portal.RelayLoginID = "" diff --git a/bridgev2/portalbackfill.go b/bridgev2/portalbackfill.go index 879f07ae..9883fb12 100644 --- a/bridgev2/portalbackfill.go +++ b/bridgev2/portalbackfill.go @@ -194,9 +194,6 @@ func (portal *Portal) doThreadBackfill(ctx context.Context, source *UserLogin, t if err != nil { log.Err(err).Msg("Failed to get last thread message") return - } else if anchorMessage == nil { - log.Warn().Msg("No messages found in thread?") - return } resp := portal.fetchThreadBackfill(ctx, source, anchorMessage) if resp != nil { @@ -342,7 +339,6 @@ func (portal *Portal) compileBatchMessage(ctx context.Context, source *UserLogin for i, part := range msg.Parts { partIDs = append(partIDs, part.ID) portal.applyRelationMeta(ctx, part.Content, replyTo, threadRoot, prevThreadEvent) - part.Content.BeeperDisappearingTimer = msg.Disappear.ToEventContent() evtID := portal.Bridge.Matrix.GenerateDeterministicEventID(portal.MXID, portal.PortalKey, msg.ID, part.ID) dbMessage := &database.Message{ ID: msg.ID, @@ -383,23 +379,19 @@ func (portal *Portal) compileBatchMessage(ctx context.Context, source *UserLogin prevThreadEvent.MXID = evtID out.PrevThreadEvents[*msg.ThreadRoot] = evtID } - if msg.Disappear.Type != event.DisappearingTypeNone { - if msg.Disappear.Type == event.DisappearingTypeAfterSend && msg.Disappear.DisappearAt.IsZero() { + if msg.Disappear.Type != database.DisappearingTypeNone { + if msg.Disappear.Type == database.DisappearingTypeAfterSend && msg.Disappear.DisappearAt.IsZero() { msg.Disappear.DisappearAt = msg.Timestamp.Add(msg.Disappear.Timer) } out.Disappear = append(out.Disappear, &database.DisappearingMessage{ RoomID: portal.MXID, EventID: evtID, - Timestamp: msg.Timestamp, DisappearingSetting: msg.Disappear, }) } } slices.Sort(partIDs) for _, reaction := range msg.Reactions { - if reaction == nil { - continue - } reactionIntent, ok := portal.GetIntentFor(ctx, reaction.Sender, source, RemoteEventReactionRemove) if !ok { continue @@ -410,7 +402,6 @@ func (portal *Portal) compileBatchMessage(ctx context.Context, source *UserLogin if reaction.Timestamp.IsZero() { reaction.Timestamp = msg.Timestamp.Add(10 * time.Millisecond) } - //lint:ignore SA4006 it's a todo targetPart, ok := partMap[*reaction.TargetPart] if !ok { // TODO warning log and/or skip reaction? diff --git a/bridgev2/portalinternal.go b/bridgev2/portalinternal.go index 4c7e2447..e82c481a 100644 --- a/bridgev2/portalinternal.go +++ b/bridgev2/portalinternal.go @@ -37,8 +37,8 @@ func (portal *PortalInternals) EventLoop() { (*Portal)(portal).eventLoop() } -func (portal *PortalInternals) HandleSingleEventWithDelayLogging(idx int, rawEvt any) (outerRes EventHandlingResult) { - return (*Portal)(portal).handleSingleEventWithDelayLogging(idx, rawEvt) +func (portal *PortalInternals) HandleSingleEventAsync(idx int, rawEvt any) (outerRes EventHandlingResult) { + return (*Portal)(portal).handleSingleEventAsync(idx, rawEvt) } func (portal *PortalInternals) GetEventCtxWithLog(rawEvt any, idx int) context.Context { @@ -49,10 +49,6 @@ func (portal *PortalInternals) HandleSingleEvent(ctx context.Context, rawEvt any (*Portal)(portal).handleSingleEvent(ctx, rawEvt, doneCallback) } -func (portal *PortalInternals) UnwrapBeeperSendState(ctx context.Context, evt *event.Event) error { - return (*Portal)(portal).unwrapBeeperSendState(ctx, evt) -} - func (portal *PortalInternals) SendSuccessStatus(ctx context.Context, evt *event.Event, streamOrder int64, newEventID id.EventID) { (*Portal)(portal).sendSuccessStatus(ctx, evt, streamOrder, newEventID) } @@ -65,8 +61,8 @@ func (portal *PortalInternals) CheckConfusableName(ctx context.Context, userID i return (*Portal)(portal).checkConfusableName(ctx, userID, name) } -func (portal *PortalInternals) HandleMatrixEvent(ctx context.Context, sender *User, evt *event.Event, isStateRequest bool) EventHandlingResult { - return (*Portal)(portal).handleMatrixEvent(ctx, sender, evt, isStateRequest) +func (portal *PortalInternals) HandleMatrixEvent(ctx context.Context, sender *User, evt *event.Event) EventHandlingResult { + return (*Portal)(portal).handleMatrixEvent(ctx, sender, evt) } func (portal *PortalInternals) HandleMatrixReceipts(ctx context.Context, evt *event.Event) EventHandlingResult { @@ -77,10 +73,6 @@ func (portal *PortalInternals) HandleMatrixReadReceipt(ctx context.Context, user (*Portal)(portal).handleMatrixReadReceipt(ctx, user, eventID, receipt) } -func (portal *PortalInternals) CallReadReceiptHandler(ctx context.Context, login *UserLogin, rrClient ReadReceiptHandlingNetworkAPI, evt *MatrixReadReceipt, userPortal *database.UserPortal) { - (*Portal)(portal).callReadReceiptHandler(ctx, login, rrClient, evt, userPortal) -} - func (portal *PortalInternals) HandleMatrixTyping(ctx context.Context, evt *event.Event) EventHandlingResult { return (*Portal)(portal).handleMatrixTyping(ctx, evt) } @@ -125,24 +117,12 @@ func (portal *PortalInternals) GetTargetUser(ctx context.Context, userID id.User return (*Portal)(portal).getTargetUser(ctx, userID) } -func (portal *PortalInternals) HandleMatrixDeleteChat(ctx context.Context, sender *UserLogin, origSender *OrigSender, evt *event.Event) EventHandlingResult { - return (*Portal)(portal).handleMatrixDeleteChat(ctx, sender, origSender, evt) +func (portal *PortalInternals) HandleMatrixMembership(ctx context.Context, sender *UserLogin, origSender *OrigSender, evt *event.Event) EventHandlingResult { + return (*Portal)(portal).handleMatrixMembership(ctx, sender, origSender, evt) } -func (portal *PortalInternals) HandleMatrixMembership(ctx context.Context, sender *UserLogin, origSender *OrigSender, evt *event.Event, isStateRequest bool) EventHandlingResult { - return (*Portal)(portal).handleMatrixMembership(ctx, sender, origSender, evt, isStateRequest) -} - -func (portal *PortalInternals) HandleMatrixPowerLevels(ctx context.Context, sender *UserLogin, origSender *OrigSender, evt *event.Event, isStateRequest bool) EventHandlingResult { - return (*Portal)(portal).handleMatrixPowerLevels(ctx, sender, origSender, evt, isStateRequest) -} - -func (portal *PortalInternals) HandleMatrixTombstone(ctx context.Context, evt *event.Event) EventHandlingResult { - return (*Portal)(portal).handleMatrixTombstone(ctx, evt) -} - -func (portal *PortalInternals) UpdateInfoAfterTombstone(ctx context.Context, senderUser *User) { - (*Portal)(portal).updateInfoAfterTombstone(ctx, senderUser) +func (portal *PortalInternals) HandleMatrixPowerLevels(ctx context.Context, sender *UserLogin, origSender *OrigSender, evt *event.Event) EventHandlingResult { + return (*Portal)(portal).handleMatrixPowerLevels(ctx, sender, origSender, evt) } func (portal *PortalInternals) HandleMatrixRedaction(ctx context.Context, sender *UserLogin, origSender *OrigSender, evt *event.Event) EventHandlingResult { @@ -153,10 +133,6 @@ func (portal *PortalInternals) HandleRemoteEvent(ctx context.Context, source *Us return (*Portal)(portal).handleRemoteEvent(ctx, source, evtType, evt) } -func (portal *PortalInternals) EnsureFunctionalMember(ctx context.Context, ghost *Ghost) { - (*Portal)(portal).ensureFunctionalMember(ctx, ghost) -} - func (portal *PortalInternals) GetIntentAndUserMXIDFor(ctx context.Context, sender EventSender, source *UserLogin, otherLogins []*UserLogin, evtType RemoteEventType) (intent MatrixAPI, extraUserID id.UserID, err error) { return (*Portal)(portal).getIntentAndUserMXIDFor(ctx, sender, source, otherLogins, evtType) } @@ -257,10 +233,6 @@ func (portal *PortalInternals) HandleRemoteChatResync(ctx context.Context, sourc return (*Portal)(portal).handleRemoteChatResync(ctx, source, evt) } -func (portal *PortalInternals) FindOtherLogins(ctx context.Context, source *UserLogin) (ownUP *database.UserPortal, others []*database.UserPortal, err error) { - return (*Portal)(portal).findOtherLogins(ctx, source) -} - func (portal *PortalInternals) HandleRemoteChatDelete(ctx context.Context, source *UserLogin, evt RemoteChatDelete) EventHandlingResult { return (*Portal)(portal).handleRemoteChatDelete(ctx, source, evt) } @@ -269,16 +241,16 @@ func (portal *PortalInternals) HandleRemoteBackfill(ctx context.Context, source return (*Portal)(portal).handleRemoteBackfill(ctx, source, backfill) } -func (portal *PortalInternals) UpdateName(ctx context.Context, name string, sender MatrixAPI, ts time.Time, excludeFromTimeline bool) bool { - return (*Portal)(portal).updateName(ctx, name, sender, ts, excludeFromTimeline) +func (portal *PortalInternals) UpdateName(ctx context.Context, name string, sender MatrixAPI, ts time.Time) bool { + return (*Portal)(portal).updateName(ctx, name, sender, ts) } -func (portal *PortalInternals) UpdateTopic(ctx context.Context, topic string, sender MatrixAPI, ts time.Time, excludeFromTimeline bool) bool { - return (*Portal)(portal).updateTopic(ctx, topic, sender, ts, excludeFromTimeline) +func (portal *PortalInternals) UpdateTopic(ctx context.Context, topic string, sender MatrixAPI, ts time.Time) bool { + return (*Portal)(portal).updateTopic(ctx, topic, sender, ts) } -func (portal *PortalInternals) UpdateAvatar(ctx context.Context, avatar *Avatar, sender MatrixAPI, ts time.Time, excludeFromTimeline bool) bool { - return (*Portal)(portal).updateAvatar(ctx, avatar, sender, ts, excludeFromTimeline) +func (portal *PortalInternals) UpdateAvatar(ctx context.Context, avatar *Avatar, sender MatrixAPI, ts time.Time) bool { + return (*Portal)(portal).updateAvatar(ctx, avatar, sender, ts) } func (portal *PortalInternals) GetBridgeInfoStateKey() string { @@ -293,12 +265,8 @@ func (portal *PortalInternals) SendStateWithIntentOrBot(ctx context.Context, sen return (*Portal)(portal).sendStateWithIntentOrBot(ctx, sender, eventType, stateKey, content, ts) } -func (portal *PortalInternals) SendRoomMeta(ctx context.Context, sender MatrixAPI, ts time.Time, eventType event.Type, stateKey string, content any, excludeFromTimeline bool, extra map[string]any) bool { - return (*Portal)(portal).sendRoomMeta(ctx, sender, ts, eventType, stateKey, content, excludeFromTimeline, extra) -} - -func (portal *PortalInternals) RevertRoomMeta(ctx context.Context, evt *event.Event) { - (*Portal)(portal).revertRoomMeta(ctx, evt) +func (portal *PortalInternals) SendRoomMeta(ctx context.Context, sender MatrixAPI, ts time.Time, eventType event.Type, stateKey string, content any) bool { + return (*Portal)(portal).sendRoomMeta(ctx, sender, ts, eventType, stateKey, content) } func (portal *PortalInternals) GetInitialMemberList(ctx context.Context, members *ChatMemberList, source *UserLogin, pl *event.PowerLevelsEventContent) (invite, functional []id.UserID, err error) { @@ -309,10 +277,6 @@ func (portal *PortalInternals) UpdateOtherUser(ctx context.Context, members *Cha return (*Portal)(portal).updateOtherUser(ctx, members) } -func (portal *PortalInternals) RoomIsPublic(ctx context.Context) bool { - return (*Portal)(portal).roomIsPublic(ctx) -} - func (portal *PortalInternals) SyncParticipants(ctx context.Context, members *ChatMemberList, source *UserLogin, sender MatrixAPI, ts time.Time) error { return (*Portal)(portal).syncParticipants(ctx, members, source, sender, ts) } @@ -333,10 +297,6 @@ func (portal *PortalInternals) CreateMatrixRoomInLoop(ctx context.Context, sourc return (*Portal)(portal).createMatrixRoomInLoop(ctx, source, info, backfillBundle) } -func (portal *PortalInternals) AddToUserSpaces(ctx context.Context) { - (*Portal)(portal).addToUserSpaces(ctx) -} - func (portal *PortalInternals) RemoveInPortalCache(ctx context.Context) { (*Portal)(portal).removeInPortalCache(ctx) } @@ -400,3 +360,7 @@ func (portal *PortalInternals) AddToParentSpaceAndSave(ctx context.Context, save func (portal *PortalInternals) ToggleSpace(ctx context.Context, spaceID id.RoomID, canonical, remove bool) error { return (*Portal)(portal).toggleSpace(ctx, spaceID, canonical, remove) } + +func (portal *PortalInternals) SetMXIDToExistingRoom(ctx context.Context, roomID id.RoomID) bool { + return (*Portal)(portal).setMXIDToExistingRoom(ctx, roomID) +} diff --git a/bridgev2/portalreid.go b/bridgev2/portalreid.go index c976d97c..a25fe820 100644 --- a/bridgev2/portalreid.go +++ b/bridgev2/portalreid.go @@ -32,40 +32,21 @@ func (br *Bridge) ReIDPortal(ctx context.Context, source, target networkid.Porta if source == target { return ReIDResultError, nil, fmt.Errorf("illegal re-ID call: source and target are the same") } - log := zerolog.Ctx(ctx).With(). - Str("action", "re-id portal"). - Stringer("source_portal_key", source). - Stringer("target_portal_key", target). - Logger() - ctx = log.WithContext(ctx) + log := zerolog.Ctx(ctx) + log.Debug().Msg("Re-ID'ing portal") defer func() { log.Debug().Msg("Finished handling portal re-ID") }() - acquireCacheLock := func() { - if !br.cacheLock.TryLock() { - log.Debug().Msg("Waiting for global cache lock") - br.cacheLock.Lock() - log.Debug().Msg("Acquired global cache lock after waiting") - } else { - log.Trace().Msg("Acquired global cache lock without waiting") - } - } - log.Debug().Msg("Re-ID'ing portal") - sourcePortal, err := br.GetExistingPortalByKey(ctx, source) + br.cacheLock.Lock() + defer br.cacheLock.Unlock() + sourcePortal, err := br.UnlockedGetPortalByKey(ctx, source, true) if err != nil { return ReIDResultError, nil, fmt.Errorf("failed to get source portal: %w", err) } else if sourcePortal == nil { log.Debug().Msg("Source portal not found, re-ID is no-op") return ReIDResultNoOp, nil, nil } - if !sourcePortal.roomCreateLock.TryLock() { - if cancelCreate := sourcePortal.cancelRoomCreate.Swap(nil); cancelCreate != nil { - (*cancelCreate)() - } - log.Debug().Msg("Waiting for source portal room creation lock") - sourcePortal.roomCreateLock.Lock() - log.Debug().Msg("Acquired source portal room creation lock after waiting") - } + sourcePortal.roomCreateLock.Lock() defer sourcePortal.roomCreateLock.Unlock() if sourcePortal.MXID == "" { log.Info().Msg("Source portal doesn't have Matrix room, deleting row") @@ -78,37 +59,22 @@ func (br *Bridge) ReIDPortal(ctx context.Context, source, target networkid.Porta log.UpdateContext(func(c zerolog.Context) zerolog.Context { return c.Stringer("source_portal_mxid", sourcePortal.MXID) }) - - acquireCacheLock() targetPortal, err := br.UnlockedGetPortalByKey(ctx, target, true) if err != nil { - br.cacheLock.Unlock() return ReIDResultError, nil, fmt.Errorf("failed to get target portal: %w", err) } if targetPortal == nil { log.Info().Msg("Target portal doesn't exist, re-ID'ing source portal") err = sourcePortal.unlockedReID(ctx, target) - br.cacheLock.Unlock() if err != nil { return ReIDResultError, nil, fmt.Errorf("failed to re-ID source portal: %w", err) } return ReIDResultSourceReIDd, sourcePortal, nil } - br.cacheLock.Unlock() - - if !targetPortal.roomCreateLock.TryLock() { - if cancelCreate := targetPortal.cancelRoomCreate.Swap(nil); cancelCreate != nil { - (*cancelCreate)() - } - log.Debug().Msg("Waiting for target portal room creation lock") - targetPortal.roomCreateLock.Lock() - log.Debug().Msg("Acquired target portal room creation lock after waiting") - } + targetPortal.roomCreateLock.Lock() defer targetPortal.roomCreateLock.Unlock() if targetPortal.MXID == "" { log.Info().Msg("Target portal row exists, but doesn't have a Matrix room. Deleting target portal row and re-ID'ing source portal") - acquireCacheLock() - defer br.cacheLock.Unlock() err = targetPortal.unlockedDelete(ctx) if err != nil { return ReIDResultError, nil, fmt.Errorf("failed to delete target portal: %w", err) @@ -123,9 +89,6 @@ func (br *Bridge) ReIDPortal(ctx context.Context, source, target networkid.Porta return c.Stringer("target_portal_mxid", targetPortal.MXID) }) log.Info().Msg("Both target and source portals have Matrix rooms, tombstoning source portal") - sourcePortal.removeInPortalCache(ctx) - acquireCacheLock() - defer br.cacheLock.Unlock() err = sourcePortal.unlockedDelete(ctx) if err != nil { return ReIDResultError, nil, fmt.Errorf("failed to delete source portal row: %w", err) @@ -133,7 +96,7 @@ func (br *Bridge) ReIDPortal(ctx context.Context, source, target networkid.Porta go func() { _, err := br.Bot.SendState(ctx, sourcePortal.MXID, event.StateTombstone, "", &event.Content{ Parsed: &event.TombstoneEventContent{ - Body: "This room has been merged", + Body: fmt.Sprintf("This room has been merged"), ReplacementRoom: targetPortal.MXID, }, }, time.Now()) diff --git a/bridgev2/provisionutil/creategroup.go b/bridgev2/provisionutil/creategroup.go deleted file mode 100644 index 72bacaff..00000000 --- a/bridgev2/provisionutil/creategroup.go +++ /dev/null @@ -1,149 +0,0 @@ -// Copyright (c) 2025 Tulir Asokan -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. - -package provisionutil - -import ( - "context" - - "github.com/rs/zerolog" - "go.mau.fi/util/ptr" - - "maunium.net/go/mautrix" - "maunium.net/go/mautrix/bridgev2" - "maunium.net/go/mautrix/bridgev2/networkid" - "maunium.net/go/mautrix/event" - "maunium.net/go/mautrix/id" -) - -type RespCreateGroup struct { - ID networkid.PortalID `json:"id"` - MXID id.RoomID `json:"mxid"` - Portal *bridgev2.Portal `json:"-"` - - FailedParticipants map[networkid.UserID]*bridgev2.CreateChatFailedParticipant `json:"failed_participants,omitempty"` -} - -func CreateGroup(ctx context.Context, login *bridgev2.UserLogin, params *bridgev2.GroupCreateParams) (*RespCreateGroup, error) { - api, ok := login.Client.(bridgev2.GroupCreatingNetworkAPI) - if !ok { - return nil, bridgev2.RespError(mautrix.MUnrecognized.WithMessage("This bridge does not support creating groups")) - } - zerolog.Ctx(ctx).Debug(). - Any("create_params", params). - Msg("Creating group chat on remote network") - caps := login.Bridge.Network.GetCapabilities() - typeSpec, validType := caps.Provisioning.GroupCreation[params.Type] - if !validType { - return nil, bridgev2.RespError(mautrix.MUnrecognized.WithMessage("Unrecognized group type %s", params.Type)) - } - if len(params.Participants) < typeSpec.Participants.MinLength { - return nil, bridgev2.RespError(mautrix.MInvalidParam.WithMessage("Must have at least %d members", typeSpec.Participants.MinLength)) - } else if typeSpec.Participants.MaxLength > 0 && len(params.Participants) > typeSpec.Participants.MaxLength { - return nil, bridgev2.RespError(mautrix.MInvalidParam.WithMessage("Must have at most %d members", typeSpec.Participants.MaxLength)) - } - userIDValidatingNetwork, uidValOK := login.Bridge.Network.(bridgev2.IdentifierValidatingNetwork) - for i, participant := range params.Participants { - parsedParticipant, ok := login.Bridge.Matrix.ParseGhostMXID(id.UserID(participant)) - if ok { - participant = parsedParticipant - params.Participants[i] = participant - } - if !typeSpec.Participants.SkipIdentifierValidation { - if uidValOK && !userIDValidatingNetwork.ValidateUserID(participant) { - return nil, bridgev2.RespError(mautrix.MInvalidParam.WithMessage("User ID %q is not valid on this network", participant)) - } - } - if api.IsThisUser(ctx, participant) { - return nil, bridgev2.RespError(mautrix.MInvalidParam.WithMessage("You can't include yourself in the participants list", participant)) - } - } - if (params.Name == nil || params.Name.Name == "") && typeSpec.Name.Required { - return nil, bridgev2.RespError(mautrix.MInvalidParam.WithMessage("Name is required")) - } else if nameLen := len(ptr.Val(params.Name).Name); nameLen > 0 && nameLen < typeSpec.Name.MinLength { - return nil, bridgev2.RespError(mautrix.MInvalidParam.WithMessage("Name must be at least %d characters", typeSpec.Name.MinLength)) - } else if typeSpec.Name.MaxLength > 0 && nameLen > typeSpec.Name.MaxLength { - return nil, bridgev2.RespError(mautrix.MInvalidParam.WithMessage("Name must be at most %d characters", typeSpec.Name.MaxLength)) - } - if (params.Avatar == nil || params.Avatar.URL == "") && typeSpec.Avatar.Required { - return nil, bridgev2.RespError(mautrix.MInvalidParam.WithMessage("Avatar is required")) - } - if (params.Topic == nil || params.Topic.Topic == "") && typeSpec.Topic.Required { - return nil, bridgev2.RespError(mautrix.MInvalidParam.WithMessage("Topic is required")) - } else if topicLen := len(ptr.Val(params.Topic).Topic); topicLen > 0 && topicLen < typeSpec.Topic.MinLength { - return nil, bridgev2.RespError(mautrix.MInvalidParam.WithMessage("Topic must be at least %d characters", typeSpec.Topic.MinLength)) - } else if typeSpec.Topic.MaxLength > 0 && topicLen > typeSpec.Topic.MaxLength { - return nil, bridgev2.RespError(mautrix.MInvalidParam.WithMessage("Topic must be at most %d characters", typeSpec.Topic.MaxLength)) - } - if (params.Disappear == nil || params.Disappear.Timer.Duration == 0) && typeSpec.Disappear.Required { - return nil, bridgev2.RespError(mautrix.MInvalidParam.WithMessage("Disappearing timer is required")) - } else if !typeSpec.Disappear.DisappearSettings.Supports(params.Disappear) { - return nil, bridgev2.RespError(mautrix.MInvalidParam.WithMessage("Unsupported value for disappearing timer")) - } - if params.Username == "" && typeSpec.Username.Required { - return nil, bridgev2.RespError(mautrix.MInvalidParam.WithMessage("Username is required")) - } else if len(params.Username) > 0 && len(params.Username) < typeSpec.Username.MinLength { - return nil, bridgev2.RespError(mautrix.MInvalidParam.WithMessage("Username must be at least %d characters", typeSpec.Username.MinLength)) - } else if typeSpec.Username.MaxLength > 0 && len(params.Username) > typeSpec.Username.MaxLength { - return nil, bridgev2.RespError(mautrix.MInvalidParam.WithMessage("Username must be at most %d characters", typeSpec.Username.MaxLength)) - } - if params.Parent == nil && typeSpec.Parent.Required { - return nil, bridgev2.RespError(mautrix.MInvalidParam.WithMessage("Parent is required")) - } - resp, err := api.CreateGroup(ctx, params) - if err != nil { - zerolog.Ctx(ctx).Err(err).Msg("Failed to create group") - return nil, err - } - if resp.PortalKey.IsEmpty() { - return nil, ErrNoPortalKey - } - zerolog.Ctx(ctx).Debug(). - Object("portal_key", resp.PortalKey). - Msg("Successfully created group on remote network") - if resp.Portal == nil { - resp.Portal, err = login.Bridge.GetPortalByKey(ctx, resp.PortalKey) - if err != nil { - zerolog.Ctx(ctx).Err(err).Msg("Failed to get portal") - return nil, bridgev2.RespError(mautrix.MUnknown.WithMessage("Failed to get portal")) - } - } - if resp.Portal.MXID == "" { - err = resp.Portal.CreateMatrixRoom(ctx, login, resp.PortalInfo) - if err != nil { - zerolog.Ctx(ctx).Err(err).Msg("Failed to create portal room") - return nil, bridgev2.RespError(mautrix.MUnknown.WithMessage("Failed to create portal room")) - } - } - for key, fp := range resp.FailedParticipants { - if fp.InviteEventType == "" { - fp.InviteEventType = event.EventMessage.Type - } - if fp.UserMXID == "" { - ghost, err := login.Bridge.GetGhostByID(ctx, key) - if err != nil { - zerolog.Ctx(ctx).Err(err).Msg("Failed to get ghost for failed participant") - } else if ghost != nil { - fp.UserMXID = ghost.Intent.GetMXID() - } - } - if fp.DMRoomMXID == "" { - portal, err := login.Bridge.GetDMPortal(ctx, login.ID, key) - if err != nil { - zerolog.Ctx(ctx).Err(err).Msg("Failed to get DM portal for failed participant") - } else if portal != nil { - fp.DMRoomMXID = portal.MXID - } - } - } - return &RespCreateGroup{ - ID: resp.Portal.ID, - MXID: resp.Portal.MXID, - Portal: resp.Portal, - - FailedParticipants: resp.FailedParticipants, - }, nil -} diff --git a/bridgev2/provisionutil/listcontacts.go b/bridgev2/provisionutil/listcontacts.go deleted file mode 100644 index ce163e67..00000000 --- a/bridgev2/provisionutil/listcontacts.go +++ /dev/null @@ -1,98 +0,0 @@ -// Copyright (c) 2025 Tulir Asokan -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. - -package provisionutil - -import ( - "context" - - "github.com/rs/zerolog" - - "maunium.net/go/mautrix" - "maunium.net/go/mautrix/bridgev2" -) - -type RespGetContactList struct { - Contacts []*RespResolveIdentifier `json:"contacts"` -} - -type RespSearchUsers struct { - Results []*RespResolveIdentifier `json:"results"` -} - -func GetContactList(ctx context.Context, login *bridgev2.UserLogin) (*RespGetContactList, error) { - api, ok := login.Client.(bridgev2.ContactListingNetworkAPI) - if !ok { - return nil, bridgev2.RespError(mautrix.MUnrecognized.WithMessage("This bridge does not support listing contacts")) - } - resp, err := api.GetContactList(ctx) - if err != nil { - zerolog.Ctx(ctx).Err(err).Msg("Failed to get contact list") - return nil, err - } - return &RespGetContactList{ - Contacts: processResolveIdentifiers(ctx, login.Bridge, resp, false), - }, nil -} - -func SearchUsers(ctx context.Context, login *bridgev2.UserLogin, query string) (*RespSearchUsers, error) { - api, ok := login.Client.(bridgev2.UserSearchingNetworkAPI) - if !ok { - return nil, bridgev2.RespError(mautrix.MUnrecognized.WithMessage("This bridge does not support searching for users")) - } - resp, err := api.SearchUsers(ctx, query) - if err != nil { - zerolog.Ctx(ctx).Err(err).Msg("Failed to get contact list") - return nil, err - } - return &RespSearchUsers{ - Results: processResolveIdentifiers(ctx, login.Bridge, resp, true), - }, nil -} - -func processResolveIdentifiers(ctx context.Context, br *bridgev2.Bridge, resp []*bridgev2.ResolveIdentifierResponse, syncInfo bool) (apiResp []*RespResolveIdentifier) { - apiResp = make([]*RespResolveIdentifier, len(resp)) - for i, contact := range resp { - apiContact := &RespResolveIdentifier{ - ID: contact.UserID, - } - apiResp[i] = apiContact - if contact.UserInfo != nil { - if contact.UserInfo.Name != nil { - apiContact.Name = *contact.UserInfo.Name - } - if contact.UserInfo.Identifiers != nil { - apiContact.Identifiers = contact.UserInfo.Identifiers - } - } - if contact.Ghost != nil { - if syncInfo && contact.UserInfo != nil { - contact.Ghost.UpdateInfo(ctx, contact.UserInfo) - } - if contact.Ghost.Name != "" { - apiContact.Name = contact.Ghost.Name - } - if len(contact.Ghost.Identifiers) >= len(apiContact.Identifiers) { - apiContact.Identifiers = contact.Ghost.Identifiers - } - apiContact.AvatarURL = contact.Ghost.AvatarMXC - apiContact.MXID = contact.Ghost.Intent.GetMXID() - } - if contact.Chat != nil { - if contact.Chat.Portal == nil { - var err error - contact.Chat.Portal, err = br.GetPortalByKey(ctx, contact.Chat.PortalKey) - if err != nil { - zerolog.Ctx(ctx).Err(err).Msg("Failed to get portal") - } - } - if contact.Chat.Portal != nil { - apiContact.DMRoomID = contact.Chat.Portal.MXID - } - } - } - return -} diff --git a/bridgev2/provisionutil/resolveidentifier.go b/bridgev2/provisionutil/resolveidentifier.go deleted file mode 100644 index cfc388d0..00000000 --- a/bridgev2/provisionutil/resolveidentifier.go +++ /dev/null @@ -1,125 +0,0 @@ -// Copyright (c) 2025 Tulir Asokan -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. - -package provisionutil - -import ( - "context" - "errors" - - "github.com/rs/zerolog" - - "maunium.net/go/mautrix" - "maunium.net/go/mautrix/bridgev2" - "maunium.net/go/mautrix/bridgev2/networkid" - "maunium.net/go/mautrix/id" -) - -type RespResolveIdentifier struct { - ID networkid.UserID `json:"id"` - Name string `json:"name,omitempty"` - AvatarURL id.ContentURIString `json:"avatar_url,omitempty"` - Identifiers []string `json:"identifiers,omitempty"` - MXID id.UserID `json:"mxid,omitempty"` - DMRoomID id.RoomID `json:"dm_room_mxid,omitempty"` - - Portal *bridgev2.Portal `json:"-"` - Ghost *bridgev2.Ghost `json:"-"` - JustCreated bool `json:"-"` -} - -var ErrNoPortalKey = errors.New("network API didn't return portal key for createChat request") - -func ResolveIdentifier( - ctx context.Context, - login *bridgev2.UserLogin, - identifier string, - createChat bool, -) (*RespResolveIdentifier, error) { - api, ok := login.Client.(bridgev2.IdentifierResolvingNetworkAPI) - if !ok { - return nil, bridgev2.RespError(mautrix.MUnrecognized.WithMessage("This bridge does not support resolving identifiers")) - } - var resp *bridgev2.ResolveIdentifierResponse - parsedUserID, ok := login.Bridge.Matrix.ParseGhostMXID(id.UserID(identifier)) - validator, vOK := login.Bridge.Network.(bridgev2.IdentifierValidatingNetwork) - if ok && (!vOK || validator.ValidateUserID(parsedUserID)) { - ghost, err := login.Bridge.GetGhostByID(ctx, parsedUserID) - if err != nil { - zerolog.Ctx(ctx).Err(err).Msg("Failed to get ghost by ID") - return nil, err - } - resp = &bridgev2.ResolveIdentifierResponse{ - Ghost: ghost, - UserID: parsedUserID, - } - gdcAPI, ok := api.(bridgev2.GhostDMCreatingNetworkAPI) - if ok && createChat { - resp.Chat, err = gdcAPI.CreateChatWithGhost(ctx, ghost) - if err != nil { - zerolog.Ctx(ctx).Err(err).Msg("Failed to create chat") - return nil, err - } - } else if createChat || ghost.Name == "" { - zerolog.Ctx(ctx).Debug(). - Bool("create_chat", createChat). - Bool("has_name", ghost.Name != ""). - Msg("Falling back to resolving identifier") - resp = nil - identifier = string(parsedUserID) - } - } - if resp == nil { - var err error - resp, err = api.ResolveIdentifier(ctx, identifier, createChat) - if err != nil { - zerolog.Ctx(ctx).Err(err).Msg("Failed to resolve identifier") - return nil, err - } else if resp == nil { - return nil, nil - } - } - apiResp := &RespResolveIdentifier{ - ID: resp.UserID, - Ghost: resp.Ghost, - } - if resp.Ghost != nil { - if resp.UserInfo != nil { - resp.Ghost.UpdateInfo(ctx, resp.UserInfo) - } - apiResp.Name = resp.Ghost.Name - apiResp.AvatarURL = resp.Ghost.AvatarMXC - apiResp.Identifiers = resp.Ghost.Identifiers - apiResp.MXID = resp.Ghost.Intent.GetMXID() - } else if resp.UserInfo != nil && resp.UserInfo.Name != nil { - apiResp.Name = *resp.UserInfo.Name - } - if resp.Chat != nil { - if resp.Chat.PortalKey.IsEmpty() { - return nil, ErrNoPortalKey - } - if resp.Chat.Portal == nil { - var err error - resp.Chat.Portal, err = login.Bridge.GetPortalByKey(ctx, resp.Chat.PortalKey) - if err != nil { - zerolog.Ctx(ctx).Err(err).Msg("Failed to get portal") - return nil, bridgev2.RespError(mautrix.MUnknown.WithMessage("Failed to get portal")) - } - } - resp.Chat.Portal.CleanupOrphanedDM(ctx, login.UserMXID) - if createChat && resp.Chat.Portal.MXID == "" { - apiResp.JustCreated = true - err := resp.Chat.Portal.CreateMatrixRoom(ctx, login, resp.Chat.PortalInfo) - if err != nil { - zerolog.Ctx(ctx).Err(err).Msg("Failed to create portal room") - return nil, bridgev2.RespError(mautrix.MUnknown.WithMessage("Failed to create portal room")) - } - } - apiResp.Portal = resp.Chat.Portal - apiResp.DMRoomID = resp.Chat.Portal.MXID - } - return apiResp, nil -} diff --git a/bridgev2/queue.go b/bridgev2/queue.go index 3775c825..95011cda 100644 --- a/bridgev2/queue.go +++ b/bridgev2/queue.go @@ -63,13 +63,6 @@ func (br *Bridge) rejectInviteOnNoPermission(ctx context.Context, evt *event.Eve return true } -var ( - ErrEventSenderUserNotFound = WrapErrorInStatus(errors.New("sender not found for event")).WithIsCertain(true).WithErrorAsMessage() - ErrNoPermissionToInteract = WrapErrorInStatus(errors.New("you don't have permission to send messages")).WithIsCertain(true).WithSendNotice(false).WithErrorAsMessage() - ErrNoPermissionForCommands = WrapErrorInStatus(WrapErrorInStatus(errors.New("you don't have permission to use commands")).WithIsCertain(true).WithSendNotice(false).WithErrorAsMessage()) - ErrCantRelayStateRequest = WrapErrorInStatus(errors.New("relayed users can't use beeper state requests")).WithIsCertain(true).WithErrorAsMessage() -) - func (br *Bridge) QueueMatrixEvent(ctx context.Context, evt *event.Event) EventHandlingResult { // TODO maybe HandleMatrixEvent would be more appropriate as this also handles bot invites and commands @@ -85,11 +78,13 @@ func (br *Bridge) QueueMatrixEvent(ctx context.Context, evt *event.Event) EventH return EventHandlingResultFailed } else if sender == nil { log.Error().Msg("Couldn't get sender for incoming non-ephemeral Matrix event") - br.Matrix.SendMessageStatus(ctx, &ErrEventSenderUserNotFound, StatusEventInfoFromEvent(evt)) + status := WrapErrorInStatus(errors.New("sender not found for event")).WithIsCertain(true).WithErrorAsMessage() + br.Matrix.SendMessageStatus(ctx, &status, StatusEventInfoFromEvent(evt)) return EventHandlingResultFailed } else if !sender.Permissions.SendEvents { if !br.rejectInviteOnNoPermission(ctx, evt, "interact with") { - br.Matrix.SendMessageStatus(ctx, &ErrNoPermissionToInteract, StatusEventInfoFromEvent(evt)) + status := WrapErrorInStatus(errors.New("you don't have permission to send messages")).WithIsCertain(true).WithSendNotice(false).WithErrorAsMessage() + br.Matrix.SendMessageStatus(ctx, &status, StatusEventInfoFromEvent(evt)) } return EventHandlingResultIgnored } else if !sender.Permissions.Commands && br.rejectInviteOnNoPermission(ctx, evt, "send commands to") { @@ -97,7 +92,8 @@ func (br *Bridge) QueueMatrixEvent(ctx context.Context, evt *event.Event) EventH } } else if evt.Type.Class != event.EphemeralEventType { log.Error().Msg("Missing sender for incoming non-ephemeral Matrix event") - br.Matrix.SendMessageStatus(ctx, &ErrEventSenderUserNotFound, StatusEventInfoFromEvent(evt)) + status := WrapErrorInStatus(errors.New("sender not found for event")).WithIsCertain(true).WithErrorAsMessage() + br.Matrix.SendMessageStatus(ctx, &status, StatusEventInfoFromEvent(evt)) return EventHandlingResultIgnored } if evt.Type == event.EventMessage && sender != nil { @@ -106,7 +102,8 @@ func (br *Bridge) QueueMatrixEvent(ctx context.Context, evt *event.Event) EventH msg.RemovePerMessageProfileFallback() if strings.HasPrefix(msg.Body, br.Config.CommandPrefix) || evt.RoomID == sender.ManagementRoom { if !sender.Permissions.Commands { - br.Matrix.SendMessageStatus(ctx, &ErrNoPermissionForCommands, StatusEventInfoFromEvent(evt)) + status := WrapErrorInStatus(errors.New("you don't have permission to use commands")).WithIsCertain(true).WithSendNotice(false).WithErrorAsMessage() + br.Matrix.SendMessageStatus(ctx, &status, StatusEventInfoFromEvent(evt)) return EventHandlingResultIgnored } go br.Commands.Handle( @@ -160,27 +157,10 @@ type EventHandlingResult struct { Ignored bool Queued bool - SkipStateEcho bool - // Error is an optional reason for failure. It is not required, Success may be false even without a specific error. Error error // Whether the Error should be sent as a MSS event. SendMSS bool - - // EventID from the network - EventID id.EventID - // Stream order from the network - StreamOrder int64 -} - -func (ehr EventHandlingResult) WithEventID(id id.EventID) EventHandlingResult { - ehr.EventID = id - return ehr -} - -func (ehr EventHandlingResult) WithStreamOrder(order int64) EventHandlingResult { - ehr.StreamOrder = order - return ehr } func (ehr EventHandlingResult) WithError(err error) EventHandlingResult { @@ -197,11 +177,6 @@ func (ehr EventHandlingResult) WithMSS() EventHandlingResult { return ehr } -func (ehr EventHandlingResult) WithSkipStateEcho(skip bool) EventHandlingResult { - ehr.SkipStateEcho = skip - return ehr -} - func (ehr EventHandlingResult) WithMSSError(err error) EventHandlingResult { if err == nil { return ehr @@ -220,7 +195,7 @@ func (ul *UserLogin) QueueRemoteEvent(evt RemoteEvent) EventHandlingResult { return ul.Bridge.QueueRemoteEvent(ul, evt) } -func (br *Bridge) QueueRemoteEvent(login *UserLogin, evt RemoteEvent) EventHandlingResult { +func (br *Bridge) QueueRemoteEvent(login *UserLogin, evt RemoteEvent) (res EventHandlingResult) { log := login.Log ctx := log.WithContext(br.BackgroundCtx) maybeUncertain, ok := evt.(RemoteEventWithUncertainPortalReceiver) @@ -236,14 +211,14 @@ func (br *Bridge) QueueRemoteEvent(login *UserLogin, evt RemoteEvent) EventHandl if err != nil { log.Err(err).Object("portal_key", key).Bool("uncertain_receiver", isUncertain). Msg("Failed to get portal to handle remote event") - return EventHandlingResultFailed.WithError(fmt.Errorf("failed to get portal: %w", err)) + return } else if portal == nil { log.Warn(). Stringer("event_type", evt.GetType()). Object("portal_key", key). Bool("uncertain_receiver", isUncertain). Msg("Portal not found to handle remote event") - return EventHandlingResultFailed.WithError(ErrPortalNotFoundInEventHandler) + return } // TODO put this in a better place, and maybe cache to avoid constant db queries login.MarkInPortal(ctx, portal) diff --git a/bridgev2/simplevent/chat.go b/bridgev2/simplevent/chat.go index 56e3a6b1..c725141b 100644 --- a/bridgev2/simplevent/chat.go +++ b/bridgev2/simplevent/chat.go @@ -65,19 +65,14 @@ func (evt *ChatResync) GetChatInfo(ctx context.Context, portal *bridgev2.Portal) type ChatDelete struct { EventMeta OnlyForMe bool - Children bool } -var _ bridgev2.RemoteChatDeleteWithChildren = (*ChatDelete)(nil) +var _ bridgev2.RemoteChatDelete = (*ChatDelete)(nil) func (evt *ChatDelete) DeleteOnlyForMe() bool { return evt.OnlyForMe } -func (evt *ChatDelete) DeleteChildren() bool { - return evt.Children -} - // ChatInfoChange is a simple implementation of [bridgev2.RemoteChatInfoChange]. type ChatInfoChange struct { EventMeta diff --git a/bridgev2/simplevent/message.go b/bridgev2/simplevent/message.go index f8f8d7e1..f648ab12 100644 --- a/bridgev2/simplevent/message.go +++ b/bridgev2/simplevent/message.go @@ -59,41 +59,6 @@ func (evt *Message[T]) GetTransactionID() networkid.TransactionID { return evt.TransactionID } -// PreConvertedMessage is a simple implementation of [bridgev2.RemoteMessage] with pre-converted data. -type PreConvertedMessage struct { - EventMeta - Data *bridgev2.ConvertedMessage - ID networkid.MessageID - TransactionID networkid.TransactionID - - HandleExistingFunc func(ctx context.Context, portal *bridgev2.Portal, intent bridgev2.MatrixAPI, existing []*database.Message) (bridgev2.UpsertResult, error) -} - -var ( - _ bridgev2.RemoteMessage = (*PreConvertedMessage)(nil) - _ bridgev2.RemoteMessageUpsert = (*PreConvertedMessage)(nil) - _ bridgev2.RemoteMessageWithTransactionID = (*PreConvertedMessage)(nil) -) - -func (evt *PreConvertedMessage) ConvertMessage(ctx context.Context, portal *bridgev2.Portal, intent bridgev2.MatrixAPI) (*bridgev2.ConvertedMessage, error) { - return evt.Data, nil -} - -func (evt *PreConvertedMessage) GetID() networkid.MessageID { - return evt.ID -} - -func (evt *PreConvertedMessage) GetTransactionID() networkid.TransactionID { - return evt.TransactionID -} - -func (evt *PreConvertedMessage) HandleExisting(ctx context.Context, portal *bridgev2.Portal, intent bridgev2.MatrixAPI, existing []*database.Message) (bridgev2.UpsertResult, error) { - if evt.HandleExistingFunc == nil { - return bridgev2.UpsertResult{}, nil - } - return evt.HandleExistingFunc(ctx, portal, intent, existing) -} - type MessageRemove struct { EventMeta diff --git a/bridgev2/simplevent/meta.go b/bridgev2/simplevent/meta.go index 449a8773..8aa91866 100644 --- a/bridgev2/simplevent/meta.go +++ b/bridgev2/simplevent/meta.go @@ -101,18 +101,6 @@ func (evt EventMeta) WithLogContext(f func(c zerolog.Context) zerolog.Context) E return evt } -func (evt EventMeta) WithMoreLogContext(f func(c zerolog.Context) zerolog.Context) EventMeta { - origFunc := evt.LogContext - if origFunc == nil { - evt.LogContext = f - return evt - } - evt.LogContext = func(c zerolog.Context) zerolog.Context { - return f(origFunc(c)) - } - return evt -} - func (evt EventMeta) WithPortalKey(p networkid.PortalKey) EventMeta { evt.PortalKey = p return evt diff --git a/bridgev2/space.go b/bridgev2/space.go index 2ca2bce3..ae9013cb 100644 --- a/bridgev2/space.go +++ b/bridgev2/space.go @@ -164,17 +164,14 @@ func (ul *UserLogin) GetSpaceRoom(ctx context.Context) (id.RoomID, error) { ul.UserMXID: 50, }, }, - Invite: []id.UserID{ul.UserMXID}, + RoomVersion: id.RoomV11, + Invite: []id.UserID{ul.UserMXID}, } if autoJoin { req.BeeperInitialMembers = []id.UserID{ul.UserMXID} // TODO remove this after initial_members is supported in hungryserv req.BeeperAutoJoinInvites = true } - pfc, ok := ul.Client.(PersonalFilteringCustomizingNetworkAPI) - if ok { - pfc.CustomizePersonalFilteringSpace(req) - } ul.SpaceRoom, err = ul.Bridge.Bot.CreateRoom(ctx, req) if err != nil { return "", fmt.Errorf("failed to create space room: %w", err) diff --git a/bridgev2/status/bridgestate.go b/bridgev2/status/bridgestate.go index 5925dd4f..01a235a0 100644 --- a/bridgev2/status/bridgestate.go +++ b/bridgev2/status/bridgestate.go @@ -19,10 +19,9 @@ import ( "github.com/tidwall/sjson" "go.mau.fi/util/jsontime" + "go.mau.fi/util/ptr" "maunium.net/go/mautrix" - "maunium.net/go/mautrix/bridgev2/networkid" - "maunium.net/go/mautrix/event" "maunium.net/go/mautrix/id" ) @@ -88,8 +87,6 @@ type RemoteProfile struct { Username string `json:"username,omitempty"` Name string `json:"name,omitempty"` Avatar id.ContentURIString `json:"avatar,omitempty"` - - AvatarFile *event.EncryptedFileInfo `json:"avatar_file,omitempty"` } func coalesce[T ~string](a, b T) T { @@ -105,14 +102,11 @@ func (rp *RemoteProfile) Merge(other RemoteProfile) RemoteProfile { other.Username = coalesce(rp.Username, other.Username) other.Name = coalesce(rp.Name, other.Name) other.Avatar = coalesce(rp.Avatar, other.Avatar) - if rp.AvatarFile != nil { - other.AvatarFile = rp.AvatarFile - } return other } -func (rp *RemoteProfile) IsZero() bool { - return rp == nil || (rp.Phone == "" && rp.Email == "" && rp.Username == "" && rp.Name == "" && rp.Avatar == "" && rp.AvatarFile == nil) +func (rp *RemoteProfile) IsEmpty() bool { + return rp == nil || (rp.Phone == "" && rp.Email == "" && rp.Username == "" && rp.Name == "" && rp.Avatar == "") } type BridgeState struct { @@ -126,10 +120,10 @@ type BridgeState struct { UserAction BridgeStateUserAction `json:"user_action,omitempty"` - UserID id.UserID `json:"user_id,omitempty"` - RemoteID networkid.UserLoginID `json:"remote_id,omitempty"` - RemoteName string `json:"remote_name,omitempty"` - RemoteProfile RemoteProfile `json:"remote_profile,omitzero"` + UserID id.UserID `json:"user_id,omitempty"` + RemoteID string `json:"remote_id,omitempty"` + RemoteName string `json:"remote_name,omitempty"` + RemoteProfile *RemoteProfile `json:"remote_profile,omitempty"` Reason string `json:"reason,omitempty"` Info map[string]interface{} `json:"info,omitempty"` @@ -209,7 +203,7 @@ func (pong *BridgeState) ShouldDeduplicate(newPong *BridgeState) bool { pong.StateEvent == newPong.StateEvent && pong.RemoteName == newPong.RemoteName && pong.UserAction == newPong.UserAction && - pong.RemoteProfile == newPong.RemoteProfile && + ptr.Val(pong.RemoteProfile) == ptr.Val(newPong.RemoteProfile) && pong.Error == newPong.Error && maps.EqualFunc(pong.Info, newPong.Info, reflect.DeepEqual) && pong.Timestamp.Add(time.Duration(pong.TTL)*time.Second).After(time.Now()) diff --git a/bridgev2/user.go b/bridgev2/user.go index 9a7896d6..87ced1d7 100644 --- a/bridgev2/user.go +++ b/bridgev2/user.go @@ -176,10 +176,6 @@ func (user *User) GetUserLogins() []*UserLogin { return maps.Values(user.logins) } -func (user *User) HasTooManyLogins() bool { - return user.Permissions.MaxLogins > 0 && len(user.GetUserLoginIDs()) >= user.Permissions.MaxLogins -} - func (user *User) GetFormattedUserLogins() string { user.Bridge.cacheLock.Lock() logins := make([]string, len(user.logins)) @@ -229,8 +225,9 @@ func (user *User) GetManagementRoom(ctx context.Context) (id.RoomID, error) { user.MXID: 50, }, }, - Invite: []id.UserID{user.MXID}, - IsDirect: true, + RoomVersion: id.RoomV11, + Invite: []id.UserID{user.MXID}, + IsDirect: true, } if autoJoin { req.BeeperInitialMembers = []id.UserID{user.MXID} diff --git a/bridgev2/userlogin.go b/bridgev2/userlogin.go index d56dc4cc..203dc122 100644 --- a/bridgev2/userlogin.go +++ b/bridgev2/userlogin.go @@ -10,7 +10,6 @@ import ( "cmp" "context" "fmt" - "maps" "slices" "sync" "time" @@ -51,8 +50,6 @@ func (br *Bridge) loadUserLogin(ctx context.Context, user *User, dbUserLogin *da if err != nil { return nil, fmt.Errorf("failed to get user: %w", err) } - // TODO if loading the user caused the provided userlogin to be loaded, cancel here? - // Currently this will double-load it } userLogin := &UserLogin{ UserLogin: dbUserLogin, @@ -143,12 +140,6 @@ func (br *Bridge) GetCachedUserLoginByID(id networkid.UserLoginID) *UserLogin { return br.userLoginsByID[id] } -func (br *Bridge) GetAllCachedUserLogins() (logins []*UserLogin) { - br.cacheLock.Lock() - defer br.cacheLock.Unlock() - return slices.Collect(maps.Values(br.userLoginsByID)) -} - func (br *Bridge) GetCurrentBridgeStates() (states []status.BridgeState) { br.cacheLock.Lock() defer br.cacheLock.Unlock() @@ -510,9 +501,9 @@ var _ status.BridgeStateFiller = (*UserLogin)(nil) func (ul *UserLogin) FillBridgeState(state status.BridgeState) status.BridgeState { state.UserID = ul.UserMXID - state.RemoteID = ul.ID + state.RemoteID = string(ul.ID) state.RemoteName = ul.RemoteName - state.RemoteProfile = ul.RemoteProfile + state.RemoteProfile = &ul.RemoteProfile filler, ok := ul.Client.(status.BridgeStateFiller) if ok { return filler.FillBridgeState(state) diff --git a/client.go b/client.go index 7062d9b9..4906169f 100644 --- a/client.go +++ b/client.go @@ -111,8 +111,6 @@ type Client struct { // Set to true to disable automatically sleeping on 429 errors. IgnoreRateLimit bool - ResponseSizeLimit int64 - txnID int32 // Should the ?user_id= query parameter be set in requests? @@ -145,8 +143,6 @@ func DiscoverClientAPI(ctx context.Context, serverName string) (*ClientWellKnown return DiscoverClientAPIWithClient(ctx, &http.Client{Timeout: 30 * time.Second}, serverName) } -const WellKnownMaxSize = 64 * 1024 - func DiscoverClientAPIWithClient(ctx context.Context, client *http.Client, serverName string) (*ClientWellKnown, error) { wellKnownURL := url.URL{ Scheme: "https", @@ -172,15 +168,11 @@ func DiscoverClientAPIWithClient(ctx context.Context, client *http.Client, serve if resp.StatusCode == http.StatusNotFound { return nil, nil - } else if resp.ContentLength > WellKnownMaxSize { - return nil, errors.New(".well-known response too large") } - data, err := io.ReadAll(io.LimitReader(resp.Body, WellKnownMaxSize)) + data, err := io.ReadAll(resp.Body) if err != nil { return nil, err - } else if len(data) >= WellKnownMaxSize { - return nil, errors.New(".well-known response too large") } var wellKnown ClientWellKnown @@ -331,7 +323,6 @@ const ( LogBodyContextKey contextKey = iota LogRequestIDContextKey MaxAttemptsContextKey - SyncTokenContextKey ) func (cli *Client) RequestStart(req *http.Request) { @@ -386,14 +377,7 @@ func (cli *Client) LogRequestDone(req *http.Request, resp *http.Response, err er } } if body := req.Context().Value(LogBodyContextKey); body != nil { - switch typedLogBody := body.(type) { - case json.RawMessage: - evt.RawJSON("req_body", typedLogBody) - case string: - evt.Str("req_body", typedLogBody) - default: - panic(fmt.Errorf("invalid type for LogBodyContextKey: %T", body)) - } + evt.Interface("req_body", body) } if errors.Is(err, context.Canceled) { evt.Msg("Request canceled") @@ -410,43 +394,32 @@ func (cli *Client) MakeRequest(ctx context.Context, method string, httpURL strin return cli.MakeFullRequest(ctx, FullRequest{Method: method, URL: httpURL, RequestJSON: reqBody, ResponseJSON: resBody}) } -type ClientResponseHandler = func(req *http.Request, res *http.Response, responseJSON any, sizeLimit int64) ([]byte, error) +type ClientResponseHandler = func(req *http.Request, res *http.Response, responseJSON interface{}) ([]byte, error) type FullRequest struct { - Method string - URL string - Headers http.Header - RequestJSON interface{} - RequestBytes []byte - RequestBody io.Reader - RequestLength int64 - ResponseJSON interface{} - MaxAttempts int - BackoffDuration time.Duration - SensitiveContent bool - Handler ClientResponseHandler - DontReadResponse bool - ResponseSizeLimit int64 - Logger *zerolog.Logger - Client *http.Client + Method string + URL string + Headers http.Header + RequestJSON interface{} + RequestBytes []byte + RequestBody io.Reader + RequestLength int64 + ResponseJSON interface{} + MaxAttempts int + BackoffDuration time.Duration + SensitiveContent bool + Handler ClientResponseHandler + DontReadResponse bool + Logger *zerolog.Logger + Client *http.Client } var requestID int32 var logSensitiveContent = os.Getenv("MAUTRIX_LOG_SENSITIVE_CONTENT") == "yes" func (params *FullRequest) compileRequest(ctx context.Context) (*http.Request, error) { - reqID := atomic.AddInt32(&requestID, 1) - logger := zerolog.Ctx(ctx) - if logger.GetLevel() == zerolog.Disabled || logger == zerolog.DefaultContextLogger { - logger = params.Logger - } - ctx = logger.With(). - Int32("req_id", reqID). - Logger().WithContext(ctx) - var logBody any - var reqBody io.Reader - var reqLen int64 + reqBody := params.RequestBody if params.RequestJSON != nil { jsonStr, err := json.Marshal(params.RequestJSON) if err != nil { @@ -457,38 +430,33 @@ func (params *FullRequest) compileRequest(ctx context.Context) (*http.Request, e } if params.SensitiveContent && !logSensitiveContent { logBody = "" - } else if len(jsonStr) > 32768 { - logBody = fmt.Sprintf("", len(jsonStr)) } else { - logBody = json.RawMessage(jsonStr) + logBody = params.RequestJSON } reqBody = bytes.NewReader(jsonStr) - reqLen = int64(len(jsonStr)) } else if params.RequestBytes != nil { logBody = fmt.Sprintf("<%d bytes>", len(params.RequestBytes)) reqBody = bytes.NewReader(params.RequestBytes) - reqLen = int64(len(params.RequestBytes)) - } else if params.RequestBody != nil { - logBody = "" - reqLen = -1 - if params.RequestLength > 0 { - logBody = fmt.Sprintf("<%d bytes>", params.RequestLength) - reqLen = params.RequestLength - } else if params.RequestLength == 0 { - zerolog.Ctx(ctx).Warn(). - Msg("RequestBody passed without specifying request length") - } - reqBody = params.RequestBody + params.RequestLength = int64(len(params.RequestBytes)) + } else if params.RequestLength > 0 && params.RequestBody != nil { + logBody = fmt.Sprintf("<%d bytes>", params.RequestLength) if rsc, ok := params.RequestBody.(io.ReadSeekCloser); ok { // Prevent HTTP from closing the request body, it might be needed for retries reqBody = nopCloseSeeker{rsc} } } else if params.Method != http.MethodGet && params.Method != http.MethodHead { params.RequestJSON = struct{}{} - logBody = json.RawMessage("{}") + logBody = params.RequestJSON reqBody = bytes.NewReader([]byte("{}")) - reqLen = 2 } + reqID := atomic.AddInt32(&requestID, 1) + logger := zerolog.Ctx(ctx) + if logger.GetLevel() == zerolog.Disabled || logger == zerolog.DefaultContextLogger { + logger = params.Logger + } + ctx = logger.With(). + Int32("req_id", reqID). + Logger().WithContext(ctx) ctx = context.WithValue(ctx, LogBodyContextKey, logBody) ctx = context.WithValue(ctx, LogRequestIDContextKey, int(reqID)) req, err := http.NewRequestWithContext(ctx, params.Method, params.URL, reqBody) @@ -504,7 +472,9 @@ func (params *FullRequest) compileRequest(ctx context.Context) (*http.Request, e if params.RequestJSON != nil { req.Header.Set("Content-Type", "application/json") } - req.ContentLength = reqLen + if params.RequestLength > 0 && params.RequestBody != nil { + req.ContentLength = params.RequestLength + } return req, nil } @@ -555,25 +525,10 @@ func (cli *Client) MakeFullRequestWithResp(ctx context.Context, params FullReque if len(cli.AccessToken) > 0 { req.Header.Set("Authorization", "Bearer "+cli.AccessToken) } - if params.ResponseSizeLimit == 0 { - params.ResponseSizeLimit = cli.ResponseSizeLimit - } - if params.ResponseSizeLimit == 0 { - params.ResponseSizeLimit = DefaultResponseSizeLimit - } if params.Client == nil { params.Client = cli.Client } - return cli.executeCompiledRequest( - req, - params.MaxAttempts-1, - params.BackoffDuration, - params.ResponseJSON, - params.Handler, - params.DontReadResponse, - params.ResponseSizeLimit, - params.Client, - ) + return cli.executeCompiledRequest(req, params.MaxAttempts-1, params.BackoffDuration, params.ResponseJSON, params.Handler, params.DontReadResponse, params.Client) } func (cli *Client) cliOrContextLog(ctx context.Context) *zerolog.Logger { @@ -584,17 +539,7 @@ func (cli *Client) cliOrContextLog(ctx context.Context) *zerolog.Logger { return log } -func (cli *Client) doRetry( - req *http.Request, - cause error, - retries int, - backoff time.Duration, - responseJSON any, - handler ClientResponseHandler, - dontReadResponse bool, - sizeLimit int64, - client *http.Client, -) ([]byte, *http.Response, error) { +func (cli *Client) doRetry(req *http.Request, cause error, retries int, backoff time.Duration, responseJSON any, handler ClientResponseHandler, dontReadResponse bool, client *http.Client) ([]byte, *http.Response, error) { log := zerolog.Ctx(req.Context()) if req.Body != nil { var err error @@ -623,30 +568,16 @@ func (cli *Client) doRetry( select { case <-time.After(backoff): case <-req.Context().Done(): - if !errors.Is(context.Cause(req.Context()), ErrContextCancelRetry) { - return nil, nil, req.Context().Err() - } + return nil, nil, req.Context().Err() } if cli.UpdateRequestOnRetry != nil { req = cli.UpdateRequestOnRetry(req, cause) } - return cli.executeCompiledRequest(req, retries-1, backoff*2, responseJSON, handler, dontReadResponse, sizeLimit, client) + return cli.executeCompiledRequest(req, retries-1, backoff*2, responseJSON, handler, dontReadResponse, client) } -func readResponseBody(req *http.Request, res *http.Response, limit int64) ([]byte, error) { - if res.ContentLength > limit { - return nil, HTTPError{ - Request: req, - Response: res, - - Message: "not reading response", - WrappedError: fmt.Errorf("%w (%.2f MiB)", ErrResponseTooLong, float64(res.ContentLength)/1024/1024), - } - } - contents, err := io.ReadAll(io.LimitReader(res.Body, limit+1)) - if err == nil && len(contents) > int(limit) { - err = ErrBodyReadReachedLimit - } +func readResponseBody(req *http.Request, res *http.Response) ([]byte, error) { + contents, err := io.ReadAll(res.Body) if err != nil { return nil, HTTPError{ Request: req, @@ -667,20 +598,17 @@ func closeTemp(log *zerolog.Logger, file *os.File) { } } -func streamResponse(req *http.Request, res *http.Response, responseJSON any, limit int64) ([]byte, error) { +func streamResponse(req *http.Request, res *http.Response, responseJSON interface{}) ([]byte, error) { log := zerolog.Ctx(req.Context()) file, err := os.CreateTemp("", "mautrix-response-") if err != nil { log.Warn().Err(err).Msg("Failed to create temporary file for streaming response") - _, err = handleNormalResponse(req, res, responseJSON, limit) + _, err = handleNormalResponse(req, res, responseJSON) return nil, err } defer closeTemp(log, file) - var n int64 - if n, err = io.Copy(file, io.LimitReader(res.Body, limit+1)); err != nil { + if _, err = io.Copy(file, res.Body); err != nil { return nil, fmt.Errorf("failed to copy response to file: %w", err) - } else if n > limit { - return nil, ErrBodyReadReachedLimit } else if _, err = file.Seek(0, 0); err != nil { return nil, fmt.Errorf("failed to seek to beginning of response file: %w", err) } else if err = json.NewDecoder(file).Decode(responseJSON); err != nil { @@ -690,12 +618,12 @@ func streamResponse(req *http.Request, res *http.Response, responseJSON any, lim } } -func noopHandleResponse(req *http.Request, res *http.Response, responseJSON any, limit int64) ([]byte, error) { +func noopHandleResponse(req *http.Request, res *http.Response, responseJSON interface{}) ([]byte, error) { return nil, nil } -func handleNormalResponse(req *http.Request, res *http.Response, responseJSON any, limit int64) ([]byte, error) { - if contents, err := readResponseBody(req, res, limit); err != nil { +func handleNormalResponse(req *http.Request, res *http.Response, responseJSON interface{}) ([]byte, error) { + if contents, err := readResponseBody(req, res); err != nil { return nil, err } else if responseJSON == nil { return contents, nil @@ -713,13 +641,8 @@ func handleNormalResponse(req *http.Request, res *http.Response, responseJSON an } } -const ErrorResponseSizeLimit = 512 * 1024 - -var DefaultResponseSizeLimit int64 = 512 * 1024 * 1024 - func ParseErrorResponse(req *http.Request, res *http.Response) ([]byte, error) { - defer res.Body.Close() - contents, err := readResponseBody(req, res, ErrorResponseSizeLimit) + contents, err := readResponseBody(req, res) if err != nil { return contents, err } @@ -738,31 +661,17 @@ func ParseErrorResponse(req *http.Request, res *http.Response) ([]byte, error) { } } -func (cli *Client) executeCompiledRequest( - req *http.Request, - retries int, - backoff time.Duration, - responseJSON any, - handler ClientResponseHandler, - dontReadResponse bool, - sizeLimit int64, - client *http.Client, -) ([]byte, *http.Response, error) { +func (cli *Client) executeCompiledRequest(req *http.Request, retries int, backoff time.Duration, responseJSON any, handler ClientResponseHandler, dontReadResponse bool, client *http.Client) ([]byte, *http.Response, error) { cli.RequestStart(req) startTime := time.Now() res, err := client.Do(req) - duration := time.Since(startTime) + duration := time.Now().Sub(startTime) if res != nil && !dontReadResponse { defer res.Body.Close() } if err != nil { - // Either error is *not* canceled or the underlying cause of cancelation explicitly asks to retry - canRetry := !errors.Is(err, context.Canceled) || - errors.Is(context.Cause(req.Context()), ErrContextCancelRetry) - if retries > 0 && canRetry { - return cli.doRetry( - req, err, retries, backoff, responseJSON, handler, dontReadResponse, sizeLimit, client, - ) + if retries > 0 && !errors.Is(err, context.Canceled) { + return cli.doRetry(req, err, retries, backoff, responseJSON, handler, dontReadResponse, client) } err = HTTPError{ Request: req, @@ -777,9 +686,7 @@ func (cli *Client) executeCompiledRequest( if retries > 0 && retryafter.Should(res.StatusCode, !cli.IgnoreRateLimit) { backoff = retryafter.Parse(res.Header.Get("Retry-After"), backoff) - return cli.doRetry( - req, fmt.Errorf("HTTP %d", res.StatusCode), retries, backoff, responseJSON, handler, dontReadResponse, sizeLimit, client, - ) + return cli.doRetry(req, fmt.Errorf("HTTP %d", res.StatusCode), retries, backoff, responseJSON, handler, dontReadResponse, client) } var body []byte @@ -787,7 +694,7 @@ func (cli *Client) executeCompiledRequest( body, err = ParseErrorResponse(req, res) cli.LogRequestDone(req, res, nil, nil, len(body), duration) } else { - body, err = handler(req, res, responseJSON, sizeLimit) + body, err = handler(req, res, responseJSON) cli.LogRequestDone(req, res, nil, err, len(body), duration) } return body, res, err @@ -847,7 +754,7 @@ func (req *ReqSync) BuildQuery() map[string]string { query["full_state"] = "true" } if req.UseStateAfter { - query["use_state_after"] = "true" + query["org.matrix.msc4222.use_state_after"] = "true" } if req.BeeperStreaming { query["com.beeper.streaming"] = "true" @@ -871,7 +778,7 @@ func (cli *Client) FullSyncRequest(ctx context.Context, req ReqSync) (resp *Resp } start := time.Now() _, err = cli.MakeFullRequest(ctx, fullReq) - duration := time.Since(start) + duration := time.Now().Sub(start) timeout := time.Duration(req.Timeout) * time.Millisecond buffer := 10 * time.Second if req.Since == "" { @@ -918,7 +825,7 @@ func (cli *Client) RegisterAvailable(ctx context.Context, username string) (resp return } -func (cli *Client) register(ctx context.Context, url string, req *ReqRegister[any]) (resp *RespRegister, uiaResp *RespUserInteractive, err error) { +func (cli *Client) register(ctx context.Context, url string, req *ReqRegister) (resp *RespRegister, uiaResp *RespUserInteractive, err error) { var bodyBytes []byte bodyBytes, err = cli.MakeFullRequest(ctx, FullRequest{ Method: http.MethodPost, @@ -942,7 +849,7 @@ func (cli *Client) register(ctx context.Context, url string, req *ReqRegister[an // Register makes an HTTP request according to https://spec.matrix.org/v1.2/client-server-api/#post_matrixclientv3register // // Registers with kind=user. For kind=guest, see RegisterGuest. -func (cli *Client) Register(ctx context.Context, req *ReqRegister[any]) (*RespRegister, *RespUserInteractive, error) { +func (cli *Client) Register(ctx context.Context, req *ReqRegister) (*RespRegister, *RespUserInteractive, error) { u := cli.BuildClientURL("v3", "register") return cli.register(ctx, u, req) } @@ -951,7 +858,7 @@ func (cli *Client) Register(ctx context.Context, req *ReqRegister[any]) (*RespRe // with kind=guest. // // For kind=user, see Register. -func (cli *Client) RegisterGuest(ctx context.Context, req *ReqRegister[any]) (*RespRegister, *RespUserInteractive, error) { +func (cli *Client) RegisterGuest(ctx context.Context, req *ReqRegister) (*RespRegister, *RespUserInteractive, error) { query := map[string]string{ "kind": "guest", } @@ -974,8 +881,8 @@ func (cli *Client) RegisterGuest(ctx context.Context, req *ReqRegister[any]) (*R // panic(err) // } // token := res.AccessToken -func (cli *Client) RegisterDummy(ctx context.Context, req *ReqRegister[any]) (*RespRegister, error) { - _, uia, err := cli.Register(ctx, req) +func (cli *Client) RegisterDummy(ctx context.Context, req *ReqRegister) (*RespRegister, error) { + res, uia, err := cli.Register(ctx, req) if err != nil && uia == nil { return nil, err } else if uia == nil { @@ -984,7 +891,7 @@ func (cli *Client) RegisterDummy(ctx context.Context, req *ReqRegister[any]) (*R return nil, errors.New("server does not support m.login.dummy") } req.Auth = BaseAuthData{Type: AuthTypeDummy, Session: uia.Session} - res, _, err := cli.Register(ctx, req) + res, _, err = cli.Register(ctx, req) if err != nil { return nil, err } @@ -1148,19 +1055,8 @@ func (cli *Client) GetProfile(ctx context.Context, mxid id.UserID) (resp *RespUs return } -func (cli *Client) SearchUserDirectory(ctx context.Context, query string, limit int) (resp *RespSearchUserDirectory, err error) { - urlPath := cli.BuildClientURL("v3", "user_directory", "search") - _, err = cli.MakeRequest(ctx, http.MethodPost, urlPath, &ReqSearchUserDirectory{ - SearchTerm: query, - Limit: limit, - }, &resp) - return -} - func (cli *Client) GetMutualRooms(ctx context.Context, otherUserID id.UserID, extras ...ReqMutualRooms) (resp *RespMutualRooms, err error) { - supportsStable := cli.SpecVersions.Supports(FeatureStableMutualRooms) - supportsUnstable := cli.SpecVersions.Supports(FeatureUnstableMutualRooms) - if cli.SpecVersions != nil && !supportsUnstable && !supportsStable { + if cli.SpecVersions != nil && !cli.SpecVersions.Supports(FeatureMutualRooms) { err = fmt.Errorf("server does not support fetching mutual rooms") return } @@ -1170,10 +1066,7 @@ func (cli *Client) GetMutualRooms(ctx context.Context, otherUserID id.UserID, ex if len(extras) > 0 { query["from"] = extras[0].From } - urlPath := cli.BuildURLWithQuery(ClientURLPath{"v1", "user", "mutual_rooms"}, query) - if !supportsStable && supportsUnstable { - urlPath = cli.BuildURLWithQuery(ClientURLPath{"unstable", "uk.half-shot.msc2666", "user", "mutual_rooms"}, query) - } + urlPath := cli.BuildURLWithQuery(ClientURLPath{"unstable", "uk.half-shot.msc2666", "user", "mutual_rooms"}, query) _, err = cli.MakeRequest(ctx, http.MethodGet, urlPath, nil, &resp) return } @@ -1195,7 +1088,8 @@ func (cli *Client) GetRoomSummary(ctx context.Context, roomIDOrAlias string, via // GetDisplayName returns the display name of the user with the specified MXID. See https://spec.matrix.org/v1.2/client-server-api/#get_matrixclientv3profileuseriddisplayname func (cli *Client) GetDisplayName(ctx context.Context, mxid id.UserID) (resp *RespUserDisplayName, err error) { - err = cli.GetProfileField(ctx, mxid, "displayname", &resp) + urlPath := cli.BuildClientURL("v3", "profile", mxid, "displayname") + _, err = cli.MakeRequest(ctx, http.MethodGet, urlPath, nil, &resp) return } @@ -1206,47 +1100,41 @@ func (cli *Client) GetOwnDisplayName(ctx context.Context) (resp *RespUserDisplay // SetDisplayName sets the user's profile display name. See https://spec.matrix.org/v1.2/client-server-api/#put_matrixclientv3profileuseriddisplayname func (cli *Client) SetDisplayName(ctx context.Context, displayName string) (err error) { - return cli.SetProfileField(ctx, "displayname", displayName) + urlPath := cli.BuildClientURL("v3", "profile", cli.UserID, "displayname") + s := struct { + DisplayName string `json:"displayname"` + }{displayName} + _, err = cli.MakeRequest(ctx, http.MethodPut, urlPath, &s, nil) + return } -// SetProfileField sets an arbitrary profile field. See https://spec.matrix.org/v1.16/client-server-api/#put_matrixclientv3profileuseridkeyname -func (cli *Client) SetProfileField(ctx context.Context, key string, value any) (err error) { - urlPath := cli.BuildClientURL("v3", "profile", cli.UserID, key) - if key != "displayname" && key != "avatar_url" && !cli.SpecVersions.Supports(FeatureArbitraryProfileFields) && cli.SpecVersions.Supports(FeatureUnstableProfileFields) { - urlPath = cli.BuildClientURL("unstable", "uk.tcpip.msc4133", "profile", cli.UserID, key) - } +// UnstableSetProfileField sets an arbitrary MSC4133 profile field. See https://github.com/matrix-org/matrix-spec-proposals/pull/4133 +func (cli *Client) UnstableSetProfileField(ctx context.Context, key string, value any) (err error) { + urlPath := cli.BuildClientURL("unstable", "uk.tcpip.msc4133", "profile", cli.UserID, key) _, err = cli.MakeRequest(ctx, http.MethodPut, urlPath, map[string]any{ key: value, }, nil) return } -// DeleteProfileField deletes an arbitrary profile field. See https://spec.matrix.org/v1.16/client-server-api/#put_matrixclientv3profileuseridkeyname -func (cli *Client) DeleteProfileField(ctx context.Context, key string) (err error) { - urlPath := cli.BuildClientURL("v3", "profile", cli.UserID, key) - if key != "displayname" && key != "avatar_url" && !cli.SpecVersions.Supports(FeatureArbitraryProfileFields) && cli.SpecVersions.Supports(FeatureUnstableProfileFields) { - urlPath = cli.BuildClientURL("unstable", "uk.tcpip.msc4133", "profile", cli.UserID, key) - } +// UnstableDeleteProfileField deletes an arbitrary MSC4133 profile field. See https://github.com/matrix-org/matrix-spec-proposals/pull/4133 +func (cli *Client) UnstableDeleteProfileField(ctx context.Context, key string) (err error) { + urlPath := cli.BuildClientURL("unstable", "uk.tcpip.msc4133", "profile", cli.UserID, key) _, err = cli.MakeRequest(ctx, http.MethodDelete, urlPath, nil, nil) return } -// GetProfileField gets an arbitrary profile field and parses the response into the given struct. See https://spec.matrix.org/unstable/client-server-api/#get_matrixclientv3profileuseridkeyname -func (cli *Client) GetProfileField(ctx context.Context, userID id.UserID, key string, into any) (err error) { - urlPath := cli.BuildClientURL("v3", "profile", userID, key) - if key != "displayname" && key != "avatar_url" && !cli.SpecVersions.Supports(FeatureArbitraryProfileFields) && cli.SpecVersions.Supports(FeatureUnstableProfileFields) { - urlPath = cli.BuildClientURL("unstable", "uk.tcpip.msc4133", "profile", cli.UserID, key) - } - _, err = cli.MakeRequest(ctx, http.MethodGet, urlPath, nil, into) - return -} - // GetAvatarURL gets the avatar URL of the user with the specified MXID. See https://spec.matrix.org/v1.2/client-server-api/#get_matrixclientv3profileuseridavatar_url func (cli *Client) GetAvatarURL(ctx context.Context, mxid id.UserID) (url id.ContentURI, err error) { + urlPath := cli.BuildClientURL("v3", "profile", mxid, "avatar_url") s := struct { AvatarURL id.ContentURI `json:"avatar_url"` }{} - err = cli.GetProfileField(ctx, mxid, "avatar_url", &s) + + _, err = cli.MakeRequest(ctx, http.MethodGet, urlPath, nil, &s) + if err != nil { + return + } url = s.AvatarURL return } @@ -1338,9 +1226,6 @@ func (cli *Client) SendMessageEvent(ctx context.Context, roomID id.RoomID, event if req.UnstableDelay > 0 { queryParams["org.matrix.msc4140.delay"] = strconv.FormatInt(req.UnstableDelay.Milliseconds(), 10) } - if req.UnstableStickyDuration > 0 { - queryParams["org.matrix.msc4354.sticky_duration_ms"] = strconv.FormatInt(req.UnstableStickyDuration.Milliseconds(), 10) - } if !req.DontEncrypt && cli != nil && cli.Crypto != nil && eventType != event.EventReaction && eventType != event.EventEncrypted { var isEncrypted bool @@ -1364,51 +1249,9 @@ func (cli *Client) SendMessageEvent(ctx context.Context, roomID id.RoomID, event return } -// BeeperSendEphemeralEvent sends an ephemeral event into a room using Beeper's unstable endpoint. -// contentJSON should be a value that can be encoded as JSON using json.Marshal. -func (cli *Client) BeeperSendEphemeralEvent(ctx context.Context, roomID id.RoomID, eventType event.Type, contentJSON any, extra ...ReqSendEvent) (resp *RespSendEvent, err error) { - var req ReqSendEvent - if len(extra) > 0 { - req = extra[0] - } - - var txnID string - if len(req.TransactionID) > 0 { - txnID = req.TransactionID - } else { - txnID = cli.TxnID() - } - - queryParams := map[string]string{} - if req.Timestamp > 0 { - queryParams["ts"] = strconv.FormatInt(req.Timestamp, 10) - } - - if !req.DontEncrypt && cli != nil && cli.Crypto != nil && eventType != event.EventEncrypted { - var isEncrypted bool - isEncrypted, err = cli.StateStore.IsEncrypted(ctx, roomID) - if err != nil { - err = fmt.Errorf("failed to check if room is encrypted: %w", err) - return - } - if isEncrypted { - if contentJSON, err = cli.Crypto.Encrypt(ctx, roomID, eventType, contentJSON); err != nil { - err = fmt.Errorf("failed to encrypt event: %w", err) - return - } - eventType = event.EventEncrypted - } - } - - urlData := ClientURLPath{"unstable", "com.beeper.ephemeral", "rooms", roomID, "ephemeral", eventType.String(), txnID} - urlPath := cli.BuildURLWithQuery(urlData, queryParams) - _, err = cli.MakeRequest(ctx, http.MethodPut, urlPath, contentJSON, &resp) - return -} - -// SendStateEvent sends a state event into a room. See https://spec.matrix.org/v1.16/client-server-api/#put_matrixclientv3roomsroomidstateeventtypestatekey +// SendStateEvent sends a state event into a room. See https://spec.matrix.org/v1.2/client-server-api/#put_matrixclientv3roomsroomidstateeventtypestatekey // contentJSON should be a pointer to something that can be encoded as JSON using json.Marshal. -func (cli *Client) SendStateEvent(ctx context.Context, roomID id.RoomID, eventType event.Type, stateKey string, contentJSON any, extra ...ReqSendEvent) (resp *RespSendEvent, err error) { +func (cli *Client) SendStateEvent(ctx context.Context, roomID id.RoomID, eventType event.Type, stateKey string, contentJSON interface{}, extra ...ReqSendEvent) (resp *RespSendEvent, err error) { var req ReqSendEvent if len(extra) > 0 { req = extra[0] @@ -1418,18 +1261,9 @@ func (cli *Client) SendStateEvent(ctx context.Context, roomID id.RoomID, eventTy if req.MeowEventID != "" { queryParams["fi.mau.event_id"] = req.MeowEventID.String() } - if req.TransactionID != "" { - queryParams["fi.mau.transaction_id"] = req.TransactionID - } if req.UnstableDelay > 0 { queryParams["org.matrix.msc4140.delay"] = strconv.FormatInt(req.UnstableDelay.Milliseconds(), 10) } - if req.UnstableStickyDuration > 0 { - queryParams["org.matrix.msc4354.sticky_duration_ms"] = strconv.FormatInt(req.UnstableStickyDuration.Milliseconds(), 10) - } - if req.Timestamp > 0 { - queryParams["ts"] = strconv.FormatInt(req.Timestamp, 10) - } urlData := ClientURLPath{"v3", "rooms", roomID, "state", eventType.String(), stateKey} urlPath := cli.BuildURLWithQuery(urlData, queryParams) @@ -1442,38 +1276,14 @@ func (cli *Client) SendStateEvent(ctx context.Context, roomID id.RoomID, eventTy // SendMassagedStateEvent sends a state event into a room with a custom timestamp. See https://spec.matrix.org/v1.2/client-server-api/#put_matrixclientv3roomsroomidstateeventtypestatekey // contentJSON should be a pointer to something that can be encoded as JSON using json.Marshal. -// -// Deprecated: SendStateEvent accepts a timestamp via ReqSendEvent and should be used instead. func (cli *Client) SendMassagedStateEvent(ctx context.Context, roomID id.RoomID, eventType event.Type, stateKey string, contentJSON interface{}, ts int64) (resp *RespSendEvent, err error) { - resp, err = cli.SendStateEvent(ctx, roomID, eventType, stateKey, contentJSON, ReqSendEvent{ - Timestamp: ts, + urlPath := cli.BuildURLWithQuery(ClientURLPath{"v3", "rooms", roomID, "state", eventType.String(), stateKey}, map[string]string{ + "ts": strconv.FormatInt(ts, 10), }) - return -} - -func (cli *Client) DelayedEvents(ctx context.Context, req *ReqDelayedEvents) (resp *RespDelayedEvents, err error) { - query := map[string]string{} - if req.DelayID != "" { - query["delay_id"] = string(req.DelayID) + _, err = cli.MakeRequest(ctx, http.MethodPut, urlPath, contentJSON, &resp) + if err == nil && cli.StateStore != nil { + cli.updateStoreWithOutgoingEvent(ctx, roomID, eventType, stateKey, contentJSON) } - if req.Status != "" { - query["status"] = string(req.Status) - } - if req.NextBatch != "" { - query["next_batch"] = req.NextBatch - } - - urlPath := cli.BuildURLWithQuery(ClientURLPath{"unstable", "org.matrix.msc4140", "delayed_events"}, query) - _, err = cli.MakeRequest(ctx, http.MethodGet, urlPath, req, &resp) - - // Migration: merge old keys with new ones - if resp != nil { - resp.Scheduled = append(resp.Scheduled, resp.DelayedEvents...) - resp.DelayedEvents = nil - resp.Finalised = append(resp.Finalised, resp.FinalisedEvents...) - resp.FinalisedEvents = nil - } - return } @@ -1766,20 +1576,11 @@ func (cli *Client) FullStateEvent(ctx context.Context, roomID id.RoomID, eventTy } // parseRoomStateArray parses a JSON array as a stream and stores the events inside it in a room state map. -func parseRoomStateArray(req *http.Request, res *http.Response, responseJSON any, limit int64) ([]byte, error) { - if res.ContentLength > limit { - return nil, HTTPError{ - Request: req, - Response: res, - - Message: "not reading response", - WrappedError: fmt.Errorf("%w (%.2f MiB)", ErrResponseTooLong, float64(res.ContentLength)/1024/1024), - } - } +func parseRoomStateArray(_ *http.Request, res *http.Response, responseJSON interface{}) ([]byte, error) { response := make(RoomStateMap) responsePtr := responseJSON.(*map[event.Type]map[string]*event.Event) *responsePtr = response - dec := json.NewDecoder(io.LimitReader(res.Body, limit)) + dec := json.NewDecoder(res.Body) arrayStart, err := dec.Token() if err != nil { @@ -1813,8 +1614,6 @@ func parseRoomStateArray(req *http.Request, res *http.Response, responseJSON any return nil, nil } -type RoomStateMap = map[event.Type]map[string]*event.Event - // State gets all state in a room. // See https://spec.matrix.org/v1.2/client-server-api/#get_matrixclientv3roomsroomidstate func (cli *Client) State(ctx context.Context, roomID id.RoomID) (stateMap RoomStateMap, err error) { @@ -1897,9 +1696,6 @@ func (cli *Client) UploadLink(ctx context.Context, link string) (*RespMediaUploa } func (cli *Client) Download(ctx context.Context, mxcURL id.ContentURI) (*http.Response, error) { - if mxcURL.IsEmpty() { - return nil, fmt.Errorf("empty mxc uri provided to Download") - } _, resp, err := cli.MakeFullRequestWithResp(ctx, FullRequest{ Method: http.MethodGet, URL: cli.BuildClientURL("v1", "media", "download", mxcURL.Homeserver, mxcURL.FileID), @@ -1908,41 +1704,6 @@ func (cli *Client) Download(ctx context.Context, mxcURL id.ContentURI) (*http.Re return resp, err } -type DownloadThumbnailExtra struct { - Method string - Animated bool -} - -func (cli *Client) DownloadThumbnail(ctx context.Context, mxcURL id.ContentURI, height, width int, extras ...DownloadThumbnailExtra) (*http.Response, error) { - if mxcURL.IsEmpty() { - return nil, fmt.Errorf("empty mxc uri provided to DownloadThumbnail") - } - if len(extras) > 1 { - panic(fmt.Errorf("invalid number of arguments to DownloadThumbnail: %d", len(extras))) - } - var extra DownloadThumbnailExtra - if len(extras) == 1 { - extra = extras[0] - } - path := ClientURLPath{"v1", "media", "thumbnail", mxcURL.Homeserver, mxcURL.FileID} - query := map[string]string{ - "height": strconv.Itoa(height), - "width": strconv.Itoa(width), - } - if extra.Method != "" { - query["method"] = extra.Method - } - if extra.Animated { - query["animated"] = "true" - } - _, resp, err := cli.MakeFullRequestWithResp(ctx, FullRequest{ - Method: http.MethodGet, - URL: cli.BuildURLWithQuery(path, query), - DontReadResponse: true, - }) - return resp, err -} - func (cli *Client) DownloadBytes(ctx context.Context, mxcURL id.ContentURI) ([]byte, error) { resp, err := cli.Download(ctx, mxcURL) if err != nil { @@ -1989,15 +1750,10 @@ func (cli *Client) UploadAsync(ctx context.Context, req ReqUploadMedia) (*RespCr } req.MXC = resp.ContentURI req.UnstableUploadURL = resp.UnstableUploadURL - if req.AsyncContext == nil { - req.AsyncContext = cli.cliOrContextLog(ctx).WithContext(context.Background()) - } go func() { - _, err = cli.UploadMedia(req.AsyncContext, req) + _, err = cli.UploadMedia(ctx, req) if err != nil { - zerolog.Ctx(req.AsyncContext).Err(err). - Stringer("mxc", req.MXC). - Msg("Async upload of media failed") + cli.Log.Error().Str("mxc", req.MXC.String()).Err(err).Msg("Async upload of media failed") } }() return resp, nil @@ -2033,7 +1789,6 @@ type ReqUploadMedia struct { ContentType string FileName string - AsyncContext context.Context DoneCallback func() // MXC specifies an existing MXC URI which doesn't have content yet to upload into. @@ -2046,10 +1801,7 @@ type ReqUploadMedia struct { } func (cli *Client) tryUploadMediaToURL(ctx context.Context, url, contentType string, content io.Reader, contentLength int64) (*http.Response, error) { - cli.Log.Debug(). - Str("url", url). - Int64("content_length", contentLength). - Msg("Uploading media to external URL") + cli.Log.Debug().Str("url", url).Msg("Uploading media to external URL") req, err := http.NewRequestWithContext(ctx, http.MethodPut, url, content) if err != nil { return nil, err @@ -2098,16 +1850,8 @@ func (cli *Client) uploadMediaToURL(ctx context.Context, data ReqUploadMedia) (* Msg("Error uploading media to external URL, not retrying") return nil, err } - backoff := time.Second * time.Duration(cli.DefaultHTTPRetries-retries) - cli.Log.Warn().Err(err). - Str("url", data.UnstableUploadURL). - Int("retry_in_seconds", int(backoff.Seconds())). + cli.Log.Warn().Str("url", data.UnstableUploadURL).Err(err). Msg("Error uploading media to external URL, retrying") - select { - case <-time.After(backoff): - case <-ctx.Done(): - return nil, ctx.Err() - } retries-- _, err = readerSeeker.Seek(0, io.SeekStart) if err != nil { @@ -2687,15 +2431,15 @@ func (cli *Client) SetDeviceInfo(ctx context.Context, deviceID id.DeviceID, req return err } -func (cli *Client) DeleteDevice(ctx context.Context, deviceID id.DeviceID, req *ReqDeleteDevice[any]) error { +func (cli *Client) DeleteDevice(ctx context.Context, deviceID id.DeviceID, req *ReqDeleteDevice) error { urlPath := cli.BuildClientURL("v3", "devices", deviceID) _, err := cli.MakeRequest(ctx, http.MethodDelete, urlPath, req, nil) return err } -func (cli *Client) DeleteDevices(ctx context.Context, req *ReqDeleteDevices[any]) error { +func (cli *Client) DeleteDevices(ctx context.Context, req *ReqDeleteDevices) error { urlPath := cli.BuildClientURL("v3", "delete_devices") - _, err := cli.MakeRequest(ctx, http.MethodPost, urlPath, req, nil) + _, err := cli.MakeRequest(ctx, http.MethodDelete, urlPath, req, nil) return err } @@ -2704,7 +2448,7 @@ type UIACallback = func(*RespUserInteractive) interface{} // UploadCrossSigningKeys uploads the given cross-signing keys to the server. // Because the endpoint requires user-interactive authentication a callback must be provided that, // given the UI auth parameters, produces the required result (or nil to end the flow). -func (cli *Client) UploadCrossSigningKeys(ctx context.Context, keys *UploadCrossSigningKeysReq[any], uiaCallback UIACallback) error { +func (cli *Client) UploadCrossSigningKeys(ctx context.Context, keys *UploadCrossSigningKeysReq, uiaCallback UIACallback) error { content, err := cli.MakeFullRequest(ctx, FullRequest{ Method: http.MethodPost, URL: cli.BuildClientURL("v3", "keys", "device_signing", "upload"), @@ -2786,61 +2530,24 @@ func (cli *Client) ReportRoom(ctx context.Context, roomID id.RoomID, reason stri return err } -// AdminWhoIs fetches session information belonging to a specific user. Typically requires being a server admin. +// BatchSend sends a batch of historical events into a room. This is only available for appservices. // -// https://spec.matrix.org/v1.15/client-server-api/#get_matrixclientv3adminwhoisuserid -func (cli *Client) AdminWhoIs(ctx context.Context, userID id.UserID) (resp RespWhoIs, err error) { - urlPath := cli.BuildClientURL("v3", "admin", "whois", userID) - _, err = cli.MakeRequest(ctx, http.MethodGet, urlPath, nil, &resp) - return -} - -func (cli *Client) makeMSC4323URL(action string, target id.UserID) string { - if cli.SpecVersions.Supports(FeatureUnstableAccountModeration) { - return cli.BuildClientURL("unstable", "uk.timedout.msc4323", "admin", action, target) - } else if cli.SpecVersions.Supports(FeatureStableAccountModeration) { - return cli.BuildClientURL("v1", "admin", action, target) +// Deprecated: MSC2716 has been abandoned, so this is now Beeper-specific. BeeperBatchSend should be used instead. +func (cli *Client) BatchSend(ctx context.Context, roomID id.RoomID, req *ReqBatchSend) (resp *RespBatchSend, err error) { + path := ClientURLPath{"unstable", "org.matrix.msc2716", "rooms", roomID, "batch_send"} + query := map[string]string{ + "prev_event_id": req.PrevEventID.String(), } - return "" -} - -// GetSuspendedStatus uses MSC4323 to check if a user is suspended. -func (cli *Client) GetSuspendedStatus(ctx context.Context, userID id.UserID) (res *RespSuspended, err error) { - urlPath := cli.makeMSC4323URL("suspend", userID) - if urlPath == "" { - return nil, MUnrecognized.WithMessage("Homeserver does not advertise MSC4323 support") + if req.BeeperNewMessages { + query["com.beeper.new_messages"] = "true" } - _, err = cli.MakeRequest(ctx, http.MethodGet, urlPath, nil, res) - return -} - -// GetLockStatus uses MSC4323 to check if a user is locked. -func (cli *Client) GetLockStatus(ctx context.Context, userID id.UserID) (res *RespLocked, err error) { - urlPath := cli.makeMSC4323URL("lock", userID) - if urlPath == "" { - return nil, MUnrecognized.WithMessage("Homeserver does not advertise MSC4323 support") + if req.BeeperMarkReadBy != "" { + query["com.beeper.mark_read_by"] = req.BeeperMarkReadBy.String() } - _, err = cli.MakeRequest(ctx, http.MethodGet, urlPath, nil, res) - return -} - -// SetSuspendedStatus uses MSC4323 to set whether a user account is suspended. -func (cli *Client) SetSuspendedStatus(ctx context.Context, userID id.UserID, suspended bool) (res *RespSuspended, err error) { - urlPath := cli.makeMSC4323URL("suspend", userID) - if urlPath == "" { - return nil, MUnrecognized.WithMessage("Homeserver does not advertise MSC4323 support") + if len(req.BatchID) > 0 { + query["batch_id"] = req.BatchID.String() } - _, err = cli.MakeRequest(ctx, http.MethodPut, urlPath, &ReqSuspend{Suspended: suspended}, res) - return -} - -// SetLockStatus uses MSC4323 to set whether a user account is locked. -func (cli *Client) SetLockStatus(ctx context.Context, userID id.UserID, locked bool) (res *RespLocked, err error) { - urlPath := cli.makeMSC4323URL("lock", userID) - if urlPath == "" { - return nil, MUnrecognized.WithMessage("Homeserver does not advertise MSC4323 support") - } - _, err = cli.MakeRequest(ctx, http.MethodPut, urlPath, &ReqLocked{Locked: locked}, res) + _, err = cli.MakeRequest(ctx, http.MethodPost, cli.BuildURLWithQuery(path, query), req, &resp) return } diff --git a/client_ephemeral_test.go b/client_ephemeral_test.go deleted file mode 100644 index c2846427..00000000 --- a/client_ephemeral_test.go +++ /dev/null @@ -1,158 +0,0 @@ -// Copyright (c) 2026 Tulir Asokan -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. - -package mautrix_test - -import ( - "context" - "encoding/json" - "errors" - "net/http" - "net/http/httptest" - "strings" - "testing" - "time" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - - "maunium.net/go/mautrix" - "maunium.net/go/mautrix/event" - "maunium.net/go/mautrix/id" -) - -func TestClient_SendEphemeralEvent_UsesUnstablePathTxnAndTS(t *testing.T) { - roomID := id.RoomID("!room:example.com") - evtType := event.Type{Type: "com.example.ephemeral", Class: event.EphemeralEventType} - txnID := "txn-123" - - var gotPath string - var gotQueryTS string - ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - gotPath = r.URL.Path - gotQueryTS = r.URL.Query().Get("ts") - assert.Equal(t, http.MethodPut, r.Method) - _, _ = w.Write([]byte(`{"event_id":"$evt"}`)) - })) - defer ts.Close() - - cli, err := mautrix.NewClient(ts.URL, "", "") - require.NoError(t, err) - - _, err = cli.BeeperSendEphemeralEvent( - context.Background(), - roomID, - evtType, - map[string]any{"foo": "bar"}, - mautrix.ReqSendEvent{TransactionID: txnID, Timestamp: 1234}, - ) - require.NoError(t, err) - - assert.True(t, strings.Contains(gotPath, "/_matrix/client/unstable/com.beeper.ephemeral/rooms/")) - assert.True(t, strings.HasSuffix(gotPath, "/ephemeral/com.example.ephemeral/"+txnID)) - assert.Equal(t, "1234", gotQueryTS) -} - -func TestClient_SendEphemeralEvent_UnsupportedReturnsMUnrecognized(t *testing.T) { - ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { - w.WriteHeader(http.StatusNotFound) - _, _ = w.Write([]byte(`{"errcode":"M_UNRECOGNIZED","error":"Unrecognized endpoint"}`)) - })) - defer ts.Close() - - cli, err := mautrix.NewClient(ts.URL, "", "") - require.NoError(t, err) - - _, err = cli.BeeperSendEphemeralEvent( - context.Background(), - id.RoomID("!room:example.com"), - event.Type{Type: "com.example.ephemeral", Class: event.EphemeralEventType}, - map[string]any{"foo": "bar"}, - ) - require.Error(t, err) - assert.True(t, errors.Is(err, mautrix.MUnrecognized)) -} - -func TestClient_SendEphemeralEvent_EncryptsInEncryptedRooms(t *testing.T) { - roomID := id.RoomID("!room:example.com") - evtType := event.Type{Type: "com.example.ephemeral", Class: event.EphemeralEventType} - txnID := "txn-encrypted" - - stateStore := mautrix.NewMemoryStateStore() - err := stateStore.SetEncryptionEvent(context.Background(), roomID, &event.EncryptionEventContent{ - Algorithm: id.AlgorithmMegolmV1, - }) - require.NoError(t, err) - - fakeCrypto := &fakeCryptoHelper{ - encryptedContent: &event.EncryptedEventContent{ - Algorithm: id.AlgorithmMegolmV1, - MegolmCiphertext: []byte("ciphertext"), - }, - } - - var gotPath string - var gotBody map[string]any - ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - gotPath = r.URL.Path - assert.Equal(t, http.MethodPut, r.Method) - err := json.NewDecoder(r.Body).Decode(&gotBody) - require.NoError(t, err) - _, _ = w.Write([]byte(`{"event_id":"$evt"}`)) - })) - defer ts.Close() - - cli, err := mautrix.NewClient(ts.URL, "", "") - require.NoError(t, err) - cli.StateStore = stateStore - cli.Crypto = fakeCrypto - - _, err = cli.BeeperSendEphemeralEvent( - context.Background(), - roomID, - evtType, - map[string]any{"foo": "bar"}, - mautrix.ReqSendEvent{TransactionID: txnID}, - ) - require.NoError(t, err) - - assert.True(t, strings.HasSuffix(gotPath, "/ephemeral/m.room.encrypted/"+txnID)) - assert.Equal(t, string(id.AlgorithmMegolmV1), gotBody["algorithm"]) - assert.Equal(t, 1, fakeCrypto.encryptCalls) - assert.Equal(t, roomID, fakeCrypto.lastRoomID) - assert.Equal(t, evtType, fakeCrypto.lastEventType) -} - -type fakeCryptoHelper struct { - encryptCalls int - lastRoomID id.RoomID - lastEventType event.Type - lastEncryptInput any - encryptedContent *event.EncryptedEventContent -} - -func (f *fakeCryptoHelper) Encrypt(_ context.Context, roomID id.RoomID, eventType event.Type, content any) (*event.EncryptedEventContent, error) { - f.encryptCalls++ - f.lastRoomID = roomID - f.lastEventType = eventType - f.lastEncryptInput = content - return f.encryptedContent, nil -} - -func (f *fakeCryptoHelper) Decrypt(context.Context, *event.Event) (*event.Event, error) { - return nil, nil -} - -func (f *fakeCryptoHelper) WaitForSession(context.Context, id.RoomID, id.SenderKey, id.SessionID, time.Duration) bool { - return false -} - -func (f *fakeCryptoHelper) RequestSession(context.Context, id.RoomID, id.SenderKey, id.SessionID, id.UserID, id.DeviceID) { -} - -func (f *fakeCryptoHelper) Init(context.Context) error { - return nil -} diff --git a/commands/container.go b/commands/container.go index 9b909b75..bc685b7b 100644 --- a/commands/container.go +++ b/commands/container.go @@ -1,4 +1,4 @@ -// Copyright (c) 2026 Tulir Asokan +// Copyright (c) 2025 Tulir Asokan // // This Source Code Form is subject to the terms of the Mozilla Public // License, v. 2.0. If a copy of the MPL was not distributed with this @@ -8,20 +8,14 @@ package commands import ( "fmt" - "slices" "strings" "sync" - - "go.mau.fi/util/exmaps" - - "maunium.net/go/mautrix/event/cmdschema" ) type CommandContainer[MetaType any] struct { commands map[string]*Handler[MetaType] aliases map[string]string lock sync.RWMutex - parent *Handler[MetaType] } func NewCommandContainer[MetaType any]() *CommandContainer[MetaType] { @@ -31,29 +25,6 @@ func NewCommandContainer[MetaType any]() *CommandContainer[MetaType] { } } -func (cont *CommandContainer[MetaType]) AllSpecs() []*cmdschema.EventContent { - data := make(exmaps.Set[*Handler[MetaType]]) - cont.collectHandlers(data) - specs := make([]*cmdschema.EventContent, 0, data.Size()) - for handler := range data.Iter() { - if handler.Parameters != nil { - specs = append(specs, handler.Spec()) - } - } - return specs -} - -func (cont *CommandContainer[MetaType]) collectHandlers(into exmaps.Set[*Handler[MetaType]]) { - cont.lock.RLock() - defer cont.lock.RUnlock() - for _, handler := range cont.commands { - into.Add(handler) - if handler.subcommandContainer != nil { - handler.subcommandContainer.collectHandlers(into) - } - } -} - // Register registers the given command handlers. func (cont *CommandContainer[MetaType]) Register(handlers ...*Handler[MetaType]) { if cont == nil { @@ -61,10 +32,7 @@ func (cont *CommandContainer[MetaType]) Register(handlers ...*Handler[MetaType]) } cont.lock.Lock() defer cont.lock.Unlock() - for i, handler := range handlers { - if handler == nil { - panic(fmt.Errorf("handler #%d is nil", i+1)) - } + for _, handler := range handlers { cont.registerOne(handler) } } @@ -77,10 +45,6 @@ func (cont *CommandContainer[MetaType]) registerOne(handler *Handler[MetaType]) } else if aliasTarget, alreadyExists := cont.aliases[handler.Name]; alreadyExists { panic(fmt.Errorf("tried to register command %q, but it's already registered as an alias for %q", handler.Name, aliasTarget)) } - if !slices.Contains(handler.parents, cont.parent) { - handler.parents = append(handler.parents, cont.parent) - handler.nestedNameCache = nil - } cont.commands[handler.Name] = handler for _, alias := range handler.Aliases { if strings.ToLower(alias) != alias { diff --git a/commands/event.go b/commands/event.go index 76d6c9f0..77a3c0d2 100644 --- a/commands/event.go +++ b/commands/event.go @@ -1,4 +1,4 @@ -// Copyright (c) 2026 Tulir Asokan +// Copyright (c) 2025 Tulir Asokan // // This Source Code Form is subject to the terms of the Mozilla Public // License, v. 2.0. If a copy of the MPL was not distributed with this @@ -8,7 +8,6 @@ package commands import ( "context" - "encoding/json" "fmt" "strings" @@ -36,8 +35,6 @@ type Event[MetaType any] struct { // RawArgs is the same as args, but without the splitting by whitespace. RawArgs string - StructuredArgs json.RawMessage - Ctx context.Context Log *zerolog.Logger Proc *Processor[MetaType] @@ -64,7 +61,7 @@ var IDHTMLParser = &format.HTMLParser{ } // ParseEvent parses a message into a command event struct. -func (proc *Processor[MetaType]) ParseEvent(ctx context.Context, evt *event.Event) *Event[MetaType] { +func ParseEvent[MetaType any](ctx context.Context, evt *event.Event) *Event[MetaType] { content, ok := evt.Content.Parsed.(*event.MessageEventContent) if !ok || content.MsgType == event.MsgNotice || content.RelatesTo.GetReplaceID() != "" { return nil @@ -73,34 +70,12 @@ func (proc *Processor[MetaType]) ParseEvent(ctx context.Context, evt *event.Even if content.Format == event.FormatHTML { text = IDHTMLParser.Parse(content.FormattedBody, format.NewContext(ctx)) } - if content.MSC4391BotCommand != nil { - if !content.Mentions.Has(proc.Client.UserID) || len(content.Mentions.UserIDs) != 1 { - return nil - } - wrapped := StructuredCommandToEvent[MetaType](ctx, evt, content.MSC4391BotCommand) - wrapped.RawInput = text - return wrapped - } if len(text) == 0 { return nil } return RawTextToEvent[MetaType](ctx, evt, text) } -func StructuredCommandToEvent[MetaType any](ctx context.Context, evt *event.Event, content *event.MSC4391BotCommandInput) *Event[MetaType] { - commandParts := strings.Split(content.Command, " ") - return &Event[MetaType]{ - Event: evt, - // Fake a command and args to let the subcommand finder in Process work. - Command: commandParts[0], - Args: commandParts[1:], - Ctx: ctx, - Log: zerolog.Ctx(ctx), - - StructuredArgs: content.Arguments, - } -} - func RawTextToEvent[MetaType any](ctx context.Context, evt *event.Event, text string) *Event[MetaType] { parts := strings.Fields(text) if len(parts) == 0 { @@ -213,25 +188,3 @@ func (evt *Event[MetaType]) UnshiftArg(arg string) { evt.RawArgs = arg + " " + evt.RawArgs evt.Args = append([]string{arg}, evt.Args...) } - -func (evt *Event[MetaType]) ParseArgs(into any) error { - return json.Unmarshal(evt.StructuredArgs, into) -} - -func ParseArgs[T, MetaType any](evt *Event[MetaType]) (into T, err error) { - err = evt.ParseArgs(&into) - return -} - -func WithParsedArgs[T, MetaType any](fn func(*Event[MetaType], T)) func(*Event[MetaType]) { - return func(evt *Event[MetaType]) { - parsed, err := ParseArgs[T, MetaType](evt) - if err != nil { - evt.Log.Debug().Err(err).Msg("Failed to parse structured args into struct") - // TODO better error, usage info? deduplicate with Process - evt.Reply("Failed to parse arguments: %v", err) - return - } - fn(evt, parsed) - } -} diff --git a/commands/handler.go b/commands/handler.go index 56f27f06..b01d594f 100644 --- a/commands/handler.go +++ b/commands/handler.go @@ -1,4 +1,4 @@ -// Copyright (c) 2026 Tulir Asokan +// Copyright (c) 2025 Tulir Asokan // // This Source Code Form is subject to the terms of the Mozilla Public // License, v. 2.0. If a copy of the MPL was not distributed with this @@ -8,9 +8,6 @@ package commands import ( "strings" - - "maunium.net/go/mautrix/event" - "maunium.net/go/mautrix/event/cmdschema" ) type Handler[MetaType any] struct { @@ -28,63 +25,12 @@ type Handler[MetaType any] struct { // Event.ShiftArg will likely be useful for implementing such parameters. PreFunc func(ce *Event[MetaType]) - // Description is a short description of the command. - Description *event.ExtensibleTextContainer - // Parameters is a description of structured command parameters. - // If set, the StructuredArgs field of Event will be populated. - Parameters []*cmdschema.Parameter - TailParam string - - parents []*Handler[MetaType] - nestedNameCache []string subcommandContainer *CommandContainer[MetaType] } -func (h *Handler[MetaType]) NestedNames() []string { - if h.nestedNameCache != nil { - return h.nestedNameCache - } - nestedNames := make([]string, 0, (1+len(h.Aliases))*len(h.parents)) - for _, parent := range h.parents { - if parent == nil { - nestedNames = append(nestedNames, h.Name) - nestedNames = append(nestedNames, h.Aliases...) - } else { - for _, parentName := range parent.NestedNames() { - nestedNames = append(nestedNames, parentName+" "+h.Name) - for _, alias := range h.Aliases { - nestedNames = append(nestedNames, parentName+" "+alias) - } - } - } - } - h.nestedNameCache = nestedNames - return nestedNames -} - -func (h *Handler[MetaType]) Spec() *cmdschema.EventContent { - names := h.NestedNames() - return &cmdschema.EventContent{ - Command: names[0], - Aliases: names[1:], - Parameters: h.Parameters, - Description: h.Description, - TailParam: h.TailParam, - } -} - -func (h *Handler[MetaType]) CopyFrom(other *Handler[MetaType]) { - if h.Parameters == nil { - h.Parameters = other.Parameters - h.TailParam = other.TailParam - } - h.Func = other.Func -} - func (h *Handler[MetaType]) initSubcommandContainer() { if len(h.Subcommands) > 0 { h.subcommandContainer = NewCommandContainer[MetaType]() - h.subcommandContainer.parent = h h.subcommandContainer.Register(h.Subcommands...) } else { h.subcommandContainer = nil diff --git a/commands/processor.go b/commands/processor.go index 80f6745d..9341329b 100644 --- a/commands/processor.go +++ b/commands/processor.go @@ -1,4 +1,4 @@ -// Copyright (c) 2026 Tulir Asokan +// Copyright (c) 2025 Tulir Asokan // // This Source Code Form is subject to the terms of the Mozilla Public // License, v. 2.0. If a copy of the MPL was not distributed with this @@ -72,9 +72,9 @@ func (proc *Processor[MetaType]) Process(ctx context.Context, evt *event.Event) case event.EventReaction: parsed = proc.ParseReaction(ctx, evt) case event.EventMessage: - parsed = proc.ParseEvent(ctx, evt) + parsed = ParseEvent[MetaType](ctx, evt) } - if parsed == nil || (!proc.PreValidator.Validate(parsed) && parsed.StructuredArgs == nil) { + if parsed == nil || !proc.PreValidator.Validate(parsed) { return } parsed.Proc = proc @@ -107,12 +107,6 @@ func (proc *Processor[MetaType]) Process(ctx context.Context, evt *event.Event) break } } - if parsed.StructuredArgs != nil && len(parsed.Args) > 0 { - // TODO allow unknown command handlers to be called? - // The client sent MSC4391 data, but the target command wasn't found - log.Debug().Msg("Didn't find handler for MSC4391 command") - return - } logWith := log.With(). Str("command", parsed.Command). @@ -122,31 +116,11 @@ func (proc *Processor[MetaType]) Process(ctx context.Context, evt *event.Event) } if proc.LogArgs { logWith = logWith.Strs("args", parsed.Args) - if parsed.StructuredArgs != nil { - logWith = logWith.RawJSON("structured_args", parsed.StructuredArgs) - } } log = logWith.Logger() parsed.Ctx = log.WithContext(ctx) parsed.Log = &log - if handler.Parameters != nil && parsed.StructuredArgs == nil { - // The handler wants structured parameters, but the client didn't send MSC4391 data - var err error - parsed.StructuredArgs, err = handler.Spec().ParseArguments(parsed.RawArgs) - if err != nil { - log.Debug().Err(err).Msg("Failed to parse structured arguments") - // TODO better error, usage info? deduplicate with WithParsedArgs - parsed.Reply("Failed to parse arguments: %v", err) - return - } - if proc.LogArgs { - log.UpdateContext(func(c zerolog.Context) zerolog.Context { - return c.RawJSON("structured_args", parsed.StructuredArgs) - }) - } - } - log.Debug().Msg("Processing command") handler.Func(parsed) } diff --git a/commands/reactions.go b/commands/reactions.go index 0d316219..0df372e5 100644 --- a/commands/reactions.go +++ b/commands/reactions.go @@ -1,4 +1,4 @@ -// Copyright (c) 2026 Tulir Asokan +// Copyright (c) 2025 Tulir Asokan // // This Source Code Form is subject to the terms of the Mozilla Public // License, v. 2.0. If a copy of the MPL was not distributed with this @@ -8,7 +8,6 @@ package commands import ( "context" - "encoding/json" "strings" "github.com/rs/zerolog" @@ -20,11 +19,6 @@ import ( const ReactionCommandsKey = "fi.mau.reaction_commands" const ReactionMultiUseKey = "fi.mau.reaction_multi_use" -type ReactionCommandData struct { - Command string `json:"command"` - Args any `json:"args,omitempty"` -} - func (proc *Processor[MetaType]) ParseReaction(ctx context.Context, evt *event.Event) *Event[MetaType] { content, ok := evt.Content.Parsed.(*event.ReactionEventContent) if !ok { @@ -73,33 +67,21 @@ func (proc *Processor[MetaType]) ParseReaction(ctx context.Context, evt *event.E Msg("Reaction command not found in target event") return nil } - var wrappedEvt *Event[MetaType] - switch typedCmd := rawCmd.(type) { - case string: - wrappedEvt = RawTextToEvent[MetaType](ctx, evt, typedCmd) - case map[string]any: - var input event.MSC4391BotCommandInput - if marshaled, err := json.Marshal(typedCmd); err != nil { - - } else if err = json.Unmarshal(marshaled, &input); err != nil { - - } else { - wrappedEvt = StructuredCommandToEvent[MetaType](ctx, evt, &input) - } - } - if wrappedEvt == nil { + cmdString, ok := rawCmd.(string) + if !ok { zerolog.Ctx(ctx).Debug(). Stringer("target_event_id", evtID). Str("reaction_key", content.RelatesTo.Key). Msg("Reaction command data is invalid") return nil } + wrappedEvt := RawTextToEvent[MetaType](ctx, evt, cmdString) wrappedEvt.Proc = proc wrappedEvt.Redact() if !isMultiUse { DeleteAllReactions(ctx, proc.Client, evt) } - if wrappedEvt.Command == "" { + if cmdString == "" { return nil } return wrappedEvt diff --git a/crypto/attachment/attachments.go b/crypto/attachment/attachments.go index 727aacbf..155cca5c 100644 --- a/crypto/attachment/attachments.go +++ b/crypto/attachment/attachments.go @@ -21,24 +21,13 @@ import ( ) var ( - ErrHashMismatch = errors.New("mismatching SHA-256 digest") - ErrUnsupportedVersion = errors.New("unsupported Matrix file encryption version") - ErrUnsupportedAlgorithm = errors.New("unsupported JWK encryption algorithm") - ErrInvalidKey = errors.New("failed to decode key") - ErrInvalidInitVector = errors.New("failed to decode initialization vector") - ErrInvalidHash = errors.New("failed to decode SHA-256 hash") - ErrReaderClosed = errors.New("encrypting reader was already closed") -) - -// Deprecated: use variables prefixed with Err -var ( - HashMismatch = ErrHashMismatch - UnsupportedVersion = ErrUnsupportedVersion - UnsupportedAlgorithm = ErrUnsupportedAlgorithm - InvalidKey = ErrInvalidKey - InvalidInitVector = ErrInvalidInitVector - InvalidHash = ErrInvalidHash - ReaderClosed = ErrReaderClosed + HashMismatch = errors.New("mismatching SHA-256 digest") + UnsupportedVersion = errors.New("unsupported Matrix file encryption version") + UnsupportedAlgorithm = errors.New("unsupported JWK encryption algorithm") + InvalidKey = errors.New("failed to decode key") + InvalidInitVector = errors.New("failed to decode initialization vector") + InvalidHash = errors.New("failed to decode SHA-256 hash") + ReaderClosed = errors.New("encrypting reader was already closed") ) var ( @@ -96,25 +85,25 @@ func (ef *EncryptedFile) decodeKeys(includeHash bool) error { if ef.decoded != nil { return nil } else if len(ef.Key.Key) != keyBase64Length { - return ErrInvalidKey + return InvalidKey } else if len(ef.InitVector) != ivBase64Length { - return ErrInvalidInitVector + return InvalidInitVector } else if includeHash && len(ef.Hashes.SHA256) != hashBase64Length { - return ErrInvalidHash + return InvalidHash } ef.decoded = &decodedKeys{} _, err := base64.RawURLEncoding.Decode(ef.decoded.key[:], []byte(ef.Key.Key)) if err != nil { - return ErrInvalidKey + return InvalidKey } _, err = base64.RawStdEncoding.Decode(ef.decoded.iv[:], []byte(ef.InitVector)) if err != nil { - return ErrInvalidInitVector + return InvalidInitVector } if includeHash { _, err = base64.RawStdEncoding.Decode(ef.decoded.sha256[:], []byte(ef.Hashes.SHA256)) if err != nil { - return ErrInvalidHash + return InvalidHash } } return nil @@ -190,7 +179,7 @@ var _ io.ReadSeekCloser = (*encryptingReader)(nil) func (r *encryptingReader) Seek(offset int64, whence int) (int64, error) { if r.closed { - return 0, ErrReaderClosed + return 0, ReaderClosed } if offset != 0 || whence != io.SeekStart { return 0, fmt.Errorf("attachments.EncryptStream: only seeking to the beginning is supported") @@ -211,7 +200,7 @@ func (r *encryptingReader) Seek(offset int64, whence int) (int64, error) { func (r *encryptingReader) Read(dst []byte) (n int, err error) { if r.closed { - return 0, ErrReaderClosed + return 0, ReaderClosed } else if r.isDecrypting && r.file.decoded == nil { if err = r.file.PrepareForDecryption(); err != nil { return @@ -235,7 +224,7 @@ func (r *encryptingReader) Close() (err error) { } if r.isDecrypting { if !hmac.Equal(r.hash.Sum(nil), r.file.decoded.sha256[:]) { - return ErrHashMismatch + return HashMismatch } } else { r.file.Hashes.SHA256 = base64.RawStdEncoding.EncodeToString(r.hash.Sum(nil)) @@ -276,9 +265,9 @@ func (ef *EncryptedFile) Decrypt(ciphertext []byte) ([]byte, error) { // DecryptInPlace will always call this automatically, so calling this manually is not necessary when using that function. func (ef *EncryptedFile) PrepareForDecryption() error { if ef.Version != "v2" { - return ErrUnsupportedVersion + return UnsupportedVersion } else if ef.Key.Algorithm != "A256CTR" { - return ErrUnsupportedAlgorithm + return UnsupportedAlgorithm } else if err := ef.decodeKeys(true); err != nil { return err } @@ -292,7 +281,7 @@ func (ef *EncryptedFile) DecryptInPlace(data []byte) error { } dataHash := sha256.Sum256(data) if !hmac.Equal(ef.decoded.sha256[:], dataHash[:]) { - return ErrHashMismatch + return HashMismatch } utils.XorA256CTR(data, ef.decoded.key, ef.decoded.iv) return nil diff --git a/crypto/attachment/attachments_test.go b/crypto/attachment/attachments_test.go index 9fe929ab..d7f1394a 100644 --- a/crypto/attachment/attachments_test.go +++ b/crypto/attachment/attachments_test.go @@ -53,33 +53,33 @@ func TestUnsupportedVersion(t *testing.T) { file := parseHelloWorld() file.Version = "foo" err := file.DecryptInPlace([]byte(helloWorldCiphertext)) - assert.ErrorIs(t, err, ErrUnsupportedVersion) + assert.ErrorIs(t, err, UnsupportedVersion) } func TestUnsupportedAlgorithm(t *testing.T) { file := parseHelloWorld() file.Key.Algorithm = "bar" err := file.DecryptInPlace([]byte(helloWorldCiphertext)) - assert.ErrorIs(t, err, ErrUnsupportedAlgorithm) + assert.ErrorIs(t, err, UnsupportedAlgorithm) } func TestHashMismatch(t *testing.T) { file := parseHelloWorld() file.Hashes.SHA256 = base64.RawStdEncoding.EncodeToString([]byte(random32Bytes)) err := file.DecryptInPlace([]byte(helloWorldCiphertext)) - assert.ErrorIs(t, err, ErrHashMismatch) + assert.ErrorIs(t, err, HashMismatch) } func TestTooLongHash(t *testing.T) { file := parseHelloWorld() file.Hashes.SHA256 = "TG9yZW0gaXBzdW0gZG9sb3Igc2l0IGFtZXQsIGNvbnNlY3RldHVlciBhZGlwaXNjaW5nIGVsaXQuIFNlZCBwb3N1ZXJlIGludGVyZHVtIHNlbS4gUXVpc3F1ZSBsaWd1bGEgZXJvcyB1bGxhbWNvcnBlciBxdWlzLCBsYWNpbmlhIHF1aXMgZmFjaWxpc2lzIHNlZCBzYXBpZW4uCg" err := file.DecryptInPlace([]byte(helloWorldCiphertext)) - assert.ErrorIs(t, err, ErrInvalidHash) + assert.ErrorIs(t, err, InvalidHash) } func TestTooShortHash(t *testing.T) { file := parseHelloWorld() file.Hashes.SHA256 = "5/Gy1JftyyQ" err := file.DecryptInPlace([]byte(helloWorldCiphertext)) - assert.ErrorIs(t, err, ErrInvalidHash) + assert.ErrorIs(t, err, InvalidHash) } diff --git a/crypto/cross_sign_key.go b/crypto/cross_sign_key.go index 5d9bf5b3..4094f695 100644 --- a/crypto/cross_sign_key.go +++ b/crypto/cross_sign_key.go @@ -135,7 +135,7 @@ func (mach *OlmMachine) PublishCrossSigningKeys(ctx context.Context, keys *Cross } userKey.Signatures = signatures.NewSingleSignature(userID, id.KeyAlgorithmEd25519, keys.MasterKey.PublicKey().String(), userSig) - err = mach.Client.UploadCrossSigningKeys(ctx, &mautrix.UploadCrossSigningKeysReq[any]{ + err = mach.Client.UploadCrossSigningKeys(ctx, &mautrix.UploadCrossSigningKeysReq{ Master: masterKey, SelfSigning: selfKey, UserSigning: userKey, diff --git a/crypto/cross_sign_pubkey.go b/crypto/cross_sign_pubkey.go index 223fc7b5..77efab5b 100644 --- a/crypto/cross_sign_pubkey.go +++ b/crypto/cross_sign_pubkey.go @@ -20,20 +20,6 @@ type CrossSigningPublicKeysCache struct { UserSigningKey id.Ed25519 } -func (mach *OlmMachine) GetOwnVerificationStatus(ctx context.Context) (hasKeys, isVerified bool, err error) { - pubkeys := mach.GetOwnCrossSigningPublicKeys(ctx) - if pubkeys != nil { - hasKeys = true - isVerified, err = mach.CryptoStore.IsKeySignedBy( - ctx, mach.Client.UserID, mach.GetAccount().SigningKey(), mach.Client.UserID, pubkeys.SelfSigningKey, - ) - if err != nil { - err = fmt.Errorf("failed to check if current device is signed by own self-signing key: %w", err) - } - } - return -} - func (mach *OlmMachine) GetOwnCrossSigningPublicKeys(ctx context.Context) *CrossSigningPublicKeysCache { if mach.crossSigningPubkeys != nil { return mach.crossSigningPubkeys @@ -63,8 +49,8 @@ func (mach *OlmMachine) GetCrossSigningPublicKeys(ctx context.Context, userID id if len(dbKeys) > 0 { masterKey, ok := dbKeys[id.XSUsageMaster] if ok { - selfSigning := dbKeys[id.XSUsageSelfSigning] - userSigning := dbKeys[id.XSUsageUserSigning] + selfSigning, _ := dbKeys[id.XSUsageSelfSigning] + userSigning, _ := dbKeys[id.XSUsageUserSigning] return &CrossSigningPublicKeysCache{ MasterKey: masterKey.Key, SelfSigningKey: selfSigning.Key, diff --git a/crypto/cross_sign_ssss.go b/crypto/cross_sign_ssss.go index fd42880d..389a9fd2 100644 --- a/crypto/cross_sign_ssss.go +++ b/crypto/cross_sign_ssss.go @@ -8,7 +8,6 @@ package crypto import ( "context" - "errors" "fmt" "maunium.net/go/mautrix" @@ -72,46 +71,6 @@ func (mach *OlmMachine) GenerateAndUploadCrossSigningKeysWithPassword(ctx contex }, passphrase) } -func (mach *OlmMachine) VerifyWithRecoveryKey(ctx context.Context, recoveryKey string) error { - keyID, keyData, err := mach.SSSS.GetDefaultKeyData(ctx) - if err != nil { - return fmt.Errorf("failed to get default SSSS key data: %w", err) - } - key, err := keyData.VerifyRecoveryKey(keyID, recoveryKey) - if errors.Is(err, ssss.ErrUnverifiableKey) { - mach.machOrContextLog(ctx).Warn(). - Str("key_id", keyID). - Msg("SSSS key is unverifiable, trying to use without verifying") - } else if err != nil { - return err - } - err = mach.FetchCrossSigningKeysFromSSSS(ctx, key) - if err != nil { - return fmt.Errorf("failed to fetch cross-signing keys from SSSS: %w", err) - } - err = mach.SignOwnDevice(ctx, mach.OwnIdentity()) - if err != nil { - return fmt.Errorf("failed to sign own device: %w", err) - } - err = mach.SignOwnMasterKey(ctx) - if err != nil { - return fmt.Errorf("failed to sign own master key: %w", err) - } - return nil -} - -func (mach *OlmMachine) GenerateAndVerifyWithRecoveryKey(ctx context.Context) (recoveryKey string, err error) { - recoveryKey, _, err = mach.GenerateAndUploadCrossSigningKeys(ctx, nil, "") - if err != nil { - err = fmt.Errorf("failed to generate and upload cross-signing keys: %w", err) - } else if err = mach.SignOwnDevice(ctx, mach.OwnIdentity()); err != nil { - err = fmt.Errorf("failed to sign own device: %w", err) - } else if err = mach.SignOwnMasterKey(ctx); err != nil { - err = fmt.Errorf("failed to sign own master key: %w", err) - } - return -} - // GenerateAndUploadCrossSigningKeys generates a new key with all corresponding cross-signing keys. // // A passphrase can be provided to generate the SSSS key. If the passphrase is empty, a random key @@ -138,12 +97,12 @@ func (mach *OlmMachine) GenerateAndUploadCrossSigningKeys(ctx context.Context, u // Publish cross-signing keys err = mach.PublishCrossSigningKeys(ctx, keysCache, uiaCallback) if err != nil { - return key.RecoveryKey(), keysCache, fmt.Errorf("failed to publish cross-signing keys: %w", err) + return "", nil, fmt.Errorf("failed to publish cross-signing keys: %w", err) } err = mach.SSSS.SetDefaultKeyID(ctx, key.ID) if err != nil { - return key.RecoveryKey(), keysCache, fmt.Errorf("failed to mark %s as the default key: %w", key.ID, err) + return "", nil, fmt.Errorf("failed to mark %s as the default key: %w", key.ID, err) } return key.RecoveryKey(), keysCache, nil diff --git a/crypto/cross_sign_store.go b/crypto/cross_sign_store.go index 57406b11..b583bada 100644 --- a/crypto/cross_sign_store.go +++ b/crypto/cross_sign_store.go @@ -20,34 +20,36 @@ import ( func (mach *OlmMachine) storeCrossSigningKeys(ctx context.Context, crossSigningKeys map[id.UserID]mautrix.CrossSigningKeys, deviceKeys map[id.UserID]map[id.DeviceID]mautrix.DeviceKeys) { log := mach.machOrContextLog(ctx) for userID, userKeys := range crossSigningKeys { - log := log.With().Stringer("user_id", userID).Logger() + log := log.With().Str("user_id", userID.String()).Logger() currentKeys, err := mach.CryptoStore.GetCrossSigningKeys(ctx, userID) if err != nil { log.Error().Err(err). Msg("Error fetching current cross-signing keys of user") } - for curKeyUsage, curKey := range currentKeys { - log := log.With().Stringer("old_key", curKey.Key).Str("old_key_usage", string(curKeyUsage)).Logger() - // got a new key with the same usage as an existing key - for _, newKeyUsage := range userKeys.Usage { - if newKeyUsage == curKeyUsage { - if _, ok := userKeys.Keys[id.NewKeyID(id.KeyAlgorithmEd25519, curKey.Key.String())]; !ok { - // old key is not in the new key map, so we drop signatures made by it - if count, err := mach.CryptoStore.DropSignaturesByKey(ctx, userID, curKey.Key); err != nil { - log.Error().Err(err).Msg("Error deleting old signatures made by user") - } else { - log.Debug(). - Int64("signature_count", count). - Msg("Dropped signatures made by old key as it has been replaced") + if currentKeys != nil { + for curKeyUsage, curKey := range currentKeys { + log := log.With().Str("old_key", curKey.Key.String()).Str("old_key_usage", string(curKeyUsage)).Logger() + // got a new key with the same usage as an existing key + for _, newKeyUsage := range userKeys.Usage { + if newKeyUsage == curKeyUsage { + if _, ok := userKeys.Keys[id.NewKeyID(id.KeyAlgorithmEd25519, curKey.Key.String())]; !ok { + // old key is not in the new key map, so we drop signatures made by it + if count, err := mach.CryptoStore.DropSignaturesByKey(ctx, userID, curKey.Key); err != nil { + log.Error().Err(err).Msg("Error deleting old signatures made by user") + } else { + log.Debug(). + Int64("signature_count", count). + Msg("Dropped signatures made by old key as it has been replaced") + } } + break } - break } } } for _, key := range userKeys.Keys { - log := log.With().Stringer("key", key).Array("usages", exzerolog.ArrayOfStrs(userKeys.Usage)).Logger() + log := log.With().Str("key", key.String()).Array("usages", exzerolog.ArrayOfStrs(userKeys.Usage)).Logger() for _, usage := range userKeys.Usage { log.Trace().Str("usage", string(usage)).Msg("Storing cross-signing key") if err = mach.CryptoStore.PutCrossSigningKey(ctx, userID, usage, key); err != nil { diff --git a/crypto/cryptohelper/cryptohelper.go b/crypto/cryptohelper/cryptohelper.go index b62dc128..56f8b484 100644 --- a/crypto/cryptohelper/cryptohelper.go +++ b/crypto/cryptohelper/cryptohelper.go @@ -225,6 +225,13 @@ func (helper *CryptoHelper) Init(ctx context.Context) error { helper.ASEventProcessor.On(event.EventEncrypted, helper.HandleEncrypted) } + if helper.client.SetAppServiceDeviceID { + err = helper.mach.ShareKeys(ctx, -1) + if err != nil { + return fmt.Errorf("failed to share keys: %w", err) + } + } + return nil } @@ -261,24 +268,24 @@ func (helper *CryptoHelper) verifyDeviceKeysOnServer(ctx context.Context) error if !ok || len(device.Keys) == 0 { if isShared { return fmt.Errorf("olm account is marked as shared, keys seem to have disappeared from the server") + } else { + helper.log.Debug().Msg("Olm account not shared and keys not on server, so device is probably fine") + return nil } - helper.log.Debug().Msg("Olm account not shared and keys not on server, sharing initial keys") - err = helper.mach.ShareKeys(ctx, -1) - if err != nil { - return fmt.Errorf("failed to share keys: %w", err) - } - return nil } else if !isShared { return fmt.Errorf("olm account is not marked as shared, but there are keys on the server") } else if ed := device.Keys.GetEd25519(helper.client.DeviceID); ownID.SigningKey != ed { return fmt.Errorf("mismatching identity key on server (%q != %q)", ownID.SigningKey, ed) + } + if !isShared { + helper.log.Debug().Msg("Olm account not marked as shared, but keys on server match?") } else { helper.log.Debug().Msg("Olm account marked as shared and keys on server match, device is fine") - return nil } + return nil } -var NoSessionFound = crypto.ErrNoSessionFound +var NoSessionFound = crypto.NoSessionFound const initialSessionWaitTimeout = 3 * time.Second const extendedSessionWaitTimeout = 22 * time.Second @@ -297,14 +304,24 @@ func (helper *CryptoHelper) HandleEncrypted(ctx context.Context, evt *event.Even ctx = log.WithContext(ctx) decrypted, err := helper.Decrypt(ctx, evt) - if errors.Is(err, NoSessionFound) && ctx.Value(mautrix.SyncTokenContextKey) != "" { - go helper.waitForSession(ctx, evt) - } else if err != nil { + if errors.Is(err, NoSessionFound) { + log.Debug(). + Int("wait_seconds", int(initialSessionWaitTimeout.Seconds())). + Msg("Couldn't find session, waiting for keys to arrive...") + if helper.mach.WaitForSession(ctx, evt.RoomID, content.SenderKey, content.SessionID, initialSessionWaitTimeout) { + log.Debug().Msg("Got keys after waiting, trying to decrypt event again") + decrypted, err = helper.Decrypt(ctx, evt) + } else { + go helper.waitLongerForSession(ctx, log, evt) + return + } + } + if err != nil { log.Warn().Err(err).Msg("Failed to decrypt event") helper.DecryptErrorCallback(evt, err) - } else { - helper.postDecrypt(ctx, decrypted) + return } + helper.postDecrypt(ctx, decrypted) } func (helper *CryptoHelper) postDecrypt(ctx context.Context, decrypted *event.Event) { @@ -345,33 +362,10 @@ func (helper *CryptoHelper) RequestSession(ctx context.Context, roomID id.RoomID } } -func (helper *CryptoHelper) waitForSession(ctx context.Context, evt *event.Event) { - log := zerolog.Ctx(ctx) - content := evt.Content.AsEncrypted() - - log.Debug(). - Int("wait_seconds", int(initialSessionWaitTimeout.Seconds())). - Msg("Couldn't find session, waiting for keys to arrive...") - if helper.mach.WaitForSession(ctx, evt.RoomID, content.SenderKey, content.SessionID, initialSessionWaitTimeout) { - log.Debug().Msg("Got keys after waiting, trying to decrypt event again") - decrypted, err := helper.Decrypt(ctx, evt) - if err != nil { - log.Warn().Err(err).Msg("Failed to decrypt event") - helper.DecryptErrorCallback(evt, err) - } else { - helper.postDecrypt(ctx, decrypted) - } - } else { - go helper.waitLongerForSession(ctx, evt) - } -} - -func (helper *CryptoHelper) waitLongerForSession(ctx context.Context, evt *event.Event) { - log := zerolog.Ctx(ctx) +func (helper *CryptoHelper) waitLongerForSession(ctx context.Context, log zerolog.Logger, evt *event.Event) { content := evt.Content.AsEncrypted() log.Debug().Int("wait_seconds", int(extendedSessionWaitTimeout.Seconds())).Msg("Couldn't find session, requesting keys and waiting longer...") - //lint:ignore SA1019 RequestSession will gracefully request from all devices if DeviceID is blank go helper.RequestSession(context.TODO(), evt.RoomID, content.SenderKey, content.SessionID, evt.Sender, content.DeviceID) if !helper.mach.WaitForSession(ctx, evt.RoomID, content.SenderKey, content.SessionID, extendedSessionWaitTimeout) { @@ -419,7 +413,7 @@ func (helper *CryptoHelper) EncryptWithStateKey(ctx context.Context, roomID id.R defer helper.lock.RUnlock() encrypted, err = helper.mach.EncryptMegolmEventWithStateKey(ctx, roomID, evtType, stateKey, content) if err != nil { - if !errors.Is(err, crypto.ErrSessionExpired) && err != crypto.ErrNoGroupSession && !errors.Is(err, crypto.ErrSessionNotShared) { + if !errors.Is(err, crypto.SessionExpired) && err != crypto.NoGroupSession && !errors.Is(err, crypto.SessionNotShared) { return } helper.log.Debug(). diff --git a/crypto/decryptmegolm.go b/crypto/decryptmegolm.go index 457d5a0c..47279474 100644 --- a/crypto/decryptmegolm.go +++ b/crypto/decryptmegolm.go @@ -24,23 +24,13 @@ import ( ) var ( - ErrIncorrectEncryptedContentType = errors.New("event content is not instance of *event.EncryptedEventContent") - ErrNoSessionFound = errors.New("failed to decrypt megolm event: no session with given ID found") - ErrDuplicateMessageIndex = errors.New("duplicate megolm message index") - ErrWrongRoom = errors.New("encrypted megolm event is not intended for this room") - ErrDeviceKeyMismatch = errors.New("device keys in event and verified device info do not match") - ErrRatchetError = errors.New("failed to ratchet session after use") - ErrCorruptedMegolmPayload = errors.New("corrupted megolm payload") -) - -// Deprecated: use variables prefixed with Err -var ( - IncorrectEncryptedContentType = ErrIncorrectEncryptedContentType - NoSessionFound = ErrNoSessionFound - DuplicateMessageIndex = ErrDuplicateMessageIndex - WrongRoom = ErrWrongRoom - DeviceKeyMismatch = ErrDeviceKeyMismatch - RatchetError = ErrRatchetError + IncorrectEncryptedContentType = errors.New("event content is not instance of *event.EncryptedEventContent") + NoSessionFound = errors.New("failed to decrypt megolm event: no session with given ID found") + DuplicateMessageIndex = errors.New("duplicate megolm message index") + WrongRoom = errors.New("encrypted megolm event is not intended for this room") + DeviceKeyMismatch = errors.New("device keys in event and verified device info do not match") + SenderKeyMismatch = errors.New("sender keys in content and megolm session do not match") + RatchetError = errors.New("failed to ratchet session after use") ) type megolmEvent struct { @@ -55,30 +45,13 @@ var ( relatesToTopLevelPath = exgjson.Path("content", "m.relates_to") ) -const sessionIDLength = 43 - -func validateCiphertextCharacters(ciphertext []byte) bool { - for _, b := range ciphertext { - if (b < 'a' || b > 'z') && (b < 'A' || b > 'Z') && (b < '0' || b > '9') && b != '+' && b != '/' { - return false - } - } - return true -} - // DecryptMegolmEvent decrypts an m.room.encrypted event where the algorithm is m.megolm.v1.aes-sha2 func (mach *OlmMachine) DecryptMegolmEvent(ctx context.Context, evt *event.Event) (*event.Event, error) { content, ok := evt.Content.Parsed.(*event.EncryptedEventContent) if !ok { - return nil, ErrIncorrectEncryptedContentType + return nil, IncorrectEncryptedContentType } else if content.Algorithm != id.AlgorithmMegolmV1 { - return nil, ErrUnsupportedAlgorithm - } else if len(content.MegolmCiphertext) < 74 { - return nil, fmt.Errorf("%w: ciphertext too short (%d bytes)", ErrCorruptedMegolmPayload, len(content.MegolmCiphertext)) - } else if len(content.SessionID) != sessionIDLength { - return nil, fmt.Errorf("%w: invalid session ID length %d", ErrCorruptedMegolmPayload, len(content.SessionID)) - } else if !validateCiphertextCharacters(content.MegolmCiphertext) { - return nil, fmt.Errorf("%w: invalid characters in ciphertext", ErrCorruptedMegolmPayload) + return nil, UnsupportedAlgorithm } log := mach.machOrContextLog(ctx).With(). Str("action", "decrypt megolm event"). @@ -124,13 +97,7 @@ func (mach *OlmMachine) DecryptMegolmEvent(ctx context.Context, evt *event.Event Msg("Couldn't resolve trust level of session: sent by unknown device") trustLevel = id.TrustStateUnknownDevice } else if device.SigningKey != sess.SigningKey || device.IdentityKey != sess.SenderKey { - log.Debug(). - Stringer("session_sender_key", sess.SenderKey). - Stringer("device_sender_key", device.IdentityKey). - Stringer("session_signing_key", sess.SigningKey). - Stringer("device_signing_key", device.SigningKey). - Msg("Device keys don't match keys in session, marking as untrusted") - trustLevel = id.TrustStateDeviceKeyMismatch + return nil, DeviceKeyMismatch } else { trustLevel, err = mach.ResolveTrustContext(ctx, device) if err != nil { @@ -180,7 +147,7 @@ func (mach *OlmMachine) DecryptMegolmEvent(ctx context.Context, evt *event.Event if err != nil { return nil, fmt.Errorf("failed to parse megolm payload: %w", err) } else if megolmEvt.RoomID != encryptionRoomID { - return nil, ErrWrongRoom + return nil, WrongRoom } if evt.StateKey != nil && megolmEvt.StateKey != nil && mach.AllowEncryptedState { megolmEvt.Type.Class = event.StateEventType @@ -213,7 +180,6 @@ func (mach *OlmMachine) DecryptMegolmEvent(ctx context.Context, evt *event.Event TrustSource: device, ForwardedKeys: forwardedKeys, WasEncrypted: true, - EventSource: evt.Mautrix.EventSource | event.SourceDecrypted, ReceivedAt: evt.Mautrix.ReceivedAt, }, }, nil @@ -235,19 +201,19 @@ func (mach *OlmMachine) checkUndecryptableMessageIndexDuplication(ctx context.Co messageIndex, decodeErr := ParseMegolmMessageIndex(content.MegolmCiphertext) if decodeErr != nil { log.Warn().Err(decodeErr).Msg("Failed to parse message index to check if it's a duplicate for message that failed to decrypt") - return 0, fmt.Errorf("%w (also failed to parse message index)", olm.ErrUnknownMessageIndex) + return 0, fmt.Errorf("%w (also failed to parse message index)", olm.UnknownMessageIndex) } firstKnown := sess.Internal.FirstKnownIndex() log = log.With().Uint("message_index", messageIndex).Uint32("first_known_index", firstKnown).Logger() if ok, err := mach.CryptoStore.ValidateMessageIndex(ctx, sess.SenderKey, content.SessionID, evt.ID, messageIndex, evt.Timestamp); err != nil { log.Debug().Err(err).Msg("Failed to check if message index is duplicate") - return messageIndex, fmt.Errorf("%w (failed to check if index is duplicate; received: %d, earliest known: %d)", olm.ErrUnknownMessageIndex, messageIndex, firstKnown) + return messageIndex, fmt.Errorf("%w (failed to check if index is duplicate; received: %d, earliest known: %d)", olm.UnknownMessageIndex, messageIndex, firstKnown) } else if !ok { log.Debug().Msg("Failed to decrypt message due to unknown index and found duplicate") - return messageIndex, fmt.Errorf("%w %d (also failed to decrypt because earliest known index is %d)", ErrDuplicateMessageIndex, messageIndex, firstKnown) + return messageIndex, fmt.Errorf("%w %d (also failed to decrypt because earliest known index is %d)", DuplicateMessageIndex, messageIndex, firstKnown) } log.Debug().Msg("Failed to decrypt message due to unknown index, but index is not duplicate") - return messageIndex, fmt.Errorf("%w (not duplicate index; received: %d, earliest known: %d)", olm.ErrUnknownMessageIndex, messageIndex, firstKnown) + return messageIndex, fmt.Errorf("%w (not duplicate index; received: %d, earliest known: %d)", olm.UnknownMessageIndex, messageIndex, firstKnown) } func (mach *OlmMachine) actuallyDecryptMegolmEvent(ctx context.Context, evt *event.Event, encryptionRoomID id.RoomID, content *event.EncryptedEventContent) (*InboundGroupSession, []byte, uint, error) { @@ -258,11 +224,13 @@ func (mach *OlmMachine) actuallyDecryptMegolmEvent(ctx context.Context, evt *eve if err != nil { return nil, nil, 0, fmt.Errorf("failed to get group session: %w", err) } else if sess == nil { - return nil, nil, 0, fmt.Errorf("%w (ID %s)", ErrNoSessionFound, content.SessionID) + return nil, nil, 0, fmt.Errorf("%w (ID %s)", NoSessionFound, content.SessionID) + } else if content.SenderKey != "" && content.SenderKey != sess.SenderKey { + return sess, nil, 0, SenderKeyMismatch } plaintext, messageIndex, err := sess.Internal.Decrypt(content.MegolmCiphertext) if err != nil { - if errors.Is(err, olm.ErrUnknownMessageIndex) && mach.RatchetKeysOnDecrypt { + if errors.Is(err, olm.UnknownMessageIndex) && mach.RatchetKeysOnDecrypt { messageIndex, err = mach.checkUndecryptableMessageIndexDuplication(ctx, sess, evt, content) return sess, nil, messageIndex, fmt.Errorf("failed to decrypt megolm event: %w", err) } @@ -270,7 +238,7 @@ func (mach *OlmMachine) actuallyDecryptMegolmEvent(ctx context.Context, evt *eve } else if ok, err := mach.CryptoStore.ValidateMessageIndex(ctx, sess.SenderKey, content.SessionID, evt.ID, messageIndex, evt.Timestamp); err != nil { return sess, nil, messageIndex, fmt.Errorf("failed to check if message index is duplicate: %w", err) } else if !ok { - return sess, nil, messageIndex, fmt.Errorf("%w %d", ErrDuplicateMessageIndex, messageIndex) + return sess, nil, messageIndex, fmt.Errorf("%w %d", DuplicateMessageIndex, messageIndex) } // Normal clients don't care about tracking the ratchet state, so let them bypass the rest of the function @@ -322,24 +290,24 @@ func (mach *OlmMachine) actuallyDecryptMegolmEvent(ctx context.Context, evt *eve err = mach.CryptoStore.RedactGroupSession(ctx, sess.RoomID, sess.ID(), "maximum messages reached") if err != nil { log.Err(err).Msg("Failed to delete fully used session") - return sess, plaintext, messageIndex, ErrRatchetError + return sess, plaintext, messageIndex, RatchetError } else { log.Info().Msg("Deleted fully used session") } } else if ratchetCurrentIndex < ratchetTargetIndex && mach.RatchetKeysOnDecrypt { if err = sess.RatchetTo(ratchetTargetIndex); err != nil { log.Err(err).Msg("Failed to ratchet session") - return sess, plaintext, messageIndex, ErrRatchetError + return sess, plaintext, messageIndex, RatchetError } else if err = mach.CryptoStore.PutGroupSession(ctx, sess); err != nil { log.Err(err).Msg("Failed to store ratcheted session") - return sess, plaintext, messageIndex, ErrRatchetError + return sess, plaintext, messageIndex, RatchetError } else { log.Info().Msg("Ratcheted session forward") } } else if didModify { if err = mach.CryptoStore.PutGroupSession(ctx, sess); err != nil { log.Err(err).Msg("Failed to store updated ratchet safety data") - return sess, plaintext, messageIndex, ErrRatchetError + return sess, plaintext, messageIndex, RatchetError } else { log.Debug().Msg("Ratchet safety data changed (ratchet state didn't change)") } diff --git a/crypto/decryptolm.go b/crypto/decryptolm.go index aea5e6dc..b737e4e1 100644 --- a/crypto/decryptolm.go +++ b/crypto/decryptolm.go @@ -17,36 +17,21 @@ import ( "time" "github.com/rs/zerolog" - "go.mau.fi/util/exerrors" - "go.mau.fi/util/ptr" - "maunium.net/go/mautrix/crypto/goolm/account" "maunium.net/go/mautrix/event" "maunium.net/go/mautrix/id" ) var ( - ErrUnsupportedAlgorithm = errors.New("unsupported event encryption algorithm") - ErrNotEncryptedForMe = errors.New("olm event doesn't contain ciphertext for this device") - ErrUnsupportedOlmMessageType = errors.New("unsupported olm message type") - ErrDecryptionFailedWithMatchingSession = errors.New("decryption failed with matching session") - ErrDecryptionFailedForNormalMessage = errors.New("decryption failed for normal message") - ErrSenderMismatch = errors.New("mismatched sender in olm payload") - ErrRecipientMismatch = errors.New("mismatched recipient in olm payload") - ErrRecipientKeyMismatch = errors.New("mismatched recipient key in olm payload") - ErrDuplicateMessage = errors.New("duplicate olm message") -) - -// Deprecated: use variables prefixed with Err -var ( - UnsupportedAlgorithm = ErrUnsupportedAlgorithm - NotEncryptedForMe = ErrNotEncryptedForMe - UnsupportedOlmMessageType = ErrUnsupportedOlmMessageType - DecryptionFailedWithMatchingSession = ErrDecryptionFailedWithMatchingSession - DecryptionFailedForNormalMessage = ErrDecryptionFailedForNormalMessage - SenderMismatch = ErrSenderMismatch - RecipientMismatch = ErrRecipientMismatch - RecipientKeyMismatch = ErrRecipientKeyMismatch + UnsupportedAlgorithm = errors.New("unsupported event encryption algorithm") + NotEncryptedForMe = errors.New("olm event doesn't contain ciphertext for this device") + UnsupportedOlmMessageType = errors.New("unsupported olm message type") + DecryptionFailedWithMatchingSession = errors.New("decryption failed with matching session") + DecryptionFailedForNormalMessage = errors.New("decryption failed for normal message") + SenderMismatch = errors.New("mismatched sender in olm payload") + RecipientMismatch = errors.New("mismatched recipient in olm payload") + RecipientKeyMismatch = errors.New("mismatched recipient key in olm payload") + ErrDuplicateMessage = errors.New("duplicate olm message") ) // DecryptedOlmEvent represents an event that was decrypted from an event encrypted with the m.olm.v1.curve25519-aes-sha2 algorithm. @@ -68,13 +53,13 @@ type DecryptedOlmEvent struct { func (mach *OlmMachine) decryptOlmEvent(ctx context.Context, evt *event.Event) (*DecryptedOlmEvent, error) { content, ok := evt.Content.Parsed.(*event.EncryptedEventContent) if !ok { - return nil, ErrIncorrectEncryptedContentType + return nil, IncorrectEncryptedContentType } else if content.Algorithm != id.AlgorithmOlmV1 { - return nil, ErrUnsupportedAlgorithm + return nil, UnsupportedAlgorithm } ownContent, ok := content.OlmCiphertext[mach.account.IdentityKey()] if !ok { - return nil, ErrNotEncryptedForMe + return nil, NotEncryptedForMe } decrypted, err := mach.decryptAndParseOlmCiphertext(ctx, evt, content.SenderKey, ownContent.Type, ownContent.Body) if err != nil { @@ -90,7 +75,7 @@ type OlmEventKeys struct { func (mach *OlmMachine) decryptAndParseOlmCiphertext(ctx context.Context, evt *event.Event, senderKey id.SenderKey, olmType id.OlmMsgType, ciphertext string) (*DecryptedOlmEvent, error) { if olmType != id.OlmMsgTypePreKey && olmType != id.OlmMsgTypeMsg { - return nil, ErrUnsupportedOlmMessageType + return nil, UnsupportedOlmMessageType } log := mach.machOrContextLog(ctx).With(). @@ -114,11 +99,11 @@ func (mach *OlmMachine) decryptAndParseOlmCiphertext(ctx context.Context, evt *e } olmEvt.Type.Class = evt.Type.Class if evt.Sender != olmEvt.Sender { - return nil, ErrSenderMismatch + return nil, SenderMismatch } else if mach.Client.UserID != olmEvt.Recipient { - return nil, ErrRecipientMismatch + return nil, RecipientMismatch } else if mach.account.SigningKey() != olmEvt.RecipientKeys.Ed25519 { - return nil, ErrRecipientKeyMismatch + return nil, RecipientKeyMismatch } if len(olmEvt.Content.VeryRaw) > 0 { @@ -134,9 +119,6 @@ func (mach *OlmMachine) decryptAndParseOlmCiphertext(ctx context.Context, evt *e } func olmMessageHash(ciphertext string) ([32]byte, error) { - if ciphertext == "" { - return [32]byte{}, fmt.Errorf("empty ciphertext") - } ciphertextBytes, err := base64.RawStdEncoding.DecodeString(ciphertext) return sha256.Sum256(ciphertextBytes), err } @@ -166,7 +148,7 @@ func (mach *OlmMachine) tryDecryptOlmCiphertext(ctx context.Context, sender id.U plaintext, err := mach.tryDecryptOlmCiphertextWithExistingSession(ctx, senderKey, olmType, ciphertext, ciphertextHash) if err != nil { - if err == ErrDecryptionFailedWithMatchingSession { + if err == DecryptionFailedWithMatchingSession { log.Warn().Msg("Found matching session, but decryption failed") go mach.unwedgeDevice(log, sender, senderKey) } @@ -184,10 +166,9 @@ func (mach *OlmMachine) tryDecryptOlmCiphertext(ctx context.Context, sender id.U // if it isn't one at this point in time anymore, so return early. if olmType != id.OlmMsgTypePreKey { go mach.unwedgeDevice(log, sender, senderKey) - return nil, ErrDecryptionFailedForNormalMessage + return nil, DecryptionFailedForNormalMessage } - accountBackup, _ := mach.account.Internal.Pickle([]byte("tmp")) log.Trace().Msg("Trying to create inbound session") endTimeTrace = mach.timeTrace(ctx, "creating inbound olm session", time.Second) session, err := mach.createInboundSession(ctx, senderKey, ciphertext) @@ -199,7 +180,6 @@ func (mach *OlmMachine) tryDecryptOlmCiphertext(ctx context.Context, sender id.U log = log.With().Str("new_olm_session_id", session.ID().String()).Logger() log.Debug(). Hex("ciphertext_hash", ciphertextHash[:]). - Hex("ciphertext_hash_repeat", ptr.Ptr(exerrors.Must(olmMessageHash(ciphertext)))[:]). Str("olm_session_description", session.Describe()). Msg("Created inbound olm session") ctx = log.WithContext(ctx) @@ -208,19 +188,6 @@ func (mach *OlmMachine) tryDecryptOlmCiphertext(ctx context.Context, sender id.U plaintext, err = session.Decrypt(ciphertext, olmType) endTimeTrace() if err != nil { - log.Debug(). - Hex("ciphertext_hash", ciphertextHash[:]). - Hex("ciphertext_hash_repeat", ptr.Ptr(exerrors.Must(olmMessageHash(ciphertext)))[:]). - Str("ciphertext", ciphertext). - Str("olm_session_description", session.Describe()). - Msg("DEBUG: Failed to decrypt prekey olm message with newly created session") - err2 := mach.goolmRetryHack(ctx, senderKey, ciphertext, accountBackup) - if err2 != nil { - log.Debug().Err(err2).Msg("Goolm confirmed decryption failure") - } else { - log.Warn().Msg("Goolm decryption was successful after libolm failure?") - } - go mach.unwedgeDevice(log, sender, senderKey) return nil, fmt.Errorf("failed to decrypt olm event with session created from prekey message: %w", err) } @@ -238,23 +205,6 @@ func (mach *OlmMachine) tryDecryptOlmCiphertext(ctx context.Context, sender id.U return plaintext, nil } -func (mach *OlmMachine) goolmRetryHack(ctx context.Context, senderKey id.SenderKey, ciphertext string, accountBackup []byte) error { - acc, err := account.AccountFromPickled(accountBackup, []byte("tmp")) - if err != nil { - return fmt.Errorf("failed to unpickle olm account: %w", err) - } - sess, err := acc.NewInboundSessionFrom(&senderKey, ciphertext) - if err != nil { - return fmt.Errorf("failed to create inbound session: %w", err) - } - _, err = sess.Decrypt(ciphertext, id.OlmMsgTypePreKey) - if err != nil { - // This is the expected result if libolm failed - return fmt.Errorf("failed to decrypt with new session: %w", err) - } - return nil -} - const MaxOlmSessionsPerDevice = 5 func (mach *OlmMachine) tryDecryptOlmCiphertextWithExistingSession( @@ -313,11 +263,10 @@ func (mach *OlmMachine) tryDecryptOlmCiphertextWithExistingSession( if err != nil { log.Warn().Err(err). Hex("ciphertext_hash", ciphertextHash[:]). - Hex("ciphertext_hash_repeat", ptr.Ptr(exerrors.Must(olmMessageHash(ciphertext)))[:]). Str("session_description", session.Describe()). Msg("Failed to decrypt olm message") if olmType == id.OlmMsgTypePreKey { - return nil, ErrDecryptionFailedWithMatchingSession + return nil, DecryptionFailedWithMatchingSession } } else { endTimeTrace = mach.timeTrace(ctx, "updating session in database", time.Second) @@ -357,10 +306,10 @@ const MinUnwedgeInterval = 1 * time.Hour func (mach *OlmMachine) unwedgeDevice(log zerolog.Logger, sender id.UserID, senderKey id.SenderKey) { log = log.With().Str("action", "unwedge olm session").Logger() - ctx := log.WithContext(mach.backgroundCtx) + ctx := log.WithContext(mach.BackgroundCtx) mach.recentlyUnwedgedLock.Lock() prevUnwedge, ok := mach.recentlyUnwedged[senderKey] - delta := time.Since(prevUnwedge) + delta := time.Now().Sub(prevUnwedge) if ok && delta < MinUnwedgeInterval { log.Debug(). Str("previous_recreation", delta.String()). @@ -391,10 +340,7 @@ func (mach *OlmMachine) unwedgeDevice(log zerolog.Logger, sender id.UserID, send return } - log.Debug(). - Time("last_created", lastCreatedAt). - Stringer("device_id", deviceIdentity.DeviceID). - Msg("Creating new Olm session") + log.Debug().Str("device_id", deviceIdentity.DeviceID.String()).Msg("Creating new Olm session") mach.devicesToUnwedgeLock.Lock() mach.devicesToUnwedge[senderKey] = true mach.devicesToUnwedgeLock.Unlock() diff --git a/crypto/devicelist.go b/crypto/devicelist.go index f0d2b129..a2116ed5 100644 --- a/crypto/devicelist.go +++ b/crypto/devicelist.go @@ -22,23 +22,14 @@ import ( ) var ( - ErrMismatchingDeviceID = errors.New("mismatching device ID in parameter and keys object") - ErrMismatchingUserID = errors.New("mismatching user ID in parameter and keys object") - ErrMismatchingSigningKey = errors.New("received update for device with different signing key") - ErrNoSigningKeyFound = errors.New("didn't find ed25519 signing key") - ErrNoIdentityKeyFound = errors.New("didn't find curve25519 identity key") - ErrInvalidKeySignature = errors.New("invalid signature on device keys") - ErrUserNotTracked = errors.New("user is not tracked") -) + MismatchingDeviceID = errors.New("mismatching device ID in parameter and keys object") + MismatchingUserID = errors.New("mismatching user ID in parameter and keys object") + MismatchingSigningKey = errors.New("received update for device with different signing key") + NoSigningKeyFound = errors.New("didn't find ed25519 signing key") + NoIdentityKeyFound = errors.New("didn't find curve25519 identity key") + InvalidKeySignature = errors.New("invalid signature on device keys") -// Deprecated: use variables prefixed with Err -var ( - MismatchingDeviceID = ErrMismatchingDeviceID - MismatchingUserID = ErrMismatchingUserID - MismatchingSigningKey = ErrMismatchingSigningKey - NoSigningKeyFound = ErrNoSigningKeyFound - NoIdentityKeyFound = ErrNoIdentityKeyFound - InvalidKeySignature = ErrInvalidKeySignature + ErrUserNotTracked = errors.New("user is not tracked") ) func (mach *OlmMachine) LoadDevices(ctx context.Context, user id.UserID) (keys map[id.DeviceID]*id.Device) { @@ -215,7 +206,7 @@ func (mach *OlmMachine) FetchKeys(ctx context.Context, users []id.UserID, includ log.Trace().Int("user_count", len(resp.DeviceKeys)).Msg("Query key result received") data = make(map[id.UserID]map[id.DeviceID]*id.Device) for userID, devices := range resp.DeviceKeys { - log := log.With().Stringer("user_id", userID).Logger() + log := log.With().Str("user_id", userID.String()).Logger() delete(req.DeviceKeys, userID) newDevices := make(map[id.DeviceID]*id.Device) @@ -231,7 +222,7 @@ func (mach *OlmMachine) FetchKeys(ctx context.Context, users []id.UserID, includ Msg("Updating devices in store") changed := false for deviceID, deviceKeys := range devices { - log := log.With().Stringer("device_id", deviceID).Logger() + log := log.With().Str("device_id", deviceID.String()).Logger() existing, ok := existingDevices[deviceID] if !ok { // New device @@ -279,7 +270,7 @@ func (mach *OlmMachine) FetchKeys(ctx context.Context, users []id.UserID, includ } } for userID := range req.DeviceKeys { - log.Warn().Stringer("user_id", userID).Msg("Didn't get any keys for user") + log.Warn().Str("user_id", userID.String()).Msg("Didn't get any keys for user") } mach.storeCrossSigningKeys(ctx, resp.MasterKeys, resp.DeviceKeys) @@ -321,28 +312,28 @@ func (mach *OlmMachine) OnDevicesChanged(ctx context.Context, userID id.UserID) func (mach *OlmMachine) validateDevice(userID id.UserID, deviceID id.DeviceID, deviceKeys mautrix.DeviceKeys, existing *id.Device) (*id.Device, error) { if deviceID != deviceKeys.DeviceID { - return nil, fmt.Errorf("%w (expected %s, got %s)", ErrMismatchingDeviceID, deviceID, deviceKeys.DeviceID) + return nil, fmt.Errorf("%w (expected %s, got %s)", MismatchingDeviceID, deviceID, deviceKeys.DeviceID) } else if userID != deviceKeys.UserID { - return nil, fmt.Errorf("%w (expected %s, got %s)", ErrMismatchingUserID, userID, deviceKeys.UserID) + return nil, fmt.Errorf("%w (expected %s, got %s)", MismatchingUserID, userID, deviceKeys.UserID) } signingKey := deviceKeys.Keys.GetEd25519(deviceID) identityKey := deviceKeys.Keys.GetCurve25519(deviceID) if signingKey == "" { - return nil, ErrNoSigningKeyFound + return nil, NoSigningKeyFound } else if identityKey == "" { - return nil, ErrNoIdentityKeyFound + return nil, NoIdentityKeyFound } if existing != nil && existing.SigningKey != signingKey { - return existing, fmt.Errorf("%w (expected %s, got %s)", ErrMismatchingSigningKey, existing.SigningKey, signingKey) + return existing, fmt.Errorf("%w (expected %s, got %s)", MismatchingSigningKey, existing.SigningKey, signingKey) } ok, err := signatures.VerifySignatureJSON(deviceKeys, userID, deviceID.String(), signingKey) if err != nil { return existing, fmt.Errorf("failed to verify signature: %w", err) } else if !ok { - return existing, ErrInvalidKeySignature + return existing, InvalidKeySignature } name, ok := deviceKeys.Unsigned["device_display_name"].(string) diff --git a/crypto/encryptmegolm.go b/crypto/encryptmegolm.go index 88f9c8d4..14ba2449 100644 --- a/crypto/encryptmegolm.go +++ b/crypto/encryptmegolm.go @@ -25,12 +25,7 @@ import ( ) var ( - ErrNoGroupSession = errors.New("no group session created") -) - -// Deprecated: use variables prefixed with Err -var ( - NoGroupSession = ErrNoGroupSession + NoGroupSession = errors.New("no group session created") ) func getRawJSON[T any](content json.RawMessage, path ...string) *T { @@ -46,7 +41,7 @@ func getRawJSON[T any](content json.RawMessage, path ...string) *T { return &result } -func getRelatesTo(content any, plaintext json.RawMessage) *event.RelatesTo { +func getRelatesTo(content any) *event.RelatesTo { contentJSON, ok := content.(json.RawMessage) if ok { return getRawJSON[event.RelatesTo](contentJSON, "m.relates_to") @@ -59,7 +54,7 @@ func getRelatesTo(content any, plaintext json.RawMessage) *event.RelatesTo { if ok { return relatable.OptionalGetRelatesTo() } - return getRawJSON[event.RelatesTo](plaintext, "content", "m.relates_to") + return nil } func getMentions(content any) *event.Mentions { @@ -87,20 +82,15 @@ type rawMegolmEvent struct { // IsShareError returns true if the error is caused by the lack of an outgoing megolm session and can be solved with OlmMachine.ShareGroupSession func IsShareError(err error) bool { - return err == ErrSessionExpired || err == ErrSessionNotShared || err == ErrNoGroupSession + return err == SessionExpired || err == SessionNotShared || err == NoGroupSession } func ParseMegolmMessageIndex(ciphertext []byte) (uint, error) { - if len(ciphertext) == 0 { - return 0, fmt.Errorf("empty ciphertext") - } decoded := make([]byte, base64.RawStdEncoding.DecodedLen(len(ciphertext))) var err error _, err = base64.RawStdEncoding.Decode(decoded, ciphertext) if err != nil { return 0, err - } else if len(decoded) < 2+binary.MaxVarintLen64 { - return 0, fmt.Errorf("decoded ciphertext too short: %d bytes", len(decoded)) } else if decoded[0] != 3 || decoded[1] != 8 { return 0, fmt.Errorf("unexpected initial bytes %d and %d", decoded[0], decoded[1]) } @@ -130,7 +120,7 @@ func (mach *OlmMachine) EncryptMegolmEventWithStateKey(ctx context.Context, room if err != nil { return nil, fmt.Errorf("failed to get outbound group session: %w", err) } else if session == nil { - return nil, ErrNoGroupSession + return nil, NoGroupSession } plaintext, err := json.Marshal(&rawMegolmEvent{ RoomID: roomID, @@ -168,21 +158,12 @@ func (mach *OlmMachine) EncryptMegolmEventWithStateKey(ctx context.Context, room Algorithm: id.AlgorithmMegolmV1, SessionID: session.ID(), MegolmCiphertext: ciphertext, - RelatesTo: getRelatesTo(content, plaintext), + RelatesTo: getRelatesTo(content), // These are deprecated SenderKey: mach.account.IdentityKey(), DeviceID: mach.Client.DeviceID, } - if mach.MSC4392Relations && encrypted.RelatesTo != nil { - // When MSC4392 mode is enabled, reply and reaction metadata is stripped from the unencrypted content. - // Other relations like threads are still left unencrypted. - encrypted.RelatesTo.InReplyTo = nil - encrypted.RelatesTo.IsFallingBack = false - if evtType == event.EventReaction || encrypted.RelatesTo.Type == "" { - encrypted.RelatesTo = nil - } - } if mach.PlaintextMentions { encrypted.Mentions = getMentions(content) } @@ -252,7 +233,7 @@ func (mach *OlmMachine) ShareGroupSession(ctx context.Context, roomID id.RoomID, var fetchKeysForUsers []id.UserID for _, userID := range users { - log := log.With().Stringer("target_user_id", userID).Logger() + log := log.With().Str("target_user_id", userID.String()).Logger() devices, err := mach.CryptoStore.GetDevices(ctx, userID) if err != nil { log.Err(err).Msg("Failed to get devices of user") @@ -324,7 +305,7 @@ func (mach *OlmMachine) ShareGroupSession(ctx context.Context, roomID id.RoomID, toDeviceWithheld.Messages[userID] = withheld } - log := log.With().Stringer("target_user_id", userID).Logger() + log := log.With().Str("target_user_id", userID.String()).Logger() log.Trace().Msg("Trying to find olm session to encrypt megolm session for user (post-fetch retry)") mach.findOlmSessionsForUser(ctx, session, userID, devices, output, withheld, nil) log.Debug(). @@ -370,19 +351,26 @@ func (mach *OlmMachine) encryptAndSendGroupSession(ctx context.Context, session log.Trace().Msg("Encrypting group session for all found devices") deviceCount := 0 toDevice := &mautrix.ReqSendToDevice{Messages: make(map[id.UserID]map[id.DeviceID]*event.Content)} - logUsers := zerolog.Dict() for userID, sessions := range olmSessions { if len(sessions) == 0 { continue } - logDevices := zerolog.Dict() output := make(map[id.DeviceID]*event.Content) toDevice.Messages[userID] = output for deviceID, device := range sessions { + log.Trace(). + Stringer("target_user_id", userID). + Stringer("target_device_id", deviceID). + Stringer("target_identity_key", device.identity.IdentityKey). + Msg("Encrypting group session for device") content := mach.encryptOlmEvent(ctx, device.session, device.identity, event.ToDeviceRoomKey, session.ShareContent()) output[deviceID] = &event.Content{Parsed: content} - logDevices.Str(string(deviceID), string(device.identity.IdentityKey)) deviceCount++ + log.Debug(). + Stringer("target_user_id", userID). + Stringer("target_device_id", deviceID). + Stringer("target_identity_key", device.identity.IdentityKey). + Msg("Encrypted group session for device") if !mach.DisableSharedGroupSessionTracking { err := mach.CryptoStore.MarkOutboundGroupSessionShared(ctx, userID, device.identity.IdentityKey, session.id) if err != nil { @@ -396,13 +384,11 @@ func (mach *OlmMachine) encryptAndSendGroupSession(ctx context.Context, session } } } - logUsers.Dict(string(userID), logDevices) } log.Debug(). Int("device_count", deviceCount). Int("user_count", len(toDevice.Messages)). - Dict("destination_map", logUsers). Msg("Sending to-device messages to share group session") _, err := mach.Client.SendToDevice(ctx, event.ToDeviceEncrypted, toDevice) return err diff --git a/crypto/encryptolm.go b/crypto/encryptolm.go index 765307af..80b76dc5 100644 --- a/crypto/encryptolm.go +++ b/crypto/encryptolm.go @@ -96,19 +96,15 @@ func (mach *OlmMachine) encryptOlmEvent(ctx context.Context, session *OlmSession panic(err) } log := mach.machOrContextLog(ctx) + log.Debug(). + Str("recipient_identity_key", recipient.IdentityKey.String()). + Str("olm_session_id", session.ID().String()). + Str("olm_session_description", session.Describe()). + Msg("Encrypting olm message") msgType, ciphertext, err := session.Encrypt(plaintext) if err != nil { panic(err) } - ciphertextStr := string(ciphertext) - ciphertextHash, _ := olmMessageHash(ciphertextStr) - log.Debug(). - Stringer("event_type", evtType). - Str("recipient_identity_key", recipient.IdentityKey.String()). - Str("olm_session_id", session.ID().String()). - Str("olm_session_description", session.Describe()). - Hex("ciphertext_hash", ciphertextHash[:]). - Msg("Encrypted olm message") err = mach.CryptoStore.UpdateSession(ctx, recipient.IdentityKey, session) if err != nil { log.Error().Err(err).Msg("Failed to update olm session in crypto store after encrypting") @@ -119,7 +115,7 @@ func (mach *OlmMachine) encryptOlmEvent(ctx context.Context, session *OlmSession OlmCiphertext: event.OlmCiphertexts{ recipient.IdentityKey: { Type: msgType, - Body: ciphertextStr, + Body: string(ciphertext), }, }, } diff --git a/crypto/goolm/account/account.go b/crypto/goolm/account/account.go index b48843a4..4da08a73 100644 --- a/crypto/goolm/account/account.go +++ b/crypto/goolm/account/account.go @@ -334,7 +334,7 @@ func (a *Account) UnpickleLibOlm(buf []byte) error { if err != nil { return err } else if pickledVersion != accountPickleVersionLibOLM && pickledVersion != 3 && pickledVersion != 2 { - return fmt.Errorf("unpickle account: %w (found version %d)", olm.ErrUnknownOlmPickleVersion, pickledVersion) + return fmt.Errorf("unpickle account: %w (found version %d)", olm.ErrBadVersion, pickledVersion) } else if err = a.IdKeys.Ed25519.UnpickleLibOlm(decoder); err != nil { // read the ed25519 key pair return err } else if err = a.IdKeys.Curve25519.UnpickleLibOlm(decoder); err != nil { // read curve25519 key pair diff --git a/crypto/goolm/account/account_test.go b/crypto/goolm/account/account_test.go index d0dec5f0..e1c9b452 100644 --- a/crypto/goolm/account/account_test.go +++ b/crypto/goolm/account/account_test.go @@ -124,7 +124,7 @@ func TestOldAccountPickle(t *testing.T) { account, err := account.NewAccount() assert.NoError(t, err) err = account.Unpickle(pickled, pickleKey) - assert.ErrorIs(t, err, olm.ErrUnknownOlmPickleVersion) + assert.ErrorIs(t, err, olm.ErrBadVersion) } func TestLoopback(t *testing.T) { diff --git a/crypto/goolm/account/register.go b/crypto/goolm/account/register.go index ec392d7e..c6b9e523 100644 --- a/crypto/goolm/account/register.go +++ b/crypto/goolm/account/register.go @@ -10,7 +10,7 @@ import ( "maunium.net/go/mautrix/crypto/olm" ) -func Register() { +func init() { olm.InitNewAccount = func() (olm.Account, error) { return NewAccount() } diff --git a/crypto/goolm/crypto/curve25519.go b/crypto/goolm/crypto/curve25519.go index 6e42d886..e9759501 100644 --- a/crypto/goolm/crypto/curve25519.go +++ b/crypto/goolm/crypto/curve25519.go @@ -53,7 +53,6 @@ func (c Curve25519KeyPair) B64Encoded() id.Curve25519 { // SharedSecret returns the shared secret between the key pair and the given public key. func (c Curve25519KeyPair) SharedSecret(pubKey Curve25519PublicKey) ([]byte, error) { - // Note: the standard library checks that the output is non-zero return c.PrivateKey.SharedSecret(pubKey) } diff --git a/crypto/goolm/crypto/curve25519_test.go b/crypto/goolm/crypto/curve25519_test.go index 2550f15e..9039c126 100644 --- a/crypto/goolm/crypto/curve25519_test.go +++ b/crypto/goolm/crypto/curve25519_test.go @@ -25,8 +25,6 @@ func TestCurve25519(t *testing.T) { fromPrivate, err := crypto.Curve25519GenerateFromPrivate(firstKeypair.PrivateKey) assert.NoError(t, err) assert.Equal(t, fromPrivate, firstKeypair) - _, err = secondKeypair.SharedSecret(make([]byte, crypto.Curve25519PublicKeyLength)) - assert.Error(t, err) } func TestCurve25519Case1(t *testing.T) { diff --git a/crypto/goolm/goolmbase64/base64.go b/crypto/goolm/goolmbase64/base64.go index 58ee26f7..061a052a 100644 --- a/crypto/goolm/goolmbase64/base64.go +++ b/crypto/goolm/goolmbase64/base64.go @@ -4,8 +4,7 @@ import ( "encoding/base64" ) -// These methods should only be used for raw byte operations, never with string conversion - +// Deprecated: base64.RawStdEncoding should be used directly func Decode(input []byte) ([]byte, error) { decoded := make([]byte, base64.RawStdEncoding.DecodedLen(len(input))) writtenBytes, err := base64.RawStdEncoding.Decode(decoded, input) @@ -15,6 +14,7 @@ func Decode(input []byte) ([]byte, error) { return decoded[:writtenBytes], nil } +// Deprecated: base64.RawStdEncoding should be used directly func Encode(input []byte) []byte { encoded := make([]byte, base64.RawStdEncoding.EncodedLen(len(input))) base64.RawStdEncoding.Encode(encoded, input) diff --git a/crypto/goolm/libolmpickle/picklejson.go b/crypto/goolm/libolmpickle/picklejson.go index f765391f..308e472c 100644 --- a/crypto/goolm/libolmpickle/picklejson.go +++ b/crypto/goolm/libolmpickle/picklejson.go @@ -50,7 +50,7 @@ func UnpickleAsJSON(object any, pickled, key []byte, pickleVersion byte) error { } } if decrypted[0] != pickleVersion { - return fmt.Errorf("unpickle: %w", olm.ErrUnknownJSONPickleVersion) + return fmt.Errorf("unpickle: %w", olm.ErrWrongPickleVersion) } err = json.Unmarshal(decrypted[1:], object) if err != nil { diff --git a/crypto/goolm/message/decoder.go b/crypto/goolm/message/decoder.go index b06756a9..a71cf302 100644 --- a/crypto/goolm/message/decoder.go +++ b/crypto/goolm/message/decoder.go @@ -3,9 +3,6 @@ package message import ( "bytes" "encoding/binary" - "fmt" - - "maunium.net/go/mautrix/crypto/olm" ) type Decoder struct { @@ -23,8 +20,6 @@ func (d *Decoder) ReadVarInt() (uint64, error) { func (d *Decoder) ReadVarBytes() ([]byte, error) { if n, err := d.ReadVarInt(); err != nil { return nil, err - } else if n > uint64(d.Len()) { - return nil, fmt.Errorf("%w: var bytes length says %d, but only %d bytes left", olm.ErrInputToSmall, n, d.Available()) } else { out := make([]byte, n) _, err = d.Read(out) diff --git a/crypto/goolm/message/group_message.go b/crypto/goolm/message/group_message.go index c83540c1..c2a43b1f 100644 --- a/crypto/goolm/message/group_message.go +++ b/crypto/goolm/message/group_message.go @@ -2,12 +2,10 @@ package message import ( "bytes" - "fmt" "io" "maunium.net/go/mautrix/crypto/goolm/aessha2" "maunium.net/go/mautrix/crypto/goolm/crypto" - "maunium.net/go/mautrix/crypto/olm" ) const ( @@ -38,9 +36,6 @@ func (r *GroupMessage) Decode(input []byte) (err error) { if err != nil { return } - if r.Version != protocolVersion { - return fmt.Errorf("GroupMessage.Decode: %w (got %d, expected %d)", olm.ErrWrongProtocolVersion, r.Version, protocolVersion) - } for { // Read Key diff --git a/crypto/goolm/message/message.go b/crypto/goolm/message/message.go index b161a2d1..8bb6e0cd 100644 --- a/crypto/goolm/message/message.go +++ b/crypto/goolm/message/message.go @@ -2,12 +2,10 @@ package message import ( "bytes" - "fmt" "io" "maunium.net/go/mautrix/crypto/goolm/aessha2" "maunium.net/go/mautrix/crypto/goolm/crypto" - "maunium.net/go/mautrix/crypto/olm" ) const ( @@ -42,9 +40,6 @@ func (r *Message) Decode(input []byte) (err error) { if err != nil { return } - if r.Version != protocolVersion { - return fmt.Errorf("Message.Decode: %w (got %d, expected %d)", olm.ErrWrongProtocolVersion, r.Version, protocolVersion) - } for { // Read Key diff --git a/crypto/goolm/message/prekey_message.go b/crypto/goolm/message/prekey_message.go index 4e3d495d..22ebf9c3 100644 --- a/crypto/goolm/message/prekey_message.go +++ b/crypto/goolm/message/prekey_message.go @@ -1,7 +1,6 @@ package message import ( - "fmt" "io" "maunium.net/go/mautrix/crypto/goolm/crypto" @@ -23,11 +22,6 @@ type PreKeyMessage struct { Message []byte `json:"message"` } -// TODO deduplicate constant with one in session/olm_session.go -const ( - protocolVersion = 0x3 -) - // Decodes decodes the input and populates the corresponding fileds. func (r *PreKeyMessage) Decode(input []byte) (err error) { r.Version = 0 @@ -47,9 +41,6 @@ func (r *PreKeyMessage) Decode(input []byte) (err error) { } return } - if r.Version != protocolVersion { - return fmt.Errorf("PreKeyMessage.Decode: %w (got %d, expected %d)", olm.ErrWrongProtocolVersion, r.Version, protocolVersion) - } for { // Read Key diff --git a/crypto/goolm/message/session_export.go b/crypto/goolm/message/session_export.go index d58dbb21..956868b2 100644 --- a/crypto/goolm/message/session_export.go +++ b/crypto/goolm/message/session_export.go @@ -35,7 +35,7 @@ func (s *MegolmSessionExport) Decode(input []byte) error { return fmt.Errorf("decrypt: %w", olm.ErrBadInput) } if input[0] != sessionExportVersion { - return fmt.Errorf("decrypt: %w", olm.ErrUnknownOlmPickleVersion) + return fmt.Errorf("decrypt: %w", olm.ErrBadVersion) } s.Counter = binary.BigEndian.Uint32(input[1:5]) copy(s.RatchetData[:], input[5:133]) diff --git a/crypto/goolm/message/session_sharing.go b/crypto/goolm/message/session_sharing.go index d04ef15a..16240945 100644 --- a/crypto/goolm/message/session_sharing.go +++ b/crypto/goolm/message/session_sharing.go @@ -42,7 +42,7 @@ func (s *MegolmSessionSharing) VerifyAndDecode(input []byte) error { } s.PublicKey = publicKey if input[0] != sessionSharingVersion { - return fmt.Errorf("verify: %w", olm.ErrUnknownOlmPickleVersion) + return fmt.Errorf("verify: %w", olm.ErrBadVersion) } s.Counter = binary.BigEndian.Uint32(input[1:5]) copy(s.RatchetData[:], input[5:133]) diff --git a/crypto/goolm/pk/decryption.go b/crypto/goolm/pk/decryption.go index cdb20eb1..afb01f74 100644 --- a/crypto/goolm/pk/decryption.go +++ b/crypto/goolm/pk/decryption.go @@ -103,7 +103,7 @@ func (a *Decryption) UnpickleLibOlm(unpickled []byte) error { if pickledVersion == decryptionPickleVersionLibOlm { return a.KeyPair.UnpickleLibOlm(decoder) } else { - return fmt.Errorf("unpickle olmSession: %w (found %d, expected %d)", olm.ErrUnknownOlmPickleVersion, pickledVersion, decryptionPickleVersionLibOlm) + return fmt.Errorf("unpickle olmSession: %w (found %d, expected %d)", olm.ErrBadVersion, pickledVersion, decryptionPickleVersionLibOlm) } } diff --git a/crypto/goolm/pk/encryption.go b/crypto/goolm/pk/encryption.go index 2897d9b0..23f67ddf 100644 --- a/crypto/goolm/pk/encryption.go +++ b/crypto/goolm/pk/encryption.go @@ -37,9 +37,6 @@ func (e Encryption) Encrypt(plaintext []byte, privateKey crypto.Curve25519Privat return nil, nil, err } cipher, err := aessha2.NewAESSHA2(sharedSecret, nil) - if err != nil { - return nil, nil, err - } ciphertext, err = cipher.Encrypt(plaintext) if err != nil { return nil, nil, err diff --git a/crypto/goolm/pk/register.go b/crypto/goolm/pk/register.go index 0e27b568..b7af6a5b 100644 --- a/crypto/goolm/pk/register.go +++ b/crypto/goolm/pk/register.go @@ -8,7 +8,7 @@ package pk import "maunium.net/go/mautrix/crypto/olm" -func Register() { +func init() { olm.InitNewPKSigningFromSeed = func(seed []byte) (olm.PKSigning, error) { return NewSigningFromSeed(seed) } diff --git a/crypto/goolm/ratchet/olm.go b/crypto/goolm/ratchet/olm.go index 9901ada8..229c9bd2 100644 --- a/crypto/goolm/ratchet/olm.go +++ b/crypto/goolm/ratchet/olm.go @@ -142,7 +142,7 @@ func (r *Ratchet) Decrypt(input []byte) ([]byte, error) { return nil, err } if message.Version != protocolVersion { - return nil, fmt.Errorf("decrypt: %w (got %d, expected %d)", olm.ErrWrongProtocolVersion, message.Version, protocolVersion) + return nil, fmt.Errorf("decrypt: %w", olm.ErrWrongProtocolVersion) } if !message.HasCounter || len(message.RatchetKey) == 0 || len(message.Ciphertext) == 0 { return nil, fmt.Errorf("decrypt: %w", olm.ErrBadMessageFormat) diff --git a/crypto/goolm/register.go b/crypto/goolm/register.go index 800f567f..80ed206b 100644 --- a/crypto/goolm/register.go +++ b/crypto/goolm/register.go @@ -7,23 +7,19 @@ package goolm import ( - "maunium.net/go/mautrix/crypto/goolm/account" - "maunium.net/go/mautrix/crypto/goolm/pk" - "maunium.net/go/mautrix/crypto/goolm/session" + // Need to import these subpackages to ensure they are registered + _ "maunium.net/go/mautrix/crypto/goolm/account" + _ "maunium.net/go/mautrix/crypto/goolm/pk" + _ "maunium.net/go/mautrix/crypto/goolm/session" + "maunium.net/go/mautrix/crypto/olm" ) -func Register() { - olm.Driver = "goolm" - +func init() { olm.GetVersion = func() (major, minor, patch uint8) { return 3, 2, 15 } olm.SetPickleKeyImpl = func(key []byte) { panic("gob and json encoding is deprecated and not supported with goolm") } - - account.Register() - pk.Register() - session.Register() } diff --git a/crypto/goolm/session/megolm_inbound_session.go b/crypto/goolm/session/megolm_inbound_session.go index 7ccbd26d..80dd71cc 100644 --- a/crypto/goolm/session/megolm_inbound_session.go +++ b/crypto/goolm/session/megolm_inbound_session.go @@ -99,7 +99,7 @@ func (o *MegolmInboundSession) getRatchet(messageIndex uint32) (*megolm.Ratchet, } if (messageIndex - o.InitialRatchet.Counter) >= uint32(1<<31) { // the counter is before our initial ratchet - we can't decode this - return nil, fmt.Errorf("decrypt: %w", olm.ErrUnknownMessageIndex) + return nil, fmt.Errorf("decrypt: %w", olm.ErrRatchetNotAvailable) } // otherwise, start from the initial ratchet. Take a copy so that we don't overwrite the initial ratchet copiedRatchet := o.InitialRatchet @@ -126,7 +126,7 @@ func (o *MegolmInboundSession) Decrypt(ciphertext []byte) ([]byte, uint, error) return nil, 0, err } if msg.Version != protocolVersion { - return nil, 0, fmt.Errorf("decrypt: %w (got %d, expected %d)", olm.ErrWrongProtocolVersion, msg.Version, protocolVersion) + return nil, 0, fmt.Errorf("decrypt: %w", olm.ErrWrongProtocolVersion) } if msg.Ciphertext == nil || !msg.HasMessageIndex { return nil, 0, fmt.Errorf("decrypt: %w", olm.ErrBadMessageFormat) @@ -206,7 +206,7 @@ func (o *MegolmInboundSession) UnpickleLibOlm(value []byte) error { return err } if pickledVersion != megolmInboundSessionPickleVersionLibOlm && pickledVersion != 1 { - return fmt.Errorf("unpickle MegolmInboundSession: %w (found version %d)", olm.ErrUnknownOlmPickleVersion, pickledVersion) + return fmt.Errorf("unpickle MegolmInboundSession: %w (found version %d)", olm.ErrBadVersion, pickledVersion) } if err = o.InitialRatchet.UnpickleLibOlm(decoder); err != nil { diff --git a/crypto/goolm/session/megolm_outbound_session.go b/crypto/goolm/session/megolm_outbound_session.go index 7f923534..2b8e1c84 100644 --- a/crypto/goolm/session/megolm_outbound_session.go +++ b/crypto/goolm/session/megolm_outbound_session.go @@ -101,10 +101,8 @@ func (o *MegolmOutboundSession) Unpickle(pickled, key []byte) error { func (o *MegolmOutboundSession) UnpickleLibOlm(buf []byte) error { decoder := libolmpickle.NewDecoder(buf) pickledVersion, err := decoder.ReadUInt32() - if err != nil { - return fmt.Errorf("unpickle MegolmOutboundSession: failed to read version: %w", err) - } else if pickledVersion != megolmOutboundSessionPickleVersionLibOlm { - return fmt.Errorf("unpickle MegolmInboundSession: %w (found version %d)", olm.ErrUnknownOlmPickleVersion, pickledVersion) + if pickledVersion != megolmOutboundSessionPickleVersionLibOlm { + return fmt.Errorf("unpickle MegolmInboundSession: %w (found version %d)", olm.ErrBadVersion, pickledVersion) } if err = o.Ratchet.UnpickleLibOlm(decoder); err != nil { return err diff --git a/crypto/goolm/session/olm_session.go b/crypto/goolm/session/olm_session.go index a1cb8d66..b99ab630 100644 --- a/crypto/goolm/session/olm_session.go +++ b/crypto/goolm/session/olm_session.go @@ -168,11 +168,11 @@ func NewInboundOlmSession(identityKeyAlice *crypto.Curve25519PublicKey, received msg := message.Message{} err = msg.Decode(oneTimeMsg.Message) if err != nil { - return nil, fmt.Errorf("message decode: %w", err) + return nil, fmt.Errorf("Message decode: %w", err) } if len(msg.RatchetKey) == 0 { - return nil, fmt.Errorf("message missing ratchet key: %w", olm.ErrBadMessageFormat) + return nil, fmt.Errorf("Message missing ratchet key: %w", olm.ErrBadMessageFormat) } //Init Ratchet s.Ratchet.InitializeAsBob(secret, msg.RatchetKey) @@ -203,7 +203,7 @@ func (s *OlmSession) ID() id.SessionID { copy(message[crypto.Curve25519PrivateKeyLength:], s.AliceBaseKey) copy(message[2*crypto.Curve25519PrivateKeyLength:], s.BobOneTimeKey) hash := sha256.Sum256(message) - res := id.SessionID(base64.RawStdEncoding.EncodeToString(hash[:])) + res := id.SessionID(goolmbase64.Encode(hash[:])) return res } @@ -325,7 +325,7 @@ func (s *OlmSession) Decrypt(crypttext string, msgType id.OlmMsgType) ([]byte, e if len(crypttext) == 0 { return nil, fmt.Errorf("decrypt: %w", olm.ErrEmptyInput) } - decodedCrypttext, err := base64.RawStdEncoding.DecodeString(crypttext) + decodedCrypttext, err := goolmbase64.Decode([]byte(crypttext)) if err != nil { return nil, err } @@ -365,9 +365,6 @@ func (o *OlmSession) Unpickle(pickled, key []byte) error { func (o *OlmSession) UnpickleLibOlm(buf []byte) error { decoder := libolmpickle.NewDecoder(buf) pickledVersion, err := decoder.ReadUInt32() - if err != nil { - return fmt.Errorf("unpickle olmSession: failed to read version: %w", err) - } var includesChainIndex bool switch pickledVersion { @@ -376,7 +373,7 @@ func (o *OlmSession) UnpickleLibOlm(buf []byte) error { case uint32(0x80000001): includesChainIndex = true default: - return fmt.Errorf("unpickle olmSession: %w (found version %d)", olm.ErrUnknownOlmPickleVersion, pickledVersion) + return fmt.Errorf("unpickle olmSession: %w (found version %d)", olm.ErrBadVersion, pickledVersion) } if o.ReceivedMessage, err = decoder.ReadBool(); err != nil { diff --git a/crypto/goolm/session/register.go b/crypto/goolm/session/register.go index b95a44ac..09ed42d4 100644 --- a/crypto/goolm/session/register.go +++ b/crypto/goolm/session/register.go @@ -10,11 +10,11 @@ import ( "maunium.net/go/mautrix/crypto/olm" ) -func Register() { +func init() { // Inbound Session olm.InitInboundGroupSessionFromPickled = func(pickled, key []byte) (olm.InboundGroupSession, error) { if len(pickled) == 0 { - return nil, olm.ErrEmptyInput + return nil, olm.EmptyInput } if len(key) == 0 { key = []byte(" ") @@ -23,13 +23,13 @@ func Register() { } olm.InitNewInboundGroupSession = func(sessionKey []byte) (olm.InboundGroupSession, error) { if len(sessionKey) == 0 { - return nil, olm.ErrEmptyInput + return nil, olm.EmptyInput } return NewMegolmInboundSession(sessionKey) } olm.InitInboundGroupSessionImport = func(sessionKey []byte) (olm.InboundGroupSession, error) { if len(sessionKey) == 0 { - return nil, olm.ErrEmptyInput + return nil, olm.EmptyInput } return NewMegolmInboundSessionFromExport(sessionKey) } @@ -40,7 +40,7 @@ func Register() { // Outbound Session olm.InitNewOutboundGroupSessionFromPickled = func(pickled, key []byte) (olm.OutboundGroupSession, error) { if len(pickled) == 0 { - return nil, olm.ErrEmptyInput + return nil, olm.EmptyInput } lenKey := len(key) if lenKey == 0 { diff --git a/crypto/keybackup.go b/crypto/keybackup.go index 7b3c30db..d8b3d715 100644 --- a/crypto/keybackup.go +++ b/crypto/keybackup.go @@ -56,12 +56,11 @@ func (mach *OlmMachine) GetAndVerifyLatestKeyBackupVersion(ctx context.Context, // ...by deriving the public key from a private key that it obtained from a trusted source. Trusted sources for the private // key include the user entering the key, retrieving the key stored in secret storage, or obtaining the key via secret sharing // from a verified device belonging to the same user." - if megolmBackupKey != nil { - megolmBackupDerivedPublicKey := id.Ed25519(base64.RawStdEncoding.EncodeToString(megolmBackupKey.PublicKey().Bytes())) - if versionInfo.AuthData.PublicKey == megolmBackupDerivedPublicKey { - log.Debug().Msg("Key backup is trusted based on derived public key") - return versionInfo, nil - } + megolmBackupDerivedPublicKey := id.Ed25519(base64.RawStdEncoding.EncodeToString(megolmBackupKey.PublicKey().Bytes())) + if megolmBackupKey != nil && versionInfo.AuthData.PublicKey == megolmBackupDerivedPublicKey { + log.Debug().Msg("key backup is trusted based on derived public key") + return versionInfo, nil + } else { log.Debug(). Stringer("expected_key", megolmBackupDerivedPublicKey). Stringer("actual_key", versionInfo.AuthData.PublicKey). @@ -200,14 +199,13 @@ func (mach *OlmMachine) ImportRoomKeyFromBackupWithoutSaving( SigningKey: keyBackupData.SenderClaimedKeys.Ed25519, SenderKey: keyBackupData.SenderKey, RoomID: roomID, - ForwardingChains: keyBackupData.ForwardingKeyChain, + ForwardingChains: append(keyBackupData.ForwardingKeyChain, keyBackupData.SenderKey.String()), id: sessionID, ReceivedAt: time.Now().UTC(), MaxAge: maxAge.Milliseconds(), MaxMessages: maxMessages, KeyBackupVersion: version, - KeySource: id.KeySourceBackup, }, nil } diff --git a/crypto/keyexport_test.go b/crypto/keyexport_test.go index fd6f105d..47616a20 100644 --- a/crypto/keyexport_test.go +++ b/crypto/keyexport_test.go @@ -31,5 +31,5 @@ func TestExportKeys(t *testing.T) { )) data, err := crypto.ExportKeys("meow", []*crypto.InboundGroupSession{sess}) assert.NoError(t, err) - assert.Len(t, data, 893) + assert.Len(t, data, 836) } diff --git a/crypto/keyimport.go b/crypto/keyimport.go index 3ffc74a5..36ad6b9c 100644 --- a/crypto/keyimport.go +++ b/crypto/keyimport.go @@ -108,20 +108,19 @@ func (mach *OlmMachine) importExportedRoomKey(ctx context.Context, session Expor return false, ErrMismatchingExportedSessionID } igs := &InboundGroupSession{ - Internal: igsInternal, - SigningKey: session.SenderClaimedKeys.Ed25519, - SenderKey: session.SenderKey, - RoomID: session.RoomID, + Internal: igsInternal, + SigningKey: session.SenderClaimedKeys.Ed25519, + SenderKey: session.SenderKey, + RoomID: session.RoomID, + // TODO should we add something here to mark the signing key as unverified like key requests do? ForwardingChains: session.ForwardingChains, - KeySource: id.KeySourceImport, - ReceivedAt: time.Now().UTC(), + + ReceivedAt: time.Now().UTC(), } existingIGS, _ := mach.CryptoStore.GetGroupSession(ctx, igs.RoomID, igs.ID()) firstKnownIndex := igs.Internal.FirstKnownIndex() if existingIGS != nil && existingIGS.Internal.FirstKnownIndex() <= firstKnownIndex { - // We already have an equivalent or better session in the store, so don't override it, - // but do notify the session received callback just in case. - mach.MarkSessionReceived(ctx, session.RoomID, igs.ID(), existingIGS.Internal.FirstKnownIndex()) + // We already have an equivalent or better session in the store, so don't override it. return false, nil } err = mach.CryptoStore.PutGroupSession(ctx, igs) diff --git a/crypto/keysharing.go b/crypto/keysharing.go index 19a68c87..f1d427af 100644 --- a/crypto/keysharing.go +++ b/crypto/keysharing.go @@ -189,7 +189,6 @@ func (mach *OlmMachine) importForwardedRoomKey(ctx context.Context, evt *Decrypt MaxAge: maxAge.Milliseconds(), MaxMessages: maxMessages, IsScheduled: content.IsScheduled, - KeySource: id.KeySourceForward, } existingIGS, _ := mach.CryptoStore.GetGroupSession(ctx, igs.RoomID, igs.ID()) if existingIGS != nil && existingIGS.Internal.FirstKnownIndex() <= igs.Internal.FirstKnownIndex() { @@ -215,7 +214,6 @@ func (mach *OlmMachine) rejectKeyRequest(ctx context.Context, rejection KeyShare RoomID: request.RoomID, Algorithm: request.Algorithm, SessionID: request.SessionID, - //lint:ignore SA1019 This is just echoing back the deprecated field SenderKey: request.SenderKey, Code: rejection.Code, Reason: rejection.Reason, @@ -265,14 +263,9 @@ func (mach *OlmMachine) defaultAllowKeyShare(ctx context.Context, device *id.Dev log.Err(err).Msg("Rejecting key request due to internal error when checking session sharing") return &KeyShareRejectNoResponse } else if !isShared { - igs, _ := mach.CryptoStore.GetGroupSession(ctx, evt.RoomID, evt.SessionID) - if igs != nil && igs.SenderKey == mach.OwnIdentity().IdentityKey { - log.Debug().Msg("Rejecting key request for unshared session") - return &KeyShareRejectNotRecipient - } - // Note: this case will also happen for redacted sessions and database errors - log.Debug().Msg("Rejecting key request for session created by another device") - return &KeyShareRejectNoResponse + // TODO differentiate session not shared with requester vs session not created by this device? + log.Debug().Msg("Rejecting key request for unshared session") + return &KeyShareRejectNotRecipient } log.Debug().Msg("Accepting key request for shared session") return nil @@ -330,9 +323,7 @@ func (mach *OlmMachine) HandleRoomKeyRequest(ctx context.Context, sender id.User if err != nil { if errors.Is(err, ErrGroupSessionWithheld) { log.Debug().Err(err).Msg("Requested group session not available") - if sender != mach.Client.UserID { - mach.rejectKeyRequest(ctx, KeyShareRejectUnavailable, device, content.Body) - } + mach.rejectKeyRequest(ctx, KeyShareRejectUnavailable, device, content.Body) } else { log.Error().Err(err).Msg("Failed to get group session to forward") mach.rejectKeyRequest(ctx, KeyShareRejectInternalError, device, content.Body) @@ -340,9 +331,7 @@ func (mach *OlmMachine) HandleRoomKeyRequest(ctx context.Context, sender id.User return } else if igs == nil { log.Error().Msg("Didn't find group session to forward") - if sender != mach.Client.UserID { - mach.rejectKeyRequest(ctx, KeyShareRejectUnavailable, device, content.Body) - } + mach.rejectKeyRequest(ctx, KeyShareRejectUnavailable, device, content.Body) return } if internalID := igs.ID(); internalID != content.Body.SessionID { @@ -367,7 +356,7 @@ func (mach *OlmMachine) HandleRoomKeyRequest(ctx context.Context, sender id.User SessionID: igs.ID(), SessionKey: string(exportedKey), }, - SenderKey: igs.SenderKey, + SenderKey: content.Body.SenderKey, ForwardingKeyChain: igs.ForwardingChains, SenderClaimedKey: igs.SigningKey, }, diff --git a/crypto/libolm/account.go b/crypto/libolm/account.go index 0350f083..cddce7ce 100644 --- a/crypto/libolm/account.go +++ b/crypto/libolm/account.go @@ -8,7 +8,6 @@ import ( "crypto/rand" "encoding/base64" "encoding/json" - "runtime" "unsafe" "github.com/tidwall/gjson" @@ -23,6 +22,18 @@ type Account struct { mem []byte } +func init() { + olm.InitNewAccount = func() (olm.Account, error) { + return NewAccount() + } + olm.InitBlankAccount = func() olm.Account { + return NewBlankAccount() + } + olm.InitNewAccountFromPickled = func(pickled, key []byte) (olm.Account, error) { + return AccountFromPickled(pickled, key) + } +} + // Ensure that [Account] implements [olm.Account]. var _ olm.Account = (*Account)(nil) @@ -33,7 +44,7 @@ var _ olm.Account = (*Account)(nil) // "INVALID_BASE64". func AccountFromPickled(pickled, key []byte) (*Account, error) { if len(pickled) == 0 { - return nil, olm.ErrEmptyInput + return nil, olm.EmptyInput } a := NewBlankAccount() return a, a.Unpickle(pickled, key) @@ -42,7 +53,7 @@ func AccountFromPickled(pickled, key []byte) (*Account, error) { func NewBlankAccount() *Account { memory := make([]byte, accountSize()) return &Account{ - int: C.olm_account(unsafe.Pointer(unsafe.SliceData(memory))), + int: C.olm_account(unsafe.Pointer(&memory[0])), mem: memory, } } @@ -53,13 +64,12 @@ func NewAccount() (*Account, error) { random := make([]byte, a.createRandomLen()+1) _, err := rand.Read(random) if err != nil { - panic(olm.ErrNotEnoughGoRandom) + panic(olm.NotEnoughGoRandom) } ret := C.olm_create_account( (*C.OlmAccount)(a.int), - unsafe.Pointer(unsafe.SliceData(random)), + unsafe.Pointer(&random[0]), C.size_t(len(random))) - runtime.KeepAlive(random) if ret == errorVal() { return nil, a.lastError() } else { @@ -128,14 +138,14 @@ func (a *Account) genOneTimeKeysRandomLen(num uint) uint { // supplied key. func (a *Account) Pickle(key []byte) ([]byte, error) { if len(key) == 0 { - return nil, olm.ErrNoKeyProvided + return nil, olm.NoKeyProvided } pickled := make([]byte, a.pickleLen()) r := C.olm_pickle_account( (*C.OlmAccount)(a.int), - unsafe.Pointer(unsafe.SliceData(key)), + unsafe.Pointer(&key[0]), C.size_t(len(key)), - unsafe.Pointer(unsafe.SliceData(pickled)), + unsafe.Pointer(&pickled[0]), C.size_t(len(pickled))) if r == errorVal() { return nil, a.lastError() @@ -145,13 +155,13 @@ func (a *Account) Pickle(key []byte) ([]byte, error) { func (a *Account) Unpickle(pickled, key []byte) error { if len(key) == 0 { - return olm.ErrNoKeyProvided + return olm.NoKeyProvided } r := C.olm_unpickle_account( (*C.OlmAccount)(a.int), - unsafe.Pointer(unsafe.SliceData(key)), + unsafe.Pointer(&key[0]), C.size_t(len(key)), - unsafe.Pointer(unsafe.SliceData(pickled)), + unsafe.Pointer(&pickled[0]), C.size_t(len(pickled))) if r == errorVal() { return a.lastError() @@ -198,7 +208,7 @@ func (a *Account) MarshalJSON() ([]byte, error) { // Deprecated func (a *Account) UnmarshalJSON(data []byte) error { if len(data) == 0 || data[0] != '"' || data[len(data)-1] != '"' { - return olm.ErrInputNotJSONString + return olm.InputNotJSONString } if a.int == nil { *a = *NewBlankAccount() @@ -211,7 +221,7 @@ func (a *Account) IdentityKeysJSON() ([]byte, error) { identityKeys := make([]byte, a.identityKeysLen()) r := C.olm_account_identity_keys( (*C.OlmAccount)(a.int), - unsafe.Pointer(unsafe.SliceData(identityKeys)), + unsafe.Pointer(&identityKeys[0]), C.size_t(len(identityKeys))) if r == errorVal() { return nil, a.lastError() @@ -235,16 +245,15 @@ func (a *Account) IdentityKeys() (id.Ed25519, id.Curve25519, error) { // Account. func (a *Account) Sign(message []byte) ([]byte, error) { if len(message) == 0 { - panic(olm.ErrEmptyInput) + panic(olm.EmptyInput) } signature := make([]byte, a.signatureLen()) r := C.olm_account_sign( (*C.OlmAccount)(a.int), - unsafe.Pointer(unsafe.SliceData(message)), + unsafe.Pointer(&message[0]), C.size_t(len(message)), - unsafe.Pointer(unsafe.SliceData(signature)), + unsafe.Pointer(&signature[0]), C.size_t(len(signature))) - runtime.KeepAlive(message) if r == errorVal() { panic(a.lastError()) } @@ -268,9 +277,8 @@ func (a *Account) OneTimeKeys() (map[string]id.Curve25519, error) { oneTimeKeysJSON := make([]byte, a.oneTimeKeysLen()) r := C.olm_account_one_time_keys( (*C.OlmAccount)(a.int), - unsafe.Pointer(unsafe.SliceData(oneTimeKeysJSON)), - C.size_t(len(oneTimeKeysJSON)), - ) + unsafe.Pointer(&oneTimeKeysJSON[0]), + C.size_t(len(oneTimeKeysJSON))) if r == errorVal() { return nil, a.lastError() } @@ -299,15 +307,13 @@ func (a *Account) GenOneTimeKeys(num uint) error { random := make([]byte, a.genOneTimeKeysRandomLen(num)+1) _, err := rand.Read(random) if err != nil { - return olm.ErrNotEnoughGoRandom + return olm.NotEnoughGoRandom } r := C.olm_account_generate_one_time_keys( (*C.OlmAccount)(a.int), C.size_t(num), - unsafe.Pointer(unsafe.SliceData(random)), - C.size_t(len(random)), - ) - runtime.KeepAlive(random) + unsafe.Pointer(&random[0]), + C.size_t(len(random))) if r == errorVal() { return a.lastError() } @@ -319,29 +325,23 @@ func (a *Account) GenOneTimeKeys(num uint) error { // keys couldn't be decoded as base64 then the error will be "INVALID_BASE64" func (a *Account) NewOutboundSession(theirIdentityKey, theirOneTimeKey id.Curve25519) (olm.Session, error) { if len(theirIdentityKey) == 0 || len(theirOneTimeKey) == 0 { - return nil, olm.ErrEmptyInput + return nil, olm.EmptyInput } s := NewBlankSession() random := make([]byte, s.createOutboundRandomLen()+1) _, err := rand.Read(random) if err != nil { - panic(olm.ErrNotEnoughGoRandom) + panic(olm.NotEnoughGoRandom) } - theirIdentityKeyCopy := []byte(theirIdentityKey) - theirOneTimeKeyCopy := []byte(theirOneTimeKey) r := C.olm_create_outbound_session( (*C.OlmSession)(s.int), (*C.OlmAccount)(a.int), - unsafe.Pointer(unsafe.SliceData(theirIdentityKeyCopy)), - C.size_t(len(theirIdentityKeyCopy)), - unsafe.Pointer(unsafe.SliceData(theirOneTimeKeyCopy)), - C.size_t(len(theirOneTimeKeyCopy)), - unsafe.Pointer(unsafe.SliceData(random)), - C.size_t(len(random)), - ) - runtime.KeepAlive(random) - runtime.KeepAlive(theirIdentityKeyCopy) - runtime.KeepAlive(theirOneTimeKeyCopy) + unsafe.Pointer(&([]byte(theirIdentityKey)[0])), + C.size_t(len(theirIdentityKey)), + unsafe.Pointer(&([]byte(theirOneTimeKey)[0])), + C.size_t(len(theirOneTimeKey)), + unsafe.Pointer(&random[0]), + C.size_t(len(random))) if r == errorVal() { return nil, s.lastError() } @@ -357,17 +357,14 @@ func (a *Account) NewOutboundSession(theirIdentityKey, theirOneTimeKey id.Curve2 // time key then the error will be "BAD_MESSAGE_KEY_ID". func (a *Account) NewInboundSession(oneTimeKeyMsg string) (olm.Session, error) { if len(oneTimeKeyMsg) == 0 { - return nil, olm.ErrEmptyInput + return nil, olm.EmptyInput } s := NewBlankSession() - oneTimeKeyMsgCopy := []byte(oneTimeKeyMsg) r := C.olm_create_inbound_session( (*C.OlmSession)(s.int), (*C.OlmAccount)(a.int), - unsafe.Pointer(unsafe.SliceData(oneTimeKeyMsgCopy)), - C.size_t(len(oneTimeKeyMsgCopy)), - ) - runtime.KeepAlive(oneTimeKeyMsgCopy) + unsafe.Pointer(&([]byte(oneTimeKeyMsg)[0])), + C.size_t(len(oneTimeKeyMsg))) if r == errorVal() { return nil, s.lastError() } @@ -383,21 +380,16 @@ func (a *Account) NewInboundSession(oneTimeKeyMsg string) (olm.Session, error) { // time key then the error will be "BAD_MESSAGE_KEY_ID". func (a *Account) NewInboundSessionFrom(theirIdentityKey *id.Curve25519, oneTimeKeyMsg string) (olm.Session, error) { if theirIdentityKey == nil || len(oneTimeKeyMsg) == 0 { - return nil, olm.ErrEmptyInput + return nil, olm.EmptyInput } - theirIdentityKeyCopy := []byte(*theirIdentityKey) - oneTimeKeyMsgCopy := []byte(oneTimeKeyMsg) s := NewBlankSession() r := C.olm_create_inbound_session_from( (*C.OlmSession)(s.int), (*C.OlmAccount)(a.int), - unsafe.Pointer(unsafe.SliceData(theirIdentityKeyCopy)), - C.size_t(len(theirIdentityKeyCopy)), - unsafe.Pointer(unsafe.SliceData(oneTimeKeyMsgCopy)), - C.size_t(len(oneTimeKeyMsgCopy)), - ) - runtime.KeepAlive(theirIdentityKeyCopy) - runtime.KeepAlive(oneTimeKeyMsgCopy) + unsafe.Pointer(&([]byte(*theirIdentityKey)[0])), + C.size_t(len(*theirIdentityKey)), + unsafe.Pointer(&([]byte(oneTimeKeyMsg)[0])), + C.size_t(len(oneTimeKeyMsg))) if r == errorVal() { return nil, s.lastError() } @@ -410,8 +402,7 @@ func (a *Account) NewInboundSessionFrom(theirIdentityKey *id.Curve25519, oneTime func (a *Account) RemoveOneTimeKeys(s olm.Session) error { r := C.olm_remove_one_time_keys( (*C.OlmAccount)(a.int), - (*C.OlmSession)(s.(*Session).int), - ) + (*C.OlmSession)(s.(*Session).int)) if r == errorVal() { return a.lastError() } diff --git a/crypto/libolm/error.go b/crypto/libolm/error.go index 6fb5512b..9ca415ee 100644 --- a/crypto/libolm/error.go +++ b/crypto/libolm/error.go @@ -11,21 +11,21 @@ import ( ) var errorMap = map[string]error{ - "NOT_ENOUGH_RANDOM": olm.ErrLibolmNotEnoughRandom, - "OUTPUT_BUFFER_TOO_SMALL": olm.ErrLibolmOutputBufferTooSmall, - "BAD_MESSAGE_VERSION": olm.ErrWrongProtocolVersion, - "BAD_MESSAGE_FORMAT": olm.ErrBadMessageFormat, - "BAD_MESSAGE_MAC": olm.ErrBadMAC, - "BAD_MESSAGE_KEY_ID": olm.ErrBadMessageKeyID, - "INVALID_BASE64": olm.ErrLibolmInvalidBase64, - "BAD_ACCOUNT_KEY": olm.ErrLibolmBadAccountKey, - "UNKNOWN_PICKLE_VERSION": olm.ErrUnknownOlmPickleVersion, - "CORRUPTED_PICKLE": olm.ErrLibolmCorruptedPickle, - "BAD_SESSION_KEY": olm.ErrLibolmBadSessionKey, - "UNKNOWN_MESSAGE_INDEX": olm.ErrUnknownMessageIndex, - "BAD_LEGACY_ACCOUNT_PICKLE": olm.ErrLibolmBadLegacyAccountPickle, - "BAD_SIGNATURE": olm.ErrBadSignature, - "INPUT_BUFFER_TOO_SMALL": olm.ErrInputToSmall, + "NOT_ENOUGH_RANDOM": olm.NotEnoughRandom, + "OUTPUT_BUFFER_TOO_SMALL": olm.OutputBufferTooSmall, + "BAD_MESSAGE_VERSION": olm.BadMessageVersion, + "BAD_MESSAGE_FORMAT": olm.BadMessageFormat, + "BAD_MESSAGE_MAC": olm.BadMessageMAC, + "BAD_MESSAGE_KEY_ID": olm.BadMessageKeyID, + "INVALID_BASE64": olm.InvalidBase64, + "BAD_ACCOUNT_KEY": olm.BadAccountKey, + "UNKNOWN_PICKLE_VERSION": olm.UnknownPickleVersion, + "CORRUPTED_PICKLE": olm.CorruptedPickle, + "BAD_SESSION_KEY": olm.BadSessionKey, + "UNKNOWN_MESSAGE_INDEX": olm.UnknownMessageIndex, + "BAD_LEGACY_ACCOUNT_PICKLE": olm.BadLegacyAccountPickle, + "BAD_SIGNATURE": olm.BadSignature, + "INPUT_BUFFER_TOO_SMALL": olm.InputBufferTooSmall, } func convertError(errCode string) error { diff --git a/crypto/libolm/inboundgroupsession.go b/crypto/libolm/inboundgroupsession.go index 8815ac32..1e25748d 100644 --- a/crypto/libolm/inboundgroupsession.go +++ b/crypto/libolm/inboundgroupsession.go @@ -7,7 +7,6 @@ import "C" import ( "bytes" "encoding/base64" - "runtime" "unsafe" "maunium.net/go/mautrix/crypto/olm" @@ -21,6 +20,21 @@ type InboundGroupSession struct { mem []byte } +func init() { + olm.InitInboundGroupSessionFromPickled = func(pickled, key []byte) (olm.InboundGroupSession, error) { + return InboundGroupSessionFromPickled(pickled, key) + } + olm.InitNewInboundGroupSession = func(sessionKey []byte) (olm.InboundGroupSession, error) { + return NewInboundGroupSession(sessionKey) + } + olm.InitInboundGroupSessionImport = func(sessionKey []byte) (olm.InboundGroupSession, error) { + return InboundGroupSessionImport(sessionKey) + } + olm.InitBlankInboundGroupSession = func() olm.InboundGroupSession { + return NewBlankInboundGroupSession() + } +} + // Ensure that [InboundGroupSession] implements [olm.InboundGroupSession]. var _ olm.InboundGroupSession = (*InboundGroupSession)(nil) @@ -31,7 +45,7 @@ var _ olm.InboundGroupSession = (*InboundGroupSession)(nil) // base64 couldn't be decoded then the error will be "INVALID_BASE64". func InboundGroupSessionFromPickled(pickled, key []byte) (*InboundGroupSession, error) { if len(pickled) == 0 { - return nil, olm.ErrEmptyInput + return nil, olm.EmptyInput } lenKey := len(key) if lenKey == 0 { @@ -48,15 +62,13 @@ func InboundGroupSessionFromPickled(pickled, key []byte) (*InboundGroupSession, // "OLM_BAD_SESSION_KEY". func NewInboundGroupSession(sessionKey []byte) (*InboundGroupSession, error) { if len(sessionKey) == 0 { - return nil, olm.ErrEmptyInput + return nil, olm.EmptyInput } s := NewBlankInboundGroupSession() r := C.olm_init_inbound_group_session( (*C.OlmInboundGroupSession)(s.int), - (*C.uint8_t)(unsafe.Pointer(unsafe.SliceData(sessionKey))), - C.size_t(len(sessionKey)), - ) - runtime.KeepAlive(sessionKey) + (*C.uint8_t)(&sessionKey[0]), + C.size_t(len(sessionKey))) if r == errorVal() { return nil, s.lastError() } @@ -69,15 +81,13 @@ func NewInboundGroupSession(sessionKey []byte) (*InboundGroupSession, error) { // error will be "OLM_BAD_SESSION_KEY". func InboundGroupSessionImport(sessionKey []byte) (*InboundGroupSession, error) { if len(sessionKey) == 0 { - return nil, olm.ErrEmptyInput + return nil, olm.EmptyInput } s := NewBlankInboundGroupSession() r := C.olm_import_inbound_group_session( (*C.OlmInboundGroupSession)(s.int), - (*C.uint8_t)(unsafe.Pointer(unsafe.SliceData(sessionKey))), - C.size_t(len(sessionKey)), - ) - runtime.KeepAlive(sessionKey) + (*C.uint8_t)(&sessionKey[0]), + C.size_t(len(sessionKey))) if r == errorVal() { return nil, s.lastError() } @@ -94,7 +104,7 @@ func inboundGroupSessionSize() uint { func NewBlankInboundGroupSession() *InboundGroupSession { memory := make([]byte, inboundGroupSessionSize()) return &InboundGroupSession{ - int: C.olm_inbound_group_session(unsafe.Pointer(unsafe.SliceData(memory))), + int: C.olm_inbound_group_session(unsafe.Pointer(&memory[0])), mem: memory, } } @@ -124,17 +134,15 @@ func (s *InboundGroupSession) pickleLen() uint { // InboundGroupSession using the supplied key. func (s *InboundGroupSession) Pickle(key []byte) ([]byte, error) { if len(key) == 0 { - return nil, olm.ErrNoKeyProvided + return nil, olm.NoKeyProvided } pickled := make([]byte, s.pickleLen()) r := C.olm_pickle_inbound_group_session( (*C.OlmInboundGroupSession)(s.int), - unsafe.Pointer(unsafe.SliceData(key)), + unsafe.Pointer(&key[0]), C.size_t(len(key)), - unsafe.Pointer(unsafe.SliceData(pickled)), - C.size_t(len(pickled)), - ) - runtime.KeepAlive(key) + unsafe.Pointer(&pickled[0]), + C.size_t(len(pickled))) if r == errorVal() { return nil, s.lastError() } @@ -143,18 +151,16 @@ func (s *InboundGroupSession) Pickle(key []byte) ([]byte, error) { func (s *InboundGroupSession) Unpickle(pickled, key []byte) error { if len(key) == 0 { - return olm.ErrNoKeyProvided + return olm.NoKeyProvided } else if len(pickled) == 0 { - return olm.ErrEmptyInput + return olm.EmptyInput } r := C.olm_unpickle_inbound_group_session( (*C.OlmInboundGroupSession)(s.int), - unsafe.Pointer(unsafe.SliceData(key)), + unsafe.Pointer(&key[0]), C.size_t(len(key)), - unsafe.Pointer(unsafe.SliceData(pickled)), - C.size_t(len(pickled)), - ) - runtime.KeepAlive(key) + unsafe.Pointer(&pickled[0]), + C.size_t(len(pickled))) if r == errorVal() { return s.lastError() } @@ -200,7 +206,7 @@ func (s *InboundGroupSession) MarshalJSON() ([]byte, error) { // Deprecated func (s *InboundGroupSession) UnmarshalJSON(data []byte) error { if len(data) == 0 || data[0] != '"' || data[len(data)-1] != '"' { - return olm.ErrInputNotJSONString + return olm.InputNotJSONString } if s == nil || s.int == nil { *s = *NewBlankInboundGroupSession() @@ -217,16 +223,14 @@ func (s *InboundGroupSession) UnmarshalJSON(data []byte) error { // will be "BAD_MESSAGE_FORMAT". func (s *InboundGroupSession) decryptMaxPlaintextLen(message []byte) (uint, error) { if len(message) == 0 { - return 0, olm.ErrEmptyInput + return 0, olm.EmptyInput } // olm_group_decrypt_max_plaintext_length destroys the input, so we have to clone it - messageCopy := bytes.Clone(message) + message = bytes.Clone(message) r := C.olm_group_decrypt_max_plaintext_length( (*C.OlmInboundGroupSession)(s.int), - (*C.uint8_t)(unsafe.Pointer(unsafe.SliceData(messageCopy))), - C.size_t(len(messageCopy)), - ) - runtime.KeepAlive(messageCopy) + (*C.uint8_t)(&message[0]), + C.size_t(len(message))) if r == errorVal() { return 0, s.lastError() } @@ -244,24 +248,23 @@ func (s *InboundGroupSession) decryptMaxPlaintextLen(message []byte) (uint, erro // was shared with us) the error will be "OLM_UNKNOWN_MESSAGE_INDEX". func (s *InboundGroupSession) Decrypt(message []byte) ([]byte, uint, error) { if len(message) == 0 { - return nil, 0, olm.ErrEmptyInput + return nil, 0, olm.EmptyInput } decryptMaxPlaintextLen, err := s.decryptMaxPlaintextLen(message) if err != nil { return nil, 0, err } - messageCopy := bytes.Clone(message) + messageCopy := make([]byte, len(message)) + copy(messageCopy, message) plaintext := make([]byte, decryptMaxPlaintextLen) var messageIndex uint32 r := C.olm_group_decrypt( (*C.OlmInboundGroupSession)(s.int), - (*C.uint8_t)(unsafe.Pointer(unsafe.SliceData(messageCopy))), + (*C.uint8_t)(&messageCopy[0]), C.size_t(len(messageCopy)), - (*C.uint8_t)(unsafe.Pointer(unsafe.SliceData(plaintext))), + (*C.uint8_t)(&plaintext[0]), C.size_t(len(plaintext)), - (*C.uint32_t)(unsafe.Pointer(&messageIndex)), - ) - runtime.KeepAlive(messageCopy) + (*C.uint32_t)(&messageIndex)) if r == errorVal() { return nil, 0, s.lastError() } @@ -278,9 +281,8 @@ func (s *InboundGroupSession) ID() id.SessionID { sessionID := make([]byte, s.sessionIdLen()) r := C.olm_inbound_group_session_id( (*C.OlmInboundGroupSession)(s.int), - (*C.uint8_t)(unsafe.Pointer(unsafe.SliceData(sessionID))), - C.size_t(len(sessionID)), - ) + (*C.uint8_t)(&sessionID[0]), + C.size_t(len(sessionID))) if r == errorVal() { panic(s.lastError()) } @@ -316,10 +318,9 @@ func (s *InboundGroupSession) Export(messageIndex uint32) ([]byte, error) { key := make([]byte, s.exportLen()) r := C.olm_export_inbound_group_session( (*C.OlmInboundGroupSession)(s.int), - (*C.uint8_t)(unsafe.Pointer(unsafe.SliceData(key))), + (*C.uint8_t)(&key[0]), C.size_t(len(key)), - C.uint32_t(messageIndex), - ) + C.uint32_t(messageIndex)) if r == errorVal() { return nil, s.lastError() } diff --git a/crypto/libolm/outboundgroupsession.go b/crypto/libolm/outboundgroupsession.go index ca5b68f7..a21f8d4a 100644 --- a/crypto/libolm/outboundgroupsession.go +++ b/crypto/libolm/outboundgroupsession.go @@ -7,7 +7,6 @@ import "C" import ( "crypto/rand" "encoding/base64" - "runtime" "unsafe" "maunium.net/go/mautrix/crypto/olm" @@ -21,6 +20,18 @@ type OutboundGroupSession struct { mem []byte } +func init() { + olm.InitNewOutboundGroupSessionFromPickled = func(pickled, key []byte) (olm.OutboundGroupSession, error) { + if len(pickled) == 0 { + return nil, olm.EmptyInput + } + s := NewBlankOutboundGroupSession() + return s, s.Unpickle(pickled, key) + } + olm.InitNewOutboundGroupSession = func() (olm.OutboundGroupSession, error) { return NewOutboundGroupSession() } + olm.InitNewBlankOutboundGroupSession = func() olm.OutboundGroupSession { return NewBlankOutboundGroupSession() } +} + // Ensure that [OutboundGroupSession] implements [olm.OutboundGroupSession]. var _ olm.OutboundGroupSession = (*OutboundGroupSession)(nil) @@ -33,10 +44,8 @@ func NewOutboundGroupSession() (*OutboundGroupSession, error) { } r := C.olm_init_outbound_group_session( (*C.OlmOutboundGroupSession)(s.int), - (*C.uint8_t)(unsafe.Pointer(unsafe.SliceData(random))), - C.size_t(len(random)), - ) - runtime.KeepAlive(random) + (*C.uint8_t)(&random[0]), + C.size_t(len(random))) if r == errorVal() { return nil, s.lastError() } @@ -53,7 +62,7 @@ func outboundGroupSessionSize() uint { func NewBlankOutboundGroupSession() *OutboundGroupSession { memory := make([]byte, outboundGroupSessionSize()) return &OutboundGroupSession{ - int: C.olm_outbound_group_session(unsafe.Pointer(unsafe.SliceData(memory))), + int: C.olm_outbound_group_session(unsafe.Pointer(&memory[0])), mem: memory, } } @@ -84,17 +93,15 @@ func (s *OutboundGroupSession) pickleLen() uint { // OutboundGroupSession using the supplied key. func (s *OutboundGroupSession) Pickle(key []byte) ([]byte, error) { if len(key) == 0 { - return nil, olm.ErrNoKeyProvided + return nil, olm.NoKeyProvided } pickled := make([]byte, s.pickleLen()) r := C.olm_pickle_outbound_group_session( (*C.OlmOutboundGroupSession)(s.int), - unsafe.Pointer(unsafe.SliceData(key)), + unsafe.Pointer(&key[0]), C.size_t(len(key)), - unsafe.Pointer(unsafe.SliceData(pickled)), - C.size_t(len(pickled)), - ) - runtime.KeepAlive(key) + unsafe.Pointer(&pickled[0]), + C.size_t(len(pickled))) if r == errorVal() { return nil, s.lastError() } @@ -103,17 +110,14 @@ func (s *OutboundGroupSession) Pickle(key []byte) ([]byte, error) { func (s *OutboundGroupSession) Unpickle(pickled, key []byte) error { if len(key) == 0 { - return olm.ErrNoKeyProvided + return olm.NoKeyProvided } r := C.olm_unpickle_outbound_group_session( (*C.OlmOutboundGroupSession)(s.int), - unsafe.Pointer(unsafe.SliceData(key)), + unsafe.Pointer(&key[0]), C.size_t(len(key)), - unsafe.Pointer(unsafe.SliceData(pickled)), - C.size_t(len(pickled)), - ) - runtime.KeepAlive(pickled) - runtime.KeepAlive(key) + unsafe.Pointer(&pickled[0]), + C.size_t(len(pickled))) if r == errorVal() { return s.lastError() } @@ -159,7 +163,7 @@ func (s *OutboundGroupSession) MarshalJSON() ([]byte, error) { // Deprecated func (s *OutboundGroupSession) UnmarshalJSON(data []byte) error { if len(data) == 0 || data[0] != '"' || data[len(data)-1] != '"' { - return olm.ErrInputNotJSONString + return olm.InputNotJSONString } if s == nil || s.int == nil { *s = *NewBlankOutboundGroupSession() @@ -183,17 +187,15 @@ func (s *OutboundGroupSession) encryptMsgLen(plainTextLen int) uint { // as base64. func (s *OutboundGroupSession) Encrypt(plaintext []byte) ([]byte, error) { if len(plaintext) == 0 { - return nil, olm.ErrEmptyInput + return nil, olm.EmptyInput } message := make([]byte, s.encryptMsgLen(len(plaintext))) r := C.olm_group_encrypt( (*C.OlmOutboundGroupSession)(s.int), - (*C.uint8_t)(unsafe.Pointer(unsafe.SliceData(plaintext))), + (*C.uint8_t)(&plaintext[0]), C.size_t(len(plaintext)), - (*C.uint8_t)(unsafe.Pointer(unsafe.SliceData(message))), - C.size_t(len(message)), - ) - runtime.KeepAlive(plaintext) + (*C.uint8_t)(&message[0]), + C.size_t(len(message))) if r == errorVal() { return nil, s.lastError() } @@ -210,9 +212,8 @@ func (s *OutboundGroupSession) ID() id.SessionID { sessionID := make([]byte, s.sessionIdLen()) r := C.olm_outbound_group_session_id( (*C.OlmOutboundGroupSession)(s.int), - (*C.uint8_t)(unsafe.Pointer(unsafe.SliceData(sessionID))), - C.size_t(len(sessionID)), - ) + (*C.uint8_t)(&sessionID[0]), + C.size_t(len(sessionID))) if r == errorVal() { panic(s.lastError()) } @@ -235,9 +236,8 @@ func (s *OutboundGroupSession) Key() string { sessionKey := make([]byte, s.sessionKeyLen()) r := C.olm_outbound_group_session_key( (*C.OlmOutboundGroupSession)(s.int), - (*C.uint8_t)(unsafe.Pointer(unsafe.SliceData(sessionKey))), - C.size_t(len(sessionKey)), - ) + (*C.uint8_t)(&sessionKey[0]), + C.size_t(len(sessionKey))) if r == errorVal() { panic(s.lastError()) } diff --git a/crypto/libolm/pk.go b/crypto/libolm/pk.go index 2683cf15..db8d35c5 100644 --- a/crypto/libolm/pk.go +++ b/crypto/libolm/pk.go @@ -14,7 +14,6 @@ import "C" import ( "crypto/rand" "encoding/json" - "runtime" "unsafe" "github.com/tidwall/sjson" @@ -35,6 +34,16 @@ type PKSigning struct { // Ensure that [PKSigning] implements [olm.PKSigning]. var _ olm.PKSigning = (*PKSigning)(nil) +func init() { + olm.InitNewPKSigning = func() (olm.PKSigning, error) { return NewPKSigning() } + olm.InitNewPKSigningFromSeed = func(seed []byte) (olm.PKSigning, error) { + return NewPKSigningFromSeed(seed) + } + olm.InitNewPKDecryptionFromPrivateKey = func(privateKey []byte) (olm.PKDecryption, error) { + return NewPkDecryption(privateKey) + } +} + func pkSigningSize() uint { return uint(C.olm_pk_signing_size()) } @@ -54,7 +63,7 @@ func pkSigningSignatureLength() uint { func newBlankPKSigning() *PKSigning { memory := make([]byte, pkSigningSize()) return &PKSigning{ - int: C.olm_pk_signing(unsafe.Pointer(unsafe.SliceData(memory))), + int: C.olm_pk_signing(unsafe.Pointer(&memory[0])), mem: memory, } } @@ -64,14 +73,9 @@ func NewPKSigningFromSeed(seed []byte) (*PKSigning, error) { p := newBlankPKSigning() p.clear() pubKey := make([]byte, pkSigningPublicKeyLength()) - r := C.olm_pk_signing_key_from_seed( - (*C.OlmPkSigning)(p.int), - unsafe.Pointer(unsafe.SliceData(pubKey)), - C.size_t(len(pubKey)), - unsafe.Pointer(unsafe.SliceData(seed)), - C.size_t(len(seed)), - ) - if r == errorVal() { + if C.olm_pk_signing_key_from_seed((*C.OlmPkSigning)(p.int), + unsafe.Pointer(&pubKey[0]), C.size_t(len(pubKey)), + unsafe.Pointer(&seed[0]), C.size_t(len(seed))) == errorVal() { return nil, p.lastError() } p.publicKey = id.Ed25519(pubKey) @@ -86,7 +90,7 @@ func NewPKSigning() (*PKSigning, error) { seed := make([]byte, pkSigningSeedLength()) _, err := rand.Read(seed) if err != nil { - panic(olm.ErrNotEnoughGoRandom) + panic(olm.NotEnoughGoRandom) } pk, err := NewPKSigningFromSeed(seed) return pk, err @@ -108,15 +112,8 @@ func (p *PKSigning) clear() { // Sign creates a signature for the given message using this key. func (p *PKSigning) Sign(message []byte) ([]byte, error) { signature := make([]byte, pkSigningSignatureLength()) - r := C.olm_pk_sign( - (*C.OlmPkSigning)(p.int), - (*C.uint8_t)(unsafe.Pointer(unsafe.SliceData(message))), - C.size_t(len(message)), - (*C.uint8_t)(unsafe.Pointer(unsafe.SliceData(signature))), - C.size_t(len(signature)), - ) - runtime.KeepAlive(message) - if r == errorVal() { + if C.olm_pk_sign((*C.OlmPkSigning)(p.int), (*C.uint8_t)(unsafe.Pointer(&message[0])), C.size_t(len(message)), + (*C.uint8_t)(unsafe.Pointer(&signature[0])), C.size_t(len(signature))) == errorVal() { return nil, p.lastError() } return signature, nil @@ -160,21 +157,15 @@ func pkDecryptionPublicKeySize() uint { func NewPkDecryption(privateKey []byte) (*PKDecryption, error) { memory := make([]byte, pkDecryptionSize()) p := &PKDecryption{ - int: C.olm_pk_decryption(unsafe.Pointer(unsafe.SliceData(memory))), + int: C.olm_pk_decryption(unsafe.Pointer(&memory[0])), mem: memory, } p.clear() pubKey := make([]byte, pkDecryptionPublicKeySize()) - r := C.olm_pk_key_from_private( - (*C.OlmPkDecryption)(p.int), - unsafe.Pointer(unsafe.SliceData(pubKey)), - C.size_t(len(pubKey)), - unsafe.Pointer(unsafe.SliceData(privateKey)), - C.size_t(len(privateKey)), - ) - runtime.KeepAlive(privateKey) - if r == errorVal() { + if C.olm_pk_key_from_private((*C.OlmPkDecryption)(p.int), + unsafe.Pointer(&pubKey[0]), C.size_t(len(pubKey)), + unsafe.Pointer(&privateKey[0]), C.size_t(len(privateKey))) == errorVal() { return nil, p.lastError() } p.publicKey = pubKey @@ -187,26 +178,14 @@ func (p *PKDecryption) PublicKey() id.Curve25519 { } func (p *PKDecryption) Decrypt(ephemeralKey []byte, mac []byte, ciphertext []byte) ([]byte, error) { - maxPlaintextLength := uint(C.olm_pk_max_plaintext_length( - (*C.OlmPkDecryption)(p.int), - C.size_t(len(ciphertext)), - )) + maxPlaintextLength := uint(C.olm_pk_max_plaintext_length((*C.OlmPkDecryption)(p.int), C.size_t(len(ciphertext)))) plaintext := make([]byte, maxPlaintextLength) - size := C.olm_pk_decrypt( - (*C.OlmPkDecryption)(p.int), - unsafe.Pointer(unsafe.SliceData(ephemeralKey)), - C.size_t(len(ephemeralKey)), - unsafe.Pointer(unsafe.SliceData(mac)), - C.size_t(len(mac)), - unsafe.Pointer(unsafe.SliceData(ciphertext)), - C.size_t(len(ciphertext)), - unsafe.Pointer(unsafe.SliceData(plaintext)), - C.size_t(len(plaintext)), - ) - runtime.KeepAlive(ephemeralKey) - runtime.KeepAlive(mac) - runtime.KeepAlive(ciphertext) + size := C.olm_pk_decrypt((*C.OlmPkDecryption)(p.int), + unsafe.Pointer(&ephemeralKey[0]), C.size_t(len(ephemeralKey)), + unsafe.Pointer(&mac[0]), C.size_t(len(mac)), + unsafe.Pointer(&ciphertext[0]), C.size_t(len(ciphertext)), + unsafe.Pointer(&plaintext[0]), C.size_t(len(plaintext))) if size == errorVal() { return nil, p.lastError() } diff --git a/crypto/libolm/register.go b/crypto/libolm/register.go index ddf84613..a423a7d0 100644 --- a/crypto/libolm/register.go +++ b/crypto/libolm/register.go @@ -3,73 +3,19 @@ package libolm // #cgo LDFLAGS: -lolm -lstdc++ // #include import "C" -import ( - "unsafe" - - "maunium.net/go/mautrix/crypto/olm" -) +import "maunium.net/go/mautrix/crypto/olm" var pickleKey = []byte("maunium.net/go/mautrix/crypto/olm") -func Register() { - olm.Driver = "libolm" - +func init() { olm.GetVersion = func() (major, minor, patch uint8) { C.olm_get_library_version( - (*C.uint8_t)(unsafe.Pointer(&major)), - (*C.uint8_t)(unsafe.Pointer(&minor)), - (*C.uint8_t)(unsafe.Pointer(&patch))) + (*C.uint8_t)(&major), + (*C.uint8_t)(&minor), + (*C.uint8_t)(&patch)) return 3, 2, 15 } olm.SetPickleKeyImpl = func(key []byte) { pickleKey = key } - - olm.InitNewAccount = func() (olm.Account, error) { - return NewAccount() - } - olm.InitBlankAccount = func() olm.Account { - return NewBlankAccount() - } - olm.InitNewAccountFromPickled = func(pickled, key []byte) (olm.Account, error) { - return AccountFromPickled(pickled, key) - } - - olm.InitSessionFromPickled = func(pickled, key []byte) (olm.Session, error) { - return SessionFromPickled(pickled, key) - } - olm.InitNewBlankSession = func() olm.Session { - return NewBlankSession() - } - - olm.InitNewPKSigning = func() (olm.PKSigning, error) { return NewPKSigning() } - olm.InitNewPKSigningFromSeed = func(seed []byte) (olm.PKSigning, error) { - return NewPKSigningFromSeed(seed) - } - olm.InitNewPKDecryptionFromPrivateKey = func(privateKey []byte) (olm.PKDecryption, error) { - return NewPkDecryption(privateKey) - } - - olm.InitInboundGroupSessionFromPickled = func(pickled, key []byte) (olm.InboundGroupSession, error) { - return InboundGroupSessionFromPickled(pickled, key) - } - olm.InitNewInboundGroupSession = func(sessionKey []byte) (olm.InboundGroupSession, error) { - return NewInboundGroupSession(sessionKey) - } - olm.InitInboundGroupSessionImport = func(sessionKey []byte) (olm.InboundGroupSession, error) { - return InboundGroupSessionImport(sessionKey) - } - olm.InitBlankInboundGroupSession = func() olm.InboundGroupSession { - return NewBlankInboundGroupSession() - } - - olm.InitNewOutboundGroupSessionFromPickled = func(pickled, key []byte) (olm.OutboundGroupSession, error) { - if len(pickled) == 0 { - return nil, olm.ErrEmptyInput - } - s := NewBlankOutboundGroupSession() - return s, s.Unpickle(pickled, key) - } - olm.InitNewOutboundGroupSession = func() (olm.OutboundGroupSession, error) { return NewOutboundGroupSession() } - olm.InitNewBlankOutboundGroupSession = func() olm.OutboundGroupSession { return NewBlankOutboundGroupSession() } } diff --git a/crypto/libolm/session.go b/crypto/libolm/session.go index 1441df26..4cc22809 100644 --- a/crypto/libolm/session.go +++ b/crypto/libolm/session.go @@ -23,7 +23,6 @@ import "C" import ( "crypto/rand" "encoding/base64" - "runtime" "unsafe" "maunium.net/go/mautrix/crypto/olm" @@ -39,6 +38,15 @@ type Session struct { // Ensure that [Session] implements [olm.Session]. var _ olm.Session = (*Session)(nil) +func init() { + olm.InitSessionFromPickled = func(pickled, key []byte) (olm.Session, error) { + return SessionFromPickled(pickled, key) + } + olm.InitNewBlankSession = func() olm.Session { + return NewBlankSession() + } +} + // sessionSize is the size of a session object in bytes. func sessionSize() uint { return uint(C.olm_session_size()) @@ -51,7 +59,7 @@ func sessionSize() uint { // "INVALID_BASE64". func SessionFromPickled(pickled, key []byte) (*Session, error) { if len(pickled) == 0 { - return nil, olm.ErrEmptyInput + return nil, olm.EmptyInput } s := NewBlankSession() return s, s.Unpickle(pickled, key) @@ -60,7 +68,7 @@ func SessionFromPickled(pickled, key []byte) (*Session, error) { func NewBlankSession() *Session { memory := make([]byte, sessionSize()) return &Session{ - int: C.olm_session(unsafe.Pointer(unsafe.SliceData(memory))), + int: C.olm_session(unsafe.Pointer(&memory[0])), mem: memory, } } @@ -118,16 +126,13 @@ func (s *Session) encryptMsgLen(plainTextLen int) uint { // will be "BAD_MESSAGE_FORMAT". func (s *Session) decryptMaxPlaintextLen(message string, msgType id.OlmMsgType) (uint, error) { if len(message) == 0 { - return 0, olm.ErrEmptyInput + return 0, olm.EmptyInput } - messageCopy := []byte(message) r := C.olm_decrypt_max_plaintext_length( (*C.OlmSession)(s.int), C.size_t(msgType), - unsafe.Pointer(unsafe.SliceData((messageCopy))), - C.size_t(len(messageCopy)), - ) - runtime.KeepAlive(messageCopy) + unsafe.Pointer(C.CString(message)), + C.size_t(len(message))) if r == errorVal() { return 0, s.lastError() } @@ -138,16 +143,15 @@ func (s *Session) decryptMaxPlaintextLen(message string, msgType id.OlmMsgType) // supplied key. func (s *Session) Pickle(key []byte) ([]byte, error) { if len(key) == 0 { - return nil, olm.ErrNoKeyProvided + return nil, olm.NoKeyProvided } pickled := make([]byte, s.pickleLen()) r := C.olm_pickle_session( (*C.OlmSession)(s.int), - unsafe.Pointer(unsafe.SliceData(key)), + unsafe.Pointer(&key[0]), C.size_t(len(key)), - unsafe.Pointer(unsafe.SliceData(pickled)), + unsafe.Pointer(&pickled[0]), C.size_t(len(pickled))) - runtime.KeepAlive(key) if r == errorVal() { panic(s.lastError()) } @@ -158,16 +162,14 @@ func (s *Session) Pickle(key []byte) ([]byte, error) { // provided key. This function mutates the input pickled data slice. func (s *Session) Unpickle(pickled, key []byte) error { if len(key) == 0 { - return olm.ErrNoKeyProvided + return olm.NoKeyProvided } r := C.olm_unpickle_session( (*C.OlmSession)(s.int), - unsafe.Pointer(unsafe.SliceData(key)), + unsafe.Pointer(&key[0]), C.size_t(len(key)), - unsafe.Pointer(unsafe.SliceData(pickled)), + unsafe.Pointer(&pickled[0]), C.size_t(len(pickled))) - runtime.KeepAlive(pickled) - runtime.KeepAlive(key) if r == errorVal() { return s.lastError() } @@ -213,7 +215,7 @@ func (s *Session) MarshalJSON() ([]byte, error) { // Deprecated func (s *Session) UnmarshalJSON(data []byte) error { if len(data) == 0 || data[0] != '"' || data[len(data)-1] != '"' { - return olm.ErrInputNotJSONString + return olm.InputNotJSONString } if s == nil || s.int == nil { *s = *NewBlankSession() @@ -227,9 +229,8 @@ func (s *Session) ID() id.SessionID { sessionID := make([]byte, s.idLen()) r := C.olm_session_id( (*C.OlmSession)(s.int), - unsafe.Pointer(unsafe.SliceData(sessionID)), - C.size_t(len(sessionID)), - ) + unsafe.Pointer(&sessionID[0]), + C.size_t(len(sessionID))) if r == errorVal() { panic(s.lastError()) } @@ -256,15 +257,12 @@ func (s *Session) HasReceivedMessage() bool { // decoded then then the error will be "BAD_MESSAGE_FORMAT". func (s *Session) MatchesInboundSession(oneTimeKeyMsg string) (bool, error) { if len(oneTimeKeyMsg) == 0 { - return false, olm.ErrEmptyInput + return false, olm.EmptyInput } - oneTimeKeyMsgCopy := []byte(oneTimeKeyMsg) r := C.olm_matches_inbound_session( (*C.OlmSession)(s.int), - unsafe.Pointer(unsafe.SliceData(oneTimeKeyMsgCopy)), - C.size_t(len(oneTimeKeyMsgCopy)), - ) - runtime.KeepAlive(oneTimeKeyMsgCopy) + unsafe.Pointer(&([]byte(oneTimeKeyMsg))[0]), + C.size_t(len(oneTimeKeyMsg))) if r == 1 { return true, nil } else if r == 0 { @@ -284,19 +282,14 @@ func (s *Session) MatchesInboundSession(oneTimeKeyMsg string) (bool, error) { // decoded then then the error will be "BAD_MESSAGE_FORMAT". func (s *Session) MatchesInboundSessionFrom(theirIdentityKey, oneTimeKeyMsg string) (bool, error) { if len(theirIdentityKey) == 0 || len(oneTimeKeyMsg) == 0 { - return false, olm.ErrEmptyInput + return false, olm.EmptyInput } - theirIdentityKeyCopy := []byte(theirIdentityKey) - oneTimeKeyMsgCopy := []byte(oneTimeKeyMsg) r := C.olm_matches_inbound_session_from( (*C.OlmSession)(s.int), - unsafe.Pointer(unsafe.SliceData(theirIdentityKeyCopy)), - C.size_t(len(theirIdentityKeyCopy)), - unsafe.Pointer(unsafe.SliceData(oneTimeKeyMsgCopy)), - C.size_t(len(oneTimeKeyMsgCopy)), - ) - runtime.KeepAlive(theirIdentityKeyCopy) - runtime.KeepAlive(oneTimeKeyMsgCopy) + unsafe.Pointer(&([]byte(theirIdentityKey))[0]), + C.size_t(len(theirIdentityKey)), + unsafe.Pointer(&([]byte(oneTimeKeyMsg))[0]), + C.size_t(len(oneTimeKeyMsg))) if r == 1 { return true, nil } else if r == 0 { @@ -325,28 +318,25 @@ func (s *Session) EncryptMsgType() id.OlmMsgType { // as base64. func (s *Session) Encrypt(plaintext []byte) (id.OlmMsgType, []byte, error) { if len(plaintext) == 0 { - return 0, nil, olm.ErrEmptyInput + return 0, nil, olm.EmptyInput } // Make the slice be at least length 1 random := make([]byte, s.encryptRandomLen()+1) _, err := rand.Read(random) if err != nil { // TODO can we just return err here? - return 0, nil, olm.ErrNotEnoughGoRandom + return 0, nil, olm.NotEnoughGoRandom } messageType := s.EncryptMsgType() message := make([]byte, s.encryptMsgLen(len(plaintext))) r := C.olm_encrypt( (*C.OlmSession)(s.int), - unsafe.Pointer(unsafe.SliceData(plaintext)), + unsafe.Pointer(&plaintext[0]), C.size_t(len(plaintext)), - unsafe.Pointer(unsafe.SliceData(random)), + unsafe.Pointer(&random[0]), C.size_t(len(random)), - unsafe.Pointer(unsafe.SliceData(message)), - C.size_t(len(message)), - ) - runtime.KeepAlive(plaintext) - runtime.KeepAlive(random) + unsafe.Pointer(&message[0]), + C.size_t(len(message))) if r == errorVal() { return 0, nil, s.lastError() } @@ -362,7 +352,7 @@ func (s *Session) Encrypt(plaintext []byte) (id.OlmMsgType, []byte, error) { // "BAD_MESSAGE_MAC". func (s *Session) Decrypt(message string, msgType id.OlmMsgType) ([]byte, error) { if len(message) == 0 { - return nil, olm.ErrEmptyInput + return nil, olm.EmptyInput } decryptMaxPlaintextLen, err := s.decryptMaxPlaintextLen(message, msgType) if err != nil { @@ -373,12 +363,10 @@ func (s *Session) Decrypt(message string, msgType id.OlmMsgType) ([]byte, error) r := C.olm_decrypt( (*C.OlmSession)(s.int), C.size_t(msgType), - unsafe.Pointer(unsafe.SliceData(messageCopy)), + unsafe.Pointer(&(messageCopy)[0]), C.size_t(len(messageCopy)), - unsafe.Pointer(unsafe.SliceData(plaintext)), - C.size_t(len(plaintext)), - ) - runtime.KeepAlive(messageCopy) + unsafe.Pointer(&plaintext[0]), + C.size_t(len(plaintext))) if r == errorVal() { return nil, s.lastError() } @@ -395,7 +383,6 @@ func (s *Session) Describe() string { C.meowlm_session_describe( (*C.OlmSession)(s.int), desc, - C.size_t(maxDescribeSize), - ) + C.size_t(maxDescribeSize)) return C.GoString(desc) } diff --git a/crypto/machine.go b/crypto/machine.go index fa051f94..cac91bf8 100644 --- a/crypto/machine.go +++ b/crypto/machine.go @@ -15,12 +15,10 @@ import ( "time" "github.com/rs/zerolog" - "go.mau.fi/util/ptr" "go.mau.fi/util/exzerolog" "maunium.net/go/mautrix" - "maunium.net/go/mautrix/crypto/olm" "maunium.net/go/mautrix/crypto/ssss" "maunium.net/go/mautrix/event" "maunium.net/go/mautrix/id" @@ -35,11 +33,9 @@ type OlmMachine struct { CryptoStore Store StateStore StateStore - backgroundCtx context.Context - cancelBackgroundCtx context.CancelFunc + BackgroundCtx context.Context PlaintextMentions bool - MSC4392Relations bool AllowEncryptedState bool // Never ask the server for keys automatically as a side effect during Megolm decryption. @@ -124,6 +120,8 @@ func NewOlmMachine(client *mautrix.Client, log *zerolog.Logger, cryptoStore Stor CryptoStore: cryptoStore, StateStore: stateStore, + BackgroundCtx: context.Background(), + SendKeysMinTrust: id.TrustStateUnset, ShareKeysMinTrust: id.TrustStateCrossSignedTOFU, @@ -136,7 +134,6 @@ func NewOlmMachine(client *mautrix.Client, log *zerolog.Logger, cryptoStore Stor recentlyUnwedged: make(map[id.IdentityKey]time.Time), secretListeners: make(map[string]chan<- string), } - mach.backgroundCtx, mach.cancelBackgroundCtx = context.WithCancel(context.Background()) mach.AllowKeyShare = mach.defaultAllowKeyShare return mach } @@ -149,11 +146,6 @@ func (mach *OlmMachine) machOrContextLog(ctx context.Context) *zerolog.Logger { return log } -func (mach *OlmMachine) SetBackgroundCtx(ctx context.Context) { - mach.cancelBackgroundCtx() - mach.backgroundCtx, mach.cancelBackgroundCtx = context.WithCancel(ctx) -} - // Load loads the Olm account information from the crypto store. If there's no olm account, a new one is created. // This must be called before using the machine. func (mach *OlmMachine) Load(ctx context.Context) (err error) { @@ -164,23 +156,9 @@ func (mach *OlmMachine) Load(ctx context.Context) (err error) { if mach.account == nil { mach.account = NewOlmAccount() } - zerolog.Ctx(ctx).Debug(). - Str("machine_ptr", fmt.Sprintf("%p", mach)). - Str("account_ptr", fmt.Sprintf("%p", mach.account.Internal)). - Str("olm_driver", olm.Driver). - Msg("Loaded olm account") return nil } -func (mach *OlmMachine) Destroy() { - mach.Log.Debug(). - Str("machine_ptr", fmt.Sprintf("%p", mach)). - Str("account_ptr", fmt.Sprintf("%p", ptr.Val(mach.account).Internal)). - Msg("Destroying olm machine") - mach.cancelBackgroundCtx() - // TODO actually destroy something? -} - func (mach *OlmMachine) saveAccount(ctx context.Context) error { err := mach.CryptoStore.PutAccount(ctx, mach.account) if err != nil { @@ -206,7 +184,7 @@ func (mach *OlmMachine) FlushStore(ctx context.Context) error { func (mach *OlmMachine) timeTrace(ctx context.Context, thing string, expectedDuration time.Duration) func() { start := time.Now() return func() { - duration := time.Since(start) + duration := time.Now().Sub(start) if duration > expectedDuration { zerolog.Ctx(ctx).Warn(). Str("action", thing). @@ -383,7 +361,7 @@ func (mach *OlmMachine) HandleMemberEvent(ctx context.Context, evt *event.Event) Msg("Got membership state change, invalidating group session in room") err := mach.CryptoStore.RemoveOutboundGroupSession(ctx, evt.RoomID) if err != nil { - mach.Log.Warn().Stringer("room_id", evt.RoomID).Msg("Failed to invalidate outbound group session") + mach.Log.Warn().Str("room_id", evt.RoomID.String()).Msg("Failed to invalidate outbound group session") } } @@ -603,7 +581,7 @@ func (mach *OlmMachine) createGroupSession(ctx context.Context, senderKey id.Sen } err = mach.CryptoStore.PutGroupSession(ctx, igs) if err != nil { - log.Err(err).Stringer("session_id", sessionID).Msg("Failed to store new inbound group session") + log.Err(err).Str("session_id", sessionID.String()).Msg("Failed to store new inbound group session") return fmt.Errorf("failed to store new inbound group session: %w", err) } mach.MarkSessionReceived(ctx, roomID, sessionID, igs.Internal.FirstKnownIndex()) @@ -730,7 +708,7 @@ func (mach *OlmMachine) ShareKeys(ctx context.Context, currentOTKCount int) erro start := time.Now() mach.otkUploadLock.Lock() defer mach.otkUploadLock.Unlock() - if mach.lastOTKUpload.Add(1*time.Minute).After(start) || (currentOTKCount < 0 && mach.account.Shared) { + if mach.lastOTKUpload.Add(1*time.Minute).After(start) || currentOTKCount < 0 { log.Debug().Msg("Checking OTK count from server due to suspiciously close share keys requests or negative OTK count") resp, err := mach.Client.UploadKeys(ctx, &mautrix.ReqUploadKeys{}) if err != nil { diff --git a/crypto/machine_bench_test.go b/crypto/machine_bench_test.go deleted file mode 100644 index fd40d795..00000000 --- a/crypto/machine_bench_test.go +++ /dev/null @@ -1,67 +0,0 @@ -// Copyright (c) 2025 Tulir Asokan -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. - -package crypto_test - -import ( - "context" - "fmt" - "math/rand/v2" - "testing" - - "github.com/rs/zerolog" - globallog "github.com/rs/zerolog/log" // zerolog-allow-global-log - "github.com/stretchr/testify/require" - - "maunium.net/go/mautrix/crypto/cryptohelper" - "maunium.net/go/mautrix/id" - "maunium.net/go/mautrix/mockserver" -) - -func randomDeviceCount(r *rand.Rand) int { - k := 1 - for k < 10 && r.IntN(3) > 0 { - k++ - } - return k -} - -func BenchmarkOlmMachine_ShareGroupSession(b *testing.B) { - globallog.Logger = zerolog.Nop() - server := mockserver.Create(b) - server.PopOTKs = false - server.MemoryStore = false - var i int - var shareTargets []id.UserID - r := rand.New(rand.NewPCG(293, 0)) - var totalDeviceCount int - for i = 1; i < 1000; i++ { - userID := id.UserID(fmt.Sprintf("@user%d:localhost", i)) - deviceCount := randomDeviceCount(r) - for j := 0; j < deviceCount; j++ { - client, _ := server.Login(b, nil, userID, id.DeviceID(fmt.Sprintf("u%d_d%d", i, j))) - mach := client.Crypto.(*cryptohelper.CryptoHelper).Machine() - keysCache, err := mach.GenerateCrossSigningKeys() - require.NoError(b, err) - err = mach.PublishCrossSigningKeys(context.TODO(), keysCache, nil) - require.NoError(b, err) - } - totalDeviceCount += deviceCount - shareTargets = append(shareTargets, userID) - } - for b.Loop() { - client, _ := server.Login(b, nil, id.UserID(fmt.Sprintf("@benchuser%d:localhost", i)), id.DeviceID(fmt.Sprintf("u%d_d1", i))) - mach := client.Crypto.(*cryptohelper.CryptoHelper).Machine() - keysCache, err := mach.GenerateCrossSigningKeys() - require.NoError(b, err) - err = mach.PublishCrossSigningKeys(context.TODO(), keysCache, nil) - require.NoError(b, err) - err = mach.ShareGroupSession(context.TODO(), "!room:localhost", shareTargets) - require.NoError(b, err) - i++ - } - fmt.Println(totalDeviceCount, "devices total") -} diff --git a/crypto/olm/account.go b/crypto/olm/account.go index 2ec5dd70..68393e8a 100644 --- a/crypto/olm/account.go +++ b/crypto/olm/account.go @@ -87,8 +87,6 @@ type Account interface { RemoveOneTimeKeys(s Session) error } -var Driver = "none" - var InitBlankAccount func() Account var InitNewAccount func() (Account, error) var InitNewAccountFromPickled func(pickled, key []byte) (Account, error) diff --git a/crypto/olm/errors.go b/crypto/olm/errors.go index 9e522b2a..957d7928 100644 --- a/crypto/olm/errors.go +++ b/crypto/olm/errors.go @@ -10,67 +10,50 @@ import "errors" // Those are the most common used errors var ( - ErrBadSignature = errors.New("bad signature") - ErrBadMAC = errors.New("the message couldn't be decrypted (bad mac)") - ErrBadMessageFormat = errors.New("the message couldn't be decoded") - ErrBadVerification = errors.New("bad verification") - ErrWrongProtocolVersion = errors.New("wrong protocol version") - ErrEmptyInput = errors.New("empty input") - ErrNoKeyProvided = errors.New("no key provided") - ErrBadMessageKeyID = errors.New("the message references an unknown key ID") - ErrUnknownMessageIndex = errors.New("attempt to decode a message whose index is earlier than our earliest known session key") - ErrMsgIndexTooHigh = errors.New("message index too high") - ErrProtocolViolation = errors.New("not protocol message order") - ErrMessageKeyNotFound = errors.New("message key not found") - ErrChainTooHigh = errors.New("chain index too high") - ErrBadInput = errors.New("bad input") - ErrUnknownOlmPickleVersion = errors.New("unknown olm pickle version") - ErrUnknownJSONPickleVersion = errors.New("unknown JSON pickle version") - ErrInputToSmall = errors.New("input too small (truncated?)") + ErrBadSignature = errors.New("bad signature") + ErrBadMAC = errors.New("bad mac") + ErrBadMessageFormat = errors.New("bad message format") + ErrBadVerification = errors.New("bad verification") + ErrWrongProtocolVersion = errors.New("wrong protocol version") + ErrEmptyInput = errors.New("empty input") + ErrNoKeyProvided = errors.New("no key") + ErrBadMessageKeyID = errors.New("bad message key id") + ErrRatchetNotAvailable = errors.New("ratchet not available: attempt to decode a message whose index is earlier than our earliest known session key") + ErrMsgIndexTooHigh = errors.New("message index too high") + ErrProtocolViolation = errors.New("not protocol message order") + ErrMessageKeyNotFound = errors.New("message key not found") + ErrChainTooHigh = errors.New("chain index too high") + ErrBadInput = errors.New("bad input") + ErrBadVersion = errors.New("wrong version") + ErrWrongPickleVersion = errors.New("wrong pickle version") + ErrInputToSmall = errors.New("input too small (truncated?)") + ErrOverflow = errors.New("overflow") ) // Error codes from go-olm var ( - ErrNotEnoughGoRandom = errors.New("couldn't get enough randomness from crypto/rand") - ErrInputNotJSONString = errors.New("input doesn't look like a JSON string") + EmptyInput = errors.New("empty input") + NoKeyProvided = errors.New("no pickle key provided") + NotEnoughGoRandom = errors.New("couldn't get enough randomness from crypto/rand") + SignatureNotFound = errors.New("input JSON doesn't contain signature from specified device") + InputNotJSONString = errors.New("input doesn't look like a JSON string") ) // Error codes from olm code var ( - ErrLibolmInvalidBase64 = errors.New("the input base64 was invalid") - - ErrLibolmNotEnoughRandom = errors.New("not enough entropy was supplied") - ErrLibolmOutputBufferTooSmall = errors.New("supplied output buffer is too small") - ErrLibolmBadAccountKey = errors.New("the supplied account key is invalid") - ErrLibolmCorruptedPickle = errors.New("the pickled object couldn't be decoded") - ErrLibolmBadSessionKey = errors.New("attempt to initialise an inbound group session from an invalid session key") - ErrLibolmBadLegacyAccountPickle = errors.New("attempt to unpickle an account which uses pickle version 1") -) - -// Deprecated: use variables prefixed with Err -var ( - EmptyInput = ErrEmptyInput - BadSignature = ErrBadSignature - InvalidBase64 = ErrLibolmInvalidBase64 - BadMessageKeyID = ErrBadMessageKeyID - BadMessageFormat = ErrBadMessageFormat - BadMessageVersion = ErrWrongProtocolVersion - BadMessageMAC = ErrBadMAC - UnknownPickleVersion = ErrUnknownOlmPickleVersion - NotEnoughRandom = ErrLibolmNotEnoughRandom - OutputBufferTooSmall = ErrLibolmOutputBufferTooSmall - BadAccountKey = ErrLibolmBadAccountKey - CorruptedPickle = ErrLibolmCorruptedPickle - BadSessionKey = ErrLibolmBadSessionKey - UnknownMessageIndex = ErrUnknownMessageIndex - BadLegacyAccountPickle = ErrLibolmBadLegacyAccountPickle - InputBufferTooSmall = ErrInputToSmall - NoKeyProvided = ErrNoKeyProvided - - NotEnoughGoRandom = ErrNotEnoughGoRandom - InputNotJSONString = ErrInputNotJSONString - - ErrBadVersion = ErrUnknownJSONPickleVersion - ErrWrongPickleVersion = ErrUnknownJSONPickleVersion - ErrRatchetNotAvailable = ErrUnknownMessageIndex + NotEnoughRandom = errors.New("not enough entropy was supplied") + OutputBufferTooSmall = errors.New("supplied output buffer is too small") + BadMessageVersion = errors.New("the message version is unsupported") + BadMessageFormat = errors.New("the message couldn't be decoded") + BadMessageMAC = errors.New("the message couldn't be decrypted") + BadMessageKeyID = errors.New("the message references an unknown key ID") + InvalidBase64 = errors.New("the input base64 was invalid") + BadAccountKey = errors.New("the supplied account key is invalid") + UnknownPickleVersion = errors.New("the pickled object is too new") + CorruptedPickle = errors.New("the pickled object couldn't be decoded") + BadSessionKey = errors.New("attempt to initialise an inbound group session from an invalid session key") + UnknownMessageIndex = errors.New("attempt to decode a message whose index is earlier than our earliest known session key") + BadLegacyAccountPickle = errors.New("attempt to unpickle an account which uses pickle version 1") + BadSignature = errors.New("received message had a bad signature") + InputBufferTooSmall = errors.New("the input data was too small to be valid") ) diff --git a/crypto/registergoolm.go b/crypto/registergoolm.go index 6b5b65fd..f5cecafc 100644 --- a/crypto/registergoolm.go +++ b/crypto/registergoolm.go @@ -2,10 +2,4 @@ package crypto -import ( - "maunium.net/go/mautrix/crypto/goolm" -) - -func init() { - goolm.Register() -} +import _ "maunium.net/go/mautrix/crypto/goolm" diff --git a/crypto/registerlibolm.go b/crypto/registerlibolm.go index ef78b6b5..ab388a5c 100644 --- a/crypto/registerlibolm.go +++ b/crypto/registerlibolm.go @@ -2,8 +2,4 @@ package crypto -import "maunium.net/go/mautrix/crypto/libolm" - -func init() { - libolm.Register() -} +import _ "maunium.net/go/mautrix/crypto/libolm" diff --git a/crypto/sessions.go b/crypto/sessions.go index ccc7b784..aecb0416 100644 --- a/crypto/sessions.go +++ b/crypto/sessions.go @@ -18,14 +18,8 @@ import ( ) var ( - ErrSessionNotShared = errors.New("session has not been shared") - ErrSessionExpired = errors.New("session has expired") -) - -// Deprecated: use variables prefixed with Err -var ( - SessionNotShared = ErrSessionNotShared - SessionExpired = ErrSessionExpired + SessionNotShared = errors.New("session has not been shared") + SessionExpired = errors.New("session has expired") ) // OlmSessionList is a list of OlmSessions. @@ -117,7 +111,6 @@ type InboundGroupSession struct { MaxMessages int IsScheduled bool KeyBackupVersion id.KeyBackupVersion - KeySource id.KeySource id id.SessionID } @@ -137,7 +130,6 @@ func NewInboundGroupSession(senderKey id.SenderKey, signingKey id.Ed25519, roomI MaxAge: maxAge.Milliseconds(), MaxMessages: maxMessages, IsScheduled: isScheduled, - KeySource: id.KeySourceDirect, }, nil } @@ -171,7 +163,7 @@ func (igs *InboundGroupSession) export() (*ExportedSession, error) { ForwardingChains: igs.ForwardingChains, RoomID: igs.RoomID, SenderKey: igs.SenderKey, - SenderClaimedKeys: SenderClaimedKeys{Ed25519: igs.SigningKey}, + SenderClaimedKeys: SenderClaimedKeys{}, SessionID: igs.ID(), SessionKey: string(key), }, nil @@ -263,9 +255,9 @@ func (ogs *OutboundGroupSession) Expired() bool { func (ogs *OutboundGroupSession) Encrypt(plaintext []byte) ([]byte, error) { if !ogs.Shared { - return nil, ErrSessionNotShared + return nil, SessionNotShared } else if ogs.Expired() { - return nil, ErrSessionExpired + return nil, SessionExpired } ogs.MessageCount++ ogs.LastEncryptedTime = time.Now() diff --git a/crypto/sql_store.go b/crypto/sql_store.go index 138cc557..b0625763 100644 --- a/crypto/sql_store.go +++ b/crypto/sql_store.go @@ -251,9 +251,8 @@ func (store *SQLCryptoStore) GetLatestSession(ctx context.Context, key id.Sender } // GetNewestSessionCreationTS gets the creation timestamp of the most recently created session with the given sender key. -// This will exclude sessions that have never been used to encrypt or decrypt a message. func (store *SQLCryptoStore) GetNewestSessionCreationTS(ctx context.Context, key id.SenderKey) (createdAt time.Time, err error) { - err = store.DB.QueryRow(ctx, "SELECT created_at FROM crypto_olm_session WHERE sender_key=$1 AND account_id=$2 AND (last_encrypted <> created_at OR last_decrypted <> created_at) ORDER BY created_at DESC LIMIT 1", + err = store.DB.QueryRow(ctx, "SELECT created_at FROM crypto_olm_session WHERE sender_key=$1 AND account_id=$2 ORDER BY created_at DESC LIMIT 1", key, store.AccountID).Scan(&createdAt) if errors.Is(err, sql.ErrNoRows) { err = nil @@ -346,23 +345,22 @@ func (store *SQLCryptoStore) PutGroupSession(ctx context.Context, session *Inbou Int("max_messages", session.MaxMessages). Bool("is_scheduled", session.IsScheduled). Stringer("key_backup_version", session.KeyBackupVersion). - Stringer("key_source", session.KeySource). Msg("Upserting megolm inbound group session") _, err = store.DB.Exec(ctx, ` INSERT INTO crypto_megolm_inbound_session ( session_id, sender_key, signing_key, room_id, session, forwarding_chains, - ratchet_safety, received_at, max_age, max_messages, is_scheduled, key_backup_version, key_source, account_id - ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14) + ratchet_safety, received_at, max_age, max_messages, is_scheduled, key_backup_version, account_id + ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13) ON CONFLICT (session_id, account_id) DO UPDATE SET withheld_code=NULL, withheld_reason=NULL, sender_key=excluded.sender_key, signing_key=excluded.signing_key, room_id=excluded.room_id, session=excluded.session, forwarding_chains=excluded.forwarding_chains, ratchet_safety=excluded.ratchet_safety, received_at=excluded.received_at, max_age=excluded.max_age, max_messages=excluded.max_messages, is_scheduled=excluded.is_scheduled, - key_backup_version=excluded.key_backup_version, key_source=excluded.key_source + key_backup_version=excluded.key_backup_version `, session.ID(), session.SenderKey, session.SigningKey, session.RoomID, sessionBytes, forwardingChains, ratchetSafety, datePtr(session.ReceivedAt), dbutil.NumPtr(session.MaxAge), dbutil.NumPtr(session.MaxMessages), - session.IsScheduled, session.KeyBackupVersion, session.KeySource, store.AccountID, + session.IsScheduled, session.KeyBackupVersion, store.AccountID, ) return err } @@ -375,13 +373,12 @@ func (store *SQLCryptoStore) GetGroupSession(ctx context.Context, roomID id.Room var maxAge, maxMessages sql.NullInt64 var isScheduled bool var version id.KeyBackupVersion - var keySource id.KeySource err := store.DB.QueryRow(ctx, ` - SELECT sender_key, signing_key, session, forwarding_chains, withheld_code, withheld_reason, ratchet_safety, received_at, max_age, max_messages, is_scheduled, key_backup_version, key_source + SELECT sender_key, signing_key, session, forwarding_chains, withheld_code, withheld_reason, ratchet_safety, received_at, max_age, max_messages, is_scheduled, key_backup_version FROM crypto_megolm_inbound_session WHERE room_id=$1 AND session_id=$2 AND account_id=$3`, roomID, sessionID, store.AccountID, - ).Scan(&senderKey, &signingKey, &sessionBytes, &forwardingChains, &withheldCode, &withheldReason, &ratchetSafetyBytes, &receivedAt, &maxAge, &maxMessages, &isScheduled, &version, &keySource) + ).Scan(&senderKey, &signingKey, &sessionBytes, &forwardingChains, &withheldCode, &withheldReason, &ratchetSafetyBytes, &receivedAt, &maxAge, &maxMessages, &isScheduled, &version) if errors.Is(err, sql.ErrNoRows) { return nil, nil } else if err != nil { @@ -412,7 +409,6 @@ func (store *SQLCryptoStore) GetGroupSession(ctx context.Context, roomID id.Room MaxMessages: int(maxMessages.Int64), IsScheduled: isScheduled, KeyBackupVersion: version, - KeySource: keySource, }, nil } @@ -537,8 +533,7 @@ func (store *SQLCryptoStore) scanInboundGroupSession(rows dbutil.Scannable) (*In var maxAge, maxMessages sql.NullInt64 var isScheduled bool var version id.KeyBackupVersion - var keySource id.KeySource - err := rows.Scan(&roomID, &senderKey, &signingKey, &sessionBytes, &forwardingChains, &ratchetSafetyBytes, &receivedAt, &maxAge, &maxMessages, &isScheduled, &version, &keySource) + err := rows.Scan(&roomID, &senderKey, &signingKey, &sessionBytes, &forwardingChains, &ratchetSafetyBytes, &receivedAt, &maxAge, &maxMessages, &isScheduled, &version) if err != nil { return nil, err } @@ -558,13 +553,12 @@ func (store *SQLCryptoStore) scanInboundGroupSession(rows dbutil.Scannable) (*In MaxMessages: int(maxMessages.Int64), IsScheduled: isScheduled, KeyBackupVersion: version, - KeySource: keySource, }, nil } func (store *SQLCryptoStore) GetGroupSessionsForRoom(ctx context.Context, roomID id.RoomID) dbutil.RowIter[*InboundGroupSession] { rows, err := store.DB.Query(ctx, ` - SELECT room_id, sender_key, signing_key, session, forwarding_chains, ratchet_safety, received_at, max_age, max_messages, is_scheduled, key_backup_version, key_source + SELECT room_id, sender_key, signing_key, session, forwarding_chains, ratchet_safety, received_at, max_age, max_messages, is_scheduled, key_backup_version FROM crypto_megolm_inbound_session WHERE room_id=$1 AND account_id=$2 AND session IS NOT NULL`, roomID, store.AccountID, ) @@ -573,7 +567,7 @@ func (store *SQLCryptoStore) GetGroupSessionsForRoom(ctx context.Context, roomID func (store *SQLCryptoStore) GetAllGroupSessions(ctx context.Context) dbutil.RowIter[*InboundGroupSession] { rows, err := store.DB.Query(ctx, ` - SELECT room_id, sender_key, signing_key, session, forwarding_chains, ratchet_safety, received_at, max_age, max_messages, is_scheduled, key_backup_version, key_source + SELECT room_id, sender_key, signing_key, session, forwarding_chains, ratchet_safety, received_at, max_age, max_messages, is_scheduled, key_backup_version FROM crypto_megolm_inbound_session WHERE account_id=$1 AND session IS NOT NULL`, store.AccountID, ) @@ -582,7 +576,7 @@ func (store *SQLCryptoStore) GetAllGroupSessions(ctx context.Context) dbutil.Row func (store *SQLCryptoStore) GetGroupSessionsWithoutKeyBackupVersion(ctx context.Context, version id.KeyBackupVersion) dbutil.RowIter[*InboundGroupSession] { rows, err := store.DB.Query(ctx, ` - SELECT room_id, sender_key, signing_key, session, forwarding_chains, ratchet_safety, received_at, max_age, max_messages, is_scheduled, key_backup_version, key_source + SELECT room_id, sender_key, signing_key, session, forwarding_chains, ratchet_safety, received_at, max_age, max_messages, is_scheduled, key_backup_version FROM crypto_megolm_inbound_session WHERE account_id=$1 AND session IS NOT NULL AND key_backup_version != $2`, store.AccountID, version, ) @@ -669,20 +663,6 @@ func (store *SQLCryptoStore) IsOutboundGroupSessionShared(ctx context.Context, u // ValidateMessageIndex returns whether the given event information match the ones stored in the database // for the given sender key, session ID and index. If the index hasn't been stored, this will store it. func (store *SQLCryptoStore) ValidateMessageIndex(ctx context.Context, senderKey id.SenderKey, sessionID id.SessionID, eventID id.EventID, index uint, timestamp int64) (bool, error) { - if eventID == "" && timestamp == 0 { - var notOK bool - const validateEmptyQuery = ` - SELECT EXISTS(SELECT 1 FROM crypto_message_index WHERE sender_key=$1 AND session_id=$2 AND "index"=$3) - ` - err := store.DB.QueryRow(ctx, validateEmptyQuery, senderKey, sessionID, index).Scan(¬OK) - if notOK { - zerolog.Ctx(ctx).Debug(). - Uint("message_index", index). - Msg("Rejecting event without event ID and timestamp due to already knowing them") - } - return !notOK, err - } - const validateQuery = ` INSERT INTO crypto_message_index (sender_key, session_id, "index", event_id, timestamp) VALUES ($1, $2, $3, $4, $5) diff --git a/crypto/sql_store_upgrade/00-latest-revision.sql b/crypto/sql_store_upgrade/00-latest-revision.sql index 3709f1e5..00dd1387 100644 --- a/crypto/sql_store_upgrade/00-latest-revision.sql +++ b/crypto/sql_store_upgrade/00-latest-revision.sql @@ -1,4 +1,4 @@ --- v0 -> v19 (compatible with v15+): Latest revision +-- v0 -> v17 (compatible with v15+): Latest revision CREATE TABLE IF NOT EXISTS crypto_account ( account_id TEXT PRIMARY KEY, device_id TEXT NOT NULL, @@ -71,11 +71,8 @@ CREATE TABLE IF NOT EXISTS crypto_megolm_inbound_session ( max_messages INTEGER, is_scheduled BOOLEAN NOT NULL DEFAULT false, key_backup_version TEXT NOT NULL DEFAULT '', - key_source TEXT NOT NULL DEFAULT '', PRIMARY KEY (account_id, session_id) ); --- Useful index to find keys that need backing up -CREATE INDEX crypto_megolm_inbound_session_backup_idx ON crypto_megolm_inbound_session(account_id, key_backup_version) WHERE session IS NOT NULL; CREATE TABLE IF NOT EXISTS crypto_megolm_outbound_session ( account_id TEXT, diff --git a/crypto/sql_store_upgrade/18-megolm-inbound-session-backup-index.sql b/crypto/sql_store_upgrade/18-megolm-inbound-session-backup-index.sql deleted file mode 100644 index da26da0f..00000000 --- a/crypto/sql_store_upgrade/18-megolm-inbound-session-backup-index.sql +++ /dev/null @@ -1,2 +0,0 @@ --- v18 (compatible with v15+): Add an index to the megolm_inbound_session table to make finding sessions to backup faster -CREATE INDEX crypto_megolm_inbound_session_backup_idx ON crypto_megolm_inbound_session(account_id, key_backup_version) WHERE session IS NOT NULL; diff --git a/crypto/sql_store_upgrade/19-megolm-session-source.sql b/crypto/sql_store_upgrade/19-megolm-session-source.sql deleted file mode 100644 index f624222f..00000000 --- a/crypto/sql_store_upgrade/19-megolm-session-source.sql +++ /dev/null @@ -1,2 +0,0 @@ --- v19 (compatible with v15+): Store megolm session source -ALTER TABLE crypto_megolm_inbound_session ADD COLUMN key_source TEXT NOT NULL DEFAULT ''; diff --git a/crypto/ssss/client.go b/crypto/ssss/client.go index 8691d032..e30925d9 100644 --- a/crypto/ssss/client.go +++ b/crypto/ssss/client.go @@ -95,22 +95,6 @@ func (mach *Machine) SetEncryptedAccountData(ctx context.Context, eventType even return mach.Client.SetAccountData(ctx, eventType.Type, &EncryptedAccountDataEventContent{Encrypted: encrypted}) } -// SetEncryptedAccountDataWithMetadata encrypts the given data with the given keys and stores it, -// alongside the unencrypted metadata, on the server. -func (mach *Machine) SetEncryptedAccountDataWithMetadata(ctx context.Context, eventType event.Type, data []byte, metadata map[string]any, keys ...*Key) error { - if len(keys) == 0 { - return ErrNoKeyGiven - } - encrypted := make(map[string]EncryptedKeyData, len(keys)) - for _, key := range keys { - encrypted[key.ID] = key.Encrypt(eventType.Type, data) - } - return mach.Client.SetAccountData(ctx, eventType.Type, &EncryptedAccountDataEventContent{ - Encrypted: encrypted, - Metadata: metadata, - }) -} - // GenerateAndUploadKey generates a new SSSS key and stores the metadata on the server. func (mach *Machine) GenerateAndUploadKey(ctx context.Context, passphrase string) (key *Key, err error) { key, err = NewKey(passphrase) diff --git a/crypto/ssss/key.go b/crypto/ssss/key.go index 78ebd8f3..aa22360a 100644 --- a/crypto/ssss/key.go +++ b/crypto/ssss/key.go @@ -7,8 +7,6 @@ package ssss import ( - "crypto/hmac" - "crypto/sha256" "encoding/base64" "fmt" "strings" @@ -59,12 +57,12 @@ func NewKey(passphrase string) (*Key, error) { // We store a certain hash in the key metadata so that clients can check if the user entered the correct key. ivBytes := random.Bytes(utils.AESCTRIVLength) keyData.IV = base64.RawStdEncoding.EncodeToString(ivBytes) - macBytes, err := keyData.calculateHash(ssssKey) + var err error + keyData.MAC, err = keyData.calculateHash(ssssKey) if err != nil { // This should never happen because we just generated the IV and key. return nil, fmt.Errorf("failed to calculate hash: %w", err) } - keyData.MAC = base64.RawStdEncoding.EncodeToString(macBytes) return &Key{ Key: ssssKey, @@ -110,18 +108,12 @@ func (key *Key) Decrypt(eventType string, data EncryptedKeyData) ([]byte, error) return nil, err } - mac, err := base64.RawStdEncoding.DecodeString(strings.TrimRight(data.MAC, "=")) - if err != nil { - return nil, err - } - // derive the AES and HMAC keys for the requested event type using the SSSS key aesKey, hmacKey := utils.DeriveKeysSHA256(key.Key, eventType) // compare the stored MAC with the one we calculated from the ciphertext - h := hmac.New(sha256.New, hmacKey[:]) - h.Write(payload) - if !hmac.Equal(h.Sum(nil), mac) { + calcMac := utils.HMACSHA256B64(payload, hmacKey) + if strings.TrimRight(data.MAC, "=") != calcMac { return nil, ErrKeyDataMACMismatch } diff --git a/crypto/ssss/meta.go b/crypto/ssss/meta.go index 34775fa7..474c85d8 100644 --- a/crypto/ssss/meta.go +++ b/crypto/ssss/meta.go @@ -7,10 +7,7 @@ package ssss import ( - "crypto/hmac" - "crypto/sha256" "encoding/base64" - "errors" "fmt" "strings" @@ -36,9 +33,7 @@ func (kd *KeyMetadata) VerifyPassphrase(keyID, passphrase string) (*Key, error) ssssKey, err := kd.Passphrase.GetKey(passphrase) if err != nil { return nil, err - } - err = kd.verifyKey(ssssKey) - if err != nil && !errors.Is(err, ErrUnverifiableKey) { + } else if err = kd.verifyKey(ssssKey); err != nil { return nil, err } @@ -54,9 +49,7 @@ func (kd *KeyMetadata) VerifyRecoveryKey(keyID, recoveryKey string) (*Key, error ssssKey := utils.DecodeBase58RecoveryKey(recoveryKey) if ssssKey == nil { return nil, ErrInvalidRecoveryKey - } - err := kd.verifyKey(ssssKey) - if err != nil && !errors.Is(err, ErrUnverifiableKey) { + } else if err := kd.verifyKey(ssssKey); err != nil { return nil, err } @@ -64,28 +57,20 @@ func (kd *KeyMetadata) VerifyRecoveryKey(keyID, recoveryKey string) (*Key, error ID: keyID, Key: ssssKey, Metadata: kd, - }, err + }, nil } func (kd *KeyMetadata) verifyKey(key []byte) error { - if kd.MAC == "" || kd.IV == "" { - return ErrUnverifiableKey - } unpaddedMAC := strings.TrimRight(kd.MAC, "=") expectedMACLength := base64.RawStdEncoding.EncodedLen(utils.SHAHashLength) if len(unpaddedMAC) != expectedMACLength { return fmt.Errorf("%w: invalid mac length %d (expected %d)", ErrCorruptedKeyMetadata, len(unpaddedMAC), expectedMACLength) } - expectedMAC, err := base64.RawStdEncoding.DecodeString(unpaddedMAC) - if err != nil { - return fmt.Errorf("%w: failed to decode mac: %w", ErrCorruptedKeyMetadata, err) - } - calculatedMAC, err := kd.calculateHash(key) + hash, err := kd.calculateHash(key) if err != nil { return err } - // This doesn't really need to be constant time since it's fully local, but might as well be. - if !hmac.Equal(expectedMAC, calculatedMAC) { + if unpaddedMAC != hash { return ErrIncorrectSSSSKey } return nil @@ -98,26 +83,23 @@ func (kd *KeyMetadata) VerifyKey(key []byte) bool { // calculateHash calculates the hash used for checking if the key is entered correctly as described // in the spec: https://matrix.org/docs/spec/client_server/unstable#m-secret-storage-v1-aes-hmac-sha2 -func (kd *KeyMetadata) calculateHash(key []byte) ([]byte, error) { +func (kd *KeyMetadata) calculateHash(key []byte) (string, error) { aesKey, hmacKey := utils.DeriveKeysSHA256(key, "") unpaddedIV := strings.TrimRight(kd.IV, "=") expectedIVLength := base64.RawStdEncoding.EncodedLen(utils.AESCTRIVLength) - if len(unpaddedIV) < expectedIVLength || len(unpaddedIV) > expectedIVLength*3 { - return nil, fmt.Errorf("%w: invalid iv length %d (expected %d)", ErrCorruptedKeyMetadata, len(unpaddedIV), expectedIVLength) + if len(unpaddedIV) != expectedIVLength { + return "", fmt.Errorf("%w: invalid iv length %d (expected %d)", ErrCorruptedKeyMetadata, len(unpaddedIV), expectedIVLength) } - rawIVBytes, err := base64.RawStdEncoding.DecodeString(unpaddedIV) - if err != nil { - return nil, fmt.Errorf("%w: failed to decode iv: %w", ErrCorruptedKeyMetadata, err) - } - // TODO log a warning for non-16 byte IVs? - // Certain broken clients like nheko generated 32-byte IVs where only the first 16 bytes were used. - ivBytes := *(*[utils.AESCTRIVLength]byte)(rawIVBytes[:utils.AESCTRIVLength]) - zeroes := make([]byte, utils.AESCTRKeyLength) - encryptedZeroes := utils.XorA256CTR(zeroes, aesKey, ivBytes) - h := hmac.New(sha256.New, hmacKey[:]) - h.Write(encryptedZeroes) - return h.Sum(nil), nil + var ivBytes [utils.AESCTRIVLength]byte + _, err := base64.RawStdEncoding.Decode(ivBytes[:], []byte(unpaddedIV)) + if err != nil { + return "", fmt.Errorf("%w: failed to decode iv: %w", ErrCorruptedKeyMetadata, err) + } + + cipher := utils.XorA256CTR(make([]byte, utils.AESCTRKeyLength), aesKey, ivBytes) + + return utils.HMACSHA256B64(cipher, hmacKey), nil } // PassphraseMetadata represents server-side metadata about a SSSS key passphrase. diff --git a/crypto/ssss/meta_test.go b/crypto/ssss/meta_test.go index d59809c7..4f2ff378 100644 --- a/crypto/ssss/meta_test.go +++ b/crypto/ssss/meta_test.go @@ -8,10 +8,10 @@ package ssss_test import ( "encoding/json" + "errors" "testing" "github.com/stretchr/testify/assert" - "go.mau.fi/util/exerrors" "maunium.net/go/mautrix/crypto/ssss" ) @@ -41,24 +41,10 @@ const key2Meta = ` } ` -const key2MetaUnverified = ` -{ - "algorithm": "m.secret_storage.v1.aes-hmac-sha2" -} -` - -const key2MetaLongIV = ` -{ - "algorithm": "m.secret_storage.v1.aes-hmac-sha2", - "iv": "O0BOvTqiIAYjC+RMcyHfW2f/gdxjceTxoYtNlpPduJ8=", - "mac": "7k6OruQlWg0UmQjxGZ0ad4Q6DdwkgnoI7G6X3IjBYtI=" -} -` - const key2MetaBrokenIV = ` { "algorithm": "m.secret_storage.v1.aes-hmac-sha2", - "iv": "MeowMeowMeow", + "iv": "O0BOvTqiIAYjC+RMcyHfWwMeowMeowMeow", "mac": "7k6OruQlWg0UmQjxGZ0ad4Q6DdwkgnoI7G6X3IjBYtI=" } ` @@ -84,11 +70,23 @@ func getKeyMeta(meta string) *ssss.KeyMetadata { } func getKey1() *ssss.Key { - return exerrors.Must(getKeyMeta(key1Meta).VerifyRecoveryKey(key1ID, key1RecoveryKey)) + km := getKeyMeta(key1Meta) + key, err := km.VerifyRecoveryKey(key1ID, key1RecoveryKey) + if err != nil { + panic(err) + } + key.ID = key1ID + return key } func getKey2() *ssss.Key { - return exerrors.Must(getKeyMeta(key2Meta).VerifyRecoveryKey(key2ID, key2RecoveryKey)) + km := getKeyMeta(key2Meta) + key, err := km.VerifyRecoveryKey(key2ID, key2RecoveryKey) + if err != nil { + panic(err) + } + key.ID = key2ID + return key } func TestKeyMetadata_VerifyRecoveryKey_Correct(t *testing.T) { @@ -107,33 +105,17 @@ func TestKeyMetadata_VerifyRecoveryKey_Correct2(t *testing.T) { assert.Equal(t, key2RecoveryKey, key.RecoveryKey()) } -func TestKeyMetadata_VerifyRecoveryKey_NonCompliant_LongIV(t *testing.T) { - km := getKeyMeta(key2MetaLongIV) - key, err := km.VerifyRecoveryKey(key2ID, key2RecoveryKey) - assert.NoError(t, err) - assert.NotNil(t, key) - assert.Equal(t, key2RecoveryKey, key.RecoveryKey()) -} - -func TestKeyMetadata_VerifyRecoveryKey_Unverified(t *testing.T) { - km := getKeyMeta(key2MetaUnverified) - key, err := km.VerifyRecoveryKey(key2ID, key2RecoveryKey) - assert.ErrorIs(t, err, ssss.ErrUnverifiableKey) - assert.NotNil(t, key) - assert.Equal(t, key2RecoveryKey, key.RecoveryKey()) -} - func TestKeyMetadata_VerifyRecoveryKey_Invalid(t *testing.T) { km := getKeyMeta(key1Meta) key, err := km.VerifyRecoveryKey(key1ID, "foo") - assert.ErrorIs(t, err, ssss.ErrInvalidRecoveryKey) + assert.True(t, errors.Is(err, ssss.ErrInvalidRecoveryKey), "unexpected error: %v", err) assert.Nil(t, key) } func TestKeyMetadata_VerifyRecoveryKey_Incorrect(t *testing.T) { km := getKeyMeta(key1Meta) key, err := km.VerifyRecoveryKey(key2ID, key2RecoveryKey) - assert.ErrorIs(t, err, ssss.ErrIncorrectSSSSKey) + assert.True(t, errors.Is(err, ssss.ErrIncorrectSSSSKey), "unexpected error: %v", err) assert.Nil(t, key) } @@ -148,27 +130,27 @@ func TestKeyMetadata_VerifyPassphrase_Correct(t *testing.T) { func TestKeyMetadata_VerifyPassphrase_Incorrect(t *testing.T) { km := getKeyMeta(key1Meta) key, err := km.VerifyPassphrase(key1ID, "incorrect horse battery staple") - assert.ErrorIs(t, err, ssss.ErrIncorrectSSSSKey) + assert.True(t, errors.Is(err, ssss.ErrIncorrectSSSSKey), "unexpected error %v", err) assert.Nil(t, key) } func TestKeyMetadata_VerifyPassphrase_NotSet(t *testing.T) { km := getKeyMeta(key2Meta) key, err := km.VerifyPassphrase(key2ID, "hmm") - assert.ErrorIs(t, err, ssss.ErrNoPassphrase) + assert.True(t, errors.Is(err, ssss.ErrNoPassphrase), "unexpected error %v", err) assert.Nil(t, key) } func TestKeyMetadata_VerifyRecoveryKey_CorruptedIV(t *testing.T) { km := getKeyMeta(key2MetaBrokenIV) key, err := km.VerifyRecoveryKey(key2ID, key2RecoveryKey) - assert.ErrorIs(t, err, ssss.ErrCorruptedKeyMetadata) + assert.True(t, errors.Is(err, ssss.ErrCorruptedKeyMetadata), "unexpected error %v", err) assert.Nil(t, key) } func TestKeyMetadata_VerifyRecoveryKey_CorruptedMAC(t *testing.T) { km := getKeyMeta(key2MetaBrokenMAC) key, err := km.VerifyRecoveryKey(key2ID, key2RecoveryKey) - assert.ErrorIs(t, err, ssss.ErrCorruptedKeyMetadata) + assert.True(t, errors.Is(err, ssss.ErrCorruptedKeyMetadata), "unexpected error %v", err) assert.Nil(t, key) } diff --git a/crypto/ssss/types.go b/crypto/ssss/types.go index b7465d3e..345393b0 100644 --- a/crypto/ssss/types.go +++ b/crypto/ssss/types.go @@ -26,8 +26,7 @@ var ( ErrUnsupportedPassphraseAlgorithm = errors.New("unsupported passphrase KDF algorithm") ErrIncorrectSSSSKey = errors.New("incorrect SSSS key") ErrInvalidRecoveryKey = errors.New("invalid recovery key") - ErrCorruptedKeyMetadata = errors.New("corrupted recovery key metadata") - ErrUnverifiableKey = errors.New("cannot verify recovery key: missing MAC or IV in metadata") + ErrCorruptedKeyMetadata = errors.New("corrupted key metadata") ) // Algorithm is the identifier for an SSSS encryption algorithm. @@ -58,7 +57,6 @@ type EncryptedKeyData struct { type EncryptedAccountDataEventContent struct { Encrypted map[string]EncryptedKeyData `json:"encrypted"` - Metadata map[string]any `json:"com.beeper.metadata,omitzero"` } func (ed *EncryptedAccountDataEventContent) Decrypt(eventType string, key *Key) ([]byte, error) { diff --git a/crypto/store.go b/crypto/store.go index 7620cf35..8b7c0a96 100644 --- a/crypto/store.go +++ b/crypto/store.go @@ -525,9 +525,6 @@ func (gs *MemoryStore) ValidateMessageIndex(_ context.Context, senderKey id.Send } val, ok := gs.MessageIndices[key] if !ok { - if eventID == "" && timestamp == 0 { - return true, nil - } gs.MessageIndices[key] = messageIndexValue{ EventID: eventID, Timestamp: timestamp, diff --git a/crypto/store_test.go b/crypto/store_test.go index 7a47243e..8aeae7af 100644 --- a/crypto/store_test.go +++ b/crypto/store_test.go @@ -75,13 +75,8 @@ func TestValidateMessageIndex(t *testing.T) { t.Run(storeName, func(t *testing.T) { acc := NewOlmAccount() - // Validating without event ID and timestamp before we have them should work - ok, err := store.ValidateMessageIndex(context.TODO(), acc.IdentityKey(), "sess1", "", 0, 0) - require.NoError(t, err, "Error validating message index") - assert.True(t, ok, "First message validation should be valid") - // First message should validate successfully - ok, err = store.ValidateMessageIndex(context.TODO(), acc.IdentityKey(), "sess1", "event1", 0, 1000) + ok, err := store.ValidateMessageIndex(context.TODO(), acc.IdentityKey(), "sess1", "event1", 0, 1000) require.NoError(t, err, "Error validating message index") assert.True(t, ok, "First message validation should be valid") @@ -99,11 +94,6 @@ func TestValidateMessageIndex(t *testing.T) { ok, err = store.ValidateMessageIndex(context.TODO(), acc.IdentityKey(), "sess1", "event1", 0, 1000) require.NoError(t, err, "Error validating message index") assert.True(t, ok, "First message validation should be valid") - - // Validating without event ID and timestamp must fail if we already know them - ok, err = store.ValidateMessageIndex(context.TODO(), acc.IdentityKey(), "sess1", "", 0, 0) - require.NoError(t, err, "Error validating message index") - assert.False(t, ok, "First message validation should be invalid") }) } } diff --git a/crypto/verificationhelper/mockserver_test.go b/crypto/verificationhelper/mockserver_test.go new file mode 100644 index 00000000..45ca7781 --- /dev/null +++ b/crypto/verificationhelper/mockserver_test.go @@ -0,0 +1,236 @@ +// Copyright (c) 2024 Sumner Evans +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package verificationhelper_test + +import ( + "context" + "encoding/json" + "io" + "net/http" + "net/http/httptest" + "strings" + "testing" + + "github.com/rs/zerolog/log" // zerolog-allow-global-log + "github.com/stretchr/testify/require" + "go.mau.fi/util/random" + + "maunium.net/go/mautrix" + "maunium.net/go/mautrix/crypto" + "maunium.net/go/mautrix/crypto/cryptohelper" + "maunium.net/go/mautrix/event" + "maunium.net/go/mautrix/id" +) + +// mockServer is a mock Matrix server that wraps an [httptest.Server] to allow +// testing of the interactive verification process. +type mockServer struct { + *httptest.Server + + AccessTokenToUserID map[string]id.UserID + DeviceInbox map[id.UserID]map[id.DeviceID][]event.Event + AccountData map[id.UserID]map[event.Type]json.RawMessage + DeviceKeys map[id.UserID]map[id.DeviceID]mautrix.DeviceKeys + MasterKeys map[id.UserID]mautrix.CrossSigningKeys + SelfSigningKeys map[id.UserID]mautrix.CrossSigningKeys + UserSigningKeys map[id.UserID]mautrix.CrossSigningKeys +} + +func createMockServer(t *testing.T) *mockServer { + t.Helper() + + server := mockServer{ + AccessTokenToUserID: map[string]id.UserID{}, + DeviceInbox: map[id.UserID]map[id.DeviceID][]event.Event{}, + AccountData: map[id.UserID]map[event.Type]json.RawMessage{}, + DeviceKeys: map[id.UserID]map[id.DeviceID]mautrix.DeviceKeys{}, + MasterKeys: map[id.UserID]mautrix.CrossSigningKeys{}, + SelfSigningKeys: map[id.UserID]mautrix.CrossSigningKeys{}, + UserSigningKeys: map[id.UserID]mautrix.CrossSigningKeys{}, + } + + router := http.NewServeMux() + router.HandleFunc("POST /_matrix/client/v3/login", server.postLogin) + router.HandleFunc("POST /_matrix/client/v3/keys/query", server.postKeysQuery) + router.HandleFunc("PUT /_matrix/client/v3/sendToDevice/{type}/{txn}", server.putSendToDevice) + router.HandleFunc("PUT /_matrix/client/v3/user/{userID}/account_data/{type}", server.putAccountData) + router.HandleFunc("POST /_matrix/client/v3/keys/device_signing/upload", server.postDeviceSigningUpload) + router.HandleFunc("POST /_matrix/client/v3/keys/signatures/upload", server.emptyResp) + router.HandleFunc("POST /_matrix/client/v3/keys/upload", server.postKeysUpload) + + server.Server = httptest.NewServer(router) + return &server +} + +func (ms *mockServer) getUserID(r *http.Request) id.UserID { + authHeader := r.Header.Get("Authorization") + authHeader = strings.TrimPrefix(authHeader, "Bearer ") + userID, ok := ms.AccessTokenToUserID[authHeader] + if !ok { + panic("no user ID found for access token " + authHeader) + } + return userID +} + +func (s *mockServer) emptyResp(w http.ResponseWriter, _ *http.Request) { + w.Write([]byte("{}")) +} + +func (s *mockServer) postLogin(w http.ResponseWriter, r *http.Request) { + var loginReq mautrix.ReqLogin + json.NewDecoder(r.Body).Decode(&loginReq) + + deviceID := loginReq.DeviceID + if deviceID == "" { + deviceID = id.DeviceID(random.String(10)) + } + + accessToken := random.String(30) + userID := id.UserID(loginReq.Identifier.User) + s.AccessTokenToUserID[accessToken] = userID + + json.NewEncoder(w).Encode(&mautrix.RespLogin{ + AccessToken: accessToken, + DeviceID: deviceID, + UserID: userID, + }) +} + +func (s *mockServer) putSendToDevice(w http.ResponseWriter, r *http.Request) { + var req mautrix.ReqSendToDevice + json.NewDecoder(r.Body).Decode(&req) + evtType := event.Type{Type: r.PathValue("type"), Class: event.ToDeviceEventType} + + for user, devices := range req.Messages { + for device, content := range devices { + if _, ok := s.DeviceInbox[user]; !ok { + s.DeviceInbox[user] = map[id.DeviceID][]event.Event{} + } + content.ParseRaw(evtType) + s.DeviceInbox[user][device] = append(s.DeviceInbox[user][device], event.Event{ + Sender: s.getUserID(r), + Type: evtType, + Content: *content, + }) + } + } + s.emptyResp(w, r) +} + +func (s *mockServer) putAccountData(w http.ResponseWriter, r *http.Request) { + userID := id.UserID(r.PathValue("userID")) + eventType := event.Type{Type: r.PathValue("type"), Class: event.AccountDataEventType} + + jsonData, _ := io.ReadAll(r.Body) + if _, ok := s.AccountData[userID]; !ok { + s.AccountData[userID] = map[event.Type]json.RawMessage{} + } + s.AccountData[userID][eventType] = json.RawMessage(jsonData) + s.emptyResp(w, r) +} + +func (s *mockServer) postKeysQuery(w http.ResponseWriter, r *http.Request) { + var req mautrix.ReqQueryKeys + json.NewDecoder(r.Body).Decode(&req) + resp := mautrix.RespQueryKeys{ + MasterKeys: map[id.UserID]mautrix.CrossSigningKeys{}, + UserSigningKeys: map[id.UserID]mautrix.CrossSigningKeys{}, + SelfSigningKeys: map[id.UserID]mautrix.CrossSigningKeys{}, + DeviceKeys: map[id.UserID]map[id.DeviceID]mautrix.DeviceKeys{}, + } + for user := range req.DeviceKeys { + resp.MasterKeys[user] = s.MasterKeys[user] + resp.UserSigningKeys[user] = s.UserSigningKeys[user] + resp.SelfSigningKeys[user] = s.SelfSigningKeys[user] + resp.DeviceKeys[user] = s.DeviceKeys[user] + } + json.NewEncoder(w).Encode(&resp) +} + +func (s *mockServer) postKeysUpload(w http.ResponseWriter, r *http.Request) { + var req mautrix.ReqUploadKeys + json.NewDecoder(r.Body).Decode(&req) + + userID := s.getUserID(r) + if _, ok := s.DeviceKeys[userID]; !ok { + s.DeviceKeys[userID] = map[id.DeviceID]mautrix.DeviceKeys{} + } + s.DeviceKeys[userID][req.DeviceKeys.DeviceID] = *req.DeviceKeys + + json.NewEncoder(w).Encode(&mautrix.RespUploadKeys{ + OneTimeKeyCounts: mautrix.OTKCount{SignedCurve25519: 50}, + }) +} + +func (s *mockServer) postDeviceSigningUpload(w http.ResponseWriter, r *http.Request) { + var req mautrix.UploadCrossSigningKeysReq + json.NewDecoder(r.Body).Decode(&req) + + userID := s.getUserID(r) + s.MasterKeys[userID] = req.Master + s.SelfSigningKeys[userID] = req.SelfSigning + s.UserSigningKeys[userID] = req.UserSigning + + s.emptyResp(w, r) +} + +func (ms *mockServer) Login(t *testing.T, ctx context.Context, userID id.UserID, deviceID id.DeviceID) (*mautrix.Client, crypto.Store) { + t.Helper() + client, err := mautrix.NewClient(ms.URL, "", "") + require.NoError(t, err) + client.StateStore = mautrix.NewMemoryStateStore() + + _, err = client.Login(ctx, &mautrix.ReqLogin{ + Type: mautrix.AuthTypePassword, + Identifier: mautrix.UserIdentifier{ + Type: mautrix.IdentifierTypeUser, + User: userID.String(), + }, + DeviceID: deviceID, + Password: "password", + StoreCredentials: true, + }) + require.NoError(t, err) + + cryptoStore := crypto.NewMemoryStore(nil) + cryptoHelper, err := cryptohelper.NewCryptoHelper(client, []byte("test"), cryptoStore) + require.NoError(t, err) + client.Crypto = cryptoHelper + + err = cryptoHelper.Init(ctx) + require.NoError(t, err) + + machineLog := log.Logger.With(). + Stringer("my_user_id", userID). + Stringer("my_device_id", deviceID). + Logger() + cryptoHelper.Machine().Log = &machineLog + + err = cryptoHelper.Machine().ShareKeys(ctx, 50) + require.NoError(t, err) + + return client, cryptoStore +} + +func (ms *mockServer) dispatchToDevice(t *testing.T, ctx context.Context, client *mautrix.Client) { + t.Helper() + + for _, evt := range ms.DeviceInbox[client.UserID][client.DeviceID] { + client.Syncer.(*mautrix.DefaultSyncer).Dispatch(ctx, &evt) + ms.DeviceInbox[client.UserID][client.DeviceID] = ms.DeviceInbox[client.UserID][client.DeviceID][1:] + } +} + +func addDeviceID(ctx context.Context, cryptoStore crypto.Store, userID id.UserID, deviceID id.DeviceID) { + err := cryptoStore.PutDevice(ctx, userID, &id.Device{ + UserID: userID, + DeviceID: deviceID, + }) + if err != nil { + panic(err) + } +} diff --git a/crypto/verificationhelper/sas.go b/crypto/verificationhelper/sas.go index e6392c79..1313a613 100644 --- a/crypto/verificationhelper/sas.go +++ b/crypto/verificationhelper/sas.go @@ -695,7 +695,7 @@ func (vh *VerificationHelper) onVerificationMAC(ctx context.Context, txn Verific // Verify the MAC for each key var theirDevice *id.Device for keyID, mac := range macEvt.MAC { - log.Info().Stringer("key_id", keyID).Msg("Received MAC for key") + log.Info().Str("key_id", keyID.String()).Msg("Received MAC for key") alg, kID := keyID.Parse() if alg != id.KeyAlgorithmEd25519 { diff --git a/crypto/verificationhelper/verificationhelper_qr_crosssign_test.go b/crypto/verificationhelper/verificationhelper_qr_crosssign_test.go index 5e3f146b..aace2230 100644 --- a/crypto/verificationhelper/verificationhelper_qr_crosssign_test.go +++ b/crypto/verificationhelper/verificationhelper_qr_crosssign_test.go @@ -32,6 +32,7 @@ func TestCrossSignVerification_ScanQRAndConfirmScan(t *testing.T) { for _, tc := range testCases { t.Run(fmt.Sprintf("sendingScansQR=%t", tc.sendingScansQR), func(t *testing.T) { ts, sendingClient, receivingClient, _, _, sendingMachine, receivingMachine := initServerAndLoginAliceBob(t, ctx) + defer ts.Close() sendingCallbacks, receivingCallbacks, sendingHelper, receivingHelper := initDefaultCallbacks(t, ctx, sendingClient, receivingClient, sendingMachine, receivingMachine) var err error @@ -50,10 +51,10 @@ func TestCrossSignVerification_ScanQRAndConfirmScan(t *testing.T) { // event on the sending device. txnID, err := sendingHelper.StartVerification(ctx, bobUserID) require.NoError(t, err) - ts.DispatchToDevice(t, ctx, receivingClient) + ts.dispatchToDevice(t, ctx, receivingClient) err = receivingHelper.AcceptVerification(ctx, txnID) require.NoError(t, err) - ts.DispatchToDevice(t, ctx, sendingClient) + ts.dispatchToDevice(t, ctx, sendingClient) receivingShownQRCode := receivingCallbacks.GetQRCodeShown(txnID) require.NotNil(t, receivingShownQRCode) @@ -82,7 +83,7 @@ func TestCrossSignVerification_ScanQRAndConfirmScan(t *testing.T) { // Handle the start and done events on the receiving client and // confirm the scan. - ts.DispatchToDevice(t, ctx, receivingClient) + ts.dispatchToDevice(t, ctx, receivingClient) // Ensure that the receiving device detected that its QR code // was scanned. @@ -97,7 +98,7 @@ func TestCrossSignVerification_ScanQRAndConfirmScan(t *testing.T) { doneEvt = sendingInbox[0].Content.AsVerificationDone() assert.Equal(t, txnID, doneEvt.TransactionID) - ts.DispatchToDevice(t, ctx, sendingClient) + ts.dispatchToDevice(t, ctx, sendingClient) } else { // receiving scans QR // Emulate scanning the QR code shown by the sending device on // the receiving device. @@ -120,7 +121,7 @@ func TestCrossSignVerification_ScanQRAndConfirmScan(t *testing.T) { // Handle the start and done events on the receiving client and // confirm the scan. - ts.DispatchToDevice(t, ctx, sendingClient) + ts.dispatchToDevice(t, ctx, sendingClient) // Ensure that the sending device detected that its QR code was // scanned. @@ -135,7 +136,7 @@ func TestCrossSignVerification_ScanQRAndConfirmScan(t *testing.T) { doneEvt = receivingInbox[0].Content.AsVerificationDone() assert.Equal(t, txnID, doneEvt.TransactionID) - ts.DispatchToDevice(t, ctx, receivingClient) + ts.dispatchToDevice(t, ctx, receivingClient) } // Ensure that both devices have marked the verification as done. diff --git a/crypto/verificationhelper/verificationhelper_qr_self_test.go b/crypto/verificationhelper/verificationhelper_qr_self_test.go index ea918cd4..937cc414 100644 --- a/crypto/verificationhelper/verificationhelper_qr_self_test.go +++ b/crypto/verificationhelper/verificationhelper_qr_self_test.go @@ -36,6 +36,7 @@ func TestSelfVerification_Accept_QRContents(t *testing.T) { for _, tc := range testCases { t.Run(fmt.Sprintf("sendingGenerated=%t receivingGenerated=%t err=%s", tc.sendingGeneratedCrossSigningKeys, tc.receivingGeneratedCrossSigningKeys, tc.expectedAcceptError), func(t *testing.T) { ts, sendingClient, receivingClient, _, _, sendingMachine, receivingMachine := initServerAndLoginTwoAlice(t, ctx) + defer ts.Close() sendingCallbacks, receivingCallbacks, sendingHelper, receivingHelper := initDefaultCallbacks(t, ctx, sendingClient, receivingClient, sendingMachine, receivingMachine) var err error @@ -61,7 +62,7 @@ func TestSelfVerification_Accept_QRContents(t *testing.T) { // event on the sending device. txnID, err := sendingHelper.StartVerification(ctx, aliceUserID) require.NoError(t, err) - ts.DispatchToDevice(t, ctx, receivingClient) + ts.dispatchToDevice(t, ctx, receivingClient) err = receivingHelper.AcceptVerification(ctx, txnID) if tc.expectedAcceptError != "" { @@ -71,7 +72,7 @@ func TestSelfVerification_Accept_QRContents(t *testing.T) { require.NoError(t, err) } - ts.DispatchToDevice(t, ctx, sendingClient) + ts.dispatchToDevice(t, ctx, sendingClient) receivingShownQRCode := receivingCallbacks.GetQRCodeShown(txnID) require.NotNil(t, receivingShownQRCode) @@ -134,6 +135,7 @@ func TestSelfVerification_ScanQRAndConfirmScan(t *testing.T) { for _, tc := range testCases { t.Run(fmt.Sprintf("sendingGeneratedCrossSigningKeys=%t sendingScansQR=%t", tc.sendingGeneratedCrossSigningKeys, tc.sendingScansQR), func(t *testing.T) { ts, sendingClient, receivingClient, _, _, sendingMachine, receivingMachine := initServerAndLoginTwoAlice(t, ctx) + defer ts.Close() sendingCallbacks, receivingCallbacks, sendingHelper, receivingHelper := initDefaultCallbacks(t, ctx, sendingClient, receivingClient, sendingMachine, receivingMachine) var err error @@ -150,10 +152,10 @@ func TestSelfVerification_ScanQRAndConfirmScan(t *testing.T) { // event on the sending device. txnID, err := sendingHelper.StartVerification(ctx, aliceUserID) require.NoError(t, err) - ts.DispatchToDevice(t, ctx, receivingClient) + ts.dispatchToDevice(t, ctx, receivingClient) err = receivingHelper.AcceptVerification(ctx, txnID) require.NoError(t, err) - ts.DispatchToDevice(t, ctx, sendingClient) + ts.dispatchToDevice(t, ctx, sendingClient) receivingShownQRCode := receivingCallbacks.GetQRCodeShown(txnID) require.NotNil(t, receivingShownQRCode) @@ -182,7 +184,7 @@ func TestSelfVerification_ScanQRAndConfirmScan(t *testing.T) { // Handle the start and done events on the receiving client and // confirm the scan. - ts.DispatchToDevice(t, ctx, receivingClient) + ts.dispatchToDevice(t, ctx, receivingClient) // Ensure that the receiving device detected that its QR code // was scanned. @@ -197,7 +199,7 @@ func TestSelfVerification_ScanQRAndConfirmScan(t *testing.T) { doneEvt = sendingInbox[0].Content.AsVerificationDone() assert.Equal(t, txnID, doneEvt.TransactionID) - ts.DispatchToDevice(t, ctx, sendingClient) + ts.dispatchToDevice(t, ctx, sendingClient) } else { // receiving scans QR // Emulate scanning the QR code shown by the sending device on // the receiving device. @@ -220,7 +222,7 @@ func TestSelfVerification_ScanQRAndConfirmScan(t *testing.T) { // Handle the start and done events on the receiving client and // confirm the scan. - ts.DispatchToDevice(t, ctx, sendingClient) + ts.dispatchToDevice(t, ctx, sendingClient) // Ensure that the sending device detected that its QR code was // scanned. @@ -235,7 +237,7 @@ func TestSelfVerification_ScanQRAndConfirmScan(t *testing.T) { doneEvt = receivingInbox[0].Content.AsVerificationDone() assert.Equal(t, txnID, doneEvt.TransactionID) - ts.DispatchToDevice(t, ctx, receivingClient) + ts.dispatchToDevice(t, ctx, receivingClient) } // Ensure that both devices have marked the verification as done. @@ -249,6 +251,7 @@ func TestSelfVerification_ScanQRTransactionIDCorrupted(t *testing.T) { ctx := log.Logger.WithContext(context.TODO()) ts, sendingClient, receivingClient, _, _, sendingMachine, receivingMachine := initServerAndLoginTwoAlice(t, ctx) + defer ts.Close() sendingCallbacks, receivingCallbacks, sendingHelper, receivingHelper := initDefaultCallbacks(t, ctx, sendingClient, receivingClient, sendingMachine, receivingMachine) var err error @@ -260,10 +263,10 @@ func TestSelfVerification_ScanQRTransactionIDCorrupted(t *testing.T) { // event on the sending device. txnID, err := sendingHelper.StartVerification(ctx, aliceUserID) require.NoError(t, err) - ts.DispatchToDevice(t, ctx, receivingClient) + ts.dispatchToDevice(t, ctx, receivingClient) err = receivingHelper.AcceptVerification(ctx, txnID) require.NoError(t, err) - ts.DispatchToDevice(t, ctx, sendingClient) + ts.dispatchToDevice(t, ctx, sendingClient) receivingShownQRCodeBytes := receivingCallbacks.GetQRCodeShown(txnID).Bytes() sendingShownQRCodeBytes := sendingCallbacks.GetQRCodeShown(txnID).Bytes() @@ -307,6 +310,7 @@ func TestSelfVerification_ScanQRKeyCorrupted(t *testing.T) { for _, tc := range testCases { t.Run(fmt.Sprintf("sendingGeneratedCrossSigningKeys=%t sendingScansQR=%t corrupt=%d", tc.sendingGeneratedCrossSigningKeys, tc.sendingScansQR, tc.corruptByte), func(t *testing.T) { ts, sendingClient, receivingClient, _, _, sendingMachine, receivingMachine := initServerAndLoginTwoAlice(t, ctx) + defer ts.Close() sendingCallbacks, receivingCallbacks, sendingHelper, receivingHelper := initDefaultCallbacks(t, ctx, sendingClient, receivingClient, sendingMachine, receivingMachine) var err error @@ -323,10 +327,10 @@ func TestSelfVerification_ScanQRKeyCorrupted(t *testing.T) { // event on the sending device. txnID, err := sendingHelper.StartVerification(ctx, aliceUserID) require.NoError(t, err) - ts.DispatchToDevice(t, ctx, receivingClient) + ts.dispatchToDevice(t, ctx, receivingClient) err = receivingHelper.AcceptVerification(ctx, txnID) require.NoError(t, err) - ts.DispatchToDevice(t, ctx, sendingClient) + ts.dispatchToDevice(t, ctx, sendingClient) receivingShownQRCodeBytes := receivingCallbacks.GetQRCodeShown(txnID).Bytes() sendingShownQRCodeBytes := sendingCallbacks.GetQRCodeShown(txnID).Bytes() @@ -344,7 +348,7 @@ func TestSelfVerification_ScanQRKeyCorrupted(t *testing.T) { // Ensure that the receiving device received a cancellation. receivingInbox := ts.DeviceInbox[aliceUserID][receivingDeviceID] assert.Len(t, receivingInbox, 1) - ts.DispatchToDevice(t, ctx, receivingClient) + ts.dispatchToDevice(t, ctx, receivingClient) cancellation := receivingCallbacks.GetVerificationCancellation(txnID) require.NotNil(t, cancellation) assert.Equal(t, event.VerificationCancelCodeKeyMismatch, cancellation.Code) @@ -358,7 +362,7 @@ func TestSelfVerification_ScanQRKeyCorrupted(t *testing.T) { // Ensure that the sending device received a cancellation. sendingInbox := ts.DeviceInbox[aliceUserID][sendingDeviceID] assert.Len(t, sendingInbox, 1) - ts.DispatchToDevice(t, ctx, sendingClient) + ts.dispatchToDevice(t, ctx, sendingClient) cancellation := sendingCallbacks.GetVerificationCancellation(txnID) require.NotNil(t, cancellation) assert.Equal(t, event.VerificationCancelCodeKeyMismatch, cancellation.Code) diff --git a/crypto/verificationhelper/verificationhelper_sas_test.go b/crypto/verificationhelper/verificationhelper_sas_test.go index 283eca84..5747ac34 100644 --- a/crypto/verificationhelper/verificationhelper_sas_test.go +++ b/crypto/verificationhelper/verificationhelper_sas_test.go @@ -36,6 +36,7 @@ func TestVerification_SAS(t *testing.T) { for _, tc := range testCases { t.Run(fmt.Sprintf("sendingGenerated=%t sendingStartsSAS=%t sendingConfirmsFirst=%t", tc.sendingGeneratedCrossSigningKeys, tc.sendingStartsSAS, tc.sendingConfirmsFirst), func(t *testing.T) { ts, sendingClient, receivingClient, _, _, sendingMachine, receivingMachine := initServerAndLoginTwoAlice(t, ctx) + defer ts.Close() sendingCallbacks, receivingCallbacks, sendingHelper, receivingHelper := initDefaultCallbacks(t, ctx, sendingClient, receivingClient, sendingMachine, receivingMachine) var err error @@ -59,10 +60,10 @@ func TestVerification_SAS(t *testing.T) { // event on the sending device. txnID, err := sendingHelper.StartVerification(ctx, aliceUserID) require.NoError(t, err) - ts.DispatchToDevice(t, ctx, receivingClient) + ts.dispatchToDevice(t, ctx, receivingClient) err = receivingHelper.AcceptVerification(ctx, txnID) require.NoError(t, err) - ts.DispatchToDevice(t, ctx, sendingClient) + ts.dispatchToDevice(t, ctx, sendingClient) // Test that the start event is correct var startEvt *event.VerificationStartEventContent @@ -101,7 +102,7 @@ func TestVerification_SAS(t *testing.T) { if tc.sendingStartsSAS { // Process the verification start event on the receiving // device. - ts.DispatchToDevice(t, ctx, receivingClient) + ts.dispatchToDevice(t, ctx, receivingClient) // Receiving device sent the accept event to the sending device sendingInbox := ts.DeviceInbox[aliceUserID][sendingDeviceID] @@ -109,7 +110,7 @@ func TestVerification_SAS(t *testing.T) { acceptEvt = sendingInbox[0].Content.AsVerificationAccept() } else { // Process the verification start event on the sending device. - ts.DispatchToDevice(t, ctx, sendingClient) + ts.dispatchToDevice(t, ctx, sendingClient) // Sending device sent the accept event to the receiving device receivingInbox := ts.DeviceInbox[aliceUserID][receivingDeviceID] @@ -128,7 +129,7 @@ func TestVerification_SAS(t *testing.T) { var firstKeyEvt *event.VerificationKeyEventContent if tc.sendingStartsSAS { // Process the verification accept event on the sending device. - ts.DispatchToDevice(t, ctx, sendingClient) + ts.dispatchToDevice(t, ctx, sendingClient) // Sending device sends first key event to the receiving // device. @@ -138,7 +139,7 @@ func TestVerification_SAS(t *testing.T) { } else { // Process the verification accept event on the receiving // device. - ts.DispatchToDevice(t, ctx, receivingClient) + ts.dispatchToDevice(t, ctx, receivingClient) // Receiving device sends first key event to the sending // device. @@ -154,7 +155,7 @@ func TestVerification_SAS(t *testing.T) { var secondKeyEvt *event.VerificationKeyEventContent if tc.sendingStartsSAS { // Process the first key event on the receiving device. - ts.DispatchToDevice(t, ctx, receivingClient) + ts.dispatchToDevice(t, ctx, receivingClient) // Receiving device sends second key event to the sending // device. @@ -169,7 +170,7 @@ func TestVerification_SAS(t *testing.T) { assert.Len(t, descriptions, 7) } else { // Process the first key event on the sending device. - ts.DispatchToDevice(t, ctx, sendingClient) + ts.dispatchToDevice(t, ctx, sendingClient) // Sending device sends second key event to the receiving // device. @@ -190,10 +191,10 @@ func TestVerification_SAS(t *testing.T) { // Ensure that the SAS codes are the same. if tc.sendingStartsSAS { // Process the second key event on the sending device. - ts.DispatchToDevice(t, ctx, sendingClient) + ts.dispatchToDevice(t, ctx, sendingClient) } else { // Process the second key event on the receiving device. - ts.DispatchToDevice(t, ctx, receivingClient) + ts.dispatchToDevice(t, ctx, receivingClient) } assert.Equal(t, sendingCallbacks.GetDecimalsShown(txnID), receivingCallbacks.GetDecimalsShown(txnID)) sendingEmojis, sendingDescriptions := sendingCallbacks.GetEmojisAndDescriptionsShown(txnID) @@ -273,10 +274,10 @@ func TestVerification_SAS(t *testing.T) { // Test the transaction is done on both sides. We have to dispatch // twice to process and drain all of the events. - ts.DispatchToDevice(t, ctx, sendingClient) - ts.DispatchToDevice(t, ctx, receivingClient) - ts.DispatchToDevice(t, ctx, sendingClient) - ts.DispatchToDevice(t, ctx, receivingClient) + ts.dispatchToDevice(t, ctx, sendingClient) + ts.dispatchToDevice(t, ctx, receivingClient) + ts.dispatchToDevice(t, ctx, sendingClient) + ts.dispatchToDevice(t, ctx, receivingClient) assert.True(t, sendingCallbacks.IsVerificationDone(txnID)) assert.True(t, receivingCallbacks.IsVerificationDone(txnID)) }) @@ -287,6 +288,7 @@ func TestVerification_SAS_BothCallStart(t *testing.T) { ctx := log.Logger.WithContext(context.TODO()) ts, sendingClient, receivingClient, _, _, sendingMachine, receivingMachine := initServerAndLoginTwoAlice(t, ctx) + defer ts.Close() sendingCallbacks, receivingCallbacks, sendingHelper, receivingHelper := initDefaultCallbacks(t, ctx, sendingClient, receivingClient, sendingMachine, receivingMachine) var err error @@ -303,10 +305,10 @@ func TestVerification_SAS_BothCallStart(t *testing.T) { // event on the sending device. txnID, err := sendingHelper.StartVerification(ctx, aliceUserID) require.NoError(t, err) - ts.DispatchToDevice(t, ctx, receivingClient) + ts.dispatchToDevice(t, ctx, receivingClient) err = receivingHelper.AcceptVerification(ctx, txnID) require.NoError(t, err) - ts.DispatchToDevice(t, ctx, sendingClient) + ts.dispatchToDevice(t, ctx, sendingClient) err = sendingHelper.StartSAS(ctx, txnID) require.NoError(t, err) @@ -323,7 +325,7 @@ func TestVerification_SAS_BothCallStart(t *testing.T) { assert.Equal(t, txnID, sendingInbox[0].Content.AsVerificationStart().TransactionID) // Process the start event from the receiving client to the sending client. - ts.DispatchToDevice(t, ctx, sendingClient) + ts.dispatchToDevice(t, ctx, sendingClient) receivingInbox = ts.DeviceInbox[aliceUserID][receivingDeviceID] assert.Len(t, receivingInbox, 2) assert.Equal(t, txnID, receivingInbox[0].Content.AsVerificationStart().TransactionID) @@ -331,13 +333,13 @@ func TestVerification_SAS_BothCallStart(t *testing.T) { // Process the rest of the events until we need to confirm the SAS. for len(ts.DeviceInbox[aliceUserID][sendingDeviceID]) > 0 || len(ts.DeviceInbox[aliceUserID][receivingDeviceID]) > 0 { - ts.DispatchToDevice(t, ctx, receivingClient) - ts.DispatchToDevice(t, ctx, sendingClient) + ts.dispatchToDevice(t, ctx, receivingClient) + ts.dispatchToDevice(t, ctx, sendingClient) } // Confirm the SAS only the receiving device. receivingHelper.ConfirmSAS(ctx, txnID) - ts.DispatchToDevice(t, ctx, sendingClient) + ts.dispatchToDevice(t, ctx, sendingClient) // Verification is not done until both devices confirm the SAS. assert.False(t, sendingCallbacks.IsVerificationDone(txnID)) @@ -348,13 +350,13 @@ func TestVerification_SAS_BothCallStart(t *testing.T) { // Dispatching the events to the receiving device should get us to the done // state on the receiving device. - ts.DispatchToDevice(t, ctx, receivingClient) + ts.dispatchToDevice(t, ctx, receivingClient) assert.False(t, sendingCallbacks.IsVerificationDone(txnID)) assert.True(t, receivingCallbacks.IsVerificationDone(txnID)) // Dispatching the events to the sending client should get us to the done // state on the sending device. - ts.DispatchToDevice(t, ctx, sendingClient) + ts.dispatchToDevice(t, ctx, sendingClient) assert.True(t, sendingCallbacks.IsVerificationDone(txnID)) assert.True(t, receivingCallbacks.IsVerificationDone(txnID)) } diff --git a/crypto/verificationhelper/verificationhelper_test.go b/crypto/verificationhelper/verificationhelper_test.go index ce5ec5b4..b4c21c18 100644 --- a/crypto/verificationhelper/verificationhelper_test.go +++ b/crypto/verificationhelper/verificationhelper_test.go @@ -19,7 +19,6 @@ import ( "maunium.net/go/mautrix/crypto/verificationhelper" "maunium.net/go/mautrix/event" "maunium.net/go/mautrix/id" - "maunium.net/go/mautrix/mockserver" ) var aliceUserID = id.UserID("@alice:example.org") @@ -32,19 +31,9 @@ func init() { zerolog.DefaultContextLogger = &log.Logger } -func addDeviceID(ctx context.Context, cryptoStore crypto.Store, userID id.UserID, deviceID id.DeviceID) { - err := cryptoStore.PutDevice(ctx, userID, &id.Device{ - UserID: userID, - DeviceID: deviceID, - }) - if err != nil { - panic(err) - } -} - -func initServerAndLoginTwoAlice(t *testing.T, ctx context.Context) (ts *mockserver.MockServer, sendingClient, receivingClient *mautrix.Client, sendingCryptoStore, receivingCryptoStore crypto.Store, sendingMachine, receivingMachine *crypto.OlmMachine) { +func initServerAndLoginTwoAlice(t *testing.T, ctx context.Context) (ts *mockServer, sendingClient, receivingClient *mautrix.Client, sendingCryptoStore, receivingCryptoStore crypto.Store, sendingMachine, receivingMachine *crypto.OlmMachine) { t.Helper() - ts = mockserver.Create(t) + ts = createMockServer(t) sendingClient, sendingCryptoStore = ts.Login(t, ctx, aliceUserID, sendingDeviceID) sendingMachine = sendingClient.Crypto.(*cryptohelper.CryptoHelper).Machine() @@ -58,9 +47,9 @@ func initServerAndLoginTwoAlice(t *testing.T, ctx context.Context) (ts *mockserv return } -func initServerAndLoginAliceBob(t *testing.T, ctx context.Context) (ts *mockserver.MockServer, sendingClient, receivingClient *mautrix.Client, sendingCryptoStore, receivingCryptoStore crypto.Store, sendingMachine, receivingMachine *crypto.OlmMachine) { +func initServerAndLoginAliceBob(t *testing.T, ctx context.Context) (ts *mockServer, sendingClient, receivingClient *mautrix.Client, sendingCryptoStore, receivingCryptoStore crypto.Store, sendingMachine, receivingMachine *crypto.OlmMachine) { t.Helper() - ts = mockserver.Create(t) + ts = createMockServer(t) sendingClient, sendingCryptoStore = ts.Login(t, ctx, aliceUserID, sendingDeviceID) sendingMachine = sendingClient.Crypto.(*cryptohelper.CryptoHelper).Machine() @@ -127,7 +116,8 @@ func TestVerification_Start(t *testing.T) { for i, tc := range testCases { t.Run(fmt.Sprintf("%d", i), func(t *testing.T) { - ts := mockserver.Create(t) + ts := createMockServer(t) + defer ts.Close() client, cryptoStore := ts.Login(t, ctx, aliceUserID, sendingDeviceID) addDeviceID(ctx, cryptoStore, aliceUserID, sendingDeviceID) @@ -176,6 +166,7 @@ func TestVerification_StartThenCancel(t *testing.T) { for _, sendingCancels := range []bool{true, false} { t.Run(fmt.Sprintf("sendingCancels=%t", sendingCancels), func(t *testing.T) { ts, sendingClient, receivingClient, sendingCryptoStore, receivingCryptoStore, sendingMachine, receivingMachine := initServerAndLoginTwoAlice(t, ctx) + defer ts.Close() _, _, sendingHelper, receivingHelper := initDefaultCallbacks(t, ctx, sendingClient, receivingClient, sendingMachine, receivingMachine) bystanderClient, _ := ts.Login(t, ctx, aliceUserID, bystanderDeviceID) @@ -195,13 +186,13 @@ func TestVerification_StartThenCancel(t *testing.T) { receivingInbox := ts.DeviceInbox[aliceUserID][receivingDeviceID] assert.Len(t, receivingInbox, 1) assert.Equal(t, txnID, receivingInbox[0].Content.AsVerificationRequest().TransactionID) - ts.DispatchToDevice(t, ctx, receivingClient) + ts.dispatchToDevice(t, ctx, receivingClient) // Process the request event on the bystander device. bystanderInbox := ts.DeviceInbox[aliceUserID][bystanderDeviceID] assert.Len(t, bystanderInbox, 1) assert.Equal(t, txnID, bystanderInbox[0].Content.AsVerificationRequest().TransactionID) - ts.DispatchToDevice(t, ctx, bystanderClient) + ts.dispatchToDevice(t, ctx, bystanderClient) // Cancel the verification request. var cancelEvt *event.VerificationCancelEventContent @@ -240,7 +231,7 @@ func TestVerification_StartThenCancel(t *testing.T) { if !sendingCancels { // Process the cancellation event on the sending device. - ts.DispatchToDevice(t, ctx, sendingClient) + ts.dispatchToDevice(t, ctx, sendingClient) // Ensure that the cancellation event was sent to the bystander device. assert.Len(t, ts.DeviceInbox[aliceUserID][bystanderDeviceID], 1) @@ -256,7 +247,8 @@ func TestVerification_StartThenCancel(t *testing.T) { func TestVerification_Accept_NoSupportedMethods(t *testing.T) { ctx := log.Logger.WithContext(context.TODO()) - ts := mockserver.Create(t) + ts := createMockServer(t) + defer ts.Close() sendingClient, sendingCryptoStore := ts.Login(t, ctx, aliceUserID, sendingDeviceID) receivingClient, _ := ts.Login(t, ctx, aliceUserID, receivingDeviceID) @@ -282,7 +274,7 @@ func TestVerification_Accept_NoSupportedMethods(t *testing.T) { require.NoError(t, err) require.NotEmpty(t, txnID) - ts.DispatchToDevice(t, ctx, receivingClient) + ts.dispatchToDevice(t, ctx, receivingClient) // Ensure that the receiver ignored the request because it // doesn't support any of the verification methods in the @@ -322,6 +314,7 @@ func TestVerification_Accept_CorrectMethodsPresented(t *testing.T) { for i, tc := range testCases { t.Run(fmt.Sprintf("%d", i), func(t *testing.T) { ts, sendingClient, receivingClient, _, _, sendingMachine, receivingMachine := initServerAndLoginTwoAlice(t, ctx) + defer ts.Close() recoveryKey, sendingCrossSigningKeysCache, err := sendingMachine.GenerateAndUploadCrossSigningKeys(ctx, nil, "") assert.NoError(t, err) @@ -340,7 +333,7 @@ func TestVerification_Accept_CorrectMethodsPresented(t *testing.T) { require.NoError(t, err) // Process the verification request on the receiving device. - ts.DispatchToDevice(t, ctx, receivingClient) + ts.dispatchToDevice(t, ctx, receivingClient) // Ensure that the receiving device received a verification // request with the correct transaction ID. @@ -380,7 +373,7 @@ func TestVerification_Accept_CorrectMethodsPresented(t *testing.T) { // Receive the m.key.verification.ready event on the sending // device. - ts.DispatchToDevice(t, ctx, sendingClient) + ts.dispatchToDevice(t, ctx, sendingClient) // Ensure that the sending device got a notification about the // transaction being ready. @@ -409,6 +402,7 @@ func TestVerification_Accept_CorrectMethodsPresented(t *testing.T) { func TestVerification_Accept_CancelOnNonParticipatingDevices(t *testing.T) { ctx := log.Logger.WithContext(context.TODO()) ts, sendingClient, receivingClient, sendingCryptoStore, receivingCryptoStore, sendingMachine, receivingMachine := initServerAndLoginTwoAlice(t, ctx) + defer ts.Close() _, _, sendingHelper, receivingHelper := initDefaultCallbacks(t, ctx, sendingClient, receivingClient, sendingMachine, receivingMachine) nonParticipatingDeviceID1 := id.DeviceID("non-participating1") @@ -425,12 +419,12 @@ func TestVerification_Accept_CancelOnNonParticipatingDevices(t *testing.T) { // the receiving device. txnID, err := sendingHelper.StartVerification(ctx, aliceUserID) require.NoError(t, err) - ts.DispatchToDevice(t, ctx, receivingClient) + ts.dispatchToDevice(t, ctx, receivingClient) err = receivingHelper.AcceptVerification(ctx, txnID) require.NoError(t, err) // Receive the m.key.verification.ready event on the sending device. - ts.DispatchToDevice(t, ctx, sendingClient) + ts.dispatchToDevice(t, ctx, sendingClient) // The sending and receiving devices should not have any cancellation // events in their inboxes. @@ -450,6 +444,7 @@ func TestVerification_Accept_CancelOnNonParticipatingDevices(t *testing.T) { func TestVerification_ErrorOnDoubleAccept(t *testing.T) { ctx := log.Logger.WithContext(context.TODO()) ts, sendingClient, receivingClient, _, _, sendingMachine, receivingMachine := initServerAndLoginTwoAlice(t, ctx) + defer ts.Close() _, _, sendingHelper, receivingHelper := initDefaultCallbacks(t, ctx, sendingClient, receivingClient, sendingMachine, receivingMachine) _, _, err := sendingMachine.GenerateAndUploadCrossSigningKeys(ctx, nil, "") @@ -457,7 +452,7 @@ func TestVerification_ErrorOnDoubleAccept(t *testing.T) { txnID, err := sendingHelper.StartVerification(ctx, aliceUserID) require.NoError(t, err) - ts.DispatchToDevice(t, ctx, receivingClient) + ts.dispatchToDevice(t, ctx, receivingClient) err = receivingHelper.AcceptVerification(ctx, txnID) require.NoError(t, err) err = receivingHelper.AcceptVerification(ctx, txnID) @@ -477,6 +472,7 @@ func TestVerification_ErrorOnDoubleAccept(t *testing.T) { func TestVerification_CancelOnDoubleStart(t *testing.T) { ctx := log.Logger.WithContext(context.TODO()) ts, sendingClient, receivingClient, _, _, sendingMachine, receivingMachine := initServerAndLoginTwoAlice(t, ctx) + defer ts.Close() sendingCallbacks, receivingCallbacks, sendingHelper, receivingHelper := initDefaultCallbacks(t, ctx, sendingClient, receivingClient, sendingMachine, receivingMachine) _, _, err := sendingMachine.GenerateAndUploadCrossSigningKeys(ctx, nil, "") @@ -485,15 +481,15 @@ func TestVerification_CancelOnDoubleStart(t *testing.T) { // Send and accept the first verification request. txnID1, err := sendingHelper.StartVerification(ctx, aliceUserID) require.NoError(t, err) - ts.DispatchToDevice(t, ctx, receivingClient) + ts.dispatchToDevice(t, ctx, receivingClient) err = receivingHelper.AcceptVerification(ctx, txnID1) require.NoError(t, err) - ts.DispatchToDevice(t, ctx, sendingClient) // Process the m.key.verification.ready event + ts.dispatchToDevice(t, ctx, sendingClient) // Process the m.key.verification.ready event // Send a second verification request txnID2, err := sendingHelper.StartVerification(ctx, aliceUserID) require.NoError(t, err) - ts.DispatchToDevice(t, ctx, receivingClient) + ts.dispatchToDevice(t, ctx, receivingClient) // Ensure that the sending device received a cancellation event for both of // the ongoing transactions. @@ -511,7 +507,7 @@ func TestVerification_CancelOnDoubleStart(t *testing.T) { assert.NotNil(t, receivingCallbacks.GetVerificationCancellation(txnID1)) assert.NotNil(t, receivingCallbacks.GetVerificationCancellation(txnID2)) - ts.DispatchToDevice(t, ctx, sendingClient) // Process the m.key.verification.cancel events + ts.dispatchToDevice(t, ctx, sendingClient) // Process the m.key.verification.cancel events assert.NotNil(t, sendingCallbacks.GetVerificationCancellation(txnID1)) assert.NotNil(t, sendingCallbacks.GetVerificationCancellation(txnID2)) } diff --git a/error.go b/error.go index 4711b3dc..6f4880df 100644 --- a/error.go +++ b/error.go @@ -13,7 +13,6 @@ import ( "net/http" "go.mau.fi/util/exhttp" - "go.mau.fi/util/exmaps" "golang.org/x/exp/maps" ) @@ -67,8 +66,6 @@ var ( MIncompatibleRoomVersion = RespError{ErrCode: "M_INCOMPATIBLE_ROOM_VERSION"} // The client specified a parameter that has the wrong value. MInvalidParam = RespError{ErrCode: "M_INVALID_PARAM", StatusCode: http.StatusBadRequest} - // The client specified a room key backup version that is not the current room key backup version for the user. - MWrongRoomKeysVersion = RespError{ErrCode: "M_WRONG_ROOM_KEYS_VERSION", StatusCode: http.StatusForbidden} MURLNotSet = RespError{ErrCode: "M_URL_NOT_SET"} MBadStatus = RespError{ErrCode: "M_BAD_STATUS"} @@ -82,13 +79,6 @@ var ( var ( ErrClientIsNil = errors.New("client is nil") ErrClientHasNoHomeserver = errors.New("client has no homeserver set") - - ErrResponseTooLong = errors.New("response content length too long") - ErrBodyReadReachedLimit = errors.New("reached response size limit while reading body") - - // Special error that indicates we should retry canceled contexts. Note that on it's own this - // is useless, the context itself must also be replaced. - ErrContextCancelRetry = errors.New("retry canceled context") ) // HTTPError An HTTP Error response, which may wrap an underlying native Go Error. @@ -140,10 +130,7 @@ type RespError struct { Err string ExtraData map[string]any - StatusCode int - ExtraHeader map[string]string - - CanRetry bool + StatusCode int } func (e *RespError) UnmarshalJSON(data []byte) error { @@ -153,17 +140,16 @@ func (e *RespError) UnmarshalJSON(data []byte) error { } e.ErrCode, _ = e.ExtraData["errcode"].(string) e.Err, _ = e.ExtraData["error"].(string) - e.CanRetry, _ = e.ExtraData["com.beeper.can_retry"].(bool) return nil } func (e *RespError) MarshalJSON() ([]byte, error) { - data := exmaps.NonNilClone(e.ExtraData) + data := maps.Clone(e.ExtraData) + if data == nil { + data = make(map[string]any) + } data["errcode"] = e.ErrCode data["error"] = e.Err - if e.CanRetry { - data["com.beeper.can_retry"] = e.CanRetry - } return json.Marshal(data) } @@ -175,9 +161,6 @@ func (e RespError) Write(w http.ResponseWriter) { if statusCode == 0 { statusCode = http.StatusInternalServerError } - for key, value := range e.ExtraHeader { - w.Header().Set(key, value) - } exhttp.WriteJSONResponse(w, statusCode, &e) } @@ -194,29 +177,6 @@ func (e RespError) WithStatus(status int) RespError { return e } -func (e RespError) WithCanRetry(canRetry bool) RespError { - e.CanRetry = canRetry - return e -} - -func (e RespError) WithExtraData(extraData map[string]any) RespError { - e.ExtraData = exmaps.NonNilClone(e.ExtraData) - maps.Copy(e.ExtraData, extraData) - return e -} - -func (e RespError) WithExtraHeader(key, value string) RespError { - e.ExtraHeader = exmaps.NonNilClone(e.ExtraHeader) - e.ExtraHeader[key] = value - return e -} - -func (e RespError) WithExtraHeaders(headers map[string]string) RespError { - e.ExtraHeader = exmaps.NonNilClone(e.ExtraHeader) - maps.Copy(e.ExtraHeader, headers) - return e -} - // Error returns the errcode and error message. func (e RespError) Error() string { return e.ErrCode + ": " + e.Err diff --git a/event/accountdata.go b/event/accountdata.go index 223919a1..30ca35a2 100644 --- a/event/accountdata.go +++ b/event/accountdata.go @@ -105,15 +105,3 @@ func (bmec *BeeperMuteEventContent) GetMutedUntilTime() time.Time { } return time.Time{} } - -func (bmec *BeeperMuteEventContent) GetMuteDuration() time.Duration { - ts := bmec.GetMutedUntilTime() - now := time.Now() - if ts.Before(now) { - return 0 - } else if ts == MutedForever { - return -1 - } else { - return ts.Sub(now) - } -} diff --git a/event/beeper.go b/event/beeper.go index a1a60b35..921e3466 100644 --- a/event/beeper.go +++ b/event/beeper.go @@ -53,8 +53,6 @@ type BeeperMessageStatusEventContent struct { LastRetry id.EventID `json:"last_retry,omitempty"` - TargetTxnID string `json:"relates_to_txn_id,omitempty"` - MutateEventKey string `json:"mutate_event_key,omitempty"` // Indicates the set of users to whom the event was delivered. If nil, then @@ -88,22 +86,6 @@ type BeeperRoomKeyAckEventContent struct { FirstMessageIndex int `json:"first_message_index"` } -type BeeperChatDeleteEventContent struct { - DeleteForEveryone bool `json:"delete_for_everyone,omitempty"` - FromMessageRequest bool `json:"from_message_request,omitempty"` -} - -type BeeperAcceptMessageRequestEventContent struct { - // Whether this was triggered by a message rather than an explicit event - IsImplicit bool `json:"-"` -} - -type BeeperSendStateEventContent struct { - Type string `json:"type"` - StateKey string `json:"state_key"` - Content Content `json:"content"` -} - type IntOrString int func (ios *IntOrString) UnmarshalJSON(data []byte) error { @@ -146,7 +128,6 @@ type BeeperLinkPreview struct { MatchedURL string `json:"matched_url,omitempty"` ImageEncryption *EncryptedFileInfo `json:"beeper:image:encryption,omitempty"` - ImageBlurhash string `json:"matrix:image:blurhash,omitempty"` } type BeeperProfileExtra struct { @@ -166,24 +147,6 @@ type BeeperPerMessageProfile struct { HasFallback bool `json:"has_fallback,omitempty"` } -type BeeperActionMessageType string - -const ( - BeeperActionMessageCall BeeperActionMessageType = "call" -) - -type BeeperActionMessageCallType string - -const ( - BeeperActionMessageCallTypeVoice BeeperActionMessageCallType = "voice" - BeeperActionMessageCallTypeVideo BeeperActionMessageCallType = "video" -) - -type BeeperActionMessage struct { - Type BeeperActionMessageType `json:"type"` - CallType BeeperActionMessageCallType `json:"call_type,omitempty"` -} - func (content *MessageEventContent) AddPerMessageProfileFallback() { if content.BeeperPerMessageProfile == nil || content.BeeperPerMessageProfile.HasFallback || content.BeeperPerMessageProfile.Displayname == "" { return @@ -214,15 +177,6 @@ func (content *MessageEventContent) RemovePerMessageProfileFallback() { } } -type BeeperAIStreamEventContent struct { - TurnID string `json:"turn_id"` - Seq int `json:"seq"` - Part map[string]any `json:"part"` - TargetEvent id.EventID `json:"target_event,omitempty"` - AgentID string `json:"agent_id,omitempty"` - RelatesTo *RelatesTo `json:"m.relates_to,omitempty"` -} - type BeeperEncodedOrder struct { order int64 suborder int16 diff --git a/event/capabilities.d.ts b/event/capabilities.d.ts index 26aeb347..4cf29de7 100644 --- a/event/capabilities.d.ts +++ b/event/capabilities.d.ts @@ -16,23 +16,6 @@ export interface RoomFeatures { * If a message type isn't listed here, it should be treated as support level -2 (will be rejected). */ file?: Record - /** - * Supported state event types and their parameters. Currently, there are no parameters, - * but it is likely there will be some in the future (like max name/topic length, avatar mime types, etc.). - * - * Events that are not listed or have a support level of zero or below should be treated as unsupported. - * - * Clients should at least check `m.room.name`, `m.room.topic`, and `m.room.avatar` here. - * `m.room.member` will not be listed here, as it's controlled by the member_actions field. - * `com.beeper.disappearing_timer` should be listed here, but the parameters are in the disappearing_timer field for now. - */ - state?: Record - /** - * Supported member actions and their support levels. - * - * Actions that are not listed or have a support level of zero or below should be treated as unsupported. - */ - member_actions?: Record /** Maximum length of normal text messages. */ max_text_length?: integer @@ -58,8 +41,6 @@ export interface RoomFeatures { delete_max_age?: seconds /** Whether deleting messages just for yourself is supported. No message age limit. */ delete_for_me?: boolean - /** Allowed configuration options for disappearing timers. */ - disappearing_timer?: DisappearingTimerCapability /** Whether reactions are supported. */ reaction?: CapabilitySupportLevel @@ -72,21 +53,10 @@ export interface RoomFeatures { allowed_reactions?: string[] /** Whether custom emoji reactions are allowed. */ custom_emoji_reactions?: boolean - - /** Whether deleting the chat for yourself is supported. */ - delete_chat?: boolean - /** Whether deleting the chat for all participants is supported. */ - delete_chat_for_everyone?: boolean - /** What can be done with message requests? */ - message_request?: { - accept_with_message?: CapabilitySupportLevel - accept_with_button?: CapabilitySupportLevel - } } declare type integer = number declare type seconds = integer -declare type milliseconds = integer declare type MIMEClass = "image" | "audio" | "video" | "text" | "font" | "model" | "application" declare type MIMETypeOrPattern = "*/*" @@ -94,21 +64,6 @@ declare type MIMETypeOrPattern = | `${MIMEClass}/${string}` | `${MIMEClass}/${string}; ${string}` -export enum MemberAction { - Ban = "ban", - Kick = "kick", - Leave = "leave", - RevokeInvite = "revoke_invite", - Invite = "invite", -} - -declare type EventType = string - -// This is an object for future extensibility (e.g. max name/topic length) -export interface StateFeatures { - level: CapabilitySupportLevel -} - export enum CapabilityMsgType { // Real message types used in the `msgtype` field Image = "m.image", @@ -151,25 +106,6 @@ export interface FileFeatures { view_once?: boolean } -export enum DisappearingType { - None = "", - AfterRead = "after_read", - AfterSend = "after_send", -} - -export interface DisappearingTimerCapability { - types: DisappearingType[] - /** Allowed timer values. If omitted, any timer is allowed. */ - timers?: milliseconds[] - /** - * Whether clients should omit the empty disappearing_timer object in messages that they don't want to disappear - * - * Generally, bridged rooms will want the object to be always present, while native Matrix rooms don't, - * so the hardcoded features for Matrix rooms should set this to true, while bridges will not. - */ - omit_empty_timer?: true -} - /** * The support level for a feature. These are integers rather than booleans * to accurately represent what the bridge is doing and hopefully make the diff --git a/event/capabilities.go b/event/capabilities.go index a86c726b..9c9eb09a 100644 --- a/event/capabilities.go +++ b/event/capabilities.go @@ -18,7 +18,6 @@ import ( "go.mau.fi/util/exerrors" "go.mau.fi/util/jsontime" - "go.mau.fi/util/ptr" "golang.org/x/exp/constraints" "golang.org/x/exp/maps" ) @@ -28,10 +27,8 @@ type RoomFeatures struct { // N.B. New fields need to be added to the Hash function to be included in the deduplication hash. - Formatting FormattingFeatureMap `json:"formatting,omitempty"` - File FileFeatureMap `json:"file,omitempty"` - State StateFeatureMap `json:"state,omitempty"` - MemberActions MemberFeatureMap `json:"member_actions,omitempty"` + Formatting FormattingFeatureMap `json:"formatting,omitempty"` + File FileFeatureMap `json:"file,omitempty"` MaxTextLength int `json:"max_text_length,omitempty"` @@ -47,23 +44,16 @@ type RoomFeatures struct { DeleteForMe bool `json:"delete_for_me,omitempty"` DeleteMaxAge *jsontime.Seconds `json:"delete_max_age,omitempty"` - DisappearingTimer *DisappearingTimerCapability `json:"disappearing_timer,omitempty"` - Reaction CapabilitySupportLevel `json:"reaction,omitempty"` ReactionCount int `json:"reaction_count,omitempty"` AllowedReactions []string `json:"allowed_reactions,omitempty"` CustomEmojiReactions bool `json:"custom_emoji_reactions,omitempty"` - ReadReceipts bool `json:"read_receipts,omitempty"` - TypingNotifications bool `json:"typing_notifications,omitempty"` - Archive bool `json:"archive,omitempty"` - MarkAsUnread bool `json:"mark_as_unread,omitempty"` - DeleteChat bool `json:"delete_chat,omitempty"` - DeleteChatForEveryone bool `json:"delete_chat_for_everyone,omitempty"` - - MessageRequest *MessageRequestFeatures `json:"message_request,omitempty"` - - PerMessageProfileRelay bool `json:"-"` + ReadReceipts bool `json:"read_receipts,omitempty"` + TypingNotifications bool `json:"typing_notifications,omitempty"` + Archive bool `json:"archive,omitempty"` + MarkAsUnread bool `json:"mark_as_unread,omitempty"` + DeleteChat bool `json:"delete_chat,omitempty"` } func (rf *RoomFeatures) GetID() string { @@ -73,120 +63,10 @@ func (rf *RoomFeatures) GetID() string { return base64.RawURLEncoding.EncodeToString(rf.Hash()) } -func (rf *RoomFeatures) Clone() *RoomFeatures { - if rf == nil { - return nil - } - clone := *rf - clone.File = clone.File.Clone() - clone.Formatting = maps.Clone(clone.Formatting) - clone.State = clone.State.Clone() - clone.MemberActions = clone.MemberActions.Clone() - clone.EditMaxAge = ptr.Clone(clone.EditMaxAge) - clone.DeleteMaxAge = ptr.Clone(clone.DeleteMaxAge) - clone.DisappearingTimer = clone.DisappearingTimer.Clone() - clone.AllowedReactions = slices.Clone(clone.AllowedReactions) - clone.MessageRequest = clone.MessageRequest.Clone() - return &clone -} - -type MemberFeatureMap map[MemberAction]CapabilitySupportLevel - -func (mfm MemberFeatureMap) Clone() MemberFeatureMap { - return maps.Clone(mfm) -} - -type MemberAction string - -const ( - MemberActionBan MemberAction = "ban" - MemberActionKick MemberAction = "kick" - MemberActionLeave MemberAction = "leave" - MemberActionRevokeInvite MemberAction = "revoke_invite" - MemberActionInvite MemberAction = "invite" -) - -type StateFeatureMap map[string]*StateFeatures - -func (sfm StateFeatureMap) Clone() StateFeatureMap { - dup := maps.Clone(sfm) - for key, value := range dup { - dup[key] = value.Clone() - } - return dup -} - -type StateFeatures struct { - Level CapabilitySupportLevel `json:"level"` -} - -func (sf *StateFeatures) Clone() *StateFeatures { - if sf == nil { - return nil - } - clone := *sf - return &clone -} - -func (sf *StateFeatures) Hash() []byte { - return sf.Level.Hash() -} - type FormattingFeatureMap map[FormattingFeature]CapabilitySupportLevel type FileFeatureMap map[CapabilityMsgType]*FileFeatures -func (ffm FileFeatureMap) Clone() FileFeatureMap { - dup := maps.Clone(ffm) - for key, value := range dup { - dup[key] = value.Clone() - } - return dup -} - -type DisappearingTimerCapability struct { - Types []DisappearingType `json:"types"` - Timers []jsontime.Milliseconds `json:"timers,omitempty"` - - OmitEmptyTimer bool `json:"omit_empty_timer,omitempty"` -} - -func (dtc *DisappearingTimerCapability) Clone() *DisappearingTimerCapability { - if dtc == nil { - return nil - } - clone := *dtc - clone.Types = slices.Clone(clone.Types) - clone.Timers = slices.Clone(clone.Timers) - return &clone -} - -func (dtc *DisappearingTimerCapability) Supports(content *BeeperDisappearingTimer) bool { - if dtc == nil || content == nil || content.Type == DisappearingTypeNone { - return true - } - return slices.Contains(dtc.Types, content.Type) && (dtc.Timers == nil || slices.Contains(dtc.Timers, content.Timer)) -} - -type MessageRequestFeatures struct { - AcceptWithMessage CapabilitySupportLevel `json:"accept_with_message,omitempty"` - AcceptWithButton CapabilitySupportLevel `json:"accept_with_button,omitempty"` -} - -func (mrf *MessageRequestFeatures) Clone() *MessageRequestFeatures { - return ptr.Clone(mrf) -} - -func (mrf *MessageRequestFeatures) Hash() []byte { - if mrf == nil { - return nil - } - hasher := sha256.New() - hashValue(hasher, "accept_with_message", mrf.AcceptWithMessage) - hashValue(hasher, "accept_with_button", mrf.AcceptWithButton) - return hasher.Sum(nil) -} - type CapabilityMsgType = MessageType // Message types which are used for event capability signaling, but aren't real values for the msgtype field. @@ -336,8 +216,6 @@ func (rf *RoomFeatures) Hash() []byte { hashMap(hasher, "formatting", rf.Formatting) hashMap(hasher, "file", rf.File) - hashMap(hasher, "state", rf.State) - hashMap(hasher, "member_actions", rf.MemberActions) hashInt(hasher, "max_text_length", rf.MaxTextLength) @@ -353,7 +231,6 @@ func (rf *RoomFeatures) Hash() []byte { hashValue(hasher, "delete", rf.Delete) hashBool(hasher, "delete_for_me", rf.DeleteForMe) hashInt(hasher, "delete_max_age", rf.DeleteMaxAge.Get()) - hashValue(hasher, "disappearing_timer", rf.DisappearingTimer) hashValue(hasher, "reaction", rf.Reaction) hashInt(hasher, "reaction_count", rf.ReactionCount) @@ -368,28 +245,10 @@ func (rf *RoomFeatures) Hash() []byte { hashBool(hasher, "archive", rf.Archive) hashBool(hasher, "mark_as_unread", rf.MarkAsUnread) hashBool(hasher, "delete_chat", rf.DeleteChat) - hashBool(hasher, "delete_chat_for_everyone", rf.DeleteChatForEveryone) - hashValue(hasher, "message_request", rf.MessageRequest) return hasher.Sum(nil) } -func (dtc *DisappearingTimerCapability) Hash() []byte { - if dtc == nil { - return nil - } - hasher := sha256.New() - hasher.Write([]byte("types")) - for _, t := range dtc.Types { - hasher.Write([]byte(t)) - } - hasher.Write([]byte("timers")) - for _, timer := range dtc.Timers { - hashInt(hasher, "", timer.Milliseconds()) - } - return hasher.Sum(nil) -} - func (ff *FileFeatures) Hash() []byte { hasher := sha256.New() hashMap(hasher, "mime_types", ff.MimeTypes) @@ -402,13 +261,3 @@ func (ff *FileFeatures) Hash() []byte { hashBool(hasher, "view_once", ff.ViewOnce) return hasher.Sum(nil) } - -func (ff *FileFeatures) Clone() *FileFeatures { - if ff == nil { - return nil - } - clone := *ff - clone.MimeTypes = maps.Clone(clone.MimeTypes) - clone.MaxDuration = ptr.Clone(clone.MaxDuration) - return &clone -} diff --git a/event/cmdschema/content.go b/event/cmdschema/content.go deleted file mode 100644 index ce07c4c0..00000000 --- a/event/cmdschema/content.go +++ /dev/null @@ -1,78 +0,0 @@ -// Copyright (c) 2026 Tulir Asokan -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. - -package cmdschema - -import ( - "crypto/sha256" - "encoding/base64" - "fmt" - "reflect" - "slices" - - "go.mau.fi/util/exsync" - "go.mau.fi/util/ptr" - - "maunium.net/go/mautrix/event" - "maunium.net/go/mautrix/id" -) - -type EventContent struct { - Command string `json:"command"` - Aliases []string `json:"aliases,omitempty"` - Parameters []*Parameter `json:"parameters,omitempty"` - Description *event.ExtensibleTextContainer `json:"description,omitempty"` - TailParam string `json:"fi.mau.tail_parameter,omitempty"` -} - -func (ec *EventContent) Validate() error { - if ec == nil { - return fmt.Errorf("event content is nil") - } else if ec.Command == "" { - return fmt.Errorf("command is empty") - } - var tailFound bool - dupMap := exsync.NewSet[string]() - for i, p := range ec.Parameters { - if err := p.Validate(); err != nil { - return fmt.Errorf("parameter %q (#%d) is invalid: %w", ptr.Val(p).Key, i+1, err) - } else if !dupMap.Add(p.Key) { - return fmt.Errorf("duplicate parameter key %q at #%d", p.Key, i+1) - } else if p.Key == ec.TailParam { - tailFound = true - } else if tailFound && !p.Optional { - return fmt.Errorf("required parameter %q (#%d) is after tail parameter %q", p.Key, i+1, ec.TailParam) - } - } - if ec.TailParam != "" && !tailFound { - return fmt.Errorf("tail parameter %q not found in parameters", ec.TailParam) - } - return nil -} - -func (ec *EventContent) IsValid() bool { - return ec.Validate() == nil -} - -func (ec *EventContent) StateKey(owner id.UserID) string { - hash := sha256.Sum256([]byte(ec.Command + owner.String())) - return base64.StdEncoding.EncodeToString(hash[:]) -} - -func (ec *EventContent) Equals(other *EventContent) bool { - if ec == nil || other == nil { - return ec == other - } - return ec.Command == other.Command && - slices.Equal(ec.Aliases, other.Aliases) && - slices.EqualFunc(ec.Parameters, other.Parameters, (*Parameter).Equals) && - ec.Description.Equals(other.Description) && - ec.TailParam == other.TailParam -} - -func init() { - event.TypeMap[event.StateMSC4391BotCommand] = reflect.TypeOf(EventContent{}) -} diff --git a/event/cmdschema/parameter.go b/event/cmdschema/parameter.go deleted file mode 100644 index 4193b297..00000000 --- a/event/cmdschema/parameter.go +++ /dev/null @@ -1,286 +0,0 @@ -// Copyright (c) 2026 Tulir Asokan -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. - -package cmdschema - -import ( - "encoding/json" - "fmt" - "slices" - - "go.mau.fi/util/exslices" - - "maunium.net/go/mautrix/event" -) - -type Parameter struct { - Key string `json:"key"` - Schema *ParameterSchema `json:"schema"` - Optional bool `json:"optional,omitempty"` - Description *event.ExtensibleTextContainer `json:"description,omitempty"` - DefaultValue any `json:"fi.mau.default_value,omitempty"` -} - -func (p *Parameter) Equals(other *Parameter) bool { - if p == nil || other == nil { - return p == other - } - return p.Key == other.Key && - p.Schema.Equals(other.Schema) && - p.Optional == other.Optional && - p.Description.Equals(other.Description) && - p.DefaultValue == other.DefaultValue // TODO this won't work for room/event ID values -} - -func (p *Parameter) Validate() error { - if p == nil { - return fmt.Errorf("parameter is nil") - } else if p.Key == "" { - return fmt.Errorf("key is empty") - } - return p.Schema.Validate() -} - -func (p *Parameter) IsValid() bool { - return p.Validate() == nil -} - -func (p *Parameter) GetDefaultValue() any { - if p != nil && p.DefaultValue != nil { - return p.DefaultValue - } else if p == nil || p.Optional { - return nil - } - return p.Schema.GetDefaultValue() -} - -type PrimitiveType string - -const ( - PrimitiveTypeString PrimitiveType = "string" - PrimitiveTypeInteger PrimitiveType = "integer" - PrimitiveTypeBoolean PrimitiveType = "boolean" - PrimitiveTypeServerName PrimitiveType = "server_name" - PrimitiveTypeUserID PrimitiveType = "user_id" - PrimitiveTypeRoomID PrimitiveType = "room_id" - PrimitiveTypeRoomAlias PrimitiveType = "room_alias" - PrimitiveTypeEventID PrimitiveType = "event_id" -) - -func (pt PrimitiveType) Schema() *ParameterSchema { - return &ParameterSchema{ - SchemaType: SchemaTypePrimitive, - Type: pt, - } -} - -func (pt PrimitiveType) IsValid() bool { - switch pt { - case PrimitiveTypeString, - PrimitiveTypeInteger, - PrimitiveTypeBoolean, - PrimitiveTypeServerName, - PrimitiveTypeUserID, - PrimitiveTypeRoomID, - PrimitiveTypeRoomAlias, - PrimitiveTypeEventID: - return true - default: - return false - } -} - -type SchemaType string - -const ( - SchemaTypePrimitive SchemaType = "primitive" - SchemaTypeArray SchemaType = "array" - SchemaTypeUnion SchemaType = "union" - SchemaTypeLiteral SchemaType = "literal" -) - -type ParameterSchema struct { - SchemaType SchemaType `json:"schema_type"` - Type PrimitiveType `json:"type,omitempty"` // Only for primitive - Items *ParameterSchema `json:"items,omitempty"` // Only for array - Variants []*ParameterSchema `json:"variants,omitempty"` // Only for union - Value any `json:"value,omitempty"` // Only for literal -} - -func Literal(value any) *ParameterSchema { - return &ParameterSchema{ - SchemaType: SchemaTypeLiteral, - Value: value, - } -} - -func Enum(values ...any) *ParameterSchema { - return Union(exslices.CastFunc(values, Literal)...) -} - -func flattenUnion(variants []*ParameterSchema) []*ParameterSchema { - var flattened []*ParameterSchema - for _, variant := range variants { - switch variant.SchemaType { - case SchemaTypeArray: - panic(fmt.Errorf("illegal array schema in union")) - case SchemaTypeUnion: - flattened = append(flattened, flattenUnion(variant.Variants)...) - default: - flattened = append(flattened, variant) - } - } - return flattened -} - -func Union(variants ...*ParameterSchema) *ParameterSchema { - needsFlattening := false - for _, variant := range variants { - if variant.SchemaType == SchemaTypeArray { - panic(fmt.Errorf("illegal array schema in union")) - } else if variant.SchemaType == SchemaTypeUnion { - needsFlattening = true - } - } - if needsFlattening { - variants = flattenUnion(variants) - } - return &ParameterSchema{ - SchemaType: SchemaTypeUnion, - Variants: variants, - } -} - -func Array(items *ParameterSchema) *ParameterSchema { - if items.SchemaType == SchemaTypeArray { - panic(fmt.Errorf("illegal array schema in array")) - } - return &ParameterSchema{ - SchemaType: SchemaTypeArray, - Items: items, - } -} - -func (ps *ParameterSchema) GetDefaultValue() any { - if ps == nil { - return nil - } - switch ps.SchemaType { - case SchemaTypePrimitive: - switch ps.Type { - case PrimitiveTypeInteger: - return 0 - case PrimitiveTypeBoolean: - return false - default: - return "" - } - case SchemaTypeArray: - return []any{} - case SchemaTypeUnion: - if len(ps.Variants) > 0 { - return ps.Variants[0].GetDefaultValue() - } - return nil - case SchemaTypeLiteral: - return ps.Value - default: - return nil - } -} - -func (ps *ParameterSchema) IsValid() bool { - return ps.validate("") == nil -} - -func (ps *ParameterSchema) Validate() error { - return ps.validate("") -} - -func (ps *ParameterSchema) validate(parent SchemaType) error { - if ps == nil { - return fmt.Errorf("schema is nil") - } - switch ps.SchemaType { - case SchemaTypePrimitive: - if !ps.Type.IsValid() { - return fmt.Errorf("invalid primitive type %s", ps.Type) - } else if ps.Items != nil || ps.Variants != nil || ps.Value != nil { - return fmt.Errorf("primitive schema has extra fields") - } - return nil - case SchemaTypeArray: - if parent != "" { - return fmt.Errorf("arrays can't be nested in other types") - } else if err := ps.Items.validate(ps.SchemaType); err != nil { - return fmt.Errorf("item schema is invalid: %w", err) - } else if ps.Type != "" || ps.Variants != nil || ps.Value != nil { - return fmt.Errorf("array schema has extra fields") - } - return nil - case SchemaTypeUnion: - if len(ps.Variants) == 0 { - return fmt.Errorf("no variants specified for union") - } else if parent != "" && parent != SchemaTypeArray { - return fmt.Errorf("unions can't be nested in anything other than arrays") - } - for i, v := range ps.Variants { - if err := v.validate(ps.SchemaType); err != nil { - return fmt.Errorf("variant #%d is invalid: %w", i+1, err) - } - } - if ps.Type != "" || ps.Items != nil || ps.Value != nil { - return fmt.Errorf("union schema has extra fields") - } - return nil - case SchemaTypeLiteral: - switch typedVal := ps.Value.(type) { - case string, float64, int, int64, json.Number, bool, RoomIDValue, *RoomIDValue: - // ok - case map[string]any: - if typedVal["type"] != "event_id" && typedVal["type"] != "room_id" { - return fmt.Errorf("literal value has invalid map data") - } - default: - return fmt.Errorf("literal value has unsupported type %T", ps.Value) - } - if ps.Type != "" || ps.Items != nil || ps.Variants != nil { - return fmt.Errorf("literal schema has extra fields") - } - return nil - default: - return fmt.Errorf("invalid schema type %s", ps.SchemaType) - } -} - -func (ps *ParameterSchema) Equals(other *ParameterSchema) bool { - if ps == nil || other == nil { - return ps == other - } - return ps.SchemaType == other.SchemaType && - ps.Type == other.Type && - ps.Items.Equals(other.Items) && - slices.EqualFunc(ps.Variants, other.Variants, (*ParameterSchema).Equals) && - ps.Value == other.Value // TODO this won't work for room/event ID values -} - -func (ps *ParameterSchema) AllowsPrimitive(prim PrimitiveType) bool { - switch ps.SchemaType { - case SchemaTypePrimitive: - return ps.Type == prim - case SchemaTypeUnion: - for _, variant := range ps.Variants { - if variant.AllowsPrimitive(prim) { - return true - } - } - return false - case SchemaTypeArray: - return ps.Items.AllowsPrimitive(prim) - default: - return false - } -} diff --git a/event/cmdschema/parse.go b/event/cmdschema/parse.go deleted file mode 100644 index 92e69b60..00000000 --- a/event/cmdschema/parse.go +++ /dev/null @@ -1,478 +0,0 @@ -// Copyright (c) 2026 Tulir Asokan -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. - -package cmdschema - -import ( - "encoding/json" - "errors" - "fmt" - "regexp" - "strconv" - "strings" - - "maunium.net/go/mautrix/event" - "maunium.net/go/mautrix/id" -) - -const botArrayOpener = "<" -const botArrayCloser = ">" - -func parseQuoted(val string) (parsed, remaining string, quoted bool) { - if len(val) == 0 { - return - } - if !strings.HasPrefix(val, `"`) { - spaceIdx := strings.IndexByte(val, ' ') - if spaceIdx == -1 { - parsed = val - } else { - parsed = val[:spaceIdx] - remaining = strings.TrimLeft(val[spaceIdx+1:], " ") - } - return - } - val = val[1:] - var buf strings.Builder - for { - quoteIdx := strings.IndexByte(val, '"') - var valUntilQuote string - if quoteIdx == -1 { - valUntilQuote = val - } else { - valUntilQuote = val[:quoteIdx] - } - escapeIdx := strings.IndexByte(valUntilQuote, '\\') - if escapeIdx >= 0 { - buf.WriteString(val[:escapeIdx]) - if len(val) > escapeIdx+1 { - buf.WriteByte(val[escapeIdx+1]) - } - val = val[min(escapeIdx+2, len(val)):] - } else if quoteIdx >= 0 { - buf.WriteString(val[:quoteIdx]) - val = val[quoteIdx+1:] - break - } else if buf.Len() == 0 { - // Unterminated quote, no escape characters, val is the whole input - return val, "", true - } else { - // Unterminated quote, but there were escape characters previously - buf.WriteString(val) - val = "" - break - } - } - return buf.String(), strings.TrimLeft(val, " "), true -} - -// ParseInput tries to parse the given text into a bot command event matching this command definition. -// -// If the prefix doesn't match, this will return a nil content and nil error. -// If the prefix does match, some content is always returned, but there may still be an error if parsing failed. -func (ec *EventContent) ParseInput(owner id.UserID, sigils []string, input string) (content *event.MessageEventContent, err error) { - prefix := ec.parsePrefix(input, sigils, owner.String()) - if prefix == "" { - return nil, nil - } - content = &event.MessageEventContent{ - MsgType: event.MsgText, - Body: input, - Mentions: &event.Mentions{UserIDs: []id.UserID{owner}}, - MSC4391BotCommand: &event.MSC4391BotCommandInput{ - Command: ec.Command, - }, - } - content.MSC4391BotCommand.Arguments, err = ec.ParseArguments(input[len(prefix):]) - return content, err -} - -func (ec *EventContent) ParseArguments(input string) (json.RawMessage, error) { - args := make(map[string]any) - var retErr error - setError := func(err error) { - if err != nil && retErr == nil { - retErr = err - } - } - processParameter := func(param *Parameter, isLast, isTail, isNamed bool) { - origInput := input - var nextVal string - var wasQuoted bool - if param.Schema.SchemaType == SchemaTypeArray { - hasOpener := strings.HasPrefix(input, botArrayOpener) - arrayClosed := false - if hasOpener { - input = input[len(botArrayOpener):] - if strings.HasPrefix(input, botArrayCloser) { - input = strings.TrimLeft(input[len(botArrayCloser):], " ") - arrayClosed = true - } - } - var collector []any - for len(input) > 0 && !arrayClosed { - //origInput = input - nextVal, input, wasQuoted = parseQuoted(input) - if !wasQuoted && hasOpener && strings.HasSuffix(nextVal, botArrayCloser) { - // The value wasn't quoted and has the array delimiter at the end, close the array - nextVal = strings.TrimRight(nextVal, botArrayCloser) - arrayClosed = true - } else if hasOpener && strings.HasPrefix(input, botArrayCloser) { - // The value was quoted or there was a space, and the next character is the - // array delimiter, close the array - input = strings.TrimLeft(input[len(botArrayCloser):], " ") - arrayClosed = true - } else if !hasOpener && !isLast { - // For array arguments in the middle without the <> delimiters, stop after the first item - arrayClosed = true - } - parsedVal, err := param.Schema.Items.ParseString(nextVal) - if err == nil { - collector = append(collector, parsedVal) - } else if hasOpener || isLast { - setError(fmt.Errorf("failed to parse item #%d of array %s: %w", len(collector)+1, param.Key, err)) - } else { - //input = origInput - } - } - args[param.Key] = collector - } else { - nextVal, input, wasQuoted = parseQuoted(input) - if (isLast || isTail) && !wasQuoted && len(input) > 0 { - // If the last argument is not quoted, just treat the rest of the string - // as the argument without escapes (arguments with escapes should be quoted). - nextVal += " " + input - input = "" - } - // Special case for named boolean parameters: if no value is given, treat it as true - if nextVal == "" && !wasQuoted && isNamed && param.Schema.AllowsPrimitive(PrimitiveTypeBoolean) { - args[param.Key] = true - return - } - if nextVal == "" && !wasQuoted && !isNamed && !param.Optional { - setError(fmt.Errorf("missing value for required parameter %s", param.Key)) - } - parsedVal, err := param.Schema.ParseString(nextVal) - if err != nil { - args[param.Key] = param.GetDefaultValue() - // For optional parameters that fail to parse, restore the input and try passing it as the next parameter - if param.Optional && !isLast && !isNamed { - input = strings.TrimLeft(origInput, " ") - } else if !param.Optional || isNamed { - setError(fmt.Errorf("failed to parse %s: %w", param.Key, err)) - } - } else { - args[param.Key] = parsedVal - } - } - } - skipParams := make([]bool, len(ec.Parameters)) - for i, param := range ec.Parameters { - for strings.HasPrefix(input, "--") { - nameEndIdx := strings.IndexAny(input, " =") - if nameEndIdx == -1 { - nameEndIdx = len(input) - } - overrideParam, paramIdx := ec.parameterByName(input[2:nameEndIdx]) - if overrideParam != nil { - // Trim the equals sign, but leave spaces alone to let parseQuoted treat it as empty input - input = strings.TrimPrefix(input[nameEndIdx:], "=") - skipParams[paramIdx] = true - processParameter(overrideParam, false, false, true) - } else { - break - } - } - isTail := param.Key == ec.TailParam - if skipParams[i] || (param.Optional && !isTail) { - continue - } - processParameter(param, i == len(ec.Parameters)-1, isTail, false) - } - jsonArgs, marshalErr := json.Marshal(args) - if marshalErr != nil { - return nil, fmt.Errorf("failed to marshal arguments: %w", marshalErr) - } - return jsonArgs, retErr -} - -func (ec *EventContent) parameterByName(name string) (*Parameter, int) { - for i, param := range ec.Parameters { - if strings.EqualFold(param.Key, name) { - return param, i - } - } - return nil, -1 -} - -func (ec *EventContent) parsePrefix(origInput string, sigils []string, owner string) (prefix string) { - input := origInput - var chosenSigil string - for _, sigil := range sigils { - if strings.HasPrefix(input, sigil) { - chosenSigil = sigil - break - } - } - if chosenSigil == "" { - return "" - } - input = input[len(chosenSigil):] - var chosenAlias string - if !strings.HasPrefix(input, ec.Command) { - for _, alias := range ec.Aliases { - if strings.HasPrefix(input, alias) { - chosenAlias = alias - break - } - } - if chosenAlias == "" { - return "" - } - } else { - chosenAlias = ec.Command - } - input = strings.TrimPrefix(input[len(chosenAlias):], owner) - if input == "" || input[0] == ' ' { - input = strings.TrimLeft(input, " ") - return origInput[:len(origInput)-len(input)] - } - return "" -} - -func (pt PrimitiveType) ValidateValue(value any) bool { - _, err := pt.NormalizeValue(value) - return err == nil -} - -func normalizeNumber(value any) (int, error) { - switch typedValue := value.(type) { - case int: - return typedValue, nil - case int64: - return int(typedValue), nil - case float64: - return int(typedValue), nil - case json.Number: - if i, err := typedValue.Int64(); err != nil { - return 0, fmt.Errorf("failed to parse json.Number: %w", err) - } else { - return int(i), nil - } - default: - return 0, fmt.Errorf("unsupported type %T for integer", value) - } -} - -func (pt PrimitiveType) NormalizeValue(value any) (any, error) { - switch pt { - case PrimitiveTypeInteger: - return normalizeNumber(value) - case PrimitiveTypeBoolean: - bv, ok := value.(bool) - if !ok { - return nil, fmt.Errorf("unsupported type %T for boolean", value) - } - return bv, nil - case PrimitiveTypeString, PrimitiveTypeServerName: - str, ok := value.(string) - if !ok { - return nil, fmt.Errorf("unsupported type %T for string", value) - } - return str, pt.validateStringValue(str) - case PrimitiveTypeUserID, PrimitiveTypeRoomAlias: - str, ok := value.(string) - if !ok { - return nil, fmt.Errorf("unsupported type %T for user ID or room alias", value) - } else if plainErr := pt.validateStringValue(str); plainErr == nil { - return str, nil - } else if parsed, err := id.ParseMatrixURIOrMatrixToURL(str); err != nil { - return nil, fmt.Errorf("couldn't parse %q as plain ID nor matrix URI: %w / %w", value, plainErr, err) - } else if parsed.Sigil1 == '@' && pt == PrimitiveTypeUserID { - return parsed.UserID(), nil - } else if parsed.Sigil1 == '#' && pt == PrimitiveTypeRoomAlias { - return parsed.RoomAlias(), nil - } else { - return nil, fmt.Errorf("unexpected sigil %c for user ID or room alias", parsed.Sigil1) - } - case PrimitiveTypeRoomID, PrimitiveTypeEventID: - riv, err := NormalizeRoomIDValue(value) - if err != nil { - return nil, err - } - return riv, riv.Validate() - default: - return nil, fmt.Errorf("cannot normalize value for argument type %s", pt) - } -} - -func (pt PrimitiveType) validateStringValue(value string) error { - switch pt { - case PrimitiveTypeString: - return nil - case PrimitiveTypeServerName: - if !id.ValidateServerName(value) { - return fmt.Errorf("invalid server name: %q", value) - } - return nil - case PrimitiveTypeUserID: - _, _, err := id.UserID(value).ParseAndValidateRelaxed() - return err - case PrimitiveTypeRoomAlias: - sigil, localpart, serverName := id.ParseCommonIdentifier(value) - if sigil != '#' || localpart == "" || serverName == "" { - return fmt.Errorf("invalid room alias: %q", value) - } else if !id.ValidateServerName(serverName) { - return fmt.Errorf("invalid server name in room alias: %q", serverName) - } - return nil - default: - panic(fmt.Errorf("validateStringValue called with invalid type %s", pt)) - } -} - -func parseBoolean(val string) (bool, error) { - if len(val) == 0 { - return false, fmt.Errorf("cannot parse empty string as boolean") - } - switch strings.ToLower(val) { - case "t", "true", "y", "yes", "1": - return true, nil - case "f", "false", "n", "no", "0": - return false, nil - default: - return false, fmt.Errorf("invalid boolean string: %q", val) - } -} - -var markdownLinkRegex = regexp.MustCompile(`^\[.+]\(([^)]+)\)$`) - -func parseRoomOrEventID(value string) (*RoomIDValue, error) { - if strings.HasPrefix(value, "[") && strings.Contains(value, "](") && strings.HasSuffix(value, ")") { - matches := markdownLinkRegex.FindStringSubmatch(value) - if len(matches) == 2 { - value = matches[1] - } - } - parsed, err := id.ParseMatrixURIOrMatrixToURL(value) - if err != nil && strings.HasPrefix(value, "!") { - return &RoomIDValue{ - Type: PrimitiveTypeRoomID, - RoomID: id.RoomID(value), - }, nil - } - if err != nil { - return nil, err - } else if parsed.Sigil1 != '!' { - return nil, fmt.Errorf("unexpected sigil %c for room ID", parsed.Sigil1) - } else if parsed.MXID2 != "" && parsed.Sigil2 != '$' { - return nil, fmt.Errorf("unexpected sigil %c for event ID", parsed.Sigil2) - } - valType := PrimitiveTypeRoomID - if parsed.MXID2 != "" { - valType = PrimitiveTypeEventID - } - return &RoomIDValue{ - Type: valType, - RoomID: parsed.RoomID(), - Via: parsed.Via, - EventID: parsed.EventID(), - }, nil -} - -func (pt PrimitiveType) ParseString(value string) (any, error) { - switch pt { - case PrimitiveTypeInteger: - return strconv.Atoi(value) - case PrimitiveTypeBoolean: - return parseBoolean(value) - case PrimitiveTypeString, PrimitiveTypeServerName, PrimitiveTypeUserID: - return value, pt.validateStringValue(value) - case PrimitiveTypeRoomAlias: - plainErr := pt.validateStringValue(value) - if plainErr == nil { - return value, nil - } - parsed, err := id.ParseMatrixURIOrMatrixToURL(value) - if err != nil { - return nil, fmt.Errorf("couldn't parse %q as plain room alias nor matrix URI: %w / %w", value, plainErr, err) - } else if parsed.Sigil1 != '#' { - return nil, fmt.Errorf("unexpected sigil %c for room alias", parsed.Sigil1) - } - return parsed.RoomAlias(), nil - case PrimitiveTypeRoomID, PrimitiveTypeEventID: - parsed, err := parseRoomOrEventID(value) - if err != nil { - return nil, err - } else if pt != parsed.Type { - return nil, fmt.Errorf("mismatching argument type: expected %s but got %s", pt, parsed.Type) - } - return parsed, nil - default: - return nil, fmt.Errorf("cannot parse string for argument type %s", pt) - } -} - -func (ps *ParameterSchema) ParseString(value string) (any, error) { - if ps == nil { - return nil, fmt.Errorf("parameter schema is nil") - } - switch ps.SchemaType { - case SchemaTypePrimitive: - return ps.Type.ParseString(value) - case SchemaTypeLiteral: - switch typedValue := ps.Value.(type) { - case string: - if value == typedValue { - return typedValue, nil - } else { - return nil, fmt.Errorf("literal value %q does not match %q", typedValue, value) - } - case int, int64, float64, json.Number: - expectedVal, _ := normalizeNumber(typedValue) - intVal, err := strconv.Atoi(value) - if err != nil { - return nil, fmt.Errorf("failed to parse integer literal: %w", err) - } else if intVal != expectedVal { - return nil, fmt.Errorf("literal value %d does not match %d", expectedVal, intVal) - } - return intVal, nil - case bool: - boolVal, err := parseBoolean(value) - if err != nil { - return nil, fmt.Errorf("failed to parse boolean literal: %w", err) - } else if boolVal != typedValue { - return nil, fmt.Errorf("literal value %t does not match %t", typedValue, boolVal) - } - return boolVal, nil - case RoomIDValue, *RoomIDValue, map[string]any, json.RawMessage: - expectedVal, _ := NormalizeRoomIDValue(typedValue) - parsed, err := parseRoomOrEventID(value) - if err != nil { - return nil, fmt.Errorf("failed to parse room or event ID literal: %w", err) - } else if !parsed.Equals(expectedVal) { - return nil, fmt.Errorf("literal value %s does not match %s", expectedVal, parsed) - } - return parsed, nil - default: - return nil, fmt.Errorf("unsupported literal type %T", ps.Value) - } - case SchemaTypeUnion: - var errs []error - for _, variant := range ps.Variants { - if parsed, err := variant.ParseString(value); err == nil { - return parsed, nil - } else { - errs = append(errs, err) - } - } - return nil, fmt.Errorf("no union variant matched: %w", errors.Join(errs...)) - case SchemaTypeArray: - return nil, fmt.Errorf("cannot parse string for array schema type") - default: - return nil, fmt.Errorf("unknown schema type %s", ps.SchemaType) - } -} diff --git a/event/cmdschema/parse_test.go b/event/cmdschema/parse_test.go deleted file mode 100644 index 1e0d1817..00000000 --- a/event/cmdschema/parse_test.go +++ /dev/null @@ -1,118 +0,0 @@ -// Copyright (c) 2026 Tulir Asokan -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. - -package cmdschema - -import ( - "bytes" - "encoding/json" - "strings" - "testing" - - "github.com/stretchr/testify/assert" - "go.mau.fi/util/exbytes" - "go.mau.fi/util/exerrors" - - "maunium.net/go/mautrix/event/cmdschema/testdata" -) - -type QuoteParseOutput struct { - Parsed string - Remaining string - Quoted bool -} - -func (qpo *QuoteParseOutput) UnmarshalJSON(data []byte) error { - var arr []any - if err := json.Unmarshal(data, &arr); err != nil { - return err - } - qpo.Parsed = arr[0].(string) - qpo.Remaining = arr[1].(string) - qpo.Quoted = arr[2].(bool) - return nil -} - -type QuoteParseTestData struct { - Name string `json:"name"` - Input string `json:"input"` - Output QuoteParseOutput `json:"output"` -} - -func loadFile[T any](name string) (into T) { - quoteData := exerrors.Must(testdata.FS.ReadFile(name)) - exerrors.PanicIfNotNil(json.Unmarshal(quoteData, &into)) - return -} - -func TestParseQuoted(t *testing.T) { - qptd := loadFile[[]QuoteParseTestData]("parse_quote.json") - for _, test := range qptd { - t.Run(test.Name, func(t *testing.T) { - parsed, remaining, quoted := parseQuoted(test.Input) - assert.Equalf(t, test.Output, QuoteParseOutput{ - Parsed: parsed, - Remaining: remaining, - Quoted: quoted, - }, "Failed with input `%s`", test.Input) - // Note: can't just test that requoted == input, because some inputs - // have unnecessary escapes which won't survive roundtripping - t.Run("roundtrip", func(t *testing.T) { - requoted := quoteString(parsed) + " " + remaining - reparsed, newRemaining, _ := parseQuoted(requoted) - assert.Equal(t, parsed, reparsed) - assert.Equal(t, remaining, newRemaining) - }) - }) - } -} - -type CommandTestData struct { - Spec *EventContent - Tests []*CommandTestUnit -} - -type CommandTestUnit struct { - Name string `json:"name"` - Input string `json:"input"` - Broken string `json:"broken,omitempty"` - Error bool `json:"error"` - Output json.RawMessage `json:"output"` -} - -func compactJSON(input json.RawMessage) json.RawMessage { - var buf bytes.Buffer - exerrors.PanicIfNotNil(json.Compact(&buf, input)) - return buf.Bytes() -} - -func TestMSC4391BotCommandEventContent_ParseInput(t *testing.T) { - for _, cmd := range exerrors.Must(testdata.FS.ReadDir("commands")) { - t.Run(strings.TrimSuffix(cmd.Name(), ".json"), func(t *testing.T) { - ctd := loadFile[CommandTestData]("commands/" + cmd.Name()) - for _, test := range ctd.Tests { - outputStr := exbytes.UnsafeString(compactJSON(test.Output)) - t.Run(test.Name, func(t *testing.T) { - if test.Broken != "" { - t.Skip(test.Broken) - } - output, err := ctd.Spec.ParseInput("@testbot", []string{"/"}, test.Input) - if test.Error { - assert.Error(t, err) - } else { - assert.NoError(t, err) - } - if outputStr == "null" { - assert.Nil(t, output) - } else { - assert.Equal(t, ctd.Spec.Command, output.MSC4391BotCommand.Command) - assert.Equalf(t, outputStr, exbytes.UnsafeString(output.MSC4391BotCommand.Arguments), "Input: %s", test.Input) - } - }) - } - }) - } -} diff --git a/event/cmdschema/roomid.go b/event/cmdschema/roomid.go deleted file mode 100644 index 98c421fc..00000000 --- a/event/cmdschema/roomid.go +++ /dev/null @@ -1,135 +0,0 @@ -// Copyright (c) 2026 Tulir Asokan -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. - -package cmdschema - -import ( - "encoding/json" - "fmt" - "slices" - "strings" - - "maunium.net/go/mautrix/id" -) - -var ParameterSchemaJoinableRoom = Union( - PrimitiveTypeRoomID.Schema(), - PrimitiveTypeRoomAlias.Schema(), -) - -type RoomIDValue struct { - Type PrimitiveType `json:"type"` - RoomID id.RoomID `json:"id"` - Via []string `json:"via,omitempty"` - EventID id.EventID `json:"event_id,omitempty"` -} - -func NormalizeRoomIDValue(input any) (riv *RoomIDValue, err error) { - switch typedValue := input.(type) { - case map[string]any, json.RawMessage: - var raw json.RawMessage - if raw, err = json.Marshal(input); err != nil { - err = fmt.Errorf("failed to roundtrip room ID value: %w", err) - } else if err = json.Unmarshal(raw, &riv); err != nil { - err = fmt.Errorf("failed to roundtrip room ID value: %w", err) - } - case *RoomIDValue: - riv = typedValue - case RoomIDValue: - riv = &typedValue - default: - err = fmt.Errorf("unsupported type %T for room or event ID", input) - } - return -} - -func (riv *RoomIDValue) String() string { - return riv.URI().String() -} - -func (riv *RoomIDValue) URI() *id.MatrixURI { - if riv == nil { - return nil - } - switch riv.Type { - case PrimitiveTypeRoomID: - return riv.RoomID.URI(riv.Via...) - case PrimitiveTypeEventID: - return riv.RoomID.EventURI(riv.EventID, riv.Via...) - default: - return nil - } -} - -func (riv *RoomIDValue) Equals(other *RoomIDValue) bool { - if riv == nil || other == nil { - return riv == other - } - return riv.Type == other.Type && - riv.RoomID == other.RoomID && - riv.EventID == other.EventID && - slices.Equal(riv.Via, other.Via) -} - -func (riv *RoomIDValue) Validate() error { - if riv == nil { - return fmt.Errorf("value is nil") - } - switch riv.Type { - case PrimitiveTypeRoomID: - if riv.EventID != "" { - return fmt.Errorf("event ID must be empty for room ID type") - } - case PrimitiveTypeEventID: - if !strings.HasPrefix(riv.EventID.String(), "$") { - return fmt.Errorf("event ID not valid: %q", riv.EventID) - } - default: - return fmt.Errorf("unexpected type %s for room/event ID value", riv.Type) - } - for _, via := range riv.Via { - if !id.ValidateServerName(via) { - return fmt.Errorf("invalid server name %q in vias", via) - } - } - sigil, localpart, serverName := id.ParseCommonIdentifier(riv.RoomID) - if sigil != '!' { - return fmt.Errorf("room ID does not start with !: %q", riv.RoomID) - } else if localpart == "" && serverName == "" { - return fmt.Errorf("room ID has empty localpart and server name: %q", riv.RoomID) - } else if serverName != "" && !id.ValidateServerName(serverName) { - return fmt.Errorf("invalid server name %q in room ID", serverName) - } - return nil -} - -func (riv *RoomIDValue) IsValid() bool { - return riv.Validate() == nil -} - -type RoomIDOrString string - -func (ros *RoomIDOrString) UnmarshalJSON(data []byte) error { - if len(data) == 0 { - return fmt.Errorf("empty data for room ID or string") - } - if data[0] == '"' { - var str string - if err := json.Unmarshal(data, &str); err != nil { - return err - } - *ros = RoomIDOrString(str) - return nil - } - var riv RoomIDValue - if err := json.Unmarshal(data, &riv); err != nil { - return err - } else if err = riv.Validate(); err != nil { - return err - } - *ros = RoomIDOrString(riv.String()) - return nil -} diff --git a/event/cmdschema/stringify.go b/event/cmdschema/stringify.go deleted file mode 100644 index c5c57c53..00000000 --- a/event/cmdschema/stringify.go +++ /dev/null @@ -1,122 +0,0 @@ -// Copyright (c) 2026 Tulir Asokan -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. - -package cmdschema - -import ( - "encoding/json" - "strconv" - "strings" -) - -var quoteEscaper = strings.NewReplacer( - `"`, `\"`, - `\`, `\\`, -) - -const charsToQuote = ` \` + botArrayOpener + botArrayCloser - -func quoteString(val string) string { - if val == "" { - return `""` - } - val = quoteEscaper.Replace(val) - if strings.ContainsAny(val, charsToQuote) { - return `"` + val + `"` - } - return val -} - -func (ec *EventContent) StringifyArgs(args any) string { - var argMap map[string]any - switch typedArgs := args.(type) { - case json.RawMessage: - err := json.Unmarshal(typedArgs, &argMap) - if err != nil { - return "" - } - case map[string]any: - argMap = typedArgs - default: - if b, err := json.Marshal(args); err != nil { - return "" - } else if err = json.Unmarshal(b, &argMap); err != nil { - return "" - } - } - parts := make([]string, 0, len(ec.Parameters)) - for i, param := range ec.Parameters { - isLast := i == len(ec.Parameters)-1 - val := argMap[param.Key] - if val == nil { - val = param.DefaultValue - if val == nil && !param.Optional { - val = param.Schema.GetDefaultValue() - } - } - if val == nil { - continue - } - var stringified string - if param.Schema.SchemaType == SchemaTypeArray { - stringified = arrayArgumentToString(val, isLast) - } else { - stringified = singleArgumentToString(val) - } - if stringified != "" { - parts = append(parts, stringified) - } - } - return strings.Join(parts, " ") -} - -func arrayArgumentToString(val any, isLast bool) string { - valArr, ok := val.([]any) - if !ok { - return "" - } - parts := make([]string, 0, len(valArr)) - for _, elem := range valArr { - stringified := singleArgumentToString(elem) - if stringified != "" { - parts = append(parts, stringified) - } - } - joinedParts := strings.Join(parts, " ") - if isLast && len(parts) > 0 { - return joinedParts - } - return botArrayOpener + joinedParts + botArrayCloser -} - -func singleArgumentToString(val any) string { - switch typedVal := val.(type) { - case string: - return quoteString(typedVal) - case json.Number: - return typedVal.String() - case bool: - return strconv.FormatBool(typedVal) - case int: - return strconv.Itoa(typedVal) - case int64: - return strconv.FormatInt(typedVal, 10) - case float64: - return strconv.FormatInt(int64(typedVal), 10) - case map[string]any, json.RawMessage, RoomIDValue, *RoomIDValue: - normalized, err := NormalizeRoomIDValue(typedVal) - if err != nil { - return "" - } - uri := normalized.URI() - if uri == nil { - return "" - } - return quoteString(uri.String()) - default: - return "" - } -} diff --git a/event/cmdschema/testdata/commands.schema.json b/event/cmdschema/testdata/commands.schema.json deleted file mode 100644 index e53382db..00000000 --- a/event/cmdschema/testdata/commands.schema.json +++ /dev/null @@ -1,281 +0,0 @@ -{ - "$schema": "https://json-schema.org/draft/2020-12/schema#", - "$id": "commands.schema.json", - "title": "ParseInput test cases", - "description": "JSON schema for test case files containing command specifications and test cases", - "type": "object", - "required": [ - "spec", - "tests" - ], - "additionalProperties": false, - "properties": { - "spec": { - "title": "MSC4391 Command Description", - "description": "JSON schema defining the structure of a bot command event content", - "type": "object", - "required": [ - "command" - ], - "additionalProperties": false, - "properties": { - "command": { - "type": "string", - "description": "The command name that triggers this bot command" - }, - "aliases": { - "type": "array", - "description": "Alternative names/aliases for this command", - "items": { - "type": "string" - } - }, - "parameters": { - "type": "array", - "description": "List of parameters accepted by this command", - "items": { - "$ref": "#/$defs/Parameter" - } - }, - "description": { - "$ref": "#/$defs/ExtensibleTextContainer", - "description": "Human-readable description of the command" - }, - "fi.mau.tail_parameter": { - "type": "string", - "description": "The key of the parameter that accepts remaining arguments as tail text" - }, - "source": { - "type": "string", - "description": "The user ID of the bot that responds to this command" - } - } - }, - "tests": { - "type": "array", - "description": "Array of test cases for the command", - "items": { - "type": "object", - "description": "A single test case for command parsing", - "required": [ - "name", - "input" - ], - "additionalProperties": false, - "properties": { - "name": { - "type": "string", - "description": "The name of the test case" - }, - "input": { - "type": "string", - "description": "The command input string to parse" - }, - "output": { - "description": "The expected parsed parameter values, or null if the parsing is expected to fail", - "oneOf": [ - { - "type": "object", - "additionalProperties": true - }, - { - "type": "null" - } - ] - }, - "error": { - "type": "boolean", - "description": "Whether parsing should result in an error. May still produce output.", - "default": false - } - } - } - } - }, - "$defs": { - "ExtensibleTextContainer": { - "type": "object", - "description": "Container for text that can have multiple representations", - "required": [ - "m.text" - ], - "properties": { - "m.text": { - "type": "array", - "description": "Array of text representations in different formats", - "items": { - "$ref": "#/$defs/ExtensibleText" - } - } - } - }, - "ExtensibleText": { - "type": "object", - "description": "A text representation with a specific MIME type", - "required": [ - "body" - ], - "properties": { - "body": { - "type": "string", - "description": "The text content" - }, - "mimetype": { - "type": "string", - "description": "The MIME type of the text (e.g., text/plain, text/html)", - "default": "text/plain", - "examples": [ - "text/plain", - "text/html" - ] - } - } - }, - "Parameter": { - "type": "object", - "description": "A parameter definition for a command", - "required": [ - "key", - "schema" - ], - "additionalProperties": false, - "properties": { - "key": { - "type": "string", - "description": "The identifier for this parameter" - }, - "schema": { - "$ref": "#/$defs/ParameterSchema", - "description": "The schema defining the type and structure of this parameter" - }, - "optional": { - "type": "boolean", - "description": "Whether this parameter is optional", - "default": false - }, - "description": { - "$ref": "#/$defs/ExtensibleTextContainer", - "description": "Human-readable description of this parameter" - }, - "fi.mau.default_value": { - "description": "Default value for this parameter if not provided" - } - } - }, - "ParameterSchema": { - "type": "object", - "description": "Schema definition for a parameter value", - "required": [ - "schema_type" - ], - "additionalProperties": false, - "properties": { - "schema_type": { - "type": "string", - "enum": [ - "primitive", - "array", - "union", - "literal" - ], - "description": "The type of schema" - } - }, - "allOf": [ - { - "if": { - "properties": { - "schema_type": { - "const": "primitive" - } - } - }, - "then": { - "required": [ - "type" - ], - "properties": { - "type": { - "type": "string", - "enum": [ - "string", - "integer", - "boolean", - "server_name", - "user_id", - "room_id", - "room_alias", - "event_id" - ], - "description": "The primitive type (only for schema_type: primitive)" - } - } - } - }, - { - "if": { - "properties": { - "schema_type": { - "const": "array" - } - } - }, - "then": { - "required": [ - "items" - ], - "properties": { - "items": { - "$ref": "#/$defs/ParameterSchema", - "description": "The schema for array items (only for schema_type: array)" - } - } - } - }, - { - "if": { - "properties": { - "schema_type": { - "const": "union" - } - } - }, - "then": { - "required": [ - "variants" - ], - "properties": { - "variants": { - "type": "array", - "description": "The possible variants (only for schema_type: union)", - "items": { - "$ref": "#/$defs/ParameterSchema" - }, - "minItems": 1 - } - } - } - }, - { - "if": { - "properties": { - "schema_type": { - "const": "literal" - } - } - }, - "then": { - "required": [ - "value" - ], - "properties": { - "value": { - "description": "The literal value (only for schema_type: literal)" - } - } - } - } - ] - } - } -} diff --git a/event/cmdschema/testdata/commands/flags.json b/event/cmdschema/testdata/commands/flags.json deleted file mode 100644 index 6ce1f4da..00000000 --- a/event/cmdschema/testdata/commands/flags.json +++ /dev/null @@ -1,126 +0,0 @@ -{ - "$schema": "../commands.schema.json#", - "spec": { - "command": "flag", - "source": "@testbot", - "parameters": [ - { - "key": "meow", - "schema": { - "schema_type": "primitive", - "type": "string" - } - }, - { - "key": "user", - "schema": { - "schema_type": "primitive", - "type": "user_id" - }, - "optional": true - }, - { - "key": "woof", - "schema": { - "schema_type": "primitive", - "type": "boolean" - }, - "optional": true, - "fi.mau.default_value": false - } - ], - "fi.mau.tail_parameter": "user" - }, - "tests": [ - { - "name": "no flags", - "input": "/flag mrrp", - "output": { - "meow": "mrrp", - "user": null - } - }, - { - "name": "no flags, has tail", - "input": "/flag mrrp @user:example.com", - "output": { - "meow": "mrrp", - "user": "@user:example.com" - } - }, - { - "name": "named flag at start", - "input": "/flag --woof=yes mrrp @user:example.com", - "output": { - "meow": "mrrp", - "user": "@user:example.com", - "woof": true - } - }, - { - "name": "boolean flag without value", - "input": "/flag --woof mrrp @user:example.com", - "output": { - "meow": "mrrp", - "user": "@user:example.com", - "woof": true - } - }, - { - "name": "user id flag without value", - "input": "/flag --user --woof mrrp", - "error": true, - "output": { - "meow": "mrrp", - "user": null, - "woof": true - } - }, - { - "name": "named flag in the middle", - "input": "/flag mrrp --woof=yes @user:example.com", - "output": { - "meow": "mrrp", - "user": "@user:example.com", - "woof": true - } - }, - { - "name": "named flag in the middle with different value", - "input": "/flag mrrp --woof=no @user:example.com", - "output": { - "meow": "mrrp", - "user": "@user:example.com", - "woof": false - } - }, - { - "name": "all variables named", - "input": "/flag --woof=no --meow=mrrp --user=@user:example.com", - "output": { - "meow": "mrrp", - "user": "@user:example.com", - "woof": false - } - }, - { - "name": "all variables named with quotes", - "input": "/flag --woof --meow=\"meow meow mrrp\" --user=\"@user:example.com\"", - "output": { - "meow": "meow meow mrrp", - "user": "@user:example.com", - "woof": true - } - }, - { - "name": "invalid value for named parameter", - "input": "/flag --user=meowings mrrp --woof", - "error": true, - "output": { - "meow": "mrrp", - "user": null, - "woof": true - } - } - ] -} diff --git a/event/cmdschema/testdata/commands/room_id_or_alias.json b/event/cmdschema/testdata/commands/room_id_or_alias.json deleted file mode 100644 index 1351c292..00000000 --- a/event/cmdschema/testdata/commands/room_id_or_alias.json +++ /dev/null @@ -1,85 +0,0 @@ -{ - "$schema": "../commands.schema.json#", - "spec": { - "command": "test room reference", - "source": "@testbot", - "parameters": [ - { - "key": "room", - "schema": { - "schema_type": "union", - "variants": [ - { - "schema_type": "primitive", - "type": "room_id" - }, - { - "schema_type": "primitive", - "type": "room_alias" - } - ] - } - } - ] - }, - "tests": [ - { - "name": "room alias", - "input": "/test room reference #test:matrix.org", - "output": { - "room": "#test:matrix.org" - } - }, - { - "name": "room id", - "input": "/test room reference !aiwVrNhPwbGBNjqlNu:matrix.org", - "output": { - "room": { - "type": "room_id", - "id": "!aiwVrNhPwbGBNjqlNu:matrix.org" - } - } - }, - { - "name": "room id matrix.to link", - "input": "/test room reference https://matrix.to/#/!aiwVrNhPwbGBNjqlNu:matrix.org?via=example.com", - "output": { - "room": { - "type": "room_id", - "id": "!aiwVrNhPwbGBNjqlNu:matrix.org", - "via": [ - "example.com" - ] - } - } - }, - { - "name": "room id matrix.to link with url encoding", - "input": "/test room reference https://matrix.to/#/!%23test%2Froom%0Aversion%20%3Cu%3E11%3C%2Fu%3E%2C%20with%20%40%F0%9F%90%88%EF%B8%8F%3Amaunium.net?via=maunium.net", - "broken": "Go's url.URL does url decoding on the fragment, which breaks splitting the path segments properly", - "output": { - "room": { - "type": "room_id", - "id": "!#test/room\nversion 11, with @🐈️:maunium.net", - "via": [ - "maunium.net" - ] - } - } - }, - { - "name": "room id matrix: URI", - "input": "/test room reference matrix:roomid/mauT12AzsoqxV7Abvy_ApA-HNPK1LcT4GbP70_AOPyQ?via=maunium.net&via=matrix.org", - "output": { - "room": { - "type": "room_id", - "id": "!mauT12AzsoqxV7Abvy_ApA-HNPK1LcT4GbP70_AOPyQ", - "via": [ - "maunium.net", - "matrix.org" - ] - } - } - } - ] -} diff --git a/event/cmdschema/testdata/commands/room_reference_list.json b/event/cmdschema/testdata/commands/room_reference_list.json deleted file mode 100644 index aa266054..00000000 --- a/event/cmdschema/testdata/commands/room_reference_list.json +++ /dev/null @@ -1,106 +0,0 @@ -{ - "$schema": "../commands.schema.json#", - "spec": { - "command": "test room reference", - "source": "@testbot", - "parameters": [ - { - "key": "rooms", - "schema": { - "schema_type": "array", - "items": { - "schema_type": "union", - "variants": [ - { - "schema_type": "primitive", - "type": "room_id" - }, - { - "schema_type": "primitive", - "type": "room_alias" - } - ] - } - } - } - ] - }, - "tests": [ - { - "name": "room alias", - "input": "/test room reference #test:matrix.org", - "output": { - "rooms": [ - "#test:matrix.org" - ] - } - }, - { - "name": "room id", - "input": "/test room reference !aiwVrNhPwbGBNjqlNu:matrix.org", - "output": { - "rooms": [ - { - "type": "room_id", - "id": "!aiwVrNhPwbGBNjqlNu:matrix.org" - } - ] - } - }, - { - "name": "two room ids", - "input": "/test room reference !mauT12AzsoqxV7Abvy_ApA-HNPK1LcT4GbP70_AOPyQ !aiwVrNhPwbGBNjqlNu:matrix.org", - "output": { - "rooms": [ - { - "type": "room_id", - "id": "!mauT12AzsoqxV7Abvy_ApA-HNPK1LcT4GbP70_AOPyQ" - }, - { - "type": "room_id", - "id": "!aiwVrNhPwbGBNjqlNu:matrix.org" - } - ] - } - }, - { - "name": "room id matrix: URI", - "input": "/test room reference matrix:roomid/mauT12AzsoqxV7Abvy_ApA-HNPK1LcT4GbP70_AOPyQ?via=maunium.net&via=matrix.org", - "output": { - "rooms": [ - { - "type": "room_id", - "id": "!mauT12AzsoqxV7Abvy_ApA-HNPK1LcT4GbP70_AOPyQ", - "via": [ - "maunium.net", - "matrix.org" - ] - } - ] - } - }, - { - "name": "room id matrix: URI and matrix.to URL", - "input": "/test room reference https://matrix.to/#/!aiwVrNhPwbGBNjqlNu:matrix.org?via=example.com matrix:roomid/mauT12AzsoqxV7Abvy_ApA-HNPK1LcT4GbP70_AOPyQ?via=maunium.net&via=matrix.org", - "output": { - "rooms": [ - { - "type": "room_id", - "id": "!aiwVrNhPwbGBNjqlNu:matrix.org", - "via": [ - "example.com" - ] - }, - { - "type": "room_id", - "id": "!mauT12AzsoqxV7Abvy_ApA-HNPK1LcT4GbP70_AOPyQ", - "via": [ - "maunium.net", - "matrix.org" - ] - } - ] - } - } - ] -} diff --git a/event/cmdschema/testdata/commands/simple.json b/event/cmdschema/testdata/commands/simple.json deleted file mode 100644 index 94667323..00000000 --- a/event/cmdschema/testdata/commands/simple.json +++ /dev/null @@ -1,46 +0,0 @@ -{ - "$schema": "../commands.schema.json#", - "spec": { - "command": "test simple", - "source": "@testbot", - "parameters": [ - { - "key": "meow", - "schema": { - "schema_type": "primitive", - "type": "string" - } - } - ] - }, - "tests": [ - { - "name": "success", - "input": "/test simple mrrp", - "output": { - "meow": "mrrp" - } - }, - { - "name": "directed success", - "input": "/test simple@testbot mrrp", - "output": { - "meow": "mrrp" - } - }, - { - "name": "missing parameter", - "input": "/test simple", - "error": true, - "output": { - "meow": "" - } - }, - { - "name": "directed at another bot", - "input": "/test simple@anotherbot mrrp", - "error": false, - "output": null - } - ] -} diff --git a/event/cmdschema/testdata/commands/tail.json b/event/cmdschema/testdata/commands/tail.json deleted file mode 100644 index 9782f8ec..00000000 --- a/event/cmdschema/testdata/commands/tail.json +++ /dev/null @@ -1,60 +0,0 @@ -{ - "$schema": "../commands.schema.json#", - "spec": { - "command": "tail", - "source": "@testbot", - "parameters": [ - { - "key": "meow", - "schema": { - "schema_type": "primitive", - "type": "string" - } - }, - { - "key": "reason", - "schema": { - "schema_type": "primitive", - "type": "string" - }, - "optional": true - }, - { - "key": "woof", - "schema": { - "schema_type": "primitive", - "type": "boolean" - }, - "optional": true - } - ], - "fi.mau.tail_parameter": "reason" - }, - "tests": [ - { - "name": "no tail or flag", - "input": "/tail mrrp", - "output": { - "meow": "mrrp", - "reason": "" - } - }, - { - "name": "tail, no flag", - "input": "/tail mrrp meow meow", - "output": { - "meow": "mrrp", - "reason": "meow meow" - } - }, - { - "name": "flag before tail", - "input": "/tail mrrp --woof meow meow", - "output": { - "meow": "mrrp", - "reason": "meow meow", - "woof": true - } - } - ] -} diff --git a/event/cmdschema/testdata/data.go b/event/cmdschema/testdata/data.go deleted file mode 100644 index eceea3d2..00000000 --- a/event/cmdschema/testdata/data.go +++ /dev/null @@ -1,14 +0,0 @@ -// Copyright (c) 2026 Tulir Asokan -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. - -package testdata - -import ( - "embed" -) - -//go:embed * -var FS embed.FS diff --git a/event/cmdschema/testdata/parse_quote.json b/event/cmdschema/testdata/parse_quote.json deleted file mode 100644 index 8f52b7f5..00000000 --- a/event/cmdschema/testdata/parse_quote.json +++ /dev/null @@ -1,30 +0,0 @@ -[ - {"name": "empty string", "input": "", "output": ["", "", false]}, - {"name": "single word", "input": "meow", "output": ["meow", "", false]}, - {"name": "two words", "input": "meow woof", "output": ["meow", "woof", false]}, - {"name": "many words", "input": "meow meow mrrp", "output": ["meow", "meow mrrp", false]}, - {"name": "extra spaces", "input": "meow meow mrrp", "output": ["meow", "meow mrrp", false]}, - {"name": "trailing space", "input": "meow ", "output": ["meow", "", false]}, - {"name": "only spaces", "input": " ", "output": ["", "", false]}, - {"name": "leading spaces", "input": " meow woof", "output": ["", "meow woof", false]}, - {"name": "backslash at end unquoted", "input": "meow\\ woof", "output": ["meow\\", "woof", false]}, - {"name": "quoted word", "input": "\"meow\" meow mrrp", "output": ["meow", "meow mrrp", true]}, - {"name": "quoted words", "input": "\"meow meow\" mrrp", "output": ["meow meow", "mrrp", true]}, - {"name": "spaces in quotes", "input": "\" meow meow \" mrrp", "output": [" meow meow ", "mrrp", true]}, - {"name": "empty quoted string", "input": "\"\"", "output": ["", "", true]}, - {"name": "empty quoted with trailing", "input": "\"\" meow", "output": ["", "meow", true]}, - {"name": "quote no space before next", "input": "\"meow\"woof", "output": ["meow", "woof", true]}, - {"name": "just opening quote", "input": "\"", "output": ["", "", true]}, - {"name": "quote then space then text", "input": "\" meow", "output": [" meow", "", true]}, - {"name": "quotes after word", "input": "meow \" meow mrrp \"", "output": ["meow", "\" meow mrrp \"", false]}, - {"name": "escaped quote", "input": "\"meow\\\" meow\" mrrp", "output": ["meow\" meow", "mrrp", true]}, - {"name": "missing end quote", "input": "\"meow meow mrrp", "output": ["meow meow mrrp", "", true]}, - {"name": "missing end quote with escaped quote", "input": "\"meow\\\" meow mrrp", "output": ["meow\" meow mrrp", "", true]}, - {"name": "quote in the middle", "input": "me\"ow meow mrrp", "output": ["me\"ow", "meow mrrp", false]}, - {"name": "backslash in the middle", "input": "me\\ow meow mrrp", "output": ["me\\ow", "meow mrrp", false]}, - {"name": "other escaped character", "input": "\"m\\eow\" meow mrrp", "output": ["meow", "meow mrrp", true]}, - {"name": "escaped backslashes", "input": "\"m\\\\e\\\"ow\\\\\" meow mrrp", "output": ["m\\e\"ow\\", "meow mrrp", true]}, - {"name": "just quotes", "input": "\"\\\"\\\"\\\\\\\"\" meow", "output": ["\"\"\\\"", "meow", true]}, - {"name": "escape at eof", "input": "\"meow\\", "output": ["meow", "", true]}, - {"name": "escaped backslash at eof", "input": "\"meow\\\\", "output": ["meow\\", "", true]} -] diff --git a/event/cmdschema/testdata/parse_quote.schema.json b/event/cmdschema/testdata/parse_quote.schema.json deleted file mode 100644 index 9f249116..00000000 --- a/event/cmdschema/testdata/parse_quote.schema.json +++ /dev/null @@ -1,46 +0,0 @@ -{ - "$schema": "https://json-schema.org/draft/2020-12/schema#", - "$id": "parse_quote.schema.json", - "title": "parseQuote test cases", - "description": "Test cases for the parseQuoted function", - "type": "array", - "items": { - "type": "object", - "required": [ - "name", - "input", - "output" - ], - "properties": { - "name": { - "type": "string", - "description": "Name of the test case" - }, - "input": { - "type": "string", - "description": "Input string to be parsed" - }, - "output": { - "type": "array", - "description": "Expected output of parsing: [first word, remaining text, was quoted]", - "minItems": 3, - "maxItems": 3, - "prefixItems": [ - { - "type": "string", - "description": "First parsed word" - }, - { - "type": "string", - "description": "Remaining text after the first word" - }, - { - "type": "boolean", - "description": "Whether the first word was quoted" - } - ] - } - }, - "additionalProperties": false - } -} diff --git a/event/content.go b/event/content.go index 814aeec4..b56e35f2 100644 --- a/event/content.go +++ b/event/content.go @@ -18,7 +18,6 @@ import ( // This is used by Content.ParseRaw() for creating the correct type of struct. var TypeMap = map[Type]reflect.Type{ StateMember: reflect.TypeOf(MemberEventContent{}), - StateThirdPartyInvite: reflect.TypeOf(ThirdPartyInviteEventContent{}), StatePowerLevels: reflect.TypeOf(PowerLevelsEventContent{}), StateCanonicalAlias: reflect.TypeOf(CanonicalAliasEventContent{}), StateRoomName: reflect.TypeOf(RoomNameEventContent{}), @@ -39,9 +38,7 @@ var TypeMap = map[Type]reflect.Type{ StateHalfShotBridge: reflect.TypeOf(BridgeEventContent{}), StateSpaceParent: reflect.TypeOf(SpaceParentEventContent{}), StateSpaceChild: reflect.TypeOf(SpaceChildEventContent{}), - - StateRoomPolicy: reflect.TypeOf(RoomPolicyEventContent{}), - StateUnstableRoomPolicy: reflect.TypeOf(RoomPolicyEventContent{}), + StateInsertionMarker: reflect.TypeOf(InsertionMarkerContent{}), StateLegacyPolicyRoom: reflect.TypeOf(ModPolicyContent{}), StateLegacyPolicyServer: reflect.TypeOf(ModPolicyContent{}), @@ -52,7 +49,6 @@ var TypeMap = map[Type]reflect.Type{ StateElementFunctionalMembers: reflect.TypeOf(ElementFunctionalMembersContent{}), StateBeeperRoomFeatures: reflect.TypeOf(RoomFeatures{}), - StateBeeperDisappearingTimer: reflect.TypeOf(BeeperDisappearingTimer{}), EventMessage: reflect.TypeOf(MessageEventContent{}), EventSticker: reflect.TypeOf(MessageEventContent{}), @@ -63,11 +59,8 @@ var TypeMap = map[Type]reflect.Type{ EventUnstablePollStart: reflect.TypeOf(PollStartEventContent{}), EventUnstablePollResponse: reflect.TypeOf(PollResponseEventContent{}), - BeeperMessageStatus: reflect.TypeOf(BeeperMessageStatusEventContent{}), - BeeperTranscription: reflect.TypeOf(BeeperTranscriptionEventContent{}), - BeeperDeleteChat: reflect.TypeOf(BeeperChatDeleteEventContent{}), - BeeperAcceptMessageRequest: reflect.TypeOf(BeeperAcceptMessageRequestEventContent{}), - BeeperSendState: reflect.TypeOf(BeeperSendStateEventContent{}), + BeeperMessageStatus: reflect.TypeOf(BeeperMessageStatusEventContent{}), + BeeperTranscription: reflect.TypeOf(BeeperTranscriptionEventContent{}), AccountDataRoomTags: reflect.TypeOf(TagEventContent{}), AccountDataDirectChats: reflect.TypeOf(DirectChatsEventContent{}), @@ -76,11 +69,9 @@ var TypeMap = map[Type]reflect.Type{ AccountDataMarkedUnread: reflect.TypeOf(MarkedUnreadEventContent{}), AccountDataBeeperMute: reflect.TypeOf(BeeperMuteEventContent{}), - EphemeralEventTyping: reflect.TypeOf(TypingEventContent{}), - EphemeralEventReceipt: reflect.TypeOf(ReceiptEventContent{}), - EphemeralEventPresence: reflect.TypeOf(PresenceEventContent{}), - EphemeralEventEncrypted: reflect.TypeOf(EncryptedEventContent{}), - BeeperEphemeralEventAIStream: reflect.TypeOf(BeeperAIStreamEventContent{}), + EphemeralEventTyping: reflect.TypeOf(TypingEventContent{}), + EphemeralEventReceipt: reflect.TypeOf(ReceiptEventContent{}), + EphemeralEventPresence: reflect.TypeOf(PresenceEventContent{}), InRoomVerificationReady: reflect.TypeOf(VerificationReadyEventContent{}), InRoomVerificationStart: reflect.TypeOf(VerificationStartEventContent{}), diff --git a/event/delayed.go b/event/delayed.go deleted file mode 100644 index fefb62af..00000000 --- a/event/delayed.go +++ /dev/null @@ -1,70 +0,0 @@ -package event - -import ( - "encoding/json" - - "go.mau.fi/util/jsontime" - - "maunium.net/go/mautrix/id" -) - -type ScheduledDelayedEvent struct { - DelayID id.DelayID `json:"delay_id"` - RoomID id.RoomID `json:"room_id"` - Type Type `json:"type"` - StateKey *string `json:"state_key,omitempty"` - Delay int64 `json:"delay"` - RunningSince jsontime.UnixMilli `json:"running_since"` - Content Content `json:"content"` -} - -func (e ScheduledDelayedEvent) AsEvent(eventID id.EventID, ts jsontime.UnixMilli) (*Event, error) { - evt := &Event{ - ID: eventID, - RoomID: e.RoomID, - Type: e.Type, - StateKey: e.StateKey, - Content: e.Content, - Timestamp: ts.UnixMilli(), - } - return evt, evt.Content.ParseRaw(evt.Type) -} - -type FinalisedDelayedEvent struct { - DelayedEvent *ScheduledDelayedEvent `json:"scheduled_event"` - Outcome DelayOutcome `json:"outcome"` - Reason DelayReason `json:"reason"` - Error json.RawMessage `json:"error,omitempty"` - EventID id.EventID `json:"event_id,omitempty"` - Timestamp jsontime.UnixMilli `json:"origin_server_ts"` -} - -type DelayStatus string - -var ( - DelayStatusScheduled DelayStatus = "scheduled" - DelayStatusFinalised DelayStatus = "finalised" -) - -type DelayAction string - -var ( - DelayActionSend DelayAction = "send" - DelayActionCancel DelayAction = "cancel" - DelayActionRestart DelayAction = "restart" -) - -type DelayOutcome string - -var ( - DelayOutcomeSend DelayOutcome = "send" - DelayOutcomeCancel DelayOutcome = "cancel" -) - -type DelayReason string - -var ( - DelayReasonAction DelayReason = "action" - DelayReasonError DelayReason = "error" - DelayReasonDelay DelayReason = "delay" -) diff --git a/event/encryption.go b/event/encryption.go index c60cb91a..cf9c2814 100644 --- a/event/encryption.go +++ b/event/encryption.go @@ -63,7 +63,7 @@ func (content *EncryptedEventContent) UnmarshalJSON(data []byte) error { return json.Unmarshal(content.Ciphertext, &content.OlmCiphertext) case id.AlgorithmMegolmV1: if len(content.Ciphertext) == 0 || content.Ciphertext[0] != '"' || content.Ciphertext[len(content.Ciphertext)-1] != '"' { - return fmt.Errorf("ciphertext %w", id.ErrInputNotJSONString) + return id.InputNotJSONString } content.MegolmCiphertext = content.Ciphertext[1 : len(content.Ciphertext)-1] } @@ -132,9 +132,8 @@ type RoomKeyRequestEventContent struct { type RequestedKeyInfo struct { Algorithm id.Algorithm `json:"algorithm"` RoomID id.RoomID `json:"room_id"` - SessionID id.SessionID `json:"session_id"` - // Deprecated: Matrix v1.3 SenderKey id.SenderKey `json:"sender_key"` + SessionID id.SessionID `json:"session_id"` } type RoomKeyWithheldCode string diff --git a/event/member.go b/event/member.go index 9956a36b..53387e8b 100644 --- a/event/member.go +++ b/event/member.go @@ -7,6 +7,8 @@ package event import ( + "encoding/json" + "maunium.net/go/mautrix/id" ) @@ -45,25 +47,11 @@ type MemberEventContent struct { MSC4293RedactEvents bool `json:"org.matrix.msc4293.redact_events,omitempty"` } -type SignedThirdPartyInvite struct { - Token string `json:"token"` - Signatures map[string]map[id.KeyID]string `json:"signatures,omitempty"` - MXID string `json:"mxid"` -} - type ThirdPartyInvite struct { - DisplayName string `json:"display_name"` - Signed SignedThirdPartyInvite `json:"signed"` -} - -type ThirdPartyInviteEventContent struct { - DisplayName string `json:"display_name"` - KeyValidityURL string `json:"key_validity_url"` - PublicKey id.Ed25519 `json:"public_key"` - PublicKeys []ThirdPartyInviteKey `json:"public_keys,omitempty"` -} - -type ThirdPartyInviteKey struct { - KeyValidityURL string `json:"key_validity_url,omitempty"` - PublicKey id.Ed25519 `json:"public_key"` + DisplayName string `json:"display_name"` + Signed struct { + Token string `json:"token"` + Signatures json.RawMessage `json:"signatures"` + MXID string `json:"mxid"` + } } diff --git a/event/message.go b/event/message.go index 3fb3dc82..51403889 100644 --- a/event/message.go +++ b/event/message.go @@ -135,16 +135,11 @@ type MessageEventContent struct { BeeperGalleryCaption string `json:"com.beeper.gallery.caption,omitempty"` BeeperGalleryCaptionHTML string `json:"com.beeper.gallery.caption_html,omitempty"` BeeperPerMessageProfile *BeeperPerMessageProfile `json:"com.beeper.per_message_profile,omitempty"` - BeeperActionMessage *BeeperActionMessage `json:"com.beeper.action_message,omitempty"` BeeperLinkPreviews []*BeeperLinkPreview `json:"com.beeper.linkpreviews,omitempty"` - BeeperDisappearingTimer *BeeperDisappearingTimer `json:"com.beeper.disappearing_timer,omitempty"` - MSC1767Audio *MSC1767Audio `json:"org.matrix.msc1767.audio,omitempty"` MSC3245Voice *MSC3245Voice `json:"org.matrix.msc3245.voice,omitempty"` - - MSC4391BotCommand *MSC4391BotCommandInput `json:"org.matrix.msc4391.command,omitempty"` } func (content *MessageEventContent) GetCapMsgType() CapabilityMsgType { @@ -276,25 +271,6 @@ func (m *Mentions) Has(userID id.UserID) bool { return m != nil && slices.Contains(m.UserIDs, userID) } -func (m *Mentions) Merge(other *Mentions) *Mentions { - if m == nil { - return other - } else if other == nil { - return m - } - return &Mentions{ - UserIDs: slices.Concat(m.UserIDs, other.UserIDs), - Room: m.Room || other.Room, - } -} - -type MSC4391BotCommandInputCustom[T any] struct { - Command string `json:"command"` - Arguments T `json:"arguments,omitempty"` -} - -type MSC4391BotCommandInput = MSC4391BotCommandInputCustom[json.RawMessage] - type EncryptedFileInfo struct { attachment.EncryptedFile URL id.ContentURIString `json:"url"` @@ -309,8 +285,7 @@ type FileInfo struct { Blurhash string AnoaBlurhash string - MauGIF bool - IsAnimated bool + MauGIF bool Width int Height int @@ -327,8 +302,7 @@ type serializableFileInfo struct { Blurhash string `json:"blurhash,omitempty"` AnoaBlurhash string `json:"xyz.amorgan.blurhash,omitempty"` - MauGIF bool `json:"fi.mau.gif,omitempty"` - IsAnimated bool `json:"is_animated,omitempty"` + MauGIF bool `json:"fi.mau.gif,omitempty"` Width json.Number `json:"w,omitempty"` Height json.Number `json:"h,omitempty"` @@ -346,8 +320,7 @@ func (sfi *serializableFileInfo) CopyFrom(fileInfo *FileInfo) *serializableFileI ThumbnailInfo: (&serializableFileInfo{}).CopyFrom(fileInfo.ThumbnailInfo), ThumbnailFile: fileInfo.ThumbnailFile, - MauGIF: fileInfo.MauGIF, - IsAnimated: fileInfo.IsAnimated, + MauGIF: fileInfo.MauGIF, Blurhash: fileInfo.Blurhash, AnoaBlurhash: fileInfo.AnoaBlurhash, @@ -378,7 +351,6 @@ func (sfi *serializableFileInfo) CopyTo(fileInfo *FileInfo) { ThumbnailURL: sfi.ThumbnailURL, ThumbnailFile: sfi.ThumbnailFile, MauGIF: sfi.MauGIF, - IsAnimated: sfi.IsAnimated, Blurhash: sfi.Blurhash, AnoaBlurhash: sfi.AnoaBlurhash, } diff --git a/event/message_test.go b/event/message_test.go index c721df35..562a6622 100644 --- a/event/message_test.go +++ b/event/message_test.go @@ -33,7 +33,7 @@ const invalidMessageEvent = `{ func TestMessageEventContent__ParseInvalid(t *testing.T) { var evt *event.Event err := json.Unmarshal([]byte(invalidMessageEvent), &evt) - assert.NoError(t, err) + assert.Nil(t, err) assert.Equal(t, id.UserID("@tulir:maunium.net"), evt.Sender) assert.Equal(t, event.EventMessage, evt.Type) @@ -42,7 +42,7 @@ func TestMessageEventContent__ParseInvalid(t *testing.T) { assert.Equal(t, id.RoomID("!bar"), evt.RoomID) err = evt.Content.ParseRaw(evt.Type) - assert.Error(t, err) + assert.NotNil(t, err) } const messageEvent = `{ @@ -68,7 +68,7 @@ const messageEvent = `{ func TestMessageEventContent__ParseEdit(t *testing.T) { var evt *event.Event err := json.Unmarshal([]byte(messageEvent), &evt) - assert.NoError(t, err) + assert.Nil(t, err) assert.Equal(t, id.UserID("@tulir:maunium.net"), evt.Sender) assert.Equal(t, event.EventMessage, evt.Type) @@ -110,7 +110,7 @@ const imageMessageEvent = `{ func TestMessageEventContent__ParseMedia(t *testing.T) { var evt *event.Event err := json.Unmarshal([]byte(imageMessageEvent), &evt) - assert.NoError(t, err) + assert.Nil(t, err) assert.Equal(t, id.UserID("@tulir:maunium.net"), evt.Sender) assert.Equal(t, event.EventMessage, evt.Type) @@ -125,7 +125,7 @@ func TestMessageEventContent__ParseMedia(t *testing.T) { content := evt.Content.Parsed.(*event.MessageEventContent) assert.Equal(t, event.MsgImage, content.MsgType) parsedURL, err := content.URL.Parse() - assert.NoError(t, err) + assert.Nil(t, err) assert.Equal(t, id.ContentURI{Homeserver: "example.com", FileID: "image"}, parsedURL) assert.Nil(t, content.NewContent) assert.Equal(t, "image/png", content.GetInfo().MimeType) @@ -145,7 +145,7 @@ const expectedMarshalResult = `{"msgtype":"m.text","body":"test"}` func TestMessageEventContent__Marshal(t *testing.T) { data, err := json.Marshal(parsedMessage) - assert.NoError(t, err) + assert.Nil(t, err) assert.Equal(t, expectedMarshalResult, string(data)) } @@ -163,6 +163,6 @@ const expectedCustomMarshalResult = `{"body":"test","msgtype":"m.text","net.maun func TestMessageEventContent__Marshal_Custom(t *testing.T) { data, err := json.Marshal(customParsedMessage) - assert.NoError(t, err) + assert.Nil(t, err) assert.Equal(t, expectedCustomMarshalResult, string(data)) } diff --git a/event/poll.go b/event/poll.go index 9082f65e..47131a8f 100644 --- a/event/poll.go +++ b/event/poll.go @@ -35,7 +35,7 @@ type MSC1767Message struct { } type PollStartEventContent struct { - RelatesTo *RelatesTo `json:"m.relates_to,omitempty"` + RelatesTo *RelatesTo `json:"m.relates_to"` Mentions *Mentions `json:"m.mentions,omitempty"` PollStart struct { Kind string `json:"kind"` diff --git a/event/powerlevels.go b/event/powerlevels.go index 668eb6d3..50df2c1f 100644 --- a/event/powerlevels.go +++ b/event/powerlevels.go @@ -28,9 +28,6 @@ type PowerLevelsEventContent struct { Events map[string]int `json:"events,omitempty"` EventsDefault int `json:"events_default,omitempty"` - beeperEphemeralLock sync.RWMutex - BeeperEphemeral map[string]int `json:"com.beeper.ephemeral,omitempty"` - Notifications *NotificationPowerLevels `json:"notifications,omitempty"` StateDefaultPtr *int `json:"state_default,omitempty"` @@ -40,8 +37,6 @@ type PowerLevelsEventContent struct { BanPtr *int `json:"ban,omitempty"` RedactPtr *int `json:"redact,omitempty"` - BeeperEphemeralDefaultPtr *int `json:"com.beeper.ephemeral_default,omitempty"` - // This is not a part of power levels, it's added by mautrix-go internally in certain places // in order to detect creator power accurately. CreateEvent *Event `json:"-"` @@ -56,7 +51,6 @@ func (pl *PowerLevelsEventContent) Clone() *PowerLevelsEventContent { UsersDefault: pl.UsersDefault, Events: maps.Clone(pl.Events), EventsDefault: pl.EventsDefault, - BeeperEphemeral: maps.Clone(pl.BeeperEphemeral), StateDefaultPtr: ptr.Clone(pl.StateDefaultPtr), Notifications: pl.Notifications.Clone(), @@ -66,8 +60,6 @@ func (pl *PowerLevelsEventContent) Clone() *PowerLevelsEventContent { BanPtr: ptr.Clone(pl.BanPtr), RedactPtr: ptr.Clone(pl.RedactPtr), - BeeperEphemeralDefaultPtr: ptr.Clone(pl.BeeperEphemeralDefaultPtr), - CreateEvent: pl.CreateEvent, } } @@ -127,13 +119,6 @@ func (pl *PowerLevelsEventContent) StateDefault() int { return 50 } -func (pl *PowerLevelsEventContent) BeeperEphemeralDefault() int { - if pl.BeeperEphemeralDefaultPtr != nil { - return *pl.BeeperEphemeralDefaultPtr - } - return pl.EventsDefault -} - func (pl *PowerLevelsEventContent) GetUserLevel(userID id.UserID) int { if pl.isCreator(userID) { return math.MaxInt @@ -147,19 +132,9 @@ func (pl *PowerLevelsEventContent) GetUserLevel(userID id.UserID) int { return level } -const maxPL = 1<<53 - 1 - func (pl *PowerLevelsEventContent) SetUserLevel(userID id.UserID, level int) { pl.usersLock.Lock() defer pl.usersLock.Unlock() - if pl.isCreator(userID) { - return - } - if level == math.MaxInt && maxPL < math.MaxInt { - // Hack to avoid breaking on 32-bit systems (they're only slightly supported) - x := int64(maxPL) - level = int(x) - } if level == pl.UsersDefault { delete(pl.Users, userID) } else { @@ -217,29 +192,6 @@ func (pl *PowerLevelsEventContent) GetEventLevel(eventType Type) int { return level } -func (pl *PowerLevelsEventContent) GetBeeperEphemeralLevel(eventType Type) int { - pl.beeperEphemeralLock.RLock() - defer pl.beeperEphemeralLock.RUnlock() - level, ok := pl.BeeperEphemeral[eventType.String()] - if !ok { - return pl.BeeperEphemeralDefault() - } - return level -} - -func (pl *PowerLevelsEventContent) SetBeeperEphemeralLevel(eventType Type, level int) { - pl.beeperEphemeralLock.Lock() - defer pl.beeperEphemeralLock.Unlock() - if level == pl.BeeperEphemeralDefault() { - delete(pl.BeeperEphemeral, eventType.String()) - } else { - if pl.BeeperEphemeral == nil { - pl.BeeperEphemeral = make(map[string]int) - } - pl.BeeperEphemeral[eventType.String()] = level - } -} - func (pl *PowerLevelsEventContent) SetEventLevel(eventType Type, level int) { pl.eventsLock.Lock() defer pl.eventsLock.Unlock() diff --git a/event/powerlevels_ephemeral_test.go b/event/powerlevels_ephemeral_test.go deleted file mode 100644 index f5861583..00000000 --- a/event/powerlevels_ephemeral_test.go +++ /dev/null @@ -1,67 +0,0 @@ -// Copyright (c) 2026 Tulir Asokan -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. - -package event_test - -import ( - "testing" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - - "maunium.net/go/mautrix/event" -) - -func TestPowerLevelsEventContent_BeeperEphemeralDefaultFallsBackToEventsDefault(t *testing.T) { - pl := &event.PowerLevelsEventContent{ - EventsDefault: 45, - } - - assert.Equal(t, 45, pl.BeeperEphemeralDefault()) - - override := 60 - pl.BeeperEphemeralDefaultPtr = &override - assert.Equal(t, 60, pl.BeeperEphemeralDefault()) -} - -func TestPowerLevelsEventContent_GetSetBeeperEphemeralLevel(t *testing.T) { - pl := &event.PowerLevelsEventContent{ - EventsDefault: 25, - } - evtType := event.Type{Type: "com.example.ephemeral", Class: event.EphemeralEventType} - - assert.Equal(t, 25, pl.GetBeeperEphemeralLevel(evtType)) - - pl.SetBeeperEphemeralLevel(evtType, 50) - assert.Equal(t, 50, pl.GetBeeperEphemeralLevel(evtType)) - require.NotNil(t, pl.BeeperEphemeral) - assert.Equal(t, 50, pl.BeeperEphemeral[evtType.String()]) - - pl.SetBeeperEphemeralLevel(evtType, 25) - _, exists := pl.BeeperEphemeral[evtType.String()] - assert.False(t, exists) -} - -func TestPowerLevelsEventContent_CloneCopiesBeeperEphemeralFields(t *testing.T) { - override := 70 - pl := &event.PowerLevelsEventContent{ - EventsDefault: 35, - BeeperEphemeral: map[string]int{"com.example.ephemeral": 90}, - BeeperEphemeralDefaultPtr: &override, - } - - cloned := pl.Clone() - require.NotNil(t, cloned) - require.NotNil(t, cloned.BeeperEphemeralDefaultPtr) - assert.Equal(t, 70, *cloned.BeeperEphemeralDefaultPtr) - assert.Equal(t, 90, cloned.BeeperEphemeral["com.example.ephemeral"]) - - cloned.BeeperEphemeral["com.example.ephemeral"] = 99 - *cloned.BeeperEphemeralDefaultPtr = 71 - - assert.Equal(t, 90, pl.BeeperEphemeral["com.example.ephemeral"]) - assert.Equal(t, 70, *pl.BeeperEphemeralDefaultPtr) -} diff --git a/event/reply.go b/event/reply.go index 5f55bb80..9ae1c110 100644 --- a/event/reply.go +++ b/event/reply.go @@ -32,13 +32,12 @@ func TrimReplyFallbackText(text string) string { } func (content *MessageEventContent) RemoveReplyFallback() { - if len(content.RelatesTo.GetReplyTo()) > 0 && !content.replyFallbackRemoved && content.Format == FormatHTML { - origHTML := content.FormattedBody - content.FormattedBody = TrimReplyFallbackHTML(content.FormattedBody) - if content.FormattedBody != origHTML { - content.Body = TrimReplyFallbackText(content.Body) - content.replyFallbackRemoved = true + if len(content.RelatesTo.GetReplyTo()) > 0 && !content.replyFallbackRemoved { + if content.Format == FormatHTML { + content.FormattedBody = TrimReplyFallbackHTML(content.FormattedBody) } + content.Body = TrimReplyFallbackText(content.Body) + content.replyFallbackRemoved = true } } diff --git a/event/state.go b/event/state.go index ace170a5..44a45a57 100644 --- a/event/state.go +++ b/event/state.go @@ -8,11 +8,8 @@ package event import ( "encoding/base64" - "encoding/json" "slices" - "go.mau.fi/util/jsontime" - "maunium.net/go/mautrix/id" ) @@ -56,40 +53,10 @@ type TopicEventContent struct { // m.room.topic state event as described in [MSC3765]. // // [MSC3765]: https://github.com/matrix-org/matrix-spec-proposals/pull/3765 -type ExtensibleTopic = ExtensibleTextContainer - -type ExtensibleTextContainer struct { +type ExtensibleTopic struct { Text []ExtensibleText `json:"m.text"` } -func (c *ExtensibleTextContainer) Equals(description *ExtensibleTextContainer) bool { - if c == nil || description == nil { - return c == description - } - return slices.Equal(c.Text, description.Text) -} - -func MakeExtensibleText(text string) *ExtensibleTextContainer { - return &ExtensibleTextContainer{ - Text: []ExtensibleText{{ - Body: text, - MimeType: "text/plain", - }}, - } -} - -func MakeExtensibleFormattedText(plaintext, html string) *ExtensibleTextContainer { - return &ExtensibleTextContainer{ - Text: []ExtensibleText{{ - Body: plaintext, - MimeType: "text/plain", - }, { - Body: html, - MimeType: "text/html", - }}, - } -} - // ExtensibleText represents the contents of an m.text field. type ExtensibleText struct { MimeType string `json:"mimetype,omitempty"` @@ -103,13 +70,6 @@ type TombstoneEventContent struct { ReplacementRoom id.RoomID `json:"replacement_room"` } -func (tec *TombstoneEventContent) GetReplacementRoom() id.RoomID { - if tec == nil { - return "" - } - return tec.ReplacementRoom -} - type Predecessor struct { RoomID id.RoomID `json:"room_id"` EventID id.EventID `json:"event_id"` @@ -149,13 +109,6 @@ type CreateEventContent struct { Creator id.UserID `json:"creator,omitempty"` } -func (cec *CreateEventContent) GetPredecessor() (p Predecessor) { - if cec != nil && cec.Predecessor != nil { - p = *cec.Predecessor - } - return -} - func (cec *CreateEventContent) SupportsCreatorPower() bool { if cec == nil { return false @@ -238,8 +191,7 @@ type BridgeInfoSection struct { AvatarURL id.ContentURIString `json:"avatar_url,omitempty"` ExternalURL string `json:"external_url,omitempty"` - Receiver string `json:"fi.mau.receiver,omitempty"` - MessageRequest bool `json:"com.beeper.message_request,omitempty"` + Receiver string `json:"fi.mau.receiver,omitempty"` } // BridgeEventContent represents the content of a m.bridge state event. @@ -253,32 +205,6 @@ type BridgeEventContent struct { BeeperRoomType string `json:"com.beeper.room_type,omitempty"` BeeperRoomTypeV2 string `json:"com.beeper.room_type.v2,omitempty"` - - TempSlackRemoteIDMigratedFlag bool `json:"com.beeper.slack_remote_id_migrated,omitempty"` - TempSlackRemoteIDMigratedFlag2 bool `json:"com.beeper.slack_remote_id_really_migrated,omitempty"` -} - -// DisappearingType represents the type of a disappearing message timer. -type DisappearingType string - -const ( - DisappearingTypeNone DisappearingType = "" - DisappearingTypeAfterRead DisappearingType = "after_read" - DisappearingTypeAfterSend DisappearingType = "after_send" -) - -type BeeperDisappearingTimer struct { - Type DisappearingType `json:"type"` - Timer jsontime.Milliseconds `json:"timer"` -} - -type marshalableBeeperDisappearingTimer BeeperDisappearingTimer - -func (bdt *BeeperDisappearingTimer) MarshalJSON() ([]byte, error) { - if bdt == nil || bdt.Type == DisappearingTypeNone { - return []byte("{}"), nil - } - return json.Marshal((*marshalableBeeperDisappearingTimer)(bdt)) } type SpaceChildEventContent struct { @@ -332,6 +258,12 @@ func (mpc *ModPolicyContent) EntityOrHash() string { return mpc.Entity } +// Deprecated: MSC2716 has been abandoned +type InsertionMarkerContent struct { + InsertionID id.EventID `json:"org.matrix.msc2716.marker.insertion"` + Timestamp int64 `json:"com.beeper.timestamp,omitempty"` +} + type ElementFunctionalMembersContent struct { ServiceMembers []id.UserID `json:"service_members"` } @@ -343,15 +275,3 @@ func (efmc *ElementFunctionalMembersContent) Add(mxid id.UserID) bool { efmc.ServiceMembers = append(efmc.ServiceMembers, mxid) return true } - -type PolicyServerPublicKeys struct { - Ed25519 id.Ed25519 `json:"ed25519,omitempty"` -} - -type RoomPolicyEventContent struct { - Via string `json:"via,omitempty"` - PublicKeys *PolicyServerPublicKeys `json:"public_keys,omitempty"` - - // Deprecated, only for legacy use - PublicKey id.Ed25519 `json:"public_key,omitempty"` -} diff --git a/event/type.go b/event/type.go index 80b86728..591d598d 100644 --- a/event/type.go +++ b/event/type.go @@ -108,14 +108,13 @@ func (et *Type) IsCustom() bool { func (et *Type) GuessClass() TypeClass { switch et.Type { - case StateAliases.Type, StateCanonicalAlias.Type, StateCreate.Type, StateJoinRules.Type, StateMember.Type, StateThirdPartyInvite.Type, + case StateAliases.Type, StateCanonicalAlias.Type, StateCreate.Type, StateJoinRules.Type, StateMember.Type, StatePowerLevels.Type, StateRoomName.Type, StateRoomAvatar.Type, StateServerACL.Type, StateTopic.Type, StatePinnedEvents.Type, StateTombstone.Type, StateEncryption.Type, StateBridge.Type, StateHalfShotBridge.Type, StateSpaceParent.Type, StateSpaceChild.Type, StatePolicyRoom.Type, StatePolicyServer.Type, StatePolicyUser.Type, - StateElementFunctionalMembers.Type, StateBeeperRoomFeatures.Type, StateBeeperDisappearingTimer.Type, - StateMSC4391BotCommand.Type, StateRoomPolicy.Type, StateUnstableRoomPolicy.Type: + StateInsertionMarker.Type, StateElementFunctionalMembers.Type, StateBeeperRoomFeatures.Type: return StateEventType - case EphemeralEventReceipt.Type, EphemeralEventTyping.Type, EphemeralEventPresence.Type, BeeperEphemeralEventAIStream.Type: + case EphemeralEventReceipt.Type, EphemeralEventTyping.Type, EphemeralEventPresence.Type: return EphemeralEventType case AccountDataDirectChats.Type, AccountDataPushRules.Type, AccountDataRoomTags.Type, AccountDataFullyRead.Type, AccountDataIgnoredUserList.Type, AccountDataMarkedUnread.Type, @@ -128,7 +127,7 @@ func (et *Type) GuessClass() TypeClass { InRoomVerificationKey.Type, InRoomVerificationMAC.Type, InRoomVerificationCancel.Type, CallInvite.Type, CallCandidates.Type, CallAnswer.Type, CallReject.Type, CallSelectAnswer.Type, CallNegotiate.Type, CallHangup.Type, BeeperMessageStatus.Type, EventUnstablePollStart.Type, EventUnstablePollResponse.Type, - EventUnstablePollEnd.Type, BeeperTranscription.Type, BeeperDeleteChat.Type, BeeperAcceptMessageRequest.Type: + BeeperTranscription.Type: return MessageEventType case ToDeviceRoomKey.Type, ToDeviceRoomKeyRequest.Type, ToDeviceForwardedRoomKey.Type, ToDeviceRoomKeyWithheld.Type, ToDeviceBeeperRoomKeyAck.Type: @@ -178,7 +177,6 @@ var ( StateHistoryVisibility = Type{"m.room.history_visibility", StateEventType} StateGuestAccess = Type{"m.room.guest_access", StateEventType} StateMember = Type{"m.room.member", StateEventType} - StateThirdPartyInvite = Type{"m.room.third_party_invite", StateEventType} StatePowerLevels = Type{"m.room.power_levels", StateEventType} StateRoomName = Type{"m.room.name", StateEventType} StateTopic = Type{"m.room.topic", StateEventType} @@ -195,9 +193,6 @@ var ( StateSpaceChild = Type{"m.space.child", StateEventType} StateSpaceParent = Type{"m.space.parent", StateEventType} - StateRoomPolicy = Type{"m.room.policy", StateEventType} - StateUnstableRoomPolicy = Type{"org.matrix.msc4284.policy", StateEventType} - StateLegacyPolicyRoom = Type{"m.room.rule.room", StateEventType} StateLegacyPolicyServer = Type{"m.room.rule.server", StateEventType} StateLegacyPolicyUser = Type{"m.room.rule.user", StateEventType} @@ -205,10 +200,11 @@ var ( StateUnstablePolicyServer = Type{"org.matrix.mjolnir.rule.server", StateEventType} StateUnstablePolicyUser = Type{"org.matrix.mjolnir.rule.user", StateEventType} + // Deprecated: MSC2716 has been abandoned + StateInsertionMarker = Type{"org.matrix.msc2716.marker", StateEventType} + StateElementFunctionalMembers = Type{"io.element.functional_members", StateEventType} StateBeeperRoomFeatures = Type{"com.beeper.room_features", StateEventType} - StateBeeperDisappearingTimer = Type{"com.beeper.disappearing_timer", StateEventType} - StateMSC4391BotCommand = Type{"org.matrix.msc4391.command_description", StateEventType} ) // Message events @@ -237,24 +233,18 @@ var ( CallNegotiate = Type{"m.call.negotiate", MessageEventType} CallHangup = Type{"m.call.hangup", MessageEventType} - BeeperMessageStatus = Type{"com.beeper.message_send_status", MessageEventType} - BeeperTranscription = Type{"com.beeper.transcription", MessageEventType} - BeeperDeleteChat = Type{"com.beeper.delete_chat", MessageEventType} - BeeperAcceptMessageRequest = Type{"com.beeper.accept_message_request", MessageEventType} - BeeperSendState = Type{"com.beeper.send_state", MessageEventType} + BeeperMessageStatus = Type{"com.beeper.message_send_status", MessageEventType} + BeeperTranscription = Type{"com.beeper.transcription", MessageEventType} EventUnstablePollStart = Type{Type: "org.matrix.msc3381.poll.start", Class: MessageEventType} EventUnstablePollResponse = Type{Type: "org.matrix.msc3381.poll.response", Class: MessageEventType} - EventUnstablePollEnd = Type{Type: "org.matrix.msc3381.poll.end", Class: MessageEventType} ) // Ephemeral events var ( - EphemeralEventReceipt = Type{"m.receipt", EphemeralEventType} - EphemeralEventTyping = Type{"m.typing", EphemeralEventType} - EphemeralEventPresence = Type{"m.presence", EphemeralEventType} - EphemeralEventEncrypted = Type{"m.room.encrypted", EphemeralEventType} - BeeperEphemeralEventAIStream = Type{"com.beeper.ai.stream_event", EphemeralEventType} + EphemeralEventReceipt = Type{"m.receipt", EphemeralEventType} + EphemeralEventTyping = Type{"m.typing", EphemeralEventType} + EphemeralEventPresence = Type{"m.presence", EphemeralEventType} ) // Account data events diff --git a/example/main.go b/example/main.go index 2bf4bef3..d8006d46 100644 --- a/example/main.go +++ b/example/main.go @@ -143,7 +143,7 @@ func main() { if err != nil { log.Error().Err(err).Msg("Failed to send event") } else { - log.Info().Stringer("event_id", resp.EventID).Msg("Event sent") + log.Info().Str("event_id", resp.EventID.String()).Msg("Event sent") } } cancelSync() diff --git a/federation/client.go b/federation/client.go index 183fb5d1..8f454516 100644 --- a/federation/client.go +++ b/federation/client.go @@ -30,8 +30,6 @@ type Client struct { ServerName string UserAgent string Key *SigningKey - - ResponseSizeLimit int64 } func NewClient(serverName string, key *SigningKey, cache ResolutionCache) *Client { @@ -39,16 +37,10 @@ func NewClient(serverName string, key *SigningKey, cache ResolutionCache) *Clien HTTP: &http.Client{ Transport: NewServerResolvingTransport(cache), Timeout: 120 * time.Second, - CheckRedirect: func(req *http.Request, via []*http.Request) error { - // Federation requests do not allow redirects. - return http.ErrUseLastResponse - }, }, UserAgent: mautrix.DefaultUserAgent, ServerName: serverName, Key: key, - - ResponseSizeLimit: mautrix.DefaultResponseSizeLimit, } } @@ -89,7 +81,7 @@ type RespSendTransaction struct { } func (c *Client) SendTransaction(ctx context.Context, req *ReqSendTransaction) (resp *RespSendTransaction, err error) { - err = c.MakeRequest(ctx, req.Destination, true, http.MethodPut, URLPath{"v1", "send", req.TxnID}, req, &resp) + err = c.MakeRequest(ctx, req.Destination, true, http.MethodPost, URLPath{"v1", "send", req.TxnID}, req, &resp) return } @@ -263,169 +255,6 @@ func (c *Client) GetOpenIDUserInfo(ctx context.Context, serverName, accessToken return } -type ReqMakeJoin struct { - RoomID id.RoomID - UserID id.UserID - Via string - SupportedVersions []id.RoomVersion -} - -type RespMakeJoin struct { - RoomVersion id.RoomVersion `json:"room_version"` - Event PDU `json:"event"` -} - -type ReqSendJoin struct { - RoomID id.RoomID - EventID id.EventID - OmitMembers bool - Event PDU - Via string -} - -type ReqSendKnock struct { - RoomID id.RoomID - EventID id.EventID - Event PDU - Via string -} - -type RespSendJoin struct { - AuthChain []PDU `json:"auth_chain"` - Event PDU `json:"event"` - MembersOmitted bool `json:"members_omitted"` - ServersInRoom []string `json:"servers_in_room"` - State []PDU `json:"state"` -} - -type RespSendKnock struct { - KnockRoomState []PDU `json:"knock_room_state"` -} - -type ReqSendInvite struct { - RoomID id.RoomID `json:"-"` - UserID id.UserID `json:"-"` - Event PDU `json:"event"` - InviteRoomState []PDU `json:"invite_room_state"` - RoomVersion id.RoomVersion `json:"room_version"` -} - -type RespSendInvite struct { - Event PDU `json:"event"` -} - -type ReqMakeLeave struct { - RoomID id.RoomID - UserID id.UserID - Via string -} - -type ReqSendLeave struct { - RoomID id.RoomID - EventID id.EventID - Event PDU - Via string -} - -type ( - ReqMakeKnock = ReqMakeJoin - RespMakeKnock = RespMakeJoin - RespMakeLeave = RespMakeJoin -) - -func (c *Client) MakeJoin(ctx context.Context, req *ReqMakeJoin) (resp *RespMakeJoin, err error) { - versions := make([]string, len(req.SupportedVersions)) - for i, v := range req.SupportedVersions { - versions[i] = string(v) - } - _, _, err = c.MakeFullRequest(ctx, RequestParams{ - ServerName: req.Via, - Method: http.MethodGet, - Path: URLPath{"v1", "make_join", req.RoomID, req.UserID}, - Query: url.Values{"ver": versions}, - Authenticate: true, - ResponseJSON: &resp, - }) - return -} - -func (c *Client) MakeKnock(ctx context.Context, req *ReqMakeKnock) (resp *RespMakeKnock, err error) { - versions := make([]string, len(req.SupportedVersions)) - for i, v := range req.SupportedVersions { - versions[i] = string(v) - } - _, _, err = c.MakeFullRequest(ctx, RequestParams{ - ServerName: req.Via, - Method: http.MethodGet, - Path: URLPath{"v1", "make_knock", req.RoomID, req.UserID}, - Query: url.Values{"ver": versions}, - Authenticate: true, - ResponseJSON: &resp, - }) - return -} - -func (c *Client) SendJoin(ctx context.Context, req *ReqSendJoin) (resp *RespSendJoin, err error) { - _, _, err = c.MakeFullRequest(ctx, RequestParams{ - ServerName: req.Via, - Method: http.MethodPut, - Path: URLPath{"v2", "send_join", req.RoomID, req.EventID}, - Query: url.Values{ - "omit_members": {strconv.FormatBool(req.OmitMembers)}, - }, - Authenticate: true, - RequestJSON: req.Event, - ResponseJSON: &resp, - }) - return -} - -func (c *Client) SendKnock(ctx context.Context, req *ReqSendKnock) (resp *RespSendKnock, err error) { - _, _, err = c.MakeFullRequest(ctx, RequestParams{ - ServerName: req.Via, - Method: http.MethodPut, - Path: URLPath{"v1", "send_knock", req.RoomID, req.EventID}, - Authenticate: true, - RequestJSON: req.Event, - ResponseJSON: &resp, - }) - return -} - -func (c *Client) SendInvite(ctx context.Context, req *ReqSendInvite) (resp *RespSendInvite, err error) { - _, _, err = c.MakeFullRequest(ctx, RequestParams{ - ServerName: req.UserID.Homeserver(), - Method: http.MethodPut, - Path: URLPath{"v2", "invite", req.RoomID, req.UserID}, - Authenticate: true, - RequestJSON: req, - ResponseJSON: &resp, - }) - return -} - -func (c *Client) MakeLeave(ctx context.Context, req *ReqMakeLeave) (resp *RespMakeLeave, err error) { - _, _, err = c.MakeFullRequest(ctx, RequestParams{ - ServerName: req.Via, - Method: http.MethodGet, - Path: URLPath{"v1", "make_leave", req.RoomID, req.UserID}, - Authenticate: true, - ResponseJSON: &resp, - }) - return -} - -func (c *Client) SendLeave(ctx context.Context, req *ReqSendLeave) (err error) { - _, _, err = c.MakeFullRequest(ctx, RequestParams{ - ServerName: req.Via, - Method: http.MethodPut, - Path: URLPath{"v2", "send_leave", req.RoomID, req.EventID}, - Authenticate: true, - RequestJSON: req.Event, - }) - return -} - type URLPath []any func (fup URLPath) FullPath() []any { @@ -477,27 +306,15 @@ func (c *Client) MakeFullRequest(ctx context.Context, params RequestParams) ([]b WrappedError: err, } } - if !params.DontReadBody { - defer resp.Body.Close() - } + defer func() { + _ = resp.Body.Close() + }() var body []byte - if resp.StatusCode >= 300 { + if resp.StatusCode >= 400 { body, err = mautrix.ParseErrorResponse(req, resp) return body, resp, err } else if params.ResponseJSON != nil || !params.DontReadBody { - if resp.ContentLength > c.ResponseSizeLimit { - return body, resp, mautrix.HTTPError{ - Request: req, - Response: resp, - - Message: "not reading response", - WrappedError: fmt.Errorf("%w (%.2f MiB)", mautrix.ErrResponseTooLong, float64(resp.ContentLength)/1024/1024), - } - } - body, err = io.ReadAll(io.LimitReader(resp.Body, c.ResponseSizeLimit+1)) - if err == nil && len(body) > int(c.ResponseSizeLimit) { - err = mautrix.ErrBodyReadReachedLimit - } + body, err = io.ReadAll(resp.Body) if err != nil { return body, resp, mautrix.HTTPError{ Request: req, diff --git a/federation/eventauth/eventauth.go b/federation/eventauth/eventauth.go deleted file mode 100644 index c72933c2..00000000 --- a/federation/eventauth/eventauth.go +++ /dev/null @@ -1,851 +0,0 @@ -// Copyright (c) 2025 Tulir Asokan -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. - -//go:build goexperiment.jsonv2 - -package eventauth - -import ( - "encoding/json" - "encoding/json/jsontext" - "errors" - "fmt" - "slices" - "strconv" - "strings" - - "github.com/tidwall/gjson" - "go.mau.fi/util/exgjson" - "go.mau.fi/util/exstrings" - "go.mau.fi/util/ptr" - - "maunium.net/go/mautrix/event" - "maunium.net/go/mautrix/federation/pdu" - "maunium.net/go/mautrix/federation/signutil" - "maunium.net/go/mautrix/id" -) - -type AuthFailError struct { - Index string - Message string - Wrapped error -} - -func (afe AuthFailError) Error() string { - if afe.Message != "" { - return fmt.Sprintf("fail %s: %s", afe.Index, afe.Message) - } else if afe.Wrapped != nil { - return fmt.Sprintf("fail %s: %s", afe.Index, afe.Wrapped.Error()) - } - return fmt.Sprintf("fail %s", afe.Index) -} - -func (afe AuthFailError) Unwrap() error { - return afe.Wrapped -} - -var mFederatePath = exgjson.Path("m.federate") - -var ( - ErrCreateHasPrevEvents = AuthFailError{Index: "1.1", Message: "m.room.create event has prev_events"} - ErrCreateHasRoomID = AuthFailError{Index: "1.2", Message: "m.room.create event has room_id set"} - ErrRoomIDDoesntMatchSender = AuthFailError{Index: "1.2", Message: "room ID server doesn't match sender server"} - ErrUnknownRoomVersion = AuthFailError{Index: "1.3", Wrapped: id.ErrUnknownRoomVersion} - ErrInvalidAdditionalCreators = AuthFailError{Index: "1.4", Message: "m.room.create event has invalid additional_creators"} - ErrMissingCreator = AuthFailError{Index: "1.4", Message: "m.room.create event is missing creator field"} - - ErrInvalidRoomIDLength = AuthFailError{Index: "2", Message: "room ID length is invalid"} - ErrFailedToGetCreateEvent = AuthFailError{Index: "2", Message: "failed to get m.room.create event"} - ErrCreateEventNotFound = AuthFailError{Index: "2", Message: "m.room.create event not found using room ID as event ID"} - ErrRejectedCreateEvent = AuthFailError{Index: "2", Message: "m.room.create event was rejected"} - - ErrFailedToGetAuthEvents = AuthFailError{Index: "3", Message: "failed to get auth events"} - ErrFailedToParsePowerLevels = AuthFailError{Index: "?", Message: "failed to parse power levels"} - ErrDuplicateAuthEvent = AuthFailError{Index: "3.1", Message: "duplicate type/state key pair in auth events"} - ErrNonStateAuthEvent = AuthFailError{Index: "3.2", Message: "non-state event in auth events"} - ErrMissingAuthEvent = AuthFailError{Index: "3.2", Message: "missing auth event"} - ErrUnexpectedAuthEvent = AuthFailError{Index: "3.2", Message: "unexpected type/state key pair in auth events"} - ErrNoCreateEvent = AuthFailError{Index: "3.2", Message: "no m.room.create event found in auth events"} - ErrRejectedAuthEvent = AuthFailError{Index: "3.3", Message: "auth event was rejected"} - ErrMismatchingRoomIDInAuthEvent = AuthFailError{Index: "3.4", Message: "auth event room ID does not match event room ID"} - - ErrFederationDisabled = AuthFailError{Index: "4", Message: "federation is disabled for this room"} - - ErrMemberNotState = AuthFailError{Index: "5.1", Message: "m.room.member event is not a state event"} - ErrNotSignedByAuthoriser = AuthFailError{Index: "5.2", Message: "m.room.member event is not signed by server of join_authorised_via_users_server"} - ErrCantJoinOtherUser = AuthFailError{Index: "5.3.2", Message: "can't send join event with different state key"} - ErrCantJoinBanned = AuthFailError{Index: "5.3.3", Message: "user is banned from the room"} - ErrAuthoriserCantInvite = AuthFailError{Index: "5.3.5.2", Message: "authoriser doesn't have sufficient power level to invite"} - ErrAuthoriserNotInRoom = AuthFailError{Index: "5.3.5.2", Message: "authoriser isn't a member of the room"} - ErrCantJoinWithoutInvite = AuthFailError{Index: "5.3.7", Message: "can't join invite-only room without invite"} - ErrInvalidJoinRule = AuthFailError{Index: "5.3.7", Message: "invalid join rule in room"} - ErrThirdPartyInviteBanned = AuthFailError{Index: "5.4.1.1", Message: "third party invite target user is banned"} - ErrThirdPartyInviteMissingFields = AuthFailError{Index: "5.4.1.3", Message: "third party invite is missing mxid or token fields"} - ErrThirdPartyInviteMXIDMismatch = AuthFailError{Index: "5.4.1.4", Message: "mxid in signed third party invite doesn't match event state key"} - ErrThirdPartyInviteNotFound = AuthFailError{Index: "5.4.1.5", Message: "matching m.room.third_party_invite event not found in auth events"} - ErrThirdPartyInviteSenderMismatch = AuthFailError{Index: "5.4.1.6", Message: "sender of third party invite doesn't match sender of member event"} - ErrThirdPartyInviteNotSigned = AuthFailError{Index: "5.4.1.8", Message: "no valid signatures found for third party invite"} - ErrInviterNotInRoom = AuthFailError{Index: "5.4.2", Message: "inviter's membership is not join"} - ErrInviteTargetAlreadyInRoom = AuthFailError{Index: "5.4.3", Message: "invite target user is already in the room"} - ErrInviteTargetBanned = AuthFailError{Index: "5.4.3", Message: "invite target user is banned"} - ErrInsufficientPermissionForInvite = AuthFailError{Index: "5.4.5", Message: "inviter does not have sufficient permission to send invites"} - ErrCantLeaveWithoutBeingInRoom = AuthFailError{Index: "5.5.1", Message: "can't leave room without being in it"} - ErrCantKickWithoutBeingInRoom = AuthFailError{Index: "5.5.2", Message: "can't kick another user without being in the room"} - ErrInsufficientPermissionForUnban = AuthFailError{Index: "5.5.3", Message: "sender does not have sufficient permission to unban users"} - ErrInsufficientPermissionForKick = AuthFailError{Index: "5.5.5", Message: "sender does not have sufficient permission to kick the user"} - ErrCantBanWithoutBeingInRoom = AuthFailError{Index: "5.6.1", Message: "can't ban another user without being in the room"} - ErrInsufficientPermissionForBan = AuthFailError{Index: "5.6.3", Message: "sender does not have sufficient permission to ban the user"} - ErrNotKnockableRoom = AuthFailError{Index: "5.7.1", Message: "join rule doesn't allow knocking"} - ErrCantKnockOtherUser = AuthFailError{Index: "5.7.1", Message: "can't send knock event with different state key"} - ErrCantKnockWhileInRoom = AuthFailError{Index: "5.7.2", Message: "can't knock while joined, invited or banned"} - ErrUnknownMembership = AuthFailError{Index: "5.8", Message: "unknown membership in m.room.member event"} - - ErrNotInRoom = AuthFailError{Index: "6", Message: "sender is not a member of the room"} - - ErrInsufficientPowerForThirdPartyInvite = AuthFailError{Index: "7.1", Message: "sender does not have sufficient power level to send third party invite"} - - ErrInsufficientPowerLevel = AuthFailError{Index: "8", Message: "sender does not have sufficient power level to send event"} - - ErrMismatchingPrivateStateKey = AuthFailError{Index: "9", Message: "state keys starting with @ must match sender user ID"} - - ErrTopLevelPLNotInteger = AuthFailError{Index: "10.1", Message: "invalid type for top-level power level field"} - ErrPLNotInteger = AuthFailError{Index: "10.2", Message: "invalid type for power level"} - ErrInvalidUserIDInPL = AuthFailError{Index: "10.3", Message: "invalid user ID in power levels"} - ErrUserPLNotInteger = AuthFailError{Index: "10.3", Message: "invalid type for user power level"} - ErrCreatorInPowerLevels = AuthFailError{Index: "10.4", Message: "room creators must not be specified in power levels"} - ErrInvalidPowerChange = AuthFailError{Index: "10.x", Message: "illegal power level change"} - ErrInvalidUserPowerChange = AuthFailError{Index: "10.9", Message: "illegal power level change"} -) - -func isRejected(evt *pdu.PDU) bool { - return evt.InternalMeta.Rejected -} - -type GetEventsFunc = func(ids []id.EventID) ([]*pdu.PDU, error) - -func Authorize(roomVersion id.RoomVersion, evt *pdu.PDU, getEvents GetEventsFunc, getKey pdu.GetKeyFunc) error { - if evt.Type == event.StateCreate.Type { - // 1. If type is m.room.create: - return authorizeCreate(roomVersion, evt) - } - var createEvt *pdu.PDU - if roomVersion.RoomIDIsCreateEventID() { - // 2. If the event’s room_id is not an event ID for an accepted (not rejected) m.room.create event, - // with the sigil ! instead of $, reject. - if len(evt.RoomID) != 44 { - return fmt.Errorf("%w (%d)", ErrInvalidRoomIDLength, len(evt.RoomID)) - } else if createEvts, err := getEvents([]id.EventID{id.EventID("$" + evt.RoomID[1:])}); err != nil { - return fmt.Errorf("%w: %w", ErrFailedToGetCreateEvent, err) - } else if len(createEvts) != 1 { - return fmt.Errorf("%w (%s)", ErrCreateEventNotFound, evt.RoomID) - } else if isRejected(createEvts[0]) { - return ErrRejectedCreateEvent - } else { - createEvt = createEvts[0] - } - } - authEvents, err := getEvents(evt.AuthEvents) - if err != nil { - return fmt.Errorf("%w: %w", ErrFailedToGetAuthEvents, err) - } - expectedAuthEvents := evt.AuthEventSelection(roomVersion) - deduplicator := make(map[pdu.StateKey]id.EventID, len(expectedAuthEvents)) - // 3. Considering the event’s auth_events: - for i, ae := range authEvents { - authEvtID := evt.AuthEvents[i] - if ae == nil { - return fmt.Errorf("%w (%s)", ErrMissingAuthEvent, authEvtID) - } else if ae.StateKey == nil { - // This approximately falls under rule 3.2. - return fmt.Errorf("%w (%s)", ErrNonStateAuthEvent, authEvtID) - } - key := pdu.StateKey{Type: ae.Type, StateKey: *ae.StateKey} - if prevEvtID, alreadyFound := deduplicator[key]; alreadyFound { - // 3.1. If there are duplicate entries for a given type and state_key pair, reject. - return fmt.Errorf("%w for %s/%s: found %s and %s", ErrDuplicateAuthEvent, ae.Type, *ae.StateKey, prevEvtID, authEvtID) - } else if !expectedAuthEvents.Has(key) { - // 3.2. If there are entries whose type and state_key don’t match those specified by - // the auth events selection algorithm described in the server specification, reject. - return fmt.Errorf("%w: found %s with key %s/%s", ErrUnexpectedAuthEvent, authEvtID, ae.Type, *ae.StateKey) - } else if isRejected(ae) { - // 3.3. If there are entries which were themselves rejected under the checks performed on receipt of a PDU, reject. - return fmt.Errorf("%w (%s)", ErrRejectedAuthEvent, authEvtID) - } else if ae.RoomID != evt.RoomID { - // 3.4. If any event in auth_events has a room_id which does not match that of the event being authorised, reject. - return fmt.Errorf("%w (%s)", ErrMismatchingRoomIDInAuthEvent, authEvtID) - } else { - deduplicator[key] = authEvtID - } - if ae.Type == event.StateCreate.Type { - if createEvt == nil { - createEvt = ae - } else { - // Duplicates are prevented by deduplicator, AuthEventSelection also won't allow a create event at all for v12+ - panic(fmt.Errorf("impossible case: multiple create events found in auth events")) - } - } - } - if createEvt == nil { - // This comes either from auth_events or room_id depending on the room version. - // The checks above make sure it's from the right source. - return ErrNoCreateEvent - } - if federateVal := gjson.GetBytes(createEvt.Content, mFederatePath); federateVal.Type == gjson.False && createEvt.Sender.Homeserver() != evt.Sender.Homeserver() { - // 4. If the content of the m.room.create event in the room state has the property m.federate set to false, - // and the sender domain of the event does not match the sender domain of the create event, reject. - return ErrFederationDisabled - } - if evt.Type == event.StateMember.Type { - // 5. If type is m.room.member: - return authorizeMember(roomVersion, evt, createEvt, authEvents, getKey) - } - senderMembership := event.Membership(findEventAndReadString(authEvents, event.StateMember.Type, evt.Sender.String(), "membership", "leave")) - if senderMembership != event.MembershipJoin { - // 6. If the sender’s current membership state is not join, reject. - return ErrNotInRoom - } - powerLevels, err := getPowerLevels(roomVersion, authEvents, createEvt) - if err != nil { - return err - } - senderPL := powerLevels.GetUserLevel(evt.Sender) - if evt.Type == event.StateThirdPartyInvite.Type { - // 7.1. Allow if and only if sender’s current power level is greater than or equal to the invite level. - if senderPL >= powerLevels.Invite() { - return nil - } - return ErrInsufficientPowerForThirdPartyInvite - } - typeClass := event.MessageEventType - if evt.StateKey != nil { - typeClass = event.StateEventType - } - evtLevel := powerLevels.GetEventLevel(event.Type{Type: evt.Type, Class: typeClass}) - if evtLevel > senderPL { - // 8. If the event type’s required power level is greater than the sender’s power level, reject. - return fmt.Errorf("%w (%d > %d)", ErrInsufficientPowerLevel, evtLevel, senderPL) - } - - if evt.StateKey != nil && strings.HasPrefix(*evt.StateKey, "@") && *evt.StateKey != evt.Sender.String() { - // 9. If the event has a state_key that starts with an @ and does not match the sender, reject. - return ErrMismatchingPrivateStateKey - } - - if evt.Type == event.StatePowerLevels.Type { - // 10. If type is m.room.power_levels: - return authorizePowerLevels(roomVersion, evt, createEvt, authEvents) - } - - // 11. Otherwise, allow. - return nil -} - -var ErrUserIDNotAString = errors.New("not a string") -var ErrUserIDNotValid = errors.New("not a valid user ID") - -func isValidUserID(roomVersion id.RoomVersion, userID gjson.Result) error { - if userID.Type != gjson.String { - return ErrUserIDNotAString - } - // In a future room version, user IDs will have stricter validation - _, _, err := id.UserID(userID.Str).Parse() - if err != nil { - return ErrUserIDNotValid - } - return nil -} - -func authorizeCreate(roomVersion id.RoomVersion, evt *pdu.PDU) error { - if len(evt.PrevEvents) > 0 { - // 1.1. If it has any prev_events, reject. - return ErrCreateHasPrevEvents - } - if roomVersion.RoomIDIsCreateEventID() { - if evt.RoomID != "" { - // 1.2. If the event has a room_id, reject. - return ErrCreateHasRoomID - } - } else { - _, _, server := id.ParseCommonIdentifier(evt.RoomID) - if server == "" || server != evt.Sender.Homeserver() { - // 1.2. (v11 and below) If the domain of the room_id does not match the domain of the sender, reject. - return ErrRoomIDDoesntMatchSender - } - } - if !roomVersion.IsKnown() { - // 1.3. If content.room_version is present and is not a recognised version, reject. - return fmt.Errorf("%w %s", ErrUnknownRoomVersion, roomVersion) - } - if roomVersion.PrivilegedRoomCreators() { - additionalCreators := gjson.GetBytes(evt.Content, "additional_creators") - if additionalCreators.Exists() { - if !additionalCreators.IsArray() { - return fmt.Errorf("%w: not an array", ErrInvalidAdditionalCreators) - } - for i, item := range additionalCreators.Array() { - // 1.4. If additional_creators is present in content and is not an array of strings - // where each string passes the same user ID validation applied to sender, reject. - if err := isValidUserID(roomVersion, item); err != nil { - return fmt.Errorf("%w: item #%d %w", ErrInvalidAdditionalCreators, i+1, err) - } - } - } - } - if roomVersion.CreatorInContent() { - // 1.4. (v10 and below) If content has no creator property, reject. - if !gjson.GetBytes(evt.Content, "creator").Exists() { - return ErrMissingCreator - } - } - // 1.5. Otherwise, allow. - return nil -} - -func authorizeMember(roomVersion id.RoomVersion, evt, createEvt *pdu.PDU, authEvents []*pdu.PDU, getKey pdu.GetKeyFunc) error { - membership := event.Membership(gjson.GetBytes(evt.Content, "membership").Str) - if evt.StateKey == nil { - // 5.1. If there is no state_key property, or no membership property in content, reject. - return ErrMemberNotState - } - authorizedVia := id.UserID(gjson.GetBytes(evt.Content, "authorised_via_users_server").Str) - if authorizedVia != "" { - homeserver := authorizedVia.Homeserver() - err := evt.VerifySignature(roomVersion, homeserver, getKey) - if err != nil { - // 5.2. If content has a join_authorised_via_users_server key: - // 5.2.1. If the event is not validly signed by the homeserver of the user ID denoted by the key, reject. - return fmt.Errorf("%w: %w", ErrNotSignedByAuthoriser, err) - } - } - targetPrevMembership := event.Membership(findEventAndReadString(authEvents, event.StateMember.Type, *evt.StateKey, "membership", "leave")) - senderMembership := event.Membership(findEventAndReadString(authEvents, event.StateMember.Type, evt.Sender.String(), "membership", "leave")) - switch membership { - case event.MembershipJoin: - createEvtID, err := createEvt.GetEventID(roomVersion) - if err != nil { - return fmt.Errorf("failed to get create event ID: %w", err) - } - creator := createEvt.Sender.String() - if roomVersion.CreatorInContent() { - creator = gjson.GetBytes(evt.Content, "creator").Str - } - if len(evt.PrevEvents) == 1 && - len(evt.AuthEvents) <= 1 && - evt.PrevEvents[0] == createEvtID && - *evt.StateKey == creator { - // 5.3.1. If the only previous event is an m.room.create and the state_key is the sender of the m.room.create, allow. - return nil - } - // Spec wart: this would make more sense before the check above. - // Now you can set anyone as the sender of the first join. - if evt.Sender.String() != *evt.StateKey { - // 5.3.2. If the sender does not match state_key, reject. - return ErrCantJoinOtherUser - } - - if senderMembership == event.MembershipBan { - // 5.3.3. If the sender is banned, reject. - return ErrCantJoinBanned - } - - joinRule := event.JoinRule(findEventAndReadString(authEvents, event.StateJoinRules.Type, "", "join_rule", "invite")) - switch joinRule { - case event.JoinRuleKnock: - if !roomVersion.Knocks() { - return ErrInvalidJoinRule - } - fallthrough - case event.JoinRuleInvite: - // 5.3.4. If the join_rule is invite or knock then allow if membership state is invite or join. - if targetPrevMembership == event.MembershipJoin || targetPrevMembership == event.MembershipInvite { - return nil - } - return ErrCantJoinWithoutInvite - case event.JoinRuleKnockRestricted: - if !roomVersion.KnockRestricted() { - return ErrInvalidJoinRule - } - fallthrough - case event.JoinRuleRestricted: - if joinRule == event.JoinRuleRestricted && !roomVersion.RestrictedJoins() { - return ErrInvalidJoinRule - } - if targetPrevMembership == event.MembershipJoin || targetPrevMembership == event.MembershipInvite { - // 5.3.5.1. If membership state is join or invite, allow. - return nil - } - powerLevels, err := getPowerLevels(roomVersion, authEvents, createEvt) - if err != nil { - return err - } - if powerLevels.GetUserLevel(authorizedVia) < powerLevels.Invite() { - // 5.3.5.2. If the join_authorised_via_users_server key in content is not a user with sufficient permission to invite other users, reject. - return ErrAuthoriserCantInvite - } - authorizerMembership := event.Membership(findEventAndReadString(authEvents, event.StateMember.Type, authorizedVia.String(), "membership", string(event.MembershipLeave))) - if authorizerMembership != event.MembershipJoin { - return ErrAuthoriserNotInRoom - } - // 5.3.5.3. Otherwise, allow. - return nil - case event.JoinRulePublic: - // 5.3.6. If the join_rule is public, allow. - return nil - default: - // 5.3.7. Otherwise, reject. - return ErrInvalidJoinRule - } - case event.MembershipInvite: - tpiVal := gjson.GetBytes(evt.Content, "third_party_invite") - if tpiVal.Exists() { - if targetPrevMembership == event.MembershipBan { - return ErrThirdPartyInviteBanned - } - signed := tpiVal.Get("signed") - mxid := signed.Get("mxid").Str - token := signed.Get("token").Str - if mxid == "" || token == "" { - // 5.4.1.2. If content.third_party_invite does not have a signed property, reject. - // 5.4.1.3. If signed does not have mxid and token properties, reject. - return ErrThirdPartyInviteMissingFields - } - if mxid != *evt.StateKey { - // 5.4.1.4. If mxid does not match state_key, reject. - return ErrThirdPartyInviteMXIDMismatch - } - tpiEvt := findEvent(authEvents, event.StateThirdPartyInvite.Type, token) - if tpiEvt == nil { - // 5.4.1.5. If there is no m.room.third_party_invite event in the current room state with state_key matching token, reject. - return ErrThirdPartyInviteNotFound - } - if tpiEvt.Sender != evt.Sender { - // 5.4.1.6. If sender does not match sender of the m.room.third_party_invite, reject. - return ErrThirdPartyInviteSenderMismatch - } - var keys []id.Ed25519 - const ed25519Base64Len = 43 - oldPubKey := gjson.GetBytes(evt.Content, "public_key.token") - if oldPubKey.Type == gjson.String && len(oldPubKey.Str) == ed25519Base64Len { - keys = append(keys, id.Ed25519(oldPubKey.Str)) - } - gjson.GetBytes(evt.Content, "public_keys").ForEach(func(key, value gjson.Result) bool { - if key.Type != gjson.Number { - return false - } - if value.Type == gjson.String && len(value.Str) == ed25519Base64Len { - keys = append(keys, id.Ed25519(value.Str)) - } - return true - }) - rawSigned := jsontext.Value(exstrings.UnsafeBytes(signed.Str)) - var validated bool - for _, key := range keys { - if signutil.VerifyJSONAny(key, rawSigned) == nil { - validated = true - } - } - if validated { - // 4.4.1.7. If any signature in signed matches any public key in the m.room.third_party_invite event, allow. - return nil - } - // 4.4.1.8. Otherwise, reject. - return ErrThirdPartyInviteNotSigned - } - if senderMembership != event.MembershipJoin { - // 5.4.2. If the sender’s current membership state is not join, reject. - return ErrInviterNotInRoom - } - // 5.4.3. If target user’s current membership state is join or ban, reject. - if targetPrevMembership == event.MembershipJoin { - return ErrInviteTargetAlreadyInRoom - } else if targetPrevMembership == event.MembershipBan { - return ErrInviteTargetBanned - } - powerLevels, err := getPowerLevels(roomVersion, authEvents, createEvt) - if err != nil { - return err - } - if powerLevels.GetUserLevel(evt.Sender) >= powerLevels.Invite() { - // 5.4.4. If the sender’s power level is greater than or equal to the invite level, allow. - return nil - } - // 5.4.5. Otherwise, reject. - return ErrInsufficientPermissionForInvite - case event.MembershipLeave: - if evt.Sender.String() == *evt.StateKey { - // 5.5.1. If the sender matches state_key, allow if and only if that user’s current membership state is invite, join, or knock. - if senderMembership == event.MembershipInvite || - senderMembership == event.MembershipJoin || - (senderMembership == event.MembershipKnock && roomVersion.Knocks()) { - return nil - } - return ErrCantLeaveWithoutBeingInRoom - } - if senderMembership != event.MembershipJoin { - // 5.5.2. If the sender’s current membership state is not join, reject. - return ErrCantKickWithoutBeingInRoom - } - powerLevels, err := getPowerLevels(roomVersion, authEvents, createEvt) - if err != nil { - return err - } - senderLevel := powerLevels.GetUserLevel(evt.Sender) - if targetPrevMembership == event.MembershipBan && senderLevel < powerLevels.Ban() { - // 5.5.3. If the target user’s current membership state is ban, and the sender’s power level is less than the ban level, reject. - return ErrInsufficientPermissionForUnban - } - if senderLevel >= powerLevels.Kick() && powerLevels.GetUserLevel(id.UserID(*evt.StateKey)) < senderLevel { - // 5.5.4. If the sender’s power level is greater than or equal to the kick level, and the target user’s power level is less than the sender’s power level, allow. - return nil - } - // TODO separate errors for < kick and < target user level? - // 5.5.5. Otherwise, reject. - return ErrInsufficientPermissionForKick - case event.MembershipBan: - if senderMembership != event.MembershipJoin { - // 5.6.1. If the sender’s current membership state is not join, reject. - return ErrCantBanWithoutBeingInRoom - } - powerLevels, err := getPowerLevels(roomVersion, authEvents, createEvt) - if err != nil { - return err - } - senderLevel := powerLevels.GetUserLevel(evt.Sender) - if senderLevel >= powerLevels.Ban() && powerLevels.GetUserLevel(id.UserID(*evt.StateKey)) < senderLevel { - // 5.6.2. If the sender’s power level is greater than or equal to the ban level, and the target user’s power level is less than the sender’s power level, allow. - return nil - } - // 5.6.3. Otherwise, reject. - return ErrInsufficientPermissionForBan - case event.MembershipKnock: - joinRule := event.JoinRule(findEventAndReadString(authEvents, event.StateJoinRules.Type, "", "join_rule", "invite")) - validKnockRule := roomVersion.Knocks() && joinRule == event.JoinRuleKnock - validKnockRestrictedRule := roomVersion.KnockRestricted() && joinRule == event.JoinRuleKnockRestricted - if !validKnockRule && !validKnockRestrictedRule { - // 5.7.1. If the join_rule is anything other than knock or knock_restricted, reject. - return ErrNotKnockableRoom - } - if evt.Sender.String() != *evt.StateKey { - // 5.7.2. If the sender does not match state_key, reject. - return ErrCantKnockOtherUser - } - if senderMembership != event.MembershipBan && senderMembership != event.MembershipInvite && senderMembership != event.MembershipJoin { - // 5.7.3. If the sender’s current membership is not ban, invite, or join, allow. - return nil - } - // 5.7.4. Otherwise, reject. - return ErrCantKnockWhileInRoom - default: - // 5.8. Otherwise, the membership is unknown. Reject. - return ErrUnknownMembership - } -} - -func authorizePowerLevels(roomVersion id.RoomVersion, evt, createEvt *pdu.PDU, authEvents []*pdu.PDU) error { - if roomVersion.ValidatePowerLevelInts() { - for _, key := range []string{"users_default", "events_default", "state_default", "ban", "redact", "kick", "invite"} { - res := gjson.GetBytes(evt.Content, key) - if !res.Exists() { - continue - } - if parseIntWithVersion(roomVersion, res) == nil { - // 10.1. If any of the properties users_default, events_default, state_default, ban, redact, kick, or invite in content are present and not an integer, reject. - return fmt.Errorf("%w %s", ErrTopLevelPLNotInteger, key) - } - } - for _, key := range []string{"events", "notifications"} { - obj := gjson.GetBytes(evt.Content, key) - if !obj.Exists() { - continue - } - // 10.2. If either of the properties events or notifications in content are present and not an object [...], reject. - if !obj.IsObject() { - return fmt.Errorf("%w %s", ErrTopLevelPLNotInteger, key) - } - var err error - // 10.2. [...] are not an object with values that are integers, reject. - obj.ForEach(func(innerKey, value gjson.Result) bool { - if parseIntWithVersion(roomVersion, value) == nil { - err = fmt.Errorf("%w %s.%s", ErrPLNotInteger, key, innerKey.Str) - return false - } - return true - }) - if err != nil { - return err - } - } - } - var creators []id.UserID - if roomVersion.PrivilegedRoomCreators() { - creators = append(creators, createEvt.Sender) - gjson.GetBytes(createEvt.Content, "additional_creators").ForEach(func(key, value gjson.Result) bool { - creators = append(creators, id.UserID(value.Str)) - return true - }) - } - users := gjson.GetBytes(evt.Content, "users") - if users.Exists() { - if !users.IsObject() { - // 10.3. If the users property in content is not an object [...], reject. - return fmt.Errorf("%w users", ErrTopLevelPLNotInteger) - } - var err error - users.ForEach(func(key, value gjson.Result) bool { - if validatorErr := isValidUserID(roomVersion, key); validatorErr != nil { - // 10.3. [...] is not an object with keys that are valid user IDs [...], reject. - err = fmt.Errorf("%w: %q %w", ErrInvalidUserIDInPL, key.Str, validatorErr) - return false - } - if parseIntWithVersion(roomVersion, value) == nil { - // 10.3. [...] is not an object [...] with values that are integers, reject. - err = fmt.Errorf("%w %q", ErrUserPLNotInteger, key.Str) - return false - } - // creators is only filled if the room version has privileged room creators - if slices.Contains(creators, id.UserID(key.Str)) { - // 10.4. If the users property in content contains the sender of the m.room.create event or any of - // the additional_creators array (if present) from the content of the m.room.create event, reject. - err = fmt.Errorf("%w: %q", ErrCreatorInPowerLevels, key.Str) - return false - } - return true - }) - if err != nil { - return err - } - } - oldPL := findEvent(authEvents, event.StatePowerLevels.Type, "") - if oldPL == nil { - // 10.5. If there is no previous m.room.power_levels event in the room, allow. - return nil - } - if slices.Contains(creators, evt.Sender) { - // Skip remaining checks for creators - return nil - } - senderPLPtr := parsePythonInt(gjson.GetBytes(oldPL.Content, exgjson.Path("users", evt.Sender.String()))) - if senderPLPtr == nil { - senderPLPtr = parsePythonInt(gjson.GetBytes(oldPL.Content, "users_default")) - if senderPLPtr == nil { - senderPLPtr = ptr.Ptr(0) - } - } - for _, key := range []string{"users_default", "events_default", "state_default", "ban", "redact", "kick", "invite"} { - oldVal := gjson.GetBytes(oldPL.Content, key) - newVal := gjson.GetBytes(evt.Content, key) - if err := allowPowerChange(roomVersion, *senderPLPtr, key, oldVal, newVal); err != nil { - return err - } - } - if err := allowPowerChangeMap( - roomVersion, *senderPLPtr, "events", "", - gjson.GetBytes(oldPL.Content, "events"), - gjson.GetBytes(evt.Content, "events"), - ); err != nil { - return err - } - if err := allowPowerChangeMap( - roomVersion, *senderPLPtr, "notifications", "", - gjson.GetBytes(oldPL.Content, "notifications"), - gjson.GetBytes(evt.Content, "notifications"), - ); err != nil { - return err - } - if err := allowPowerChangeMap( - roomVersion, *senderPLPtr, "users", evt.Sender.String(), - gjson.GetBytes(oldPL.Content, "users"), - gjson.GetBytes(evt.Content, "users"), - ); err != nil { - return err - } - return nil -} - -func allowPowerChangeMap(roomVersion id.RoomVersion, maxVal int, path, ownID string, old, new gjson.Result) (err error) { - old.ForEach(func(key, value gjson.Result) bool { - newVal := new.Get(exgjson.Path(key.Str)) - err = allowPowerChange(roomVersion, maxVal, path+"."+key.Str, value, newVal) - if err == nil && ownID != "" && key.Str != ownID { - parsedOldVal := parseIntWithVersion(roomVersion, value) - parsedNewVal := parseIntWithVersion(roomVersion, newVal) - if *parsedOldVal >= maxVal && *parsedOldVal != *parsedNewVal { - err = fmt.Errorf("%w: can't change users.%s from %s to %s with sender level %d", ErrInvalidUserPowerChange, key.Str, stringifyForError(value), stringifyForError(newVal), maxVal) - } - } - return err == nil - }) - if err != nil { - return - } - new.ForEach(func(key, value gjson.Result) bool { - err = allowPowerChange(roomVersion, maxVal, path+"."+key.Str, old.Get(exgjson.Path(key.Str)), value) - return err == nil - }) - return -} - -func allowPowerChange(roomVersion id.RoomVersion, maxVal int, path string, old, new gjson.Result) error { - oldVal := parseIntWithVersion(roomVersion, old) - newVal := parseIntWithVersion(roomVersion, new) - if oldVal == nil { - if newVal == nil || *newVal <= maxVal { - return nil - } - } else if newVal == nil { - if *oldVal <= maxVal { - return nil - } - } else if *oldVal == *newVal || (*oldVal <= maxVal && *newVal <= maxVal) { - return nil - } - return fmt.Errorf("%w can't change %s from %s to %s with sender level %d", ErrInvalidPowerChange, path, stringifyForError(old), stringifyForError(new), maxVal) -} - -func stringifyForError(val gjson.Result) string { - if !val.Exists() { - return "null" - } - return val.Raw -} - -func findEvent(events []*pdu.PDU, evtType, stateKey string) *pdu.PDU { - for _, evt := range events { - if evt.Type == evtType && *evt.StateKey == stateKey { - return evt - } - } - return nil -} - -func findEventAndReadData[T any](events []*pdu.PDU, evtType, stateKey string, reader func(evt *pdu.PDU) T) T { - return reader(findEvent(events, evtType, stateKey)) -} - -func findEventAndReadString(events []*pdu.PDU, evtType, stateKey, fieldPath, defVal string) string { - return findEventAndReadData(events, evtType, stateKey, func(evt *pdu.PDU) string { - if evt == nil { - return defVal - } - res := gjson.GetBytes(evt.Content, fieldPath) - if res.Type != gjson.String { - return defVal - } - return res.Str - }) -} - -func getPowerLevels(roomVersion id.RoomVersion, authEvents []*pdu.PDU, createEvt *pdu.PDU) (*event.PowerLevelsEventContent, error) { - var err error - powerLevels := findEventAndReadData(authEvents, event.StatePowerLevels.Type, "", func(evt *pdu.PDU) *event.PowerLevelsEventContent { - if evt == nil { - return nil - } - content := evt.Content - out := &event.PowerLevelsEventContent{} - if !roomVersion.ValidatePowerLevelInts() { - safeParsePowerLevels(content, out) - } else { - err = json.Unmarshal(content, out) - } - return out - }) - if err != nil { - // This should never happen thanks to safeParsePowerLevels for v1-9 and strict validation in v10+ - return nil, fmt.Errorf("%w: %w", ErrFailedToParsePowerLevels, err) - } - if roomVersion.PrivilegedRoomCreators() { - if powerLevels == nil { - powerLevels = &event.PowerLevelsEventContent{} - } - powerLevels.CreateEvent, err = createEvt.ToClientEvent(roomVersion) - if err != nil { - return nil, fmt.Errorf("%w: %w", ErrFailedToParsePowerLevels, err) - } - err = powerLevels.CreateEvent.Content.ParseRaw(powerLevels.CreateEvent.Type) - if err != nil { - return nil, fmt.Errorf("%w: %w", ErrFailedToParsePowerLevels, err) - } - } else if powerLevels == nil { - powerLevels = &event.PowerLevelsEventContent{ - Users: map[id.UserID]int{ - createEvt.Sender: 100, - }, - } - } - return powerLevels, nil -} - -func parseIntWithVersion(roomVersion id.RoomVersion, val gjson.Result) *int { - if roomVersion.ValidatePowerLevelInts() { - if val.Type != gjson.Number { - return nil - } - return ptr.Ptr(int(val.Int())) - } - return parsePythonInt(val) -} - -func parsePythonInt(val gjson.Result) *int { - switch val.Type { - case gjson.True: - return ptr.Ptr(1) - case gjson.False: - return ptr.Ptr(0) - case gjson.Number: - return ptr.Ptr(int(val.Int())) - case gjson.String: - // strconv.Atoi accepts signs as well as leading zeroes, so we just need to trim spaces beforehand - num, err := strconv.Atoi(strings.TrimSpace(val.Str)) - if err != nil { - return nil - } - return &num - default: - // Python int() doesn't accept nulls, arrays or dicts - return nil - } -} - -func safeParsePowerLevels(content jsontext.Value, into *event.PowerLevelsEventContent) { - *into = event.PowerLevelsEventContent{ - Users: make(map[id.UserID]int), - UsersDefault: ptr.Val(parsePythonInt(gjson.GetBytes(content, "users_default"))), - Events: make(map[string]int), - EventsDefault: ptr.Val(parsePythonInt(gjson.GetBytes(content, "events_default"))), - Notifications: nil, // irrelevant for event auth - StateDefaultPtr: parsePythonInt(gjson.GetBytes(content, "state_default")), - InvitePtr: parsePythonInt(gjson.GetBytes(content, "invite")), - KickPtr: parsePythonInt(gjson.GetBytes(content, "kick")), - BanPtr: parsePythonInt(gjson.GetBytes(content, "ban")), - RedactPtr: parsePythonInt(gjson.GetBytes(content, "redact")), - } - gjson.GetBytes(content, "events").ForEach(func(key, value gjson.Result) bool { - if key.Type != gjson.String { - return false - } - val := parsePythonInt(value) - if val != nil { - into.Events[key.Str] = *val - } - return true - }) - gjson.GetBytes(content, "users").ForEach(func(key, value gjson.Result) bool { - if key.Type != gjson.String { - return false - } - val := parsePythonInt(value) - if val == nil { - return false - } - userID := id.UserID(key.Str) - if _, _, err := userID.Parse(); err != nil { - return false - } - into.Users[userID] = *val - return true - }) -} diff --git a/federation/eventauth/eventauth_internal_test.go b/federation/eventauth/eventauth_internal_test.go deleted file mode 100644 index d316f3c8..00000000 --- a/federation/eventauth/eventauth_internal_test.go +++ /dev/null @@ -1,66 +0,0 @@ -// Copyright (c) 2026 Tulir Asokan -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. - -//go:build goexperiment.jsonv2 - -package eventauth - -import ( - "strings" - "testing" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - "github.com/tidwall/gjson" -) - -type pythonIntTest struct { - Name string - Input string - Expected int64 -} - -var pythonIntTests = []pythonIntTest{ - {"True", `true`, 1}, - {"False", `false`, 0}, - {"SmallFloat", `3.1415`, 3}, - {"SmallFloatRoundDown", `10.999999999999999`, 10}, - {"SmallFloatRoundUp", `10.9999999999999999`, 11}, - {"BigFloatRoundDown", `1000000.9999999999`, 1000000}, - {"BigFloatRoundUp", `1000000.99999999999`, 1000001}, - {"BigFloatPrecisionError", `9007199254740993.0`, 9007199254740992}, - {"BigFloatPrecisionError2", `9007199254740993.123`, 9007199254740994}, - {"Int64", `9223372036854775807`, 9223372036854775807}, - {"Int64String", `"9223372036854775807"`, 9223372036854775807}, - {"String", `"123"`, 123}, - {"InvalidFloatInString", `"123.456"`, 0}, - {"StringWithPlusSign", `"+123"`, 123}, - {"StringWithMinusSign", `"-123"`, -123}, - {"StringWithSpaces", `" 123 "`, 123}, - {"StringWithSpacesAndSign", `" -123 "`, -123}, - //{"StringWithUnderscores", `"123_456"`, 123456}, - //{"StringWithUnderscores", `"123_456"`, 123456}, - {"InvalidStringWithTrailingUnderscore", `"123_456_"`, 0}, - {"InvalidStringWithMultipleUnderscores", `"123__456"`, 0}, - {"InvalidStringWithLeadingUnderscore", `"_123_456"`, 0}, - {"InvalidStringWithUnderscoreAfterSign", `"+_123_456"`, 0}, - {"InvalidStringWithUnderscoreAfterSpace", `" _123_456"`, 0}, - //{"StringWithUnderscoresAndSpaces", `" +1_2_3_4_5_6 "`, 123456}, -} - -func TestParsePythonInt(t *testing.T) { - for _, test := range pythonIntTests { - t.Run(test.Name, func(t *testing.T) { - output := parsePythonInt(gjson.Parse(test.Input)) - if strings.HasPrefix(test.Name, "Invalid") { - assert.Nil(t, output) - } else { - require.NotNil(t, output) - assert.Equal(t, int(test.Expected), *output) - } - }) - } -} diff --git a/federation/eventauth/eventauth_test.go b/federation/eventauth/eventauth_test.go deleted file mode 100644 index e3c5cd76..00000000 --- a/federation/eventauth/eventauth_test.go +++ /dev/null @@ -1,85 +0,0 @@ -// Copyright (c) 2025 Tulir Asokan -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. - -//go:build goexperiment.jsonv2 - -package eventauth_test - -import ( - "embed" - "encoding/json/jsontext" - "encoding/json/v2" - "errors" - "io" - "testing" - "time" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - "github.com/tidwall/gjson" - "go.mau.fi/util/exerrors" - "go.mau.fi/util/ptr" - - "maunium.net/go/mautrix/federation/eventauth" - "maunium.net/go/mautrix/federation/pdu" - "maunium.net/go/mautrix/id" -) - -//go:embed *.jsonl -var data embed.FS - -type eventMap map[id.EventID]*pdu.PDU - -func (em eventMap) Get(ids []id.EventID) ([]*pdu.PDU, error) { - output := make([]*pdu.PDU, len(ids)) - for i, evtID := range ids { - output[i] = em[evtID] - } - return output, nil -} - -func GetKey(serverName string, keyID id.KeyID, validUntilTS time.Time) (id.SigningKey, time.Time, error) { - return "", time.Time{}, nil -} - -func TestAuthorize(t *testing.T) { - files := exerrors.Must(data.ReadDir(".")) - for _, file := range files { - t.Run(file.Name(), func(t *testing.T) { - decoder := jsontext.NewDecoder(exerrors.Must(data.Open(file.Name()))) - events := make(eventMap) - var roomVersion *id.RoomVersion - for i := 1; ; i++ { - var evt *pdu.PDU - err := json.UnmarshalDecode(decoder, &evt) - if errors.Is(err, io.EOF) { - break - } - require.NoError(t, err) - if roomVersion == nil { - require.Equal(t, evt.Type, "m.room.create") - roomVersion = ptr.Ptr(id.RoomVersion(gjson.GetBytes(evt.Content, "room_version").Str)) - } - expectedEventID := gjson.GetBytes(evt.Unsigned, "event_id").Str - evtID, err := evt.GetEventID(*roomVersion) - require.NoError(t, err) - require.Equalf(t, id.EventID(expectedEventID), evtID, "Event ID mismatch for event #%d", i) - - // TODO allow redacted events - assert.True(t, evt.VerifyContentHash(), i) - - events[evtID] = evt - err = eventauth.Authorize(*roomVersion, evt, events.Get, GetKey) - if err != nil { - evt.InternalMeta.Rejected = true - } - // TODO allow testing intentionally rejected events - assert.NoErrorf(t, err, "Failed to authorize event #%d / %s of type %s", i, evtID, evt.Type) - } - }) - } - -} diff --git a/federation/eventauth/testroom-v12-success.jsonl b/federation/eventauth/testroom-v12-success.jsonl deleted file mode 100644 index 2b751de3..00000000 --- a/federation/eventauth/testroom-v12-success.jsonl +++ /dev/null @@ -1,21 +0,0 @@ -{"auth_events":[],"content":{"room_version":"12"},"depth":1,"hashes":{"sha256":"qJYytb+EqWPiiZ0ogDODcLeA8XYw/2hVTaLHihcVBZQ"},"origin_server_ts":1756071567186,"prev_events":[],"sender":"@tulir:maunium.net","signatures":{"maunium.net":{"ed25519:a_xxeS":"/9pp+2tkLo6XcZ3opqLeIpa3D96fh3QLpR2PQrZ6Z6j7wyRAvBrcgCpAeMtuyDCzW8Wh1QFEPG4FSsGvVaEFBg"}},"state_key":"","type":"m.room.create","unsigned":{"age_ts":1756071567186,"event_id":"$lVEL38waGAf4ggmWC3OVk_bbx8kZx-iOcTBKXTBnM54"}} -{"auth_events":[],"content":{"avatar_url":"mxc://maunium.net/jdlSfvudiMSmcRrleeiYjjFO","displayname":"tulir","membership":"join"},"depth":2,"hashes":{"sha256":"MXmgq0e4J9CdIP0IVKVvueFhOb+ndlsXpeyI+6l/2FI"},"origin_server_ts":1756071567259,"prev_events":["$lVEL38waGAf4ggmWC3OVk_bbx8kZx-iOcTBKXTBnM54"],"room_id":"!lVEL38waGAf4ggmWC3OVk_bbx8kZx-iOcTBKXTBnM54","sender":"@tulir:maunium.net","signatures":{"maunium.net":{"ed25519:a_xxeS":"xMgRzyRg9VM9XCKpfFJA+MrYoI68b8PIddKpMTcxz/fDzmGSHEy6Ta2b59VxiX3NoJe2CigkDZ3+jVsQoZYIBA"}},"state_key":"@tulir:maunium.net","type":"m.room.member","unsigned":{"age_ts":1756071567259,"event_id":"$mmqm2KS4UExkNL65c6CIhKofn_L9fzF2OhghVqajksU"}} -{"auth_events":["$mmqm2KS4UExkNL65c6CIhKofn_L9fzF2OhghVqajksU"],"content":{"ban":50,"events":{"m.room.avatar":50,"m.room.canonical_alias":50,"m.room.encryption":100,"m.room.history_visibility":100,"m.room.name":50,"m.room.power_levels":100,"m.room.server_acl":100,"m.room.tombstone":150},"events_default":0,"historical":100,"invite":0,"kick":50,"redact":50,"state_default":50,"users":{"@tulir:envs.net":9001},"users_default":0},"depth":3,"hashes":{"sha256":"/JzQNBNqJ/i8vwj6xESDaD5EDdOqB4l/LmKlvAVl5jY"},"origin_server_ts":1756071567319,"prev_events":["$mmqm2KS4UExkNL65c6CIhKofn_L9fzF2OhghVqajksU"],"room_id":"!lVEL38waGAf4ggmWC3OVk_bbx8kZx-iOcTBKXTBnM54","sender":"@tulir:maunium.net","signatures":{"maunium.net":{"ed25519:a_xxeS":"W3N3X/enja+lumXw3uz66/wT9oczoxrmHbAD5/RF069cX4wkCtqtDd61VWPkSGmKxdV1jurgbCqSX6+Q9/t3AA"}},"state_key":"","type":"m.room.power_levels","unsigned":{"age_ts":1756071567319,"event_id":"$v3gylw64IK4PohOe0M8XO1PZthibpBCKVBI3x_8xiUU"}} -{"auth_events":["$v3gylw64IK4PohOe0M8XO1PZthibpBCKVBI3x_8xiUU","$mmqm2KS4UExkNL65c6CIhKofn_L9fzF2OhghVqajksU"],"content":{"join_rule":"invite"},"depth":4,"hashes":{"sha256":"GBu5AySj75ZXlOLd65mB03KueFKOHNgvtg2o/LUnLyI"},"origin_server_ts":1756071567320,"prev_events":["$v3gylw64IK4PohOe0M8XO1PZthibpBCKVBI3x_8xiUU"],"room_id":"!lVEL38waGAf4ggmWC3OVk_bbx8kZx-iOcTBKXTBnM54","sender":"@tulir:maunium.net","signatures":{"maunium.net":{"ed25519:a_xxeS":"XqWEnFREo2PhRnaebGjNzdHdtD691BtCQKkLnpKd8P3lVDewDt8OkCbDSk/Uzh9rDtzwWEsbsIoKSYuOm+G6CA"}},"state_key":"","type":"m.room.join_rules","unsigned":{"age_ts":1756071567320,"event_id":"$deNVGs6Ef7OKVrvewhtPv7DCCqSip112cEJYp-jkP6M"}} -{"auth_events":["$v3gylw64IK4PohOe0M8XO1PZthibpBCKVBI3x_8xiUU","$mmqm2KS4UExkNL65c6CIhKofn_L9fzF2OhghVqajksU"],"content":{"history_visibility":"shared"},"depth":5,"hashes":{"sha256":"niDi5vG2akQm0f5pm0aoCYXqmWjXRfmP1ulr/ZEPm/k"},"origin_server_ts":1756071567320,"prev_events":["$deNVGs6Ef7OKVrvewhtPv7DCCqSip112cEJYp-jkP6M"],"room_id":"!lVEL38waGAf4ggmWC3OVk_bbx8kZx-iOcTBKXTBnM54","sender":"@tulir:maunium.net","signatures":{"maunium.net":{"ed25519:a_xxeS":"PTIrNke/fc9+ObKAl/K0PGZfmpe8dwREyoA5rXffOXWdRHSaBifn9UIiJUqd68Bzvrv4RcADTR/ci7lUquFBBw"}},"state_key":"","type":"m.room.history_visibility","unsigned":{"age_ts":1756071567320,"event_id":"$Wmy3G9yxl9ArVg5ZsdeIDPxBsNAdgseuvHoqHTZ2vug"}} -{"auth_events":["$v3gylw64IK4PohOe0M8XO1PZthibpBCKVBI3x_8xiUU","$mmqm2KS4UExkNL65c6CIhKofn_L9fzF2OhghVqajksU"],"content":{"guest_access":"can_join"},"depth":6,"hashes":{"sha256":"sZ9QqsId4oarFF724esTohXuRxDNnaXPl+QmTDG60dw"},"origin_server_ts":1756071567321,"prev_events":["$Wmy3G9yxl9ArVg5ZsdeIDPxBsNAdgseuvHoqHTZ2vug"],"room_id":"!lVEL38waGAf4ggmWC3OVk_bbx8kZx-iOcTBKXTBnM54","sender":"@tulir:maunium.net","signatures":{"maunium.net":{"ed25519:a_xxeS":"Eh2P9/hl38wfZx2AQbeS5VCD4wldXPfeP2sQsJsLtfmdwFV74jrlGVBaKIkaYcXY4eA08iDp8HW5jqttZqKKDg"}},"state_key":"","type":"m.room.guest_access","unsigned":{"age_ts":1756071567321,"event_id":"$hYVRH7F4P5mB5IqvBDDU5aXY7pYGG0ApstrryiVPKmQ"}} -{"auth_events":["$v3gylw64IK4PohOe0M8XO1PZthibpBCKVBI3x_8xiUU","$mmqm2KS4UExkNL65c6CIhKofn_L9fzF2OhghVqajksU"],"content":{"name":"event auth test v12"},"depth":7,"hashes":{"sha256":"tjwPo38yR+23Was6SbxLvPMhNx44DaXLhF3rKgngepU"},"origin_server_ts":1756071567321,"prev_events":["$hYVRH7F4P5mB5IqvBDDU5aXY7pYGG0ApstrryiVPKmQ"],"room_id":"!lVEL38waGAf4ggmWC3OVk_bbx8kZx-iOcTBKXTBnM54","sender":"@tulir:maunium.net","signatures":{"maunium.net":{"ed25519:a_xxeS":"q1rk0c5m8TJYE9tePsMaLeaigatNNbvaLRom0X8KiZY0EH+itujfA+/UnksvmPmMmThfAXWlFLx5u8tcuSVyCQ"}},"state_key":"","type":"m.room.name","unsigned":{"age_ts":1756071567321,"event_id":"$fFDwIavLTEIfcnggWuryB6JwfS-L2KT6vP1ap3P6ctE"}} -{"auth_events":["$v3gylw64IK4PohOe0M8XO1PZthibpBCKVBI3x_8xiUU","$mmqm2KS4UExkNL65c6CIhKofn_L9fzF2OhghVqajksU","$deNVGs6Ef7OKVrvewhtPv7DCCqSip112cEJYp-jkP6M"],"content":{"avatar_url":"mxc://envs.net/000cf1510b7c61018f9c72ca4cc63668370782c81725865933316030464","displayname":"tulir[e]","membership":"invite"},"depth":8,"hashes":{"sha256":"r5EBUZN/4LbVcMYwuffDcVV9G4OMHzAQuNbnjigL+OE"},"origin_server_ts":1756071567548,"prev_events":["$fFDwIavLTEIfcnggWuryB6JwfS-L2KT6vP1ap3P6ctE"],"room_id":"!lVEL38waGAf4ggmWC3OVk_bbx8kZx-iOcTBKXTBnM54","sender":"@tulir:maunium.net","signatures":{"envs.net":{"ed25519:wuJyKT":"svB+uW4Tsj8/I+SYbLl+LPPjBlqxGNXE4wGyAxlP7vfyJtFf7Kn/19jx65wT9ebeCq5sTGlEDV4Fabwma9LhDA"},"maunium.net":{"ed25519:a_xxeS":"LBYMcdJVSNsLd6SmOgx5oOU/0xOeCl03o4g83VwJfHWlRuTT5l9+qlpNED28wY07uxoU9MgLgXXICJ0EezMBCg"}},"state_key":"@tulir:envs.net","type":"m.room.member","unsigned":{"age_ts":1756071567548,"event_id":"$qYZqSKiKMCNjzH6Trhr6nBSvbfuwr8Sh2bC4USSAxok","invite_room_state":[{"content":{"join_rule":"invite"},"sender":"@tulir:maunium.net","state_key":"","type":"m.room.join_rules"},{"content":{"name":"event auth test v12"},"sender":"@tulir:maunium.net","state_key":"","type":"m.room.name"},{"auth_events":[],"content":{"room_version":"12"},"depth":1,"hashes":{"sha256":"qJYytb+EqWPiiZ0ogDODcLeA8XYw/2hVTaLHihcVBZQ"},"origin_server_ts":1756071567186,"prev_events":[],"sender":"@tulir:maunium.net","signatures":{"maunium.net":{"ed25519:a_xxeS":"/9pp+2tkLo6XcZ3opqLeIpa3D96fh3QLpR2PQrZ6Z6j7wyRAvBrcgCpAeMtuyDCzW8Wh1QFEPG4FSsGvVaEFBg"}},"state_key":"","type":"m.room.create","unsigned":{"age_ts":1756071567186}},{"content":{"avatar_url":"mxc://maunium.net/jdlSfvudiMSmcRrleeiYjjFO","displayname":"tulir","membership":"join"},"sender":"@tulir:maunium.net","state_key":"@tulir:maunium.net","type":"m.room.member"}]}} -{"auth_events":["$v3gylw64IK4PohOe0M8XO1PZthibpBCKVBI3x_8xiUU","$mmqm2KS4UExkNL65c6CIhKofn_L9fzF2OhghVqajksU"],"content":{"body":"meow","com.beeper.linkpreviews":[],"m.mentions":{},"msgtype":"m.text"},"depth":9,"hashes":{"sha256":"23rgMf7EGJcYt3Aj0qAFnmBWCxuU9Uk+ReidqtIJDKQ"},"origin_server_ts":1756071575986,"prev_events":["$qYZqSKiKMCNjzH6Trhr6nBSvbfuwr8Sh2bC4USSAxok"],"room_id":"!lVEL38waGAf4ggmWC3OVk_bbx8kZx-iOcTBKXTBnM54","sender":"@tulir:maunium.net","signatures":{"maunium.net":{"ed25519:a_xxeS":"p+Fm/uWO8VXJdCYvN/dVb8HF8W3t1sssNCBiOWbzAeuS3QqYjoMKHyixLuN1mOdnCyATv7SsHHmA4+cELRGdAA"}},"type":"m.room.message","unsigned":{"age_ts":1756071576002,"event_id":"$eZDCydRWSRnR5od0c7ahz2qSZQDHbl5g5PITT0OMC3E"}} -{"auth_events":["$qYZqSKiKMCNjzH6Trhr6nBSvbfuwr8Sh2bC4USSAxok","$v3gylw64IK4PohOe0M8XO1PZthibpBCKVBI3x_8xiUU","$deNVGs6Ef7OKVrvewhtPv7DCCqSip112cEJYp-jkP6M"],"content":{"avatar_url":"mxc://envs.net/000cf1510b7c61018f9c72ca4cc63668370782c81725865933316030464","displayname":"tulir[e]","membership":"join"},"depth":10,"hashes":{"sha256":"2kJPx2UsysNzTH8QGYHUKTO/05yetxKRlI0nKFeGbts"},"origin_server_ts":1756071578631,"prev_events":["$eZDCydRWSRnR5od0c7ahz2qSZQDHbl5g5PITT0OMC3E"],"room_id":"!lVEL38waGAf4ggmWC3OVk_bbx8kZx-iOcTBKXTBnM54","sender":"@tulir:envs.net","signatures":{"envs.net":{"ed25519:wuJyKT":"Wuzxkh8nEEX6mdJzph6Bt5ku+odFkEg2RIpFAAirOqxgcrwRaz42PsJni3YbfzH1qneF+iWQ/neA+up6jLXFBw"}},"state_key":"@tulir:envs.net","type":"m.room.member","unsigned":{"age":6,"event_id":"$Bz2lxsbUYkeBDE7eMAsOm_TK_iuSuHNvQdrHnc-T1PE","replaces_state":"$qYZqSKiKMCNjzH6Trhr6nBSvbfuwr8Sh2bC4USSAxok"}} -{"auth_events":["$Bz2lxsbUYkeBDE7eMAsOm_TK_iuSuHNvQdrHnc-T1PE","$deNVGs6Ef7OKVrvewhtPv7DCCqSip112cEJYp-jkP6M","$v3gylw64IK4PohOe0M8XO1PZthibpBCKVBI3x_8xiUU"],"content":{"avatar_url":"mxc://matrix.org/BDYVQFSLvZHMaKHDGiRkvhVg","displayname":"tulir[m]","membership":"invite"},"depth":11,"hashes":{"sha256":"dRE11R2hBfFalQ5tIJdyaElUIiSE5aCKMddjek4wR3c"},"origin_server_ts":1756071591449,"prev_events":["$Bz2lxsbUYkeBDE7eMAsOm_TK_iuSuHNvQdrHnc-T1PE"],"room_id":"!lVEL38waGAf4ggmWC3OVk_bbx8kZx-iOcTBKXTBnM54","sender":"@tulir:envs.net","signatures":{"envs.net":{"ed25519:wuJyKT":"/Mi4kX40fbR+V3DCJJGI/9L3Uuf8y5Un8LHlCQv1T0O5gnFZGQ3qN6rRNaZ1Kdh3QJBU6H4NTfnd+SVj3wt3CQ"},"matrix.org":{"ed25519:a_RXGa":"ZeLm/oxP3/Cds/uCL2FaZpgjUp0vTDBlGG6YVFNl76yIVlyIKKQKR6BSVw2u5KC5Mu9M1f+0lDmLGQujR5NkBg"}},"state_key":"@tulir:matrix.org","type":"m.room.member","unsigned":{"event_id":"$g4eBtA9EFNGLkHOofvQ4U87GNt4W8NmfmNRyR0wOUO4","invite_room_state":[{"content":{"join_rule":"invite"},"sender":"@tulir:maunium.net","state_key":"","type":"m.room.join_rules"},{"content":{"name":"event auth test v12"},"sender":"@tulir:maunium.net","state_key":"","type":"m.room.name"},{"auth_events":[],"content":{"room_version":"12"},"depth":1,"hashes":{"sha256":"qJYytb+EqWPiiZ0ogDODcLeA8XYw/2hVTaLHihcVBZQ"},"origin_server_ts":1756071567186,"prev_events":[],"sender":"@tulir:maunium.net","signatures":{"maunium.net":{"ed25519:a_xxeS":"/9pp+2tkLo6XcZ3opqLeIpa3D96fh3QLpR2PQrZ6Z6j7wyRAvBrcgCpAeMtuyDCzW8Wh1QFEPG4FSsGvVaEFBg"}},"state_key":"","type":"m.room.create","unsigned":{"age":11553}},{"content":{"avatar_url":"mxc://envs.net/000cf1510b7c61018f9c72ca4cc63668370782c81725865933316030464","displayname":"tulir[e]","membership":"join"},"sender":"@tulir:envs.net","state_key":"@tulir:envs.net","type":"m.room.member"}]}} -{"auth_events":["$g4eBtA9EFNGLkHOofvQ4U87GNt4W8NmfmNRyR0wOUO4","$deNVGs6Ef7OKVrvewhtPv7DCCqSip112cEJYp-jkP6M","$v3gylw64IK4PohOe0M8XO1PZthibpBCKVBI3x_8xiUU"],"content":{"avatar_url":"mxc://matrix.org/BDYVQFSLvZHMaKHDGiRkvhVg","displayname":"tulir[m]","membership":"join"},"depth":12,"hashes":{"sha256":"hR/fRIyFkxKnA1XNxIB+NKC0VR0vHs82EDgydhmmZXU"},"origin_server_ts":1756071609205,"prev_events":["$g4eBtA9EFNGLkHOofvQ4U87GNt4W8NmfmNRyR0wOUO4"],"room_id":"!lVEL38waGAf4ggmWC3OVk_bbx8kZx-iOcTBKXTBnM54","sender":"@tulir:matrix.org","signatures":{"matrix.org":{"ed25519:a_RXGa":"keWbZHm+LPW22XWxb14Att4Ae4GVc6XAKAnxFRr3hxhrgEhsnMcxUx7fjqlA1dk3As6kjLKdekcyCef+AQCXCA"}},"state_key":"@tulir:matrix.org","type":"m.room.member","unsigned":{"age":19,"event_id":"$_gYjNODWJdo5-S1IN0bmAk3rzIeXzr5W5cmXZSmUsNw","replaces_state":"$g4eBtA9EFNGLkHOofvQ4U87GNt4W8NmfmNRyR0wOUO4"}} -{"auth_events":["$Bz2lxsbUYkeBDE7eMAsOm_TK_iuSuHNvQdrHnc-T1PE","$v3gylw64IK4PohOe0M8XO1PZthibpBCKVBI3x_8xiUU"],"content":{"ban":50,"events":{"m.room.avatar":50,"m.room.canonical_alias":50,"m.room.encryption":100,"m.room.history_visibility":100,"m.room.name":50,"m.room.power_levels":100,"m.room.server_acl":100,"m.room.tombstone":150},"events_default":0,"historical":100,"invite":0,"kick":50,"redact":50,"state_default":50,"users":{"@tulir:envs.net":9001,"@tulir:matrix.org":9000},"users_default":0},"depth":13,"hashes":{"sha256":"30Wuw3xIbA8+eXQBa4nFDKcyHtMbKPBYhLW1zft9/fE"},"origin_server_ts":1756071643928,"prev_events":["$_gYjNODWJdo5-S1IN0bmAk3rzIeXzr5W5cmXZSmUsNw"],"room_id":"!lVEL38waGAf4ggmWC3OVk_bbx8kZx-iOcTBKXTBnM54","sender":"@tulir:envs.net","signatures":{"envs.net":{"ed25519:wuJyKT":"x6Y4uViq4nK8LVPqtMLdCuvNET2bnjxYTgiKuEe1JYfwB4jPBnPuqvrt1O9oaanMpcRWbnuiZjckq4bUlRZ7Cw"}},"state_key":"","type":"m.room.power_levels","unsigned":{"event_id":"$Qg1xRB8nL8lGykGvt9_agu_WCWq8Y3rl_p_LKa6D2Hg","replaces_state":"$v3gylw64IK4PohOe0M8XO1PZthibpBCKVBI3x_8xiUU"}} -{"auth_events":["$Qg1xRB8nL8lGykGvt9_agu_WCWq8Y3rl_p_LKa6D2Hg","$_gYjNODWJdo5-S1IN0bmAk3rzIeXzr5W5cmXZSmUsNw"],"content":{"name":"event auth test v12!"},"depth":14,"hashes":{"sha256":"WT0gz7KYXvbdNruRavqIi9Hhul3rxCdZ+YY9yMGN+Fw"},"origin_server_ts":1756071656988,"prev_events":["$Qg1xRB8nL8lGykGvt9_agu_WCWq8Y3rl_p_LKa6D2Hg"],"room_id":"!lVEL38waGAf4ggmWC3OVk_bbx8kZx-iOcTBKXTBnM54","sender":"@tulir:matrix.org","signatures":{"matrix.org":{"ed25519:a_RXGa":"bSplmqtXVhO2Z3hJ8JMQ/u7G2Wmg6yt7SwhYXObRQJfthekddJN152ME4YJIwy7YD8WFq7EkyB/NMyQoliYyCg"}},"state_key":"","type":"m.room.name","unsigned":{"event_id":"$p4xvOczrhzQMtRW3-Tf86LYUb5aqpGFIgjwHBuxWIcI","replaces_state":"$fFDwIavLTEIfcnggWuryB6JwfS-L2KT6vP1ap3P6ctE"}} -{"auth_events":["$Bz2lxsbUYkeBDE7eMAsOm_TK_iuSuHNvQdrHnc-T1PE","$Qg1xRB8nL8lGykGvt9_agu_WCWq8Y3rl_p_LKa6D2Hg"],"content":{"ban":50,"events":{"m.room.avatar":50,"m.room.canonical_alias":50,"m.room.encryption":100,"m.room.history_visibility":100,"m.room.name":50,"m.room.power_levels":100,"m.room.server_acl":100,"m.room.tombstone":9001},"events_default":0,"historical":12345,"invite":0,"kick":50,"redact":50,"state_default":50,"users":{"@tulir:envs.net":9001,"@tulir:matrix.org":9000},"users_default":0},"depth":15,"hashes":{"sha256":"FnGzbcXc8YOiB1TY33QunGA17Axoyuu3sdVOj5Z408o"},"origin_server_ts":1756071804931,"prev_events":["$p4xvOczrhzQMtRW3-Tf86LYUb5aqpGFIgjwHBuxWIcI"],"room_id":"!lVEL38waGAf4ggmWC3OVk_bbx8kZx-iOcTBKXTBnM54","sender":"@tulir:envs.net","signatures":{"envs.net":{"ed25519:wuJyKT":"uyTUsPR+CzCtlevzB5+sNXvmfbPSp6u7RZC4E4TLVsj45+pjmMRswAvuHP9PT2+Tkl6Hu8ZPigsXgbKZtR35Aw"}},"state_key":"","type":"m.room.power_levels","unsigned":{"event_id":"$uZ4OOtkM8RcbEkhjNp-YlEH0zBqgsRx1eI8b2YP7ovw","replaces_state":"$Qg1xRB8nL8lGykGvt9_agu_WCWq8Y3rl_p_LKa6D2Hg"}} -{"auth_events":["$Bz2lxsbUYkeBDE7eMAsOm_TK_iuSuHNvQdrHnc-T1PE","$uZ4OOtkM8RcbEkhjNp-YlEH0zBqgsRx1eI8b2YP7ovw"],"content":{"ban":50,"events":{"m.room.avatar":50,"m.room.canonical_alias":50,"m.room.encryption":100,"m.room.history_visibility":100,"m.room.name":50,"m.room.power_levels":100,"m.room.server_acl":100,"m.room.tombstone":100},"events_default":0,"historical":12345,"invite":0,"kick":50,"redact":50,"state_default":50,"users":{"@tulir:envs.net":9001,"@tulir:matrix.org":9000},"users_default":0},"depth":16,"hashes":{"sha256":"KcivsiLesdnUnKX23Akk3OJEJFGRSY0g4H+p7XIThnw"},"origin_server_ts":1756071812688,"prev_events":["$uZ4OOtkM8RcbEkhjNp-YlEH0zBqgsRx1eI8b2YP7ovw"],"room_id":"!lVEL38waGAf4ggmWC3OVk_bbx8kZx-iOcTBKXTBnM54","sender":"@tulir:envs.net","signatures":{"envs.net":{"ed25519:wuJyKT":"cAK8dO2AVZklY9te5aVKbF1jR/eB5rzeNOXfYPjBLf+aSAS4Z6R2aMKW6hJB9PqRS4S+UZc24DTrjUjnvMzeBA"}},"state_key":"","type":"m.room.power_levels","unsigned":{"event_id":"$iwqRXQc2cx8K4AclTjU1Se-BMJpUl4DxrLm3nfUgeQU","replaces_state":"$uZ4OOtkM8RcbEkhjNp-YlEH0zBqgsRx1eI8b2YP7ovw"}} -{"auth_events":["$Bz2lxsbUYkeBDE7eMAsOm_TK_iuSuHNvQdrHnc-T1PE","$iwqRXQc2cx8K4AclTjU1Se-BMJpUl4DxrLm3nfUgeQU"],"content":{"body":"meow #2","com.beeper.linkpreviews":[],"m.mentions":{},"msgtype":"m.text"},"depth":17,"hashes":{"sha256":"SgH9fOXGdbdqpRfYmoz1t29+gX8Ze4ThSoj6klZs3Og"},"origin_server_ts":1756247476706,"prev_events":["$iwqRXQc2cx8K4AclTjU1Se-BMJpUl4DxrLm3nfUgeQU"],"room_id":"!lVEL38waGAf4ggmWC3OVk_bbx8kZx-iOcTBKXTBnM54","sender":"@tulir:envs.net","signatures":{"envs.net":{"ed25519:wuJyKT":"SMYK7zP3SaQOKhzZUKUBVCKwffYqi3PFAlPM34kRJtmfGU3KZXNBT0zi+veXDMmxkMunqhF2RTHBD6joa0kBAQ"}},"type":"m.room.message","unsigned":{"event_id":"$KFHLO0-ENYOGQXogp84C-ISSu1xtKUzIMaZ6LiBcR_w"}} -{"auth_events":["$_gYjNODWJdo5-S1IN0bmAk3rzIeXzr5W5cmXZSmUsNw","$iwqRXQc2cx8K4AclTjU1Se-BMJpUl4DxrLm3nfUgeQU"],"content":{"ban":50,"events":{"m.room.avatar":50,"m.room.canonical_alias":50,"m.room.encryption":100,"m.room.history_visibility":100,"m.room.name":50,"m.room.power_levels":100,"m.room.server_acl":100,"m.room.tombstone":100},"events_default":0,"historical":12345,"invite":0,"kick":50,"redact":50,"state_default":50,"users":{"@tulir:beeper.com":8999,"@tulir:envs.net":9001,"@tulir:matrix.org":9000},"users_default":0},"depth":18,"hashes":{"sha256":"l8Mw3VKn/Bvntg7bZ8uh5J8M2IBZM93Xg7hsdaSci8s"},"origin_server_ts":1758918656341,"prev_events":["$KFHLO0-ENYOGQXogp84C-ISSu1xtKUzIMaZ6LiBcR_w"],"room_id":"!lVEL38waGAf4ggmWC3OVk_bbx8kZx-iOcTBKXTBnM54","sender":"@tulir:matrix.org","signatures":{"matrix.org":{"ed25519:a_RXGa":"cg5LP0WuTnVB5jFhNERLLU5b+EhmyACiOq6cp3gKJnZsTAb1yajcgJybLWKrc8QQqxPa7hPnskRBgt4OBTFNAA"}},"state_key":"","type":"m.room.power_levels","unsigned":{"event_id":"$x-CCUewbWOHQXqfcUsywOmbHvNnOSwNM1RyOu-c8SB0","replaces_state":"$iwqRXQc2cx8K4AclTjU1Se-BMJpUl4DxrLm3nfUgeQU"}} -{"auth_events":["$deNVGs6Ef7OKVrvewhtPv7DCCqSip112cEJYp-jkP6M","$_gYjNODWJdo5-S1IN0bmAk3rzIeXzr5W5cmXZSmUsNw","$x-CCUewbWOHQXqfcUsywOmbHvNnOSwNM1RyOu-c8SB0"],"content":{"avatar_url":"mxc://beeper.com/eBdwbHbllONoAySQkXLjbfFM","displayname":"tulir[b]","membership":"invite"},"depth":19,"hashes":{"sha256":"KpmaRUQnJju8TIDMPzakitUIKOWJxTvULpFB3a1CGgc"},"origin_server_ts":1758918665952,"prev_events":["$x-CCUewbWOHQXqfcUsywOmbHvNnOSwNM1RyOu-c8SB0"],"room_id":"!lVEL38waGAf4ggmWC3OVk_bbx8kZx-iOcTBKXTBnM54","sender":"@tulir:matrix.org","signatures":{"beeper.com":{"ed25519:a_zgvp":"mzI9rPkQ1xHl2/G5Yrn0qmIRt5OyjPNqRwilPfH4jmr1tP+vv3vC0m4mph/MCOq8S1c/DQaCWSpdOX1uWfchBQ"},"matrix.org":{"ed25519:a_RXGa":"kEdfr8DjxC/bdvGYxnniFI/pxDWeyG73OjG/Gu1uoHLhjdtAT/vEQ6lotJJs214/KX5eAaQWobE9qtMvtPwMDw"}},"state_key":"@tulir:beeper.com","type":"m.room.member","unsigned":{"event_id":"$PZJZoUwNySl0jY16DkHBHR0HyAppLdxc0rkSuYp5Mro","invite_room_state":[{"auth_events":[],"content":{"room_version":"12"},"depth":1,"hashes":{"sha256":"qJYytb+EqWPiiZ0ogDODcLeA8XYw/2hVTaLHihcVBZQ"},"origin_server_ts":1756071567186,"prev_events":[],"sender":"@tulir:maunium.net","signatures":{"maunium.net":{"ed25519:a_xxeS":"/9pp+2tkLo6XcZ3opqLeIpa3D96fh3QLpR2PQrZ6Z6j7wyRAvBrcgCpAeMtuyDCzW8Wh1QFEPG4FSsGvVaEFBg"}},"state_key":"","type":"m.room.create","unsigned":{"age":11553}},{"content":{"avatar_url":"mxc://matrix.org/BDYVQFSLvZHMaKHDGiRkvhVg","displayname":"tulir[m]","membership":"join"},"sender":"@tulir:matrix.org","state_key":"@tulir:matrix.org","type":"m.room.member"},{"content":{"name":"event auth test v12!"},"sender":"@tulir:matrix.org","state_key":"","type":"m.room.name"},{"content":{"join_rule":"invite"},"sender":"@tulir:maunium.net","state_key":"","type":"m.room.join_rules"}]}} -{"auth_events":["$deNVGs6Ef7OKVrvewhtPv7DCCqSip112cEJYp-jkP6M","$PZJZoUwNySl0jY16DkHBHR0HyAppLdxc0rkSuYp5Mro","$x-CCUewbWOHQXqfcUsywOmbHvNnOSwNM1RyOu-c8SB0"],"content":{"avatar_url":"mxc://beeper.com/eBdwbHbllONoAySQkXLjbfFM","displayname":"tulir[b]","membership":"join"},"depth":20,"hashes":{"sha256":"bmaHSm4mYPNBNlUfFsauSTxLrUH4CUSAKYvr1v76qkk"},"origin_server_ts":1758918670276,"prev_events":["$PZJZoUwNySl0jY16DkHBHR0HyAppLdxc0rkSuYp5Mro"],"room_id":"!lVEL38waGAf4ggmWC3OVk_bbx8kZx-iOcTBKXTBnM54","sender":"@tulir:beeper.com","signatures":{"beeper.com":{"ed25519:a_zgvp":"D3cz3m15m89a3G4c5yWOBCjhtSeI5IxBfQKt5XOr9a44QHyc3nwjjvIJaRrKNcS5tLUJwZ2IpVzjlrpbPHpxDA"}},"state_key":"@tulir:beeper.com","type":"m.room.member","unsigned":{"age":6,"event_id":"$_hayW1Y0HRWp3VEGZZbsMf0Ncg9x6n0ikveD0lbCwMw","replaces_state":"$PZJZoUwNySl0jY16DkHBHR0HyAppLdxc0rkSuYp5Mro"}} -{"auth_events":["$x-CCUewbWOHQXqfcUsywOmbHvNnOSwNM1RyOu-c8SB0","$Bz2lxsbUYkeBDE7eMAsOm_TK_iuSuHNvQdrHnc-T1PE"],"content":{"ban":50,"events":{"m.room.avatar":50,"m.room.canonical_alias":50,"m.room.encryption":100,"m.room.history_visibility":100,"m.room.name":50,"m.room.power_levels":100,"m.room.server_acl":100,"m.room.tombstone":100},"events_default":0,"historical":12345,"invite":0,"kick":50,"redact":50,"state_default":50,"users":{"@tulir:beeper.com":9000,"@tulir:envs.net":9001,"@tulir:matrix.org":8999},"users_default":0},"depth":21,"hashes":{"sha256":"xCj9vszChHiXba9DaPzhtF79Tphek3pRViMp36DOurU"},"origin_server_ts":1758918689485,"prev_events":["$_hayW1Y0HRWp3VEGZZbsMf0Ncg9x6n0ikveD0lbCwMw"],"room_id":"!lVEL38waGAf4ggmWC3OVk_bbx8kZx-iOcTBKXTBnM54","sender":"@tulir:envs.net","signatures":{"envs.net":{"ed25519:wuJyKT":"odkrWD30+ObeYtagULtECB/QmGae7qNy66nmJMWYXiQMYUJw/GMzSmgAiLAWfVYlfD3aEvMb/CBdrhL07tfSBw"}},"state_key":"","type":"m.room.power_levels","unsigned":{"event_id":"$di6cI89-GxX8-Wbx-0T69l4wg6TUWITRkjWXzG7EBqo","replaces_state":"$x-CCUewbWOHQXqfcUsywOmbHvNnOSwNM1RyOu-c8SB0"}} diff --git a/federation/pdu/auth.go b/federation/pdu/auth.go deleted file mode 100644 index 16706fe5..00000000 --- a/federation/pdu/auth.go +++ /dev/null @@ -1,71 +0,0 @@ -// Copyright (c) 2025 Tulir Asokan -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. - -//go:build goexperiment.jsonv2 - -package pdu - -import ( - "slices" - - "github.com/tidwall/gjson" - "go.mau.fi/util/exgjson" - - "maunium.net/go/mautrix/event" - "maunium.net/go/mautrix/id" -) - -type StateKey struct { - Type string - StateKey string -} - -var thirdPartyInviteTokenPath = exgjson.Path("third_party_invite", "signed", "token") - -type AuthEventSelection []StateKey - -func (aes *AuthEventSelection) Add(evtType, stateKey string) { - key := StateKey{Type: evtType, StateKey: stateKey} - if !aes.Has(key) { - *aes = append(*aes, key) - } -} - -func (aes *AuthEventSelection) Has(key StateKey) bool { - return slices.Contains(*aes, key) -} - -func (pdu *PDU) AuthEventSelection(roomVersion id.RoomVersion) (keys AuthEventSelection) { - if pdu.Type == event.StateCreate.Type && pdu.StateKey != nil { - return AuthEventSelection{} - } - keys = make(AuthEventSelection, 0, 3) - if !roomVersion.RoomIDIsCreateEventID() { - keys.Add(event.StateCreate.Type, "") - } - keys.Add(event.StatePowerLevels.Type, "") - keys.Add(event.StateMember.Type, pdu.Sender.String()) - if pdu.Type == event.StateMember.Type && pdu.StateKey != nil { - keys.Add(event.StateMember.Type, *pdu.StateKey) - membership := event.Membership(gjson.GetBytes(pdu.Content, "membership").Str) - if membership == event.MembershipJoin || membership == event.MembershipInvite || membership == event.MembershipKnock { - keys.Add(event.StateJoinRules.Type, "") - } - if membership == event.MembershipInvite { - thirdPartyInviteToken := gjson.GetBytes(pdu.Content, thirdPartyInviteTokenPath).Str - if thirdPartyInviteToken != "" { - keys.Add(event.StateThirdPartyInvite.Type, thirdPartyInviteToken) - } - } - if membership == event.MembershipJoin && roomVersion.RestrictedJoins() { - authorizedVia := gjson.GetBytes(pdu.Content, "authorised_via_users_server").Str - if authorizedVia != "" { - keys.Add(event.StateMember.Type, authorizedVia) - } - } - } - return -} diff --git a/federation/pdu/hash.go b/federation/pdu/hash.go deleted file mode 100644 index 38ef83e9..00000000 --- a/federation/pdu/hash.go +++ /dev/null @@ -1,118 +0,0 @@ -// Copyright (c) 2025 Tulir Asokan -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. - -//go:build goexperiment.jsonv2 - -package pdu - -import ( - "crypto/hmac" - "crypto/sha256" - "encoding/base64" - "fmt" - - "github.com/tidwall/gjson" - - "maunium.net/go/mautrix/id" -) - -func (pdu *PDU) CalculateContentHash() ([32]byte, error) { - if pdu == nil { - return [32]byte{}, ErrPDUIsNil - } - pduClone := pdu.Clone() - pduClone.Signatures = nil - pduClone.Unsigned = nil - pduClone.Hashes = nil - rawJSON, err := marshalCanonical(pduClone) - if err != nil { - return [32]byte{}, fmt.Errorf("failed to marshal PDU to calculate content hash: %w", err) - } - return sha256.Sum256(rawJSON), nil -} - -func (pdu *PDU) FillContentHash() error { - if pdu == nil { - return ErrPDUIsNil - } else if pdu.Hashes != nil { - return nil - } else if hash, err := pdu.CalculateContentHash(); err != nil { - return err - } else { - pdu.Hashes = &Hashes{SHA256: hash[:]} - return nil - } -} - -func (pdu *PDU) VerifyContentHash() bool { - if pdu == nil || pdu.Hashes == nil { - return false - } - calculatedHash, err := pdu.CalculateContentHash() - if err != nil { - return false - } - return hmac.Equal(calculatedHash[:], pdu.Hashes.SHA256) -} - -func (pdu *PDU) GetRoomID() (id.RoomID, error) { - if pdu == nil { - return "", ErrPDUIsNil - } else if pdu.Type != "m.room.create" { - return "", fmt.Errorf("room ID can only be calculated for m.room.create events") - } else if roomVersion := id.RoomVersion(gjson.GetBytes(pdu.Content, "room_version").Str); !roomVersion.RoomIDIsCreateEventID() { - return "", fmt.Errorf("room version %s does not use m.room.create event ID as room ID", roomVersion) - } else if evtID, err := pdu.calculateEventID(roomVersion, '!'); err != nil { - return "", fmt.Errorf("failed to calculate event ID: %w", err) - } else { - return id.RoomID(evtID), nil - } -} - -var UseInternalMetaForGetEventID = false - -func (pdu *PDU) GetEventID(roomVersion id.RoomVersion) (id.EventID, error) { - if UseInternalMetaForGetEventID && pdu.InternalMeta.EventID != "" { - return pdu.InternalMeta.EventID, nil - } - return pdu.calculateEventID(roomVersion, '$') -} - -func (pdu *PDU) GetReferenceHash(roomVersion id.RoomVersion) ([32]byte, error) { - if pdu == nil { - return [32]byte{}, ErrPDUIsNil - } - if pdu.Hashes == nil || pdu.Hashes.SHA256 == nil { - if err := pdu.FillContentHash(); err != nil { - return [32]byte{}, err - } - } - rawJSON, err := marshalCanonical(pdu.Clone().RedactForSignature(roomVersion)) - if err != nil { - return [32]byte{}, fmt.Errorf("failed to marshal redacted PDU to calculate event ID: %w", err) - } - return sha256.Sum256(rawJSON), nil -} - -func (pdu *PDU) calculateEventID(roomVersion id.RoomVersion, prefix byte) (id.EventID, error) { - referenceHash, err := pdu.GetReferenceHash(roomVersion) - if err != nil { - return "", err - } - eventID := make([]byte, 44) - eventID[0] = prefix - switch roomVersion.EventIDFormat() { - case id.EventIDFormatCustom: - return "", fmt.Errorf("*pdu.PDU can only be used for room v3+") - case id.EventIDFormatBase64: - base64.RawStdEncoding.Encode(eventID[1:], referenceHash[:]) - case id.EventIDFormatURLSafeBase64: - base64.RawURLEncoding.Encode(eventID[1:], referenceHash[:]) - default: - return "", fmt.Errorf("unknown event ID format %v", roomVersion.EventIDFormat()) - } - return id.EventID(eventID), nil -} diff --git a/federation/pdu/hash_test.go b/federation/pdu/hash_test.go deleted file mode 100644 index 17417e12..00000000 --- a/federation/pdu/hash_test.go +++ /dev/null @@ -1,55 +0,0 @@ -// Copyright (c) 2025 Tulir Asokan -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. - -//go:build goexperiment.jsonv2 - -package pdu_test - -import ( - "encoding/base64" - "testing" - - "github.com/stretchr/testify/assert" - "go.mau.fi/util/exerrors" -) - -func TestPDU_CalculateContentHash(t *testing.T) { - for _, test := range testPDUs { - if test.redacted { - continue - } - t.Run(test.name, func(t *testing.T) { - parsed := parsePDU(test.pdu) - contentHash := exerrors.Must(parsed.CalculateContentHash()) - assert.Equal( - t, - base64.RawStdEncoding.EncodeToString(parsed.Hashes.SHA256), - base64.RawStdEncoding.EncodeToString(contentHash[:]), - ) - }) - } -} - -func TestPDU_VerifyContentHash(t *testing.T) { - for _, test := range testPDUs { - if test.redacted { - continue - } - t.Run(test.name, func(t *testing.T) { - parsed := parsePDU(test.pdu) - assert.True(t, parsed.VerifyContentHash()) - }) - } -} - -func TestPDU_GetEventID(t *testing.T) { - for _, test := range testPDUs { - t.Run(test.name, func(t *testing.T) { - gotEventID := exerrors.Must(parsePDU(test.pdu).GetEventID(test.roomVersion)) - assert.Equal(t, test.eventID, gotEventID) - }) - } -} diff --git a/federation/pdu/pdu.go b/federation/pdu/pdu.go deleted file mode 100644 index 17db6995..00000000 --- a/federation/pdu/pdu.go +++ /dev/null @@ -1,156 +0,0 @@ -// Copyright (c) 2025 Tulir Asokan -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. - -//go:build goexperiment.jsonv2 - -package pdu - -import ( - "bytes" - "crypto/ed25519" - "encoding/json/jsontext" - "encoding/json/v2" - "errors" - "fmt" - "strings" - "time" - - "github.com/tidwall/gjson" - "go.mau.fi/util/jsonbytes" - "go.mau.fi/util/ptr" - - "maunium.net/go/mautrix/crypto/canonicaljson" - "maunium.net/go/mautrix/event" - "maunium.net/go/mautrix/id" -) - -// GetKeyFunc is a callback for retrieving the key corresponding to a given key ID when verifying the signature of a PDU. -// -// The input time is the timestamp of the event. The function should attempt to fetch a key that is -// valid at or after this time, but if that is not possible, the latest available key should be -// returned without an error. The verify function will do its own validity checking based on the -// returned valid until timestamp. -type GetKeyFunc = func(serverName string, keyID id.KeyID, minValidUntil time.Time) (key id.SigningKey, validUntil time.Time, err error) - -type AnyPDU interface { - GetRoomID() (id.RoomID, error) - GetEventID(roomVersion id.RoomVersion) (id.EventID, error) - GetReferenceHash(roomVersion id.RoomVersion) ([32]byte, error) - CalculateContentHash() ([32]byte, error) - FillContentHash() error - VerifyContentHash() bool - Sign(roomVersion id.RoomVersion, serverName string, keyID id.KeyID, privateKey ed25519.PrivateKey) error - VerifySignature(roomVersion id.RoomVersion, serverName string, getKey GetKeyFunc) error - ToClientEvent(roomVersion id.RoomVersion) (*event.Event, error) - AuthEventSelection(roomVersion id.RoomVersion) (keys AuthEventSelection) -} - -var ( - _ AnyPDU = (*PDU)(nil) - _ AnyPDU = (*RoomV1PDU)(nil) -) - -type InternalMeta struct { - EventID id.EventID `json:"event_id,omitempty"` - Rejected bool `json:"rejected,omitempty"` - Extra map[string]any `json:",unknown"` -} - -type PDU struct { - AuthEvents []id.EventID `json:"auth_events"` - Content jsontext.Value `json:"content"` - Depth int64 `json:"depth"` - Hashes *Hashes `json:"hashes,omitzero"` - OriginServerTS int64 `json:"origin_server_ts"` - PrevEvents []id.EventID `json:"prev_events"` - Redacts *id.EventID `json:"redacts,omitzero"` - RoomID id.RoomID `json:"room_id,omitzero"` // not present for room v12+ create events - Sender id.UserID `json:"sender"` - Signatures map[string]map[id.KeyID]string `json:"signatures,omitzero"` - StateKey *string `json:"state_key,omitzero"` - Type string `json:"type"` - Unsigned jsontext.Value `json:"unsigned,omitzero"` - InternalMeta InternalMeta `json:"-"` - - Unknown jsontext.Value `json:",unknown"` - - // Deprecated legacy fields - DeprecatedPrevState jsontext.Value `json:"prev_state,omitzero"` - DeprecatedOrigin jsontext.Value `json:"origin,omitzero"` - DeprecatedMembership jsontext.Value `json:"membership,omitzero"` -} - -var ErrPDUIsNil = errors.New("PDU is nil") - -type Hashes struct { - SHA256 jsonbytes.UnpaddedBytes `json:"sha256"` - - Unknown jsontext.Value `json:",unknown"` -} - -func (pdu *PDU) ToClientEvent(roomVersion id.RoomVersion) (*event.Event, error) { - if pdu.Type == "m.room.create" && roomVersion == "" { - roomVersion = id.RoomVersion(gjson.GetBytes(pdu.Content, "room_version").Str) - } - evtType := event.Type{Type: pdu.Type, Class: event.MessageEventType} - if pdu.StateKey != nil { - evtType.Class = event.StateEventType - } - eventID, err := pdu.GetEventID(roomVersion) - if err != nil { - return nil, err - } - roomID := pdu.RoomID - if pdu.Type == "m.room.create" && roomVersion.RoomIDIsCreateEventID() { - roomID = id.RoomID(strings.Replace(string(eventID), "$", "!", 1)) - } - evt := &event.Event{ - StateKey: pdu.StateKey, - Sender: pdu.Sender, - Type: evtType, - Timestamp: pdu.OriginServerTS, - ID: eventID, - RoomID: roomID, - Redacts: ptr.Val(pdu.Redacts), - } - err = json.Unmarshal(pdu.Content, &evt.Content) - if err != nil { - return nil, fmt.Errorf("failed to unmarshal content: %w", err) - } - return evt, nil -} - -func (pdu *PDU) AddSignature(serverName string, keyID id.KeyID, signature string) { - if signature == "" { - return - } - if pdu.Signatures == nil { - pdu.Signatures = make(map[string]map[id.KeyID]string) - } - if _, ok := pdu.Signatures[serverName]; !ok { - pdu.Signatures[serverName] = make(map[id.KeyID]string) - } - pdu.Signatures[serverName][keyID] = signature -} - -func marshalCanonical(data any) (jsontext.Value, error) { - marshaledBytes, err := json.Marshal(data) - if err != nil { - return nil, err - } - marshaled := jsontext.Value(marshaledBytes) - err = marshaled.Canonicalize() - if err != nil { - return nil, err - } - check := canonicaljson.CanonicalJSONAssumeValid(marshaled) - if !bytes.Equal(marshaled, check) { - fmt.Println(string(marshaled)) - fmt.Println(string(check)) - return nil, fmt.Errorf("canonical JSON mismatch for %s", string(marshaled)) - } - return marshaled, nil -} diff --git a/federation/pdu/pdu_test.go b/federation/pdu/pdu_test.go deleted file mode 100644 index 59d7c3a6..00000000 --- a/federation/pdu/pdu_test.go +++ /dev/null @@ -1,193 +0,0 @@ -// Copyright (c) 2025 Tulir Asokan -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. - -//go:build goexperiment.jsonv2 - -package pdu_test - -import ( - "encoding/json/v2" - "time" - - "go.mau.fi/util/exerrors" - - "maunium.net/go/mautrix/federation/pdu" - "maunium.net/go/mautrix/id" -) - -type serverKey struct { - key id.SigningKey - validUntilTS time.Time -} - -type serverDetails struct { - serverName string - keys map[id.KeyID]serverKey -} - -func (sd serverDetails) getKey(serverName string, keyID id.KeyID, _ time.Time) (id.SigningKey, time.Time, error) { - if serverName != sd.serverName { - return "", time.Time{}, nil - } - key, ok := sd.keys[keyID] - if ok { - return key.key, key.validUntilTS, nil - } - return "", time.Time{}, nil -} - -var mauniumNet = serverDetails{ - serverName: "maunium.net", - keys: map[id.KeyID]serverKey{ - "ed25519:a_xxeS": { - key: "lVt/CC3tv74OH6xTph2JrUmeRj/j+1q0HVa0Xf4QlCg", - validUntilTS: time.Now(), - }, - }, -} -var envsNet = serverDetails{ - serverName: "envs.net", - keys: map[id.KeyID]serverKey{ - "ed25519:a_zIqy": { - key: "vCUcZpt9hUn0aabfh/9GP/6sZvXcydww8DUstPHdJm0", - validUntilTS: time.UnixMilli(1722360538068), - }, - "ed25519:wuJyKT": { - key: "xbE1QssgomL4wCSlyMYF5/7KxVyM4HPwAbNa+nFFnx0", - validUntilTS: time.Now(), - }, - }, -} -var matrixOrg = serverDetails{ - serverName: "matrix.org", - keys: map[id.KeyID]serverKey{ - "ed25519:auto": { - key: "Noi6WqcDj0QmPxCNQqgezwTlBKrfqehY1u2FyWP9uYw", - validUntilTS: time.UnixMilli(1576767829750), - }, - "ed25519:a_RXGa": { - key: "l8Hft5qXKn1vfHrg3p4+W8gELQVo8N13JkluMfmn2sQ", - validUntilTS: time.Now(), - }, - }, -} -var continuwuityOrg = serverDetails{ - serverName: "continuwuity.org", - keys: map[id.KeyID]serverKey{ - "ed25519:PwHlNsFu": { - key: "8eNx2s0zWW+heKAmOH5zKv/nCPkEpraDJfGHxDu6hFI", - validUntilTS: time.Now(), - }, - }, -} -var novaAstraltechOrg = serverDetails{ - serverName: "nova.astraltech.org", - keys: map[id.KeyID]serverKey{ - "ed25519:a_afpo": { - key: "O1Y9GWuKo9xkuzuQef6gROxtTgxxAbS3WPNghPYXF3o", - validUntilTS: time.Now(), - }, - }, -} - -type testPDU struct { - name string - pdu string - eventID id.EventID - roomVersion id.RoomVersion - redacted bool - serverDetails -} - -var roomV4MessageTestPDU = testPDU{ - name: "m.room.message in v4 room", - pdu: `{"auth_events":["$OB87jNemaIVDHAfu0-pa_cP7OPFXUXCbFpjYVi8gll4","$RaWbTF9wQfGQgUpe1S13wzICtGTB2PNKRHUNHu9IO1c","$ZmEWOXw6cC4Rd1wTdY5OzeLJVzjhrkxFPwwKE4gguGk"],"content":{"body":"the last one is saying it shouldn't have effects","com.beeper.linkpreviews":[],"m.mentions":{},"msgtype":"m.text"},"depth":13103,"hashes":{"sha256":"c2wb8qMlvzIPCP1Wd+eYZ4BRgnGYxS97dR1UlJjVMeg"},"origin_server_ts":1752875275263,"prev_events":["$-7_BMI3BXwj3ayoxiJvraJxYWTKwjiQ6sh7CW_Brvj0"],"room_id":"!JiiOHXrIUCtcOJsZCa:matrix.org","sender":"@tulir:maunium.net","type":"m.room.message","signatures":{"maunium.net":{"ed25519:a_xxeS":"99TAqHpBkUEtgCraXsVXogmf/hnijPbgbG9eACtA+mbix3Y6gURI4QGQgcX/NhcE3pJQZ/YDjmbuvCnKvEccAA"}},"unsigned":{"age_ts":1752875275281}}`, - eventID: "$Jo_lmFR-e6lzrimzCA7DevIn2OwhuQYmd9xkcJBoqAA", - roomVersion: id.RoomV4, - serverDetails: mauniumNet, -} - -var roomV12MessageTestPDU = testPDU{ - name: "m.room.message in v12 room", - pdu: `{"auth_events":["$gCzdJUVV93Qory0x7p_PLG5UUiDjPJNe1H12qbHTuFA","$hyeL_nU_L3tsZ2dtZZpAHk0Skv-PqFQIipuII_By584"],"content":{"body":"meow","com.beeper.linkpreviews":[],"m.mentions":{},"msgtype":"m.text"},"depth":122,"hashes":{"sha256":"IQ0zlc+PXeEs6R3JvRkW3xTPV3zlGKSSd3x07KXGjzs"},"origin_server_ts":1755384351627,"prev_events":["$gCzdJUVV93Qory0x7p_PLG5UUiDjPJNe1H12qbHTuFA"],"room_id":"!mauT12AzsoqxV7Abvy_ApA-HNPK1LcT4GbP70_AOPyQ","sender":"@tulir_test:maunium.net","type":"m.room.message","signatures":{"maunium.net":{"ed25519:a_xxeS":"0GDMddL2k7gF4V1VU8sL3wTfhAIzAu5iVH5jeavZ2VEg3J9/tHLWXAOn2tzkLaMRWl0/XpINT2YlH/rd2U21Ag"}},"unsigned":{"age_ts":1755384351627}}`, - eventID: "$xmP-wZfpannuHG-Akogi6c4YvqxChMtdyYbUMGOrMWc", - roomVersion: id.RoomV12, - serverDetails: mauniumNet, -} - -var testPDUs = []testPDU{roomV4MessageTestPDU, { - name: "m.room.message in v5 room", - pdu: `{"auth_events":["$hp0ImHqYgHTRbLeWKPeTeFmxdb5SdMJN9cfmTrTk7d0","$KAj7X7tnJbR9qYYMWJSw-1g414_KlPptbbkZm7_kUtg","$V-2ShOwZYhA_nxMijaf3lqFgIJgzE2UMeFPtOLnoBYM"],"content":{"body":"meow","com.beeper.linkpreviews":[],"m.mentions":{},"msgtype":"m.text"},"depth":2248,"hashes":{"sha256":"kV+JuLbWXJ2r6PjHT3wt8bFc/TfI1nTaSN3Lamg/xHs"},"origin_server_ts":1755422945654,"prev_events":["$49lFLem2Nk4dxHk9RDXxTdaq9InIJpmkHpzVnjKcYwg"],"room_id":"!vzBgJsjNzgHSdWsmki:mozilla.org","sender":"@tulir:maunium.net","type":"m.room.message","signatures":{"maunium.net":{"ed25519:a_xxeS":"JIl60uVgfCLBZLPoSiE7wVkJ9U5cNEPVPuv1sCCYUOq5yOW56WD1adgpBUdX2UFpYkCHvkRnyQGxU0+6HBp5BA"}},"unsigned":{"age_ts":1755422945673}}`, - eventID: "$Qn4tHfuAe6PlnKXPZnygAU9wd6RXqMKtt_ZzstHTSgA", - roomVersion: id.RoomV5, - serverDetails: mauniumNet, -}, { - name: "m.room.message in v10 room", - pdu: `{"auth_events":["$--ilpwnsHaEdHrwiMrZNu5xHP6TthWG0FIXMHnlHCcs","$tn1FZUI_YUpfTr_a3Y_r8kC3inliIZZratzg0UsNdCQ","$Z-qMWmiMvm-aIEffcfSO6lN7TyjyTOsIcHIymfzoo20"],"content":{"body":"meow","com.beeper.linkpreviews":[],"m.mentions":{},"msgtype":"m.text"},"depth":100885,"hashes":{"sha256":"jc9272JPpPIVreJC3UEAm3BNVnLX8sm3U/TZs23wsHo"},"origin_server_ts":1755422792518,"prev_events":["$HDtbzpSys36Hk-F2NsiXfp9slsGXBH0b58qyddj_q5E"],"room_id":"!UzZHbJYcgggctGnlzr:envs.net","sender":"@tulir:maunium.net","type":"m.room.message","signatures":{"maunium.net":{"ed25519:a_xxeS":"sAMLo9jPtNB0Jq67IQm06siEBx82qZa2edu56IDQ4tDylEV4Mq7iFO23gCghqXA7B/MqBsjXotGBxv6AvlJ2Dw"}},"unsigned":{"age_ts":1755422792540}}`, - eventID: "$4ZFr_ypfp4DyZQP4zyxM_cvuOMFkl07doJmwi106YFY", - roomVersion: id.RoomV10, - serverDetails: mauniumNet, -}, { - name: "m.room.message in v11 room", - pdu: `{"auth_events":["$L8Ak6A939llTRIsZrytMlLDXQhI4uLEjx-wb1zSg-Bw","$QJmr7mmGeXGD4Tof0ZYSPW2oRGklseyHTKtZXnF-YNM","$7bkKK_Z-cGQ6Ae4HXWGBwXyZi3YjC6rIcQzGfVyl3Eo"],"content":{"body":"meow","com.beeper.linkpreviews":[],"m.mentions":{},"msgtype":"m.text"},"depth":3212,"hashes":{"sha256":"K549YdTnv62Jn84Y7sS5ZN3+AdmhleZHbenbhUpR2R8"},"origin_server_ts":1754242687127,"prev_events":["$DAhJg4jVsqk5FRatE2hbT1dSA8D2ASy5DbjEHIMSHwY"],"room_id":"!offtopic-2:continuwuity.org","sender":"@tulir:maunium.net","type":"m.room.message","signatures":{"maunium.net":{"ed25519:a_xxeS":"SkzZdZ+rH22kzCBBIAErTdB0Vg6vkFmzvwjlOarGul72EnufgtE/tJcd3a8szAdK7f1ZovRyQxDgVm/Ib2u0Aw"}},"unsigned":{"age_ts":1754242687146}}`, - eventID: `$qkWfTL7_l3oRZO2CItW8-Q0yAmi_l_1ua629ZDqponE`, - roomVersion: id.RoomV11, - serverDetails: mauniumNet, -}, roomV12MessageTestPDU, { - name: "m.room.create in v4 room", - pdu: `{"auth_events": [], "prev_events": [], "type": "m.room.create", "room_id": "!jxlRxnrZCsjpjDubDX:matrix.org", "sender": "@neilj:matrix.org", "content": {"room_version": "4", "predecessor": {"room_id": "!DYgXKezaHgMbiPMzjX:matrix.org", "event_id": "$156171636353XwPJT:matrix.org"}, "creator": "@neilj:matrix.org"}, "depth": 1, "prev_state": [], "state_key": "", "origin": "matrix.org", "origin_server_ts": 1561716363993, "hashes": {"sha256": "9tj8GpXjTAJvdNAbnuKLemZZk+Tjv2LAbGodSX6nJAo"}, "signatures": {"matrix.org": {"ed25519:auto": "2+sNt8uJUhzU4GPxnFVYtU2ZRgFdtVLT1vEZGUdJYN40zBpwYEGJy+kyb5matA+8/yLeYD9gu1O98lhleH0aCA"}}, "unsigned": {"age": 104769}}`, - eventID: "$ay_9_nPilrTpb3UxIwHHBBfFjTJb6hBAE_JzQwSjqeY", - roomVersion: id.RoomV4, - serverDetails: matrixOrg, -}, { - name: "m.room.create in v10 room", - pdu: `{"auth_events":[],"content":{"creator":"@creme:envs.net","predecessor":{"event_id":"$BxYNisKcyBDhPLiVC06t18qhv7wsT72MzMCqn5vRhfY","room_id":"!tEyFYiMHhwJlDXTxwf:envs.net"},"room_version":"10"},"depth":1,"hashes":{"sha256":"us3TrsIjBWpwbm+k3F9fUVnz9GIuhnb+LcaY47fWwUI"},"origin":"envs.net","origin_server_ts":1664394769527,"prev_events":[],"room_id":"!UzZHbJYcgggctGnlzr:envs.net","sender":"@creme:envs.net","state_key":"","type":"m.room.create","signatures":{"envs.net":{"ed25519:a_zIqy":"0g3FDaD1e5BekJYW2sR7dgxuKoZshrf8P067c9+jmH6frsWr2Ua86Ax08CFa/n46L8uvV2SGofP8iiVYgXCRBg"}},"unsigned":{"age":2060}}`, - eventID: "$tn1FZUI_YUpfTr_a3Y_r8kC3inliIZZratzg0UsNdCQ", - roomVersion: id.RoomV10, - serverDetails: envsNet, -}, { - name: "m.room.create in v12 room", - pdu: `{"auth_events":[],"content":{"fi.mau.randomness":"AAXZ6aIc","predecessor":{"room_id":"!#test/room\nversion 11, with @\ud83d\udc08\ufe0f:maunium.net"},"room_version":"12"},"depth":1,"hashes":{"sha256":"d3L1M3KUdyIKWcShyW6grUoJ8GOjCdSIEvQrDVHSpE8"},"origin_server_ts":1754940000000,"prev_events":[],"sender":"@tulir:maunium.net","state_key":"","type":"m.room.create","signatures":{"maunium.net":{"ed25519:a_xxeS":"ebjIRpzToc82cjb/RGY+VUzZic0yeRZrjctgx0SUTJxkprXn3/i1KdiYULfl/aD0cUJ5eL8gLakOSk2glm+sBw"}},"unsigned":{"age_ts":1754939139045}}`, - eventID: "$mauT12AzsoqxV7Abvy_ApA-HNPK1LcT4GbP70_AOPyQ", - roomVersion: id.RoomV12, - serverDetails: mauniumNet, -}, { - name: "m.room.member in v4 room", - pdu: `{"auth_events":["$ay_9_nPilrTpb3UxIwHHBBfFjTJb6hBAE_JzQwSjqeY","$jg2AgCfnwnjR-osoyM0lVYS21QrtfmZxhGO90PRkmO4","$wMGMP4Ucij2_d4h_fVDgIT2xooLZAgMcBruT9oo3Jio","$yyDgV8w0_e8qslmn0nh9OeSq_fO0zjpjTjSEdKFxDso"],"prev_events":["$zSjNuTXhUe3Rq6NpKD3sNyl8a_asMnBhGC5IbacHlJ4"],"type":"m.room.member","room_id":"!jxlRxnrZCsjpjDubDX:matrix.org","sender":"@tulir:maunium.net","content":{"membership":"join","displayname":"tulir","avatar_url":"mxc://maunium.net/jdlSfvudiMSmcRrleeiYjjFO","clicked \"send membership event with no changes\"":true},"depth":14370,"prev_state":[],"state_key":"@tulir:maunium.net","origin":"maunium.net","origin_server_ts":1600871136259,"hashes":{"sha256":"Ga6bG9Mk0887ruzM9TAAfa1O3DbNssb+qSFtE9oeRL4"},"signatures":{"maunium.net":{"ed25519:a_xxeS":"fzOyDG3G3pEzixtWPttkRA1DfnHETiKbiG8SEBQe2qycQbZWPky7xX8WujSrUJH/+bxTABpQwEH49d+RakxtBw"}},"unsigned":{"age_ts":1600871136259,"replaces_state":"$jg2AgCfnwnjR-osoyM0lVYS21QrtfmZxhGO90PRkmO4"}}`, - eventID: "$VtuCNOfAWGow-cxy0ajeK3fvONcC8QzF2yWa43g0Gwo", - roomVersion: id.RoomV4, - serverDetails: mauniumNet, -}, { - name: "m.room.member in v10 room", - pdu: `{"auth_events":["$HQC4hWaioLKVbMH94qKbfb3UnL4ocql2vi-VdUYI48I","$R9FUDgNAp9ms7b6ASunZOIkpqmsIRq_ROrNEznu62fs","$kEPF8Aj87EzRmFPriu2zdyEY0rY15XSqywTYVLUUlCA","$tn1FZUI_YUpfTr_a3Y_r8kC3inliIZZratzg0UsNdCQ"],"content":{"avatar_url":"mxc://maunium.net/jdlSfvudiMSmcRrleeiYjjFO","displayname":"tulir","membership":"join"},"depth":182,"hashes":{"sha256":"0HscBc921QV2dxK2qY7qrnyoAgfxBM7kKvqAXlEk+GE"},"origin":"maunium.net","origin_server_ts":1665402609039,"prev_events":["$R9FUDgNAp9ms7b6ASunZOIkpqmsIRq_ROrNEznu62fs"],"room_id":"!UzZHbJYcgggctGnlzr:envs.net","sender":"@tulir:maunium.net","state_key":"@tulir:maunium.net","type":"m.room.member","signatures":{"maunium.net":{"ed25519:a_xxeS":"lkOW0FSJ8MJ0wZpdwLH1Uf6FSl2q9/u6KthRIlM0CwHDJG4sIZ9DrMA8BdU8L/PWoDS/CoDUlLanDh99SplgBw"}},"unsigned":{"age_ts":1665402609039,"replaces_state":"$R9FUDgNAp9ms7b6ASunZOIkpqmsIRq_ROrNEznu62fs"}}`, - eventID: "$--ilpwnsHaEdHrwiMrZNu5xHP6TthWG0FIXMHnlHCcs", - roomVersion: id.RoomV10, - serverDetails: mauniumNet, -}, { - name: "m.room.member of creator in v12 room", - pdu: `{"auth_events":[],"content":{"avatar_url":"mxc://maunium.net/jdlSfvudiMSmcRrleeiYjjFO","displayname":"tulir","membership":"join"},"depth":2,"hashes":{"sha256":"IebdOBYaaWYIx2zq/lkVCnjWIXTLk1g+vgFpJMgd2/E"},"origin_server_ts":1754939139117,"prev_events":["$mauT12AzsoqxV7Abvy_ApA-HNPK1LcT4GbP70_AOPyQ"],"room_id":"!mauT12AzsoqxV7Abvy_ApA-HNPK1LcT4GbP70_AOPyQ","sender":"@tulir:maunium.net","state_key":"@tulir:maunium.net","type":"m.room.member","signatures":{"maunium.net":{"ed25519:a_xxeS":"rFCgF2hmavdm6+P6/f7rmuOdoSOmELFaH3JdWjgBLZXS2z51Ma7fa2v2+BkAH1FvBo9FLhvEoFVM4WbNQLXtAA"}},"unsigned":{"age_ts":1754939139117}}`, - eventID: "$accqGxfvhBvMP4Sf6P7t3WgnaJK6UbonO2ZmwqSE5Sg", - roomVersion: id.RoomV12, - serverDetails: mauniumNet, -}, { - name: "custom message event in v4 room", - pdu: `{"auth_events":["$VtuCNOfAWGow-cxy0ajeK3fvONcC8QzF2yWa43g0Gwo","$ay_9_nPilrTpb3UxIwHHBBfFjTJb6hBAE_JzQwSjqeY","$Gau_XwziYsr-rt3SouhbKN14twgmbKjcZZc_hz-nOgU"],"content":{"\ud83d\udc08\ufe0f":true,"\ud83d\udc15\ufe0f":false},"depth":69645,"hashes":{"sha256":"VHtWyCt+15ZesNnStU3FOkxrjzHJYZfd3JUgO9JWe0s"},"origin_server_ts":1755423939146,"prev_events":["$exmp4cj0OKOFSxuqBYiOYwQi5j_0XRc78d6EavAkhy0"],"room_id":"!jxlRxnrZCsjpjDubDX:matrix.org","sender":"@tulir:maunium.net","type":"\ud83d\udc08\ufe0f","signatures":{"maunium.net":{"ed25519:a_xxeS":"wfmP1XN4JBkKVkqrQnwysyEUslXt8hQRFwN9NC9vJaIeDMd0OJ6uqCas75808DuG71p23fzqbzhRnHckst6FCQ"}},"unsigned":{"age_ts":1755423939164}}`, - eventID: "$kAagtZAIEeZaLVCUSl74tAxQbdKbE22GU7FM-iAJBc0", - roomVersion: id.RoomV4, - serverDetails: mauniumNet, -}, { - name: "redacted m.room.member event in v11 room with 2 signatures", - pdu: `{"auth_events":["$9f12-_stoY07BOTmyguE1QlqvghLBh9Rk6PWRLoZn_M","$IP8hyjBkIDREVadyv0fPCGAW9IXGNllaZyxqQwiY_tA","$7dN5J8EveliaPkX6_QSejl4GQtem4oieavgALMeWZyE"],"content":{"membership":"join"},"depth":96978,"hashes":{"sha256":"APYA/aj3u+P0EwNaEofuSIlfqY3cK3lBz6RkwHX+Zak"},"origin_server_ts":1755664164485,"prev_events":["$XBN9W5Ll8VEH3eYqJaemxCBTDdy0hZB0sWpmyoUp93c"],"room_id":"!main-1:continuwuity.org","sender":"@6a19abdd4766:nova.astraltech.org","state_key":"@6a19abdd4766:nova.astraltech.org","type":"m.room.member","signatures":{"continuwuity.org":{"ed25519:PwHlNsFu":"+b/Fp2vWnC+Z2lI3GnCu7ZHdo3iWNDZ2AJqMoU9owMtLBPMxs4dVIsJXvaFq0ryawsgwDwKZ7f4xaFUNARJSDg"},"nova.astraltech.org":{"ed25519:a_afpo":"pXIngyxKukCPR7WOIIy8FTZxQ5L2dLiou5Oc8XS4WyY4YzJuckQzOaToigLLZxamfbN/jXbO+XUizpRpYccDAA"}},"unsigned":{}}`, - eventID: "$r6d9m125YWG28-Tln47bWtm6Jlv4mcSUWJTHijBlXLQ", - roomVersion: id.RoomV11, - serverDetails: novaAstraltechOrg, - redacted: true, -}} - -func parsePDU(pdu string) (out *pdu.PDU) { - exerrors.PanicIfNotNil(json.Unmarshal([]byte(pdu), &out)) - return -} diff --git a/federation/pdu/redact.go b/federation/pdu/redact.go deleted file mode 100644 index d7ee0c15..00000000 --- a/federation/pdu/redact.go +++ /dev/null @@ -1,111 +0,0 @@ -// Copyright (c) 2025 Tulir Asokan -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. - -//go:build goexperiment.jsonv2 - -package pdu - -import ( - "encoding/json/jsontext" - - "github.com/tidwall/gjson" - "github.com/tidwall/sjson" - "go.mau.fi/util/exgjson" - "go.mau.fi/util/ptr" - - "maunium.net/go/mautrix/id" -) - -func filteredObject(object jsontext.Value, allowedPaths ...string) jsontext.Value { - filtered := jsontext.Value("{}") - var err error - for _, path := range allowedPaths { - res := gjson.GetBytes(object, path) - if res.Exists() { - var raw jsontext.Value - if res.Index > 0 { - raw = object[res.Index : res.Index+len(res.Raw)] - } else { - raw = jsontext.Value(res.Raw) - } - filtered, err = sjson.SetRawBytes(filtered, path, raw) - if err != nil { - panic(err) - } - } - } - return filtered -} - -func (pdu *PDU) Clone() *PDU { - return ptr.Clone(pdu) -} - -func (pdu *PDU) RedactForSignature(roomVersion id.RoomVersion) *PDU { - pdu.Signatures = nil - return pdu.Redact(roomVersion) -} - -var emptyObject = jsontext.Value("{}") - -func RedactContent(eventType string, content jsontext.Value, roomVersion id.RoomVersion) jsontext.Value { - switch eventType { - case "m.room.member": - allowedPaths := []string{"membership"} - if roomVersion.RestrictedJoinsFix() { - allowedPaths = append(allowedPaths, "join_authorised_via_users_server") - } - if roomVersion.UpdatedRedactionRules() { - allowedPaths = append(allowedPaths, exgjson.Path("third_party_invite", "signed")) - } - return filteredObject(content, allowedPaths...) - case "m.room.create": - if !roomVersion.UpdatedRedactionRules() { - return filteredObject(content, "creator") - } - return content - case "m.room.join_rules": - if roomVersion.RestrictedJoins() { - return filteredObject(content, "join_rule", "allow") - } - return filteredObject(content, "join_rule") - case "m.room.power_levels": - allowedKeys := []string{"ban", "events", "events_default", "kick", "redact", "state_default", "users", "users_default"} - if roomVersion.UpdatedRedactionRules() { - allowedKeys = append(allowedKeys, "invite") - } - return filteredObject(content, allowedKeys...) - case "m.room.history_visibility": - return filteredObject(content, "history_visibility") - case "m.room.redaction": - if roomVersion.RedactsInContent() { - return filteredObject(content, "redacts") - } - return emptyObject - case "m.room.aliases": - if roomVersion.SpecialCasedAliasesAuth() { - return filteredObject(content, "aliases") - } - return emptyObject - default: - return emptyObject - } -} - -func (pdu *PDU) Redact(roomVersion id.RoomVersion) *PDU { - pdu.Unknown = nil - pdu.Unsigned = nil - if roomVersion.UpdatedRedactionRules() { - pdu.DeprecatedPrevState = nil - pdu.DeprecatedOrigin = nil - pdu.DeprecatedMembership = nil - } - if pdu.Type != "m.room.redaction" || roomVersion.RedactsInContent() { - pdu.Redacts = nil - } - pdu.Content = RedactContent(pdu.Type, pdu.Content, roomVersion) - return pdu -} diff --git a/federation/pdu/signature.go b/federation/pdu/signature.go deleted file mode 100644 index 04e7c5ef..00000000 --- a/federation/pdu/signature.go +++ /dev/null @@ -1,60 +0,0 @@ -// Copyright (c) 2025 Tulir Asokan -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. - -//go:build goexperiment.jsonv2 - -package pdu - -import ( - "crypto/ed25519" - "encoding/base64" - "fmt" - "time" - - "maunium.net/go/mautrix/federation/signutil" - "maunium.net/go/mautrix/id" -) - -func (pdu *PDU) Sign(roomVersion id.RoomVersion, serverName string, keyID id.KeyID, privateKey ed25519.PrivateKey) error { - err := pdu.FillContentHash() - if err != nil { - return err - } - rawJSON, err := marshalCanonical(pdu.Clone().RedactForSignature(roomVersion)) - if err != nil { - return fmt.Errorf("failed to marshal redacted PDU to sign: %w", err) - } - signature := ed25519.Sign(privateKey, rawJSON) - pdu.AddSignature(serverName, keyID, base64.RawStdEncoding.EncodeToString(signature)) - return nil -} - -func (pdu *PDU) VerifySignature(roomVersion id.RoomVersion, serverName string, getKey GetKeyFunc) error { - rawJSON, err := marshalCanonical(pdu.Clone().RedactForSignature(roomVersion)) - if err != nil { - return fmt.Errorf("failed to marshal redacted PDU to verify signature: %w", err) - } - verified := false - for keyID, sig := range pdu.Signatures[serverName] { - originServerTS := time.UnixMilli(pdu.OriginServerTS) - key, validUntil, err := getKey(serverName, keyID, originServerTS) - if err != nil { - return fmt.Errorf("failed to get key %s for %s: %w", keyID, serverName, err) - } else if key == "" { - return fmt.Errorf("key %s not found for %s", keyID, serverName) - } else if validUntil.Before(originServerTS) && roomVersion.EnforceSigningKeyValidity() { - return fmt.Errorf("key %s for %s is only valid until %s, but event is from %s", keyID, serverName, validUntil, originServerTS) - } else if err = signutil.VerifyJSONRaw(key, sig, rawJSON); err != nil { - return fmt.Errorf("failed to verify signature from key %s: %w", keyID, err) - } else { - verified = true - } - } - if !verified { - return fmt.Errorf("no verifiable signatures found for server %s", serverName) - } - return nil -} diff --git a/federation/pdu/signature_test.go b/federation/pdu/signature_test.go deleted file mode 100644 index 01df5076..00000000 --- a/federation/pdu/signature_test.go +++ /dev/null @@ -1,102 +0,0 @@ -// Copyright (c) 2025 Tulir Asokan -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. - -//go:build goexperiment.jsonv2 - -package pdu_test - -import ( - "crypto/ed25519" - "encoding/base64" - "encoding/json/jsontext" - "testing" - "time" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - "go.mau.fi/util/exerrors" - - "maunium.net/go/mautrix/federation/pdu" - "maunium.net/go/mautrix/id" -) - -func TestPDU_VerifySignature(t *testing.T) { - for _, test := range testPDUs { - t.Run(test.name, func(t *testing.T) { - parsed := parsePDU(test.pdu) - err := parsed.VerifySignature(test.roomVersion, test.serverName, test.getKey) - assert.NoError(t, err) - }) - } -} - -func TestPDU_VerifySignature_Fail_NoKey(t *testing.T) { - test := roomV12MessageTestPDU - parsed := parsePDU(test.pdu) - err := parsed.VerifySignature(test.roomVersion, test.serverName, func(serverName string, keyID id.KeyID, minValidUntil time.Time) (key id.SigningKey, validUntil time.Time, err error) { - return - }) - assert.Error(t, err) -} - -func TestPDU_VerifySignature_V4ExpiredKey(t *testing.T) { - test := roomV4MessageTestPDU - parsed := parsePDU(test.pdu) - err := parsed.VerifySignature(test.roomVersion, test.serverName, func(serverName string, keyID id.KeyID, minValidUntil time.Time) (key id.SigningKey, validUntil time.Time, err error) { - key = test.keys[keyID].key - validUntil = time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC) - return - }) - assert.NoError(t, err) -} - -func TestPDU_VerifySignature_V12ExpiredKey(t *testing.T) { - test := roomV12MessageTestPDU - parsed := parsePDU(test.pdu) - err := parsed.VerifySignature(test.roomVersion, test.serverName, func(serverName string, keyID id.KeyID, minValidUntil time.Time) (key id.SigningKey, validUntil time.Time, err error) { - key = test.keys[keyID].key - validUntil = time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC) - return - }) - assert.Error(t, err) -} - -func TestPDU_VerifySignature_V12InvalidSignature(t *testing.T) { - test := roomV12MessageTestPDU - parsed := parsePDU(test.pdu) - for _, sigs := range parsed.Signatures { - for key := range sigs { - sigs[key] = sigs[key][:len(sigs[key])-3] + "ABC" - } - } - err := parsed.VerifySignature(test.roomVersion, test.serverName, test.getKey) - assert.Error(t, err) -} - -func TestPDU_Sign(t *testing.T) { - pubKey, privKey := exerrors.Must2(ed25519.GenerateKey(nil)) - evt := &pdu.PDU{ - AuthEvents: []id.EventID{"$gCzdJUVV93Qory0x7p_PLG5UUiDjPJNe1H12qbHTuFA", "$hyeL_nU_L3tsZ2dtZZpAHk0Skv-PqFQIipuII_By584"}, - Content: jsontext.Value(`{"msgtype":"m.text","body":"Hello, world!"}`), - Depth: 123, - OriginServerTS: 1755384351627, - PrevEvents: []id.EventID{"$gCzdJUVV93Qory0x7p_PLG5UUiDjPJNe1H12qbHTuFA"}, - RoomID: "!mauT12AzsoqxV7Abvy_ApA-HNPK1LcT4GbP70_AOPyQ", - Sender: "@tulir:example.com", - Type: "m.room.message", - } - err := evt.Sign(id.RoomV12, "example.com", "ed25519:rand", privKey) - require.NoError(t, err) - err = evt.VerifySignature(id.RoomV11, "example.com", func(serverName string, keyID id.KeyID, minValidUntil time.Time) (key id.SigningKey, validUntil time.Time, err error) { - if serverName == "example.com" && keyID == "ed25519:rand" { - key = id.SigningKey(base64.RawStdEncoding.EncodeToString(pubKey)) - validUntil = time.Now() - } - return - }) - require.NoError(t, err) - -} diff --git a/federation/pdu/v1.go b/federation/pdu/v1.go deleted file mode 100644 index 9557f8ab..00000000 --- a/federation/pdu/v1.go +++ /dev/null @@ -1,277 +0,0 @@ -// Copyright (c) 2025 Tulir Asokan -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. - -//go:build goexperiment.jsonv2 - -package pdu - -import ( - "crypto/ed25519" - "crypto/hmac" - "crypto/sha256" - "encoding/base64" - "encoding/json/jsontext" - "encoding/json/v2" - "fmt" - "time" - - "github.com/tidwall/gjson" - "go.mau.fi/util/ptr" - - "maunium.net/go/mautrix/event" - "maunium.net/go/mautrix/federation/signutil" - "maunium.net/go/mautrix/id" -) - -type V1EventReference struct { - ID id.EventID - Hashes Hashes -} - -var ( - _ json.UnmarshalerFrom = (*V1EventReference)(nil) - _ json.MarshalerTo = (*V1EventReference)(nil) -) - -func (er *V1EventReference) MarshalJSONTo(enc *jsontext.Encoder) error { - return json.MarshalEncode(enc, []any{er.ID, er.Hashes}) -} - -func (er *V1EventReference) UnmarshalJSONFrom(dec *jsontext.Decoder) error { - var ref V1EventReference - var data []jsontext.Value - if err := json.UnmarshalDecode(dec, &data); err != nil { - return err - } else if len(data) != 2 { - return fmt.Errorf("V1EventReference.UnmarshalJSONFrom: expected array with 2 elements, got %d", len(data)) - } else if err = json.Unmarshal(data[0], &ref.ID); err != nil { - return fmt.Errorf("V1EventReference.UnmarshalJSONFrom: failed to unmarshal event ID: %w", err) - } else if err = json.Unmarshal(data[1], &ref.Hashes); err != nil { - return fmt.Errorf("V1EventReference.UnmarshalJSONFrom: failed to unmarshal hashes: %w", err) - } - *er = ref - return nil -} - -type RoomV1PDU struct { - AuthEvents []V1EventReference `json:"auth_events"` - Content jsontext.Value `json:"content"` - Depth int64 `json:"depth"` - EventID id.EventID `json:"event_id"` - Hashes *Hashes `json:"hashes,omitzero"` - OriginServerTS int64 `json:"origin_server_ts"` - PrevEvents []V1EventReference `json:"prev_events"` - Redacts *id.EventID `json:"redacts,omitzero"` - RoomID id.RoomID `json:"room_id"` - Sender id.UserID `json:"sender"` - Signatures map[string]map[id.KeyID]string `json:"signatures,omitzero"` - StateKey *string `json:"state_key,omitzero"` - Type string `json:"type"` - Unsigned jsontext.Value `json:"unsigned,omitzero"` - - Unknown jsontext.Value `json:",unknown"` - - // Deprecated legacy fields - DeprecatedPrevState jsontext.Value `json:"prev_state,omitzero"` - DeprecatedOrigin jsontext.Value `json:"origin,omitzero"` - DeprecatedMembership jsontext.Value `json:"membership,omitzero"` -} - -func (pdu *RoomV1PDU) GetRoomID() (id.RoomID, error) { - return pdu.RoomID, nil -} - -func (pdu *RoomV1PDU) GetEventID(roomVersion id.RoomVersion) (id.EventID, error) { - if !pdu.SupportsRoomVersion(roomVersion) { - return "", fmt.Errorf("RoomV1PDU.GetEventID: unsupported room version %s", roomVersion) - } - return pdu.EventID, nil -} - -func (pdu *RoomV1PDU) RedactForSignature(roomVersion id.RoomVersion) *RoomV1PDU { - pdu.Signatures = nil - return pdu.Redact(roomVersion) -} - -func (pdu *RoomV1PDU) Redact(roomVersion id.RoomVersion) *RoomV1PDU { - pdu.Unknown = nil - pdu.Unsigned = nil - if pdu.Type != "m.room.redaction" { - pdu.Redacts = nil - } - pdu.Content = RedactContent(pdu.Type, pdu.Content, roomVersion) - return pdu -} - -func (pdu *RoomV1PDU) GetReferenceHash(roomVersion id.RoomVersion) ([32]byte, error) { - if !pdu.SupportsRoomVersion(roomVersion) { - return [32]byte{}, fmt.Errorf("RoomV1PDU.GetReferenceHash: unsupported room version %s", roomVersion) - } - if pdu == nil { - return [32]byte{}, ErrPDUIsNil - } - if pdu.Hashes == nil || pdu.Hashes.SHA256 == nil { - if err := pdu.FillContentHash(); err != nil { - return [32]byte{}, err - } - } - rawJSON, err := marshalCanonical(pdu.Clone().RedactForSignature(roomVersion)) - if err != nil { - return [32]byte{}, fmt.Errorf("failed to marshal redacted PDU to calculate event ID: %w", err) - } - return sha256.Sum256(rawJSON), nil -} - -func (pdu *RoomV1PDU) CalculateContentHash() ([32]byte, error) { - if pdu == nil { - return [32]byte{}, ErrPDUIsNil - } - pduClone := pdu.Clone() - pduClone.Signatures = nil - pduClone.Unsigned = nil - pduClone.Hashes = nil - rawJSON, err := marshalCanonical(pduClone) - if err != nil { - return [32]byte{}, fmt.Errorf("failed to marshal PDU to calculate content hash: %w", err) - } - return sha256.Sum256(rawJSON), nil -} - -func (pdu *RoomV1PDU) FillContentHash() error { - if pdu == nil { - return ErrPDUIsNil - } else if pdu.Hashes != nil { - return nil - } else if hash, err := pdu.CalculateContentHash(); err != nil { - return err - } else { - pdu.Hashes = &Hashes{SHA256: hash[:]} - return nil - } -} - -func (pdu *RoomV1PDU) VerifyContentHash() bool { - if pdu == nil || pdu.Hashes == nil { - return false - } - calculatedHash, err := pdu.CalculateContentHash() - if err != nil { - return false - } - return hmac.Equal(calculatedHash[:], pdu.Hashes.SHA256) -} - -func (pdu *RoomV1PDU) Clone() *RoomV1PDU { - return ptr.Clone(pdu) -} - -func (pdu *RoomV1PDU) Sign(roomVersion id.RoomVersion, serverName string, keyID id.KeyID, privateKey ed25519.PrivateKey) error { - if !pdu.SupportsRoomVersion(roomVersion) { - return fmt.Errorf("RoomV1PDU.Sign: unsupported room version %s", roomVersion) - } - err := pdu.FillContentHash() - if err != nil { - return err - } - rawJSON, err := marshalCanonical(pdu.Clone().RedactForSignature(roomVersion)) - if err != nil { - return fmt.Errorf("failed to marshal redacted PDU to sign: %w", err) - } - signature := ed25519.Sign(privateKey, rawJSON) - if pdu.Signatures == nil { - pdu.Signatures = make(map[string]map[id.KeyID]string) - } - if _, ok := pdu.Signatures[serverName]; !ok { - pdu.Signatures[serverName] = make(map[id.KeyID]string) - } - pdu.Signatures[serverName][keyID] = base64.RawStdEncoding.EncodeToString(signature) - return nil -} - -func (pdu *RoomV1PDU) VerifySignature(roomVersion id.RoomVersion, serverName string, getKey GetKeyFunc) error { - if !pdu.SupportsRoomVersion(roomVersion) { - return fmt.Errorf("RoomV1PDU.VerifySignature: unsupported room version %s", roomVersion) - } - rawJSON, err := marshalCanonical(pdu.Clone().RedactForSignature(roomVersion)) - if err != nil { - return fmt.Errorf("failed to marshal redacted PDU to verify signature: %w", err) - } - verified := false - for keyID, sig := range pdu.Signatures[serverName] { - originServerTS := time.UnixMilli(pdu.OriginServerTS) - key, _, err := getKey(serverName, keyID, originServerTS) - if err != nil { - return fmt.Errorf("failed to get key %s for %s: %w", keyID, serverName, err) - } else if key == "" { - return fmt.Errorf("key %s not found for %s", keyID, serverName) - } else if err = signutil.VerifyJSONRaw(key, sig, rawJSON); err != nil { - return fmt.Errorf("failed to verify signature from key %s: %w", keyID, err) - } else { - verified = true - } - } - if !verified { - return fmt.Errorf("no verifiable signatures found for server %s", serverName) - } - return nil -} - -func (pdu *RoomV1PDU) SupportsRoomVersion(roomVersion id.RoomVersion) bool { - switch roomVersion { - case id.RoomV0, id.RoomV1, id.RoomV2: - return true - default: - return false - } -} - -func (pdu *RoomV1PDU) ToClientEvent(roomVersion id.RoomVersion) (*event.Event, error) { - if !pdu.SupportsRoomVersion(roomVersion) { - return nil, fmt.Errorf("RoomV1PDU.ToClientEvent: unsupported room version %s", roomVersion) - } - evtType := event.Type{Type: pdu.Type, Class: event.MessageEventType} - if pdu.StateKey != nil { - evtType.Class = event.StateEventType - } - evt := &event.Event{ - StateKey: pdu.StateKey, - Sender: pdu.Sender, - Type: evtType, - Timestamp: pdu.OriginServerTS, - ID: pdu.EventID, - RoomID: pdu.RoomID, - Redacts: ptr.Val(pdu.Redacts), - } - err := json.Unmarshal(pdu.Content, &evt.Content) - if err != nil { - return nil, fmt.Errorf("failed to unmarshal content: %w", err) - } - return evt, nil -} - -func (pdu *RoomV1PDU) AuthEventSelection(_ id.RoomVersion) (keys AuthEventSelection) { - if pdu.Type == event.StateCreate.Type && pdu.StateKey != nil { - return AuthEventSelection{} - } - keys = make(AuthEventSelection, 0, 3) - keys.Add(event.StateCreate.Type, "") - keys.Add(event.StatePowerLevels.Type, "") - keys.Add(event.StateMember.Type, pdu.Sender.String()) - if pdu.Type == event.StateMember.Type && pdu.StateKey != nil { - keys.Add(event.StateMember.Type, *pdu.StateKey) - membership := event.Membership(gjson.GetBytes(pdu.Content, "membership").Str) - if membership == event.MembershipJoin || membership == event.MembershipInvite || membership == event.MembershipKnock { - keys.Add(event.StateJoinRules.Type, "") - } - if membership == event.MembershipInvite { - thirdPartyInviteToken := gjson.GetBytes(pdu.Content, thirdPartyInviteTokenPath).Str - if thirdPartyInviteToken != "" { - keys.Add(event.StateThirdPartyInvite.Type, thirdPartyInviteToken) - } - } - } - return -} diff --git a/federation/pdu/v1_test.go b/federation/pdu/v1_test.go deleted file mode 100644 index ecf2dbd2..00000000 --- a/federation/pdu/v1_test.go +++ /dev/null @@ -1,86 +0,0 @@ -// Copyright (c) 2025 Tulir Asokan -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. - -//go:build goexperiment.jsonv2 - -package pdu_test - -import ( - "encoding/base64" - "encoding/json/v2" - "testing" - "time" - - "github.com/stretchr/testify/assert" - "go.mau.fi/util/exerrors" - - "maunium.net/go/mautrix/federation/pdu" - "maunium.net/go/mautrix/id" -) - -var testV1PDUs = []testPDU{{ - name: "m.room.message in v1 room", - pdu: `{"auth_events":[["$159234730483190eXavq:matrix.org",{"sha256":"VprZrhMqOQyKbfF3UE26JXE8D27ih4R/FGGc8GZ0Whs"}],["$143454825711DhCxH:matrix.org",{"sha256":"3sJh/5GOB094OKuhbjL634Gt69YIcge9GD55ciJa9ok"}],["$156837651426789wiPdh:maunium.net",{"sha256":"FGyR3sxJ/VxYabDkO/5qtwrPR3hLwGknJ0KX0w3GUHE"}]],"content":{"body":"photo-1526336024174-e58f5cdd8e13.jpg","info":{"h":1620,"mimetype":"image/jpeg","size":208053,"w":1080},"msgtype":"m.image","url":"mxc://maunium.net/aEqEghIjFPAerIhCxJCYpQeC"},"depth":16669,"event_id":"$16738169022163bokdi:maunium.net","hashes":{"sha256":"XYB47Gf2vAci3BTguIJaC75ZYGMuVY65jcvoUVgpcLA"},"origin":"maunium.net","origin_server_ts":1673816902100,"prev_events":[["$1673816901121325UMCjA:matrix.org",{"sha256":"t7e0IYHLI3ydIPoIU8a8E/pIWXH9cNLlQBEtGyGtHwc"}]],"room_id":"!jhpZBTbckszblMYjMK:matrix.org","sender":"@cat:maunium.net","type":"m.room.message","signatures":{"maunium.net":{"ed25519:a_xxeS":"uRZbEm+P+Y1ZVgwBn5I6SlaUZdzlH1bB4nv81yt5EIQ0b1fZ8YgM4UWMijrrXp3+NmqRFl0cakSM3MneJOtFCw"}},"unsigned":{"age_ts":1673816902100}}`, - eventID: "$16738169022163bokdi:maunium.net", - roomVersion: id.RoomV1, - serverDetails: mauniumNet, -}, { - name: "m.room.create in v1 room", - pdu: `{"origin": "matrix.org", "signatures": {"matrix.org": {"ed25519:auto": "XTejpXn5REoHrZWgCpJglGX7MfOWS2zUjYwJRLrwW2PQPbFdqtL+JnprBXwIP2C1NmgWSKG+am1QdApu0KoHCQ"}}, "origin_server_ts": 1434548257426, "sender": "@appservice-irc:matrix.org", "event_id": "$143454825711DhCxH:matrix.org", "prev_events": [], "unsigned": {"age": 12872287834}, "state_key": "", "content": {"creator": "@appservice-irc:matrix.org"}, "depth": 1, "prev_state": [], "room_id": "!jhpZBTbckszblMYjMK:matrix.org", "auth_events": [], "hashes": {"sha256": "+SSdmeeoKI/6yK6sY4XAFljWFiugSlCiXQf0QMCZjTs"}, "type": "m.room.create"}`, - eventID: "$143454825711DhCxH:matrix.org", - roomVersion: id.RoomV1, - serverDetails: matrixOrg, -}, { - name: "m.room.member in v1 room", - pdu: `{"auth_events": [["$1536447669931522zlyWe:matrix.org", {"sha256": "UkzPGd7cPAGvC0FVx3Yy2/Q0GZhA2kcgj8MGp5pjYV8"}], ["$143454825711DhCxH:matrix.org", {"sha256": "3sJh/5GOB094OKuhbjL634Gt69YIcge9GD55ciJa9ok"}], ["$143454825714nUEqZ:matrix.org", {"sha256": "NjuZXu8EDMfIfejPcNlC/IdnKQAGpPIcQjHaf0BZaHk"}]], "prev_events": [["$15660585503271JRRMm:maunium.net", {"sha256": "/Sm7uSLkYMHapp6I3NuEVJlk2JucW2HqjsQy9vzhciA"}]], "type": "m.room.member", "room_id": "!jhpZBTbckszblMYjMK:matrix.org", "sender": "@tulir:maunium.net", "content": {"membership": "join", "avatar_url": "mxc://maunium.net/jdlSfvudiMSmcRrleeiYjjFO", "displayname": "tulir"}, "depth": 10485, "prev_state": [], "state_key": "@tulir:maunium.net", "event_id": "$15660585693272iEryv:maunium.net", "origin": "maunium.net", "origin_server_ts": 1566058569201, "hashes": {"sha256": "1D6fdDzKsMGCxSqlXPA7I9wGQNTutVuJke1enGHoWK8"}, "signatures": {"maunium.net": {"ed25519:a_xxeS": "Lj/zDK6ozr4vgsxyL8jY56wTGWoA4jnlvkTs5paCX1w3nNKHnQnSMi+wuaqI6yv5vYh9usGWco2LLMuMzYXcBg"}}, "unsigned": {"age_ts": 1566058569201, "replaces_state": "$15660585383268liyBc:maunium.net"}}`, - eventID: "$15660585693272iEryv:maunium.net", - roomVersion: id.RoomV1, - serverDetails: mauniumNet, -}} - -func parseV1PDU(pdu string) (out *pdu.RoomV1PDU) { - exerrors.PanicIfNotNil(json.Unmarshal([]byte(pdu), &out)) - return -} - -func TestRoomV1PDU_CalculateContentHash(t *testing.T) { - for _, test := range testV1PDUs { - t.Run(test.name, func(t *testing.T) { - parsed := parseV1PDU(test.pdu) - contentHash := exerrors.Must(parsed.CalculateContentHash()) - assert.Equal( - t, - base64.RawStdEncoding.EncodeToString(parsed.Hashes.SHA256), - base64.RawStdEncoding.EncodeToString(contentHash[:]), - ) - }) - } -} - -func TestRoomV1PDU_VerifyContentHash(t *testing.T) { - for _, test := range testV1PDUs { - t.Run(test.name, func(t *testing.T) { - parsed := parseV1PDU(test.pdu) - assert.True(t, parsed.VerifyContentHash()) - }) - } -} - -func TestRoomV1PDU_VerifySignature(t *testing.T) { - for _, test := range testV1PDUs { - t.Run(test.name, func(t *testing.T) { - parsed := parseV1PDU(test.pdu) - err := parsed.VerifySignature(test.roomVersion, test.serverName, func(serverName string, keyID id.KeyID, _ time.Time) (id.SigningKey, time.Time, error) { - key, ok := test.keys[keyID] - if ok { - return key.key, key.validUntilTS, nil - } - return "", time.Time{}, nil - }) - assert.NoError(t, err) - }) - } -} diff --git a/federation/resolution.go b/federation/resolution.go index a3188266..69d4d3bf 100644 --- a/federation/resolution.go +++ b/federation/resolution.go @@ -20,8 +20,6 @@ import ( "time" "github.com/rs/zerolog" - - "maunium.net/go/mautrix" ) type ResolvedServerName struct { @@ -80,10 +78,7 @@ func ResolveServerName(ctx context.Context, serverName string, opts ...*ResolveS } else if wellKnown != nil { output.Expires = expiry output.HostHeader = wellKnown.Server - wkHost, wkPort, ok := ParseServerName(wellKnown.Server) - if ok { - hostname, port = wkHost, wkPort - } + hostname, port, ok = ParseServerName(wellKnown.Server) // Step 3.1 and 3.2: IP literals and hostnames with port inside .well-known if net.ParseIP(hostname) != nil || port != 0 { if port == 0 { @@ -176,11 +171,9 @@ func RequestWellKnown(ctx context.Context, cli *http.Client, hostname string) (* defer resp.Body.Close() if resp.StatusCode != http.StatusOK { return nil, time.Time{}, fmt.Errorf("unexpected status code %d", resp.StatusCode) - } else if resp.ContentLength > mautrix.WellKnownMaxSize { - return nil, time.Time{}, fmt.Errorf("response too large: %d bytes", resp.ContentLength) } var respData RespWellKnown - err = json.NewDecoder(io.LimitReader(resp.Body, mautrix.WellKnownMaxSize)).Decode(&respData) + err = json.NewDecoder(io.LimitReader(resp.Body, 50*1024)).Decode(&respData) if err != nil { return nil, time.Time{}, fmt.Errorf("failed to decode response: %w", err) } else if respData.Server == "" { diff --git a/federation/serverauth.go b/federation/serverauth.go index cd300341..f46c7991 100644 --- a/federation/serverauth.go +++ b/federation/serverauth.go @@ -231,7 +231,7 @@ func (sa *ServerAuth) Authenticate(r *http.Request) (*http.Request, *mautrix.Res } err = (&signableRequest{ Method: r.Method, - URI: r.URL.RequestURI(), + URI: r.URL.EscapedPath(), Origin: parsed.Origin, Destination: destination, Content: reqBody, diff --git a/federation/serverauth_test.go b/federation/serverauth_test.go index f99fc6cf..9fa15459 100644 --- a/federation/serverauth_test.go +++ b/federation/serverauth_test.go @@ -19,9 +19,9 @@ import ( func TestServerKeyResponse_VerifySelfSignature(t *testing.T) { cli := federation.NewClient("", nil, nil) ctx := context.Background() - for _, name := range []string{"matrix.org", "maunium.net", "cd.mau.dev", "uwu.mau.dev"} { + for _, name := range []string{"matrix.org", "maunium.net", "continuwuity.org"} { t.Run(name, func(t *testing.T) { - resp, err := cli.ServerKeys(ctx, name) + resp, err := cli.ServerKeys(ctx, "matrix.org") require.NoError(t, err) assert.NoError(t, resp.VerifySelfSignature()) }) diff --git a/federation/signutil/verify.go b/federation/signutil/verify.go index ea0e7886..8fe55b2f 100644 --- a/federation/signutil/verify.go +++ b/federation/signutil/verify.go @@ -48,47 +48,6 @@ func VerifyJSON(serverName string, keyID id.KeyID, key id.SigningKey, data any) return VerifyJSONRaw(key, sigVal.Str, message) } -func VerifyJSONAny(key id.SigningKey, data any) error { - var err error - message, ok := data.(json.RawMessage) - if !ok { - message, err = json.Marshal(data) - if err != nil { - return fmt.Errorf("failed to marshal data: %w", err) - } - } - sigs := gjson.GetBytes(message, "signatures") - if !sigs.IsObject() { - return ErrSignatureNotFound - } - message, err = sjson.DeleteBytes(message, "signatures") - if err != nil { - return fmt.Errorf("failed to delete signatures: %w", err) - } - message, err = sjson.DeleteBytes(message, "unsigned") - if err != nil { - return fmt.Errorf("failed to delete unsigned: %w", err) - } - var validated bool - sigs.ForEach(func(_, value gjson.Result) bool { - if !value.IsObject() { - return true - } - value.ForEach(func(_, value gjson.Result) bool { - if value.Type != gjson.String { - return true - } - validated = VerifyJSONRaw(key, value.Str, message) == nil - return !validated - }) - return !validated - }) - if !validated { - return ErrInvalidSignature - } - return nil -} - func VerifyJSONRaw(key id.SigningKey, sig string, message json.RawMessage) error { sigBytes, err := base64.RawStdEncoding.DecodeString(sig) if err != nil { diff --git a/filter.go b/filter.go index 54973dab..c6c8211b 100644 --- a/filter.go +++ b/filter.go @@ -57,7 +57,7 @@ type FilterPart struct { // Validate checks if the filter contains valid property values func (filter *Filter) Validate() error { if filter.EventFormat != EventFormatClient && filter.EventFormat != EventFormatFederation { - return errors.New("bad event_format value") + return errors.New("Bad event_format value. Must be one of [\"client\", \"federation\"]") } return nil } diff --git a/format/htmlparser.go b/format/htmlparser.go index e0507d93..e5f92896 100644 --- a/format/htmlparser.go +++ b/format/htmlparser.go @@ -93,30 +93,6 @@ func DefaultPillConverter(displayname, mxid, eventID string, ctx Context) string } } -func onlyBacktickCount(line string) (count int) { - for i := 0; i < len(line); i++ { - if line[i] != '`' { - return -1 - } - count++ - } - return -} - -func DefaultMonospaceBlockConverter(code, language string, ctx Context) string { - if len(code) == 0 || code[len(code)-1] != '\n' { - code += "\n" - } - fence := "```" - for line := range strings.SplitSeq(code, "\n") { - count := onlyBacktickCount(strings.TrimSpace(line)) - if count >= len(fence) { - fence = strings.Repeat("`", count+1) - } - } - return fmt.Sprintf("%s%s\n%s%s", fence, language, code, fence) -} - // HTMLParser is a somewhat customizable Matrix HTML parser. type HTMLParser struct { PillConverter PillConverter @@ -372,7 +348,10 @@ func (parser *HTMLParser) tagToString(node *html.Node, ctx Context) string { if parser.MonospaceBlockConverter != nil { return parser.MonospaceBlockConverter(preStr, language, ctx) } - return DefaultMonospaceBlockConverter(preStr, language, ctx) + if len(preStr) == 0 || preStr[len(preStr)-1] != '\n' { + preStr += "\n" + } + return fmt.Sprintf("```%s\n%s```", language, preStr) default: return parser.nodeToTagAwareString(node.FirstChild, ctx) } diff --git a/format/markdown.go b/format/markdown.go index 77ced0dc..3d9979b4 100644 --- a/format/markdown.go +++ b/format/markdown.go @@ -57,18 +57,7 @@ type uriAble interface { } func MarkdownMention(id uriAble) string { - return MarkdownMentionWithName(id.String(), id) -} - -func MarkdownMentionWithName(name string, id uriAble) string { - return MarkdownLink(name, id.URI().MatrixToURL()) -} - -func MarkdownMentionRoomID(name string, id id.RoomID, via ...string) string { - if name == "" { - name = id.String() - } - return MarkdownLink(name, id.URI(via...).MatrixToURL()) + return MarkdownLink(id.String(), id.URI().MatrixToURL()) } func MarkdownLink(name string, url string) string { diff --git a/go.mod b/go.mod index 49a1d4e4..4abdc4ff 100644 --- a/go.mod +++ b/go.mod @@ -1,42 +1,42 @@ module maunium.net/go/mautrix -go 1.25.0 +go 1.24.0 -toolchain go1.26.0 +toolchain go1.25.0 require ( - filippo.io/edwards25519 v1.2.0 + filippo.io/edwards25519 v1.1.0 github.com/chzyer/readline v1.5.1 - github.com/coder/websocket v1.8.14 - github.com/lib/pq v1.11.2 - github.com/mattn/go-sqlite3 v1.14.34 + github.com/coder/websocket v1.8.13 + github.com/lib/pq v1.10.9 + github.com/mattn/go-sqlite3 v1.14.32 github.com/rs/xid v1.6.0 github.com/rs/zerolog v1.34.0 github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e - github.com/stretchr/testify v1.11.1 + github.com/stretchr/testify v1.10.0 github.com/tidwall/gjson v1.18.0 github.com/tidwall/sjson v1.2.5 - github.com/yuin/goldmark v1.7.16 - go.mau.fi/util v0.9.6 + github.com/yuin/goldmark v1.7.13 + go.mau.fi/util v0.9.0 go.mau.fi/zeroconfig v0.2.0 - golang.org/x/crypto v0.48.0 - golang.org/x/exp v0.0.0-20260218203240-3dfff04db8fa - golang.org/x/net v0.50.0 - golang.org/x/sync v0.19.0 + golang.org/x/crypto v0.41.0 + golang.org/x/exp v0.0.0-20250813145105-42675adae3e6 + golang.org/x/net v0.43.0 + golang.org/x/sync v0.16.0 gopkg.in/yaml.v3 v3.0.1 maunium.net/go/mauflag v1.0.0 ) require ( - github.com/coreos/go-systemd/v22 v22.6.0 // indirect + github.com/coreos/go-systemd/v22 v22.5.0 // indirect github.com/davecgh/go-spew v1.1.1 // indirect github.com/mattn/go-colorable v0.1.14 // indirect github.com/mattn/go-isatty v0.0.20 // indirect - github.com/petermattis/goid v0.0.0-20260113132338-7c7de50cc741 // indirect + github.com/petermattis/goid v0.0.0-20250813065127-a731cc31b4fe // indirect github.com/pmezard/go-difflib v1.0.0 // indirect github.com/tidwall/match v1.1.1 // indirect github.com/tidwall/pretty v1.2.1 // indirect - golang.org/x/sys v0.41.0 // indirect - golang.org/x/text v0.34.0 // indirect + golang.org/x/sys v0.35.0 // indirect + golang.org/x/text v0.28.0 // indirect gopkg.in/natefinch/lumberjack.v2 v2.2.1 // indirect ) diff --git a/go.sum b/go.sum index 871a5156..bb5d5cdb 100644 --- a/go.sum +++ b/go.sum @@ -1,5 +1,5 @@ -filippo.io/edwards25519 v1.2.0 h1:crnVqOiS4jqYleHd9vaKZ+HKtHfllngJIiOpNpoJsjo= -filippo.io/edwards25519 v1.2.0/go.mod h1:xzAOLCNug/yB62zG1bQ8uziwrIqIuxhctzJT18Q77mc= +filippo.io/edwards25519 v1.1.0 h1:FNf4tywRC1HmFuKW5xopWpigGjJKiJSV0Cqo0cJWDaA= +filippo.io/edwards25519 v1.1.0/go.mod h1:BxyFTGdWcka3PhytdK4V28tE5sGfRvvvRV7EaN4VDT4= github.com/DATA-DOG/go-sqlmock v1.5.2 h1:OcvFkGmslmlZibjAjaHm3L//6LiuBgolP7OputlJIzU= github.com/DATA-DOG/go-sqlmock v1.5.2/go.mod h1:88MAG/4G7SMwSE3CeA0ZKzrT5CiOU3OJ+JlNzwDqpNU= github.com/chzyer/logex v1.2.1 h1:XHDu3E6q+gdHgsdTPH6ImJMIp436vR6MPtH8gP05QzM= @@ -8,16 +8,15 @@ github.com/chzyer/readline v1.5.1 h1:upd/6fQk4src78LMRzh5vItIt361/o4uq553V8B5sGI github.com/chzyer/readline v1.5.1/go.mod h1:Eh+b79XXUwfKfcPLepksvw2tcLE/Ct21YObkaSkeBlk= github.com/chzyer/test v1.0.0 h1:p3BQDXSxOhOG0P9z6/hGnII4LGiEPOYBhs8asl/fC04= github.com/chzyer/test v1.0.0/go.mod h1:2JlltgoNkt4TW/z9V/IzDdFaMTM2JPIi26O1pF38GC8= -github.com/coder/websocket v1.8.14 h1:9L0p0iKiNOibykf283eHkKUHHrpG7f65OE3BhhO7v9g= -github.com/coder/websocket v1.8.14/go.mod h1:NX3SzP+inril6yawo5CQXx8+fk145lPDC6pumgx0mVg= +github.com/coder/websocket v1.8.13 h1:f3QZdXy7uGVz+4uCJy2nTZyM0yTBj8yANEHhqlXZ9FE= +github.com/coder/websocket v1.8.13/go.mod h1:LNVeNrXQZfe5qhS9ALED3uA+l5pPqvwXg3CKoDBB2gs= +github.com/coreos/go-systemd/v22 v22.5.0 h1:RrqgGjYQKalulkV8NGVIfkXQf6YYmOyiJKk8iXXhfZs= github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc= -github.com/coreos/go-systemd/v22 v22.6.0 h1:aGVa/v8B7hpb0TKl0MWoAavPDmHvobFe5R5zn0bCJWo= -github.com/coreos/go-systemd/v22 v22.6.0/go.mod h1:iG+pp635Fo7ZmV/j14KUcmEyWF+0X7Lua8rrTWzYgWU= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA= -github.com/lib/pq v1.11.2 h1:x6gxUeu39V0BHZiugWe8LXZYZ+Utk7hSJGThs8sdzfs= -github.com/lib/pq v1.11.2/go.mod h1:/p+8NSbOcwzAEI7wiMXFlgydTwcgTr3OSKMsD2BitpA= +github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw= +github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg= github.com/mattn/go-colorable v0.1.14 h1:9A9LHSqF/7dyVVX6g0U9cwm9pG3kP9gSzcuIPHPsaIE= github.com/mattn/go-colorable v0.1.14/go.mod h1:6LmQG8QLFO4G5z1gPvYEzlUgJ2wF+stgPZH1UqBm1s8= @@ -25,10 +24,10 @@ github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/ github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= -github.com/mattn/go-sqlite3 v1.14.34 h1:3NtcvcUnFBPsuRcno8pUtupspG/GM+9nZ88zgJcp6Zk= -github.com/mattn/go-sqlite3 v1.14.34/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y= -github.com/petermattis/goid v0.0.0-20260113132338-7c7de50cc741 h1:KPpdlQLZcHfTMQRi6bFQ7ogNO0ltFT4PmtwTLW4W+14= -github.com/petermattis/goid v0.0.0-20260113132338-7c7de50cc741/go.mod h1:pxMtw7cyUw6B2bRH0ZBANSPg+AoSud1I1iyJHI69jH4= +github.com/mattn/go-sqlite3 v1.14.32 h1:JD12Ag3oLy1zQA+BNn74xRgaBbdhbNIDYvQUEuuErjs= +github.com/mattn/go-sqlite3 v1.14.32/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y= +github.com/petermattis/goid v0.0.0-20250813065127-a731cc31b4fe h1:vHpqOnPlnkba8iSxU4j/CvDSS9J4+F4473esQsYLGoE= +github.com/petermattis/goid v0.0.0-20250813065127-a731cc31b4fe/go.mod h1:pxMtw7cyUw6B2bRH0ZBANSPg+AoSud1I1iyJHI69jH4= github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= @@ -38,8 +37,8 @@ github.com/rs/zerolog v1.34.0 h1:k43nTLIwcTVQAncfCw4KZ2VY6ukYoZaBPNOE8txlOeY= github.com/rs/zerolog v1.34.0/go.mod h1:bJsvje4Z08ROH4Nhs5iH600c3IkWhwp44iRc54W6wYQ= github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e h1:MRM5ITcdelLK2j1vwZ3Je0FKVCfqOLp5zO6trqMLYs0= github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e/go.mod h1:XV66xRDqSt+GTGFMVlhk3ULuV0y9ZmzeVGR4mloJI3M= -github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= -github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= +github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA= +github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= github.com/tidwall/gjson v1.14.2/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= github.com/tidwall/gjson v1.18.0 h1:FIDeeyB800efLX89e5a8Y0BNH+LOngJyGrIWxG2FKQY= github.com/tidwall/gjson v1.18.0/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= @@ -50,28 +49,28 @@ github.com/tidwall/pretty v1.2.1 h1:qjsOFOWWQl+N3RsoF5/ssm1pHmJJwhjlSbZ51I6wMl4= github.com/tidwall/pretty v1.2.1/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU= github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY= github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28= -github.com/yuin/goldmark v1.7.16 h1:n+CJdUxaFMiDUNnWC3dMWCIQJSkxH4uz3ZwQBkAlVNE= -github.com/yuin/goldmark v1.7.16/go.mod h1:ip/1k0VRfGynBgxOz0yCqHrbZXhcjxyuS66Brc7iBKg= -go.mau.fi/util v0.9.6 h1:2nsvxm49KhI3wrFltr0+wSUBlnQ4CMtykuELjpIU+ts= -go.mau.fi/util v0.9.6/go.mod h1:sIJpRH7Iy5Ad1SBuxQoatxtIeErgzxCtjd/2hCMkYMI= +github.com/yuin/goldmark v1.7.13 h1:GPddIs617DnBLFFVJFgpo1aBfe/4xcvMc3SB5t/D0pA= +github.com/yuin/goldmark v1.7.13/go.mod h1:ip/1k0VRfGynBgxOz0yCqHrbZXhcjxyuS66Brc7iBKg= +go.mau.fi/util v0.9.0 h1:ya3s3pX+Y8R2fgp0DbE7a0o3FwncoelDX5iyaeVE8ls= +go.mau.fi/util v0.9.0/go.mod h1:pdL3lg2aaeeHIreGXNnPwhJPXkXdc3ZxsI6le8hOWEA= go.mau.fi/zeroconfig v0.2.0 h1:e/OGEERqVRRKlgaro7E6bh8xXiKFSXB3eNNIud7FUjU= go.mau.fi/zeroconfig v0.2.0/go.mod h1:J0Vn0prHNOm493oZoQ84kq83ZaNCYZnq+noI1b1eN8w= -golang.org/x/crypto v0.48.0 h1:/VRzVqiRSggnhY7gNRxPauEQ5Drw9haKdM0jqfcCFts= -golang.org/x/crypto v0.48.0/go.mod h1:r0kV5h3qnFPlQnBSrULhlsRfryS2pmewsg+XfMgkVos= -golang.org/x/exp v0.0.0-20260218203240-3dfff04db8fa h1:Zt3DZoOFFYkKhDT3v7Lm9FDMEV06GpzjG2jrqW+QTE0= -golang.org/x/exp v0.0.0-20260218203240-3dfff04db8fa/go.mod h1:K79w1Vqn7PoiZn+TkNpx3BUWUQksGO3JcVX6qIjytmA= -golang.org/x/net v0.50.0 h1:ucWh9eiCGyDR3vtzso0WMQinm2Dnt8cFMuQa9K33J60= -golang.org/x/net v0.50.0/go.mod h1:UgoSli3F/pBgdJBHCTc+tp3gmrU4XswgGRgtnwWTfyM= -golang.org/x/sync v0.19.0 h1:vV+1eWNmZ5geRlYjzm2adRgW2/mcpevXNg50YZtPCE4= -golang.org/x/sync v0.19.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI= +golang.org/x/crypto v0.41.0 h1:WKYxWedPGCTVVl5+WHSSrOBT0O8lx32+zxmHxijgXp4= +golang.org/x/crypto v0.41.0/go.mod h1:pO5AFd7FA68rFak7rOAGVuygIISepHftHnr8dr6+sUc= +golang.org/x/exp v0.0.0-20250813145105-42675adae3e6 h1:SbTAbRFnd5kjQXbczszQ0hdk3ctwYf3qBNH9jIsGclE= +golang.org/x/exp v0.0.0-20250813145105-42675adae3e6/go.mod h1:4QTo5u+SEIbbKW1RacMZq1YEfOBqeXa19JeshGi+zc4= +golang.org/x/net v0.43.0 h1:lat02VYK2j4aLzMzecihNvTlJNQUq316m2Mr9rnM6YE= +golang.org/x/net v0.43.0/go.mod h1:vhO1fvI4dGsIjh73sWfUVjj3N7CA9WkKJNQm2svM6Jg= +golang.org/x/sync v0.16.0 h1:ycBJEhp9p4vXvUZNszeOq0kGTPghopOL8q0fq3vstxw= +golang.org/x/sync v0.16.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA= golang.org/x/sys v0.0.0-20220310020820-b874c991c1a5/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.41.0 h1:Ivj+2Cp/ylzLiEU89QhWblYnOE9zerudt9Ftecq2C6k= -golang.org/x/sys v0.41.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= -golang.org/x/text v0.34.0 h1:oL/Qq0Kdaqxa1KbNeMKwQq0reLCCaFtqu2eNuSeNHbk= -golang.org/x/text v0.34.0/go.mod h1:homfLqTYRFyVYemLBFl5GgL/DWEiH5wcsQ5gSh1yziA= +golang.org/x/sys v0.35.0 h1:vz1N37gP5bs89s7He8XuIYXpyY0+QlsKmzipCbUtyxI= +golang.org/x/sys v0.35.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k= +golang.org/x/text v0.28.0 h1:rhazDwis8INMIwQ4tpjLDzUhx6RlXqZNPEM0huQojng= +golang.org/x/text v0.28.0/go.mod h1:U8nCwOR8jO/marOQ0QbDiOngZVEBB7MAiitBuMjXiNU= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/natefinch/lumberjack.v2 v2.2.1 h1:bBRl1b0OH9s/DuPhuXpNl+VtCaJXFZ5/uEFST95x9zc= diff --git a/id/contenturi.go b/id/contenturi.go index 67127b6c..e6a313f5 100644 --- a/id/contenturi.go +++ b/id/contenturi.go @@ -17,14 +17,8 @@ import ( ) var ( - ErrInvalidContentURI = errors.New("invalid Matrix content URI") - ErrInputNotJSONString = errors.New("input doesn't look like a JSON string") -) - -// Deprecated: use variables prefixed with Err -var ( - InvalidContentURI = ErrInvalidContentURI - InputNotJSONString = ErrInputNotJSONString + InvalidContentURI = errors.New("invalid Matrix content URI") + InputNotJSONString = errors.New("input doesn't look like a JSON string") ) // ContentURIString is a string that's expected to be a Matrix content URI. @@ -61,9 +55,9 @@ func ParseContentURI(uri string) (parsed ContentURI, err error) { if len(uri) == 0 { return } else if !strings.HasPrefix(uri, "mxc://") { - err = ErrInvalidContentURI + err = InvalidContentURI } else if index := strings.IndexRune(uri[6:], '/'); index == -1 || index == len(uri)-7 { - err = ErrInvalidContentURI + err = InvalidContentURI } else { parsed.Homeserver = uri[6 : 6+index] parsed.FileID = uri[6+index+1:] @@ -77,9 +71,9 @@ func ParseContentURIBytes(uri []byte) (parsed ContentURI, err error) { if len(uri) == 0 { return } else if !bytes.HasPrefix(uri, mxcBytes) { - err = ErrInvalidContentURI + err = InvalidContentURI } else if index := bytes.IndexRune(uri[6:], '/'); index == -1 || index == len(uri)-7 { - err = ErrInvalidContentURI + err = InvalidContentURI } else { parsed.Homeserver = string(uri[6 : 6+index]) parsed.FileID = string(uri[6+index+1:]) @@ -92,7 +86,7 @@ func (uri *ContentURI) UnmarshalJSON(raw []byte) (err error) { *uri = ContentURI{} return nil } else if len(raw) < 2 || raw[0] != '"' || raw[len(raw)-1] != '"' { - return fmt.Errorf("ContentURI: %w", ErrInputNotJSONString) + return InputNotJSONString } parsed, err := ParseContentURIBytes(raw[1 : len(raw)-1]) if err != nil { diff --git a/id/crypto.go b/id/crypto.go index ee857f78..355a84a8 100644 --- a/id/crypto.go +++ b/id/crypto.go @@ -53,34 +53,6 @@ const ( KeyBackupAlgorithmMegolmBackupV1 KeyBackupAlgorithm = "m.megolm_backup.v1.curve25519-aes-sha2" ) -type KeySource string - -func (source KeySource) String() string { - return string(source) -} - -func (source KeySource) Int() int { - switch source { - case KeySourceDirect: - return 100 - case KeySourceBackup: - return 90 - case KeySourceImport: - return 80 - case KeySourceForward: - return 50 - default: - return 0 - } -} - -const ( - KeySourceDirect KeySource = "direct" - KeySourceBackup KeySource = "backup" - KeySourceImport KeySource = "import" - KeySourceForward KeySource = "forward" -) - // BackupVersion is an arbitrary string that identifies a server side key backup. type KeyBackupVersion string diff --git a/id/matrixuri.go b/id/matrixuri.go index d5c78bc7..8f5ec849 100644 --- a/id/matrixuri.go +++ b/id/matrixuri.go @@ -54,7 +54,7 @@ var SigilToPathSegment = map[rune]string{ func (uri *MatrixURI) getQuery() url.Values { q := make(url.Values) - if len(uri.Via) > 0 { + if uri.Via != nil && len(uri.Via) > 0 { q["via"] = uri.Via } if len(uri.Action) > 0 { diff --git a/id/opaque.go b/id/opaque.go index c1ad4988..1d9f0dcf 100644 --- a/id/opaque.go +++ b/id/opaque.go @@ -32,9 +32,6 @@ type EventID string // https://github.com/matrix-org/matrix-doc/pull/2716 type BatchID string -// A DelayID is a string identifying a delayed event. -type DelayID string - func (roomID RoomID) String() string { return string(roomID) } diff --git a/id/trust.go b/id/trust.go index 6255093e..04f6e36b 100644 --- a/id/trust.go +++ b/id/trust.go @@ -16,7 +16,6 @@ type TrustState int const ( TrustStateBlacklisted TrustState = -100 - TrustStateDeviceKeyMismatch TrustState = -5 TrustStateUnset TrustState = 0 TrustStateUnknownDevice TrustState = 10 TrustStateForwarded TrustState = 20 @@ -24,7 +23,7 @@ const ( TrustStateCrossSignedTOFU TrustState = 100 TrustStateCrossSignedVerified TrustState = 200 TrustStateVerified TrustState = 300 - TrustStateInvalid TrustState = -2147483647 + TrustStateInvalid TrustState = (1 << 31) - 1 ) func (ts *TrustState) UnmarshalText(data []byte) error { @@ -45,8 +44,6 @@ func ParseTrustState(val string) TrustState { switch strings.ToLower(val) { case "blacklisted": return TrustStateBlacklisted - case "device-key-mismatch": - return TrustStateDeviceKeyMismatch case "unverified": return TrustStateUnset case "cross-signed-untrusted": @@ -70,8 +67,6 @@ func (ts TrustState) String() string { switch ts { case TrustStateBlacklisted: return "blacklisted" - case TrustStateDeviceKeyMismatch: - return "device-key-mismatch" case TrustStateUnset: return "unverified" case TrustStateCrossSignedUntrusted: diff --git a/id/userid.go b/id/userid.go index 726a0d58..6d9f4080 100644 --- a/id/userid.go +++ b/id/userid.go @@ -104,24 +104,16 @@ func ValidateUserLocalpart(localpart string) error { return nil } -// ParseAndValidateStrict is a stricter version of ParseAndValidateRelaxed that checks the localpart to only allow non-historical localparts. -// This should be used with care: there are real users still using historical localparts. -func (userID UserID) ParseAndValidateStrict() (localpart, homeserver string, err error) { - localpart, homeserver, err = userID.ParseAndValidateRelaxed() +// ParseAndValidate parses the user ID into the localpart and server name like Parse, +// and also validates that the localpart is allowed according to the user identifiers spec. +func (userID UserID) ParseAndValidate() (localpart, homeserver string, err error) { + localpart, homeserver, err = userID.Parse() if err == nil { err = ValidateUserLocalpart(localpart) } - return -} - -// ParseAndValidateRelaxed parses the user ID into the localpart and server name like Parse, -// and also validates that the user ID is not too long and that the server name is valid. -func (userID UserID) ParseAndValidateRelaxed() (localpart, homeserver string, err error) { - if len(userID) > UserIDMaxLength { + if err == nil && len(userID) > UserIDMaxLength { err = ErrUserIDTooLong - return } - localpart, homeserver, err = userID.Parse() if err == nil && !ValidateServerName(homeserver) { err = fmt.Errorf("%q %q", homeserver, ErrNoncompliantServerPart) } @@ -129,7 +121,7 @@ func (userID UserID) ParseAndValidateRelaxed() (localpart, homeserver string, er } func (userID UserID) ParseAndDecode() (localpart, homeserver string, err error) { - localpart, homeserver, err = userID.ParseAndValidateStrict() + localpart, homeserver, err = userID.ParseAndValidate() if err == nil { localpart, err = DecodeUserLocalpart(localpart) } @@ -219,15 +211,15 @@ func DecodeUserLocalpart(str string) (string, error) { for i := 0; i < len(strBytes); i++ { b := strBytes[i] if !isValidByte(b) { - return "", fmt.Errorf("invalid encoded byte at position %d: %c", i, b) + return "", fmt.Errorf("Byte pos %d: Invalid byte", i) } if b == '_' { // next byte is a-z and should be upper-case or is another _ and should be a literal _ if i+1 >= len(strBytes) { - return "", fmt.Errorf("unexpected end of string after underscore at %d", i) + return "", fmt.Errorf("Byte pos %d: expected _[a-z_] encoding but ran out of string", i) } if !isValidEscapedChar(strBytes[i+1]) { // invalid escaping - return "", fmt.Errorf("unexpected byte %c after underscore at %d", strBytes[i+1], i) + return "", fmt.Errorf("Byte pos %d: expected _[a-z_] encoding", i) } if strBytes[i+1] == '_' { outputBuffer.WriteByte('_') @@ -237,7 +229,7 @@ func DecodeUserLocalpart(str string) (string, error) { i++ // skip next byte since we just handled it } else if b == '=' { // next 2 bytes are hex and should be buffered ready to be read as utf8 if i+2 >= len(strBytes) { - return "", fmt.Errorf("unexpected end of string after equals sign at %d", i) + return "", fmt.Errorf("Byte pos: %d: expected quote-printable encoding but ran out of string", i) } dst := make([]byte, 1) _, err := hex.Decode(dst, strBytes[i+1:i+3]) diff --git a/id/userid_test.go b/id/userid_test.go index 57a88066..359bc687 100644 --- a/id/userid_test.go +++ b/id/userid_test.go @@ -38,30 +38,30 @@ func TestUserID_Parse_Invalid(t *testing.T) { assert.True(t, errors.Is(err, id.ErrInvalidUserID)) } -func TestUserID_ParseAndValidateStrict_Invalid(t *testing.T) { +func TestUserID_ParseAndValidate_Invalid(t *testing.T) { const inputUserID = "@s p a c e:maunium.net" - _, _, err := id.UserID(inputUserID).ParseAndValidateStrict() + _, _, err := id.UserID(inputUserID).ParseAndValidate() assert.Error(t, err) assert.True(t, errors.Is(err, id.ErrNoncompliantLocalpart)) } -func TestUserID_ParseAndValidateStrict_Empty(t *testing.T) { +func TestUserID_ParseAndValidate_Empty(t *testing.T) { const inputUserID = "@:ponies.im" - _, _, err := id.UserID(inputUserID).ParseAndValidateStrict() + _, _, err := id.UserID(inputUserID).ParseAndValidate() assert.Error(t, err) assert.True(t, errors.Is(err, id.ErrEmptyLocalpart)) } -func TestUserID_ParseAndValidateStrict_Long(t *testing.T) { +func TestUserID_ParseAndValidate_Long(t *testing.T) { const inputUserID = "@aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa:example.com" - _, _, err := id.UserID(inputUserID).ParseAndValidateStrict() + _, _, err := id.UserID(inputUserID).ParseAndValidate() assert.Error(t, err) assert.True(t, errors.Is(err, id.ErrUserIDTooLong)) } -func TestUserID_ParseAndValidateStrict_NotLong(t *testing.T) { +func TestUserID_ParseAndValidate_NotLong(t *testing.T) { const inputUserID = "@aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa:example.com" - _, _, err := id.UserID(inputUserID).ParseAndValidateStrict() + _, _, err := id.UserID(inputUserID).ParseAndValidate() assert.NoError(t, err) } @@ -70,7 +70,7 @@ func TestUserIDEncoding(t *testing.T) { const encodedLocalpart = "_this=20local+part=20contains=20_il_le_ga_l=20ch=c3=a4racters=20=f0=9f=9a=a8" const inputServerName = "example.com" userID := id.NewEncodedUserID(inputLocalpart, inputServerName) - parsedLocalpart, parsedServerName, err := userID.ParseAndValidateStrict() + parsedLocalpart, parsedServerName, err := userID.ParseAndValidate() assert.NoError(t, err) assert.Equal(t, encodedLocalpart, parsedLocalpart) assert.Equal(t, inputServerName, parsedServerName) diff --git a/mediaproxy/mediaproxy.go b/mediaproxy/mediaproxy.go index 4d2bc7cf..07e30810 100644 --- a/mediaproxy/mediaproxy.go +++ b/mediaproxy/mediaproxy.go @@ -95,13 +95,9 @@ func (d *GetMediaResponseCallback) GetContentType() string { return d.ContentType } -type FileMeta struct { - ContentType string - ReplacementFile string -} - type GetMediaResponseFile struct { - Callback func(w *os.File) (*FileMeta, error) + Callback func(w *os.File) error + ContentType string } type GetMediaFunc = func(ctx context.Context, mediaID string, params map[string]string) (response GetMediaResponse, err error) @@ -143,7 +139,6 @@ func New(serverName string, serverKey string, getMedia GetMediaFunc) (*MediaProx } mp.FederationRouter = http.NewServeMux() mp.FederationRouter.HandleFunc("GET /v1/media/download/{mediaID}", mp.DownloadMediaFederation) - mp.FederationRouter.HandleFunc("GET /v1/media/thumbnail/{mediaID}", mp.DownloadMediaFederation) mp.FederationRouter.HandleFunc("GET /v1/version", mp.KeyServer.GetServerVersion) mp.ClientMediaRouter = http.NewServeMux() mp.ClientMediaRouter.HandleFunc("GET /download/{serverName}/{mediaID}", mp.DownloadMedia) @@ -458,35 +453,23 @@ func doTempFileDownload( if err != nil { return false, fmt.Errorf("failed to create temp file: %w", err) } - origTempFile := tempFile defer func() { - _ = origTempFile.Close() - _ = os.Remove(origTempFile.Name()) + _ = tempFile.Close() + _ = os.Remove(tempFile.Name()) }() - meta, err := data.Callback(tempFile) + err = data.Callback(tempFile) if err != nil { return false, err } - if meta.ReplacementFile != "" { - tempFile, err = os.Open(meta.ReplacementFile) - if err != nil { - return false, fmt.Errorf("failed to open replacement file: %w", err) - } - defer func() { - _ = tempFile.Close() - _ = os.Remove(origTempFile.Name()) - }() - } else { - _, err = tempFile.Seek(0, io.SeekStart) - if err != nil { - return false, fmt.Errorf("failed to seek to start of temp file: %w", err) - } + _, err = tempFile.Seek(0, io.SeekStart) + if err != nil { + return false, fmt.Errorf("failed to seek to start of temp file: %w", err) } fileInfo, err := tempFile.Stat() if err != nil { return false, fmt.Errorf("failed to stat temp file: %w", err) } - mimeType := meta.ContentType + mimeType := data.ContentType if mimeType == "" { buf := make([]byte, 512) n, err := tempFile.Read(buf) diff --git a/mockserver/mockserver.go b/mockserver/mockserver.go deleted file mode 100644 index 507c24a5..00000000 --- a/mockserver/mockserver.go +++ /dev/null @@ -1,307 +0,0 @@ -// Copyright (c) 2025 Tulir Asokan -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. - -package mockserver - -import ( - "context" - "encoding/json" - "fmt" - "io" - "maps" - "net/http" - "net/http/httptest" - "strings" - "testing" - - globallog "github.com/rs/zerolog/log" // zerolog-allow-global-log - "github.com/stretchr/testify/require" - "go.mau.fi/util/dbutil" - "go.mau.fi/util/exerrors" - "go.mau.fi/util/exhttp" - "go.mau.fi/util/random" - - "maunium.net/go/mautrix" - "maunium.net/go/mautrix/crypto" - "maunium.net/go/mautrix/crypto/cryptohelper" - "maunium.net/go/mautrix/event" - "maunium.net/go/mautrix/id" -) - -func mustDecode(r *http.Request, data any) { - exerrors.PanicIfNotNil(json.NewDecoder(r.Body).Decode(data)) -} - -type userAndDeviceID struct { - UserID id.UserID - DeviceID id.DeviceID -} - -type MockServer struct { - Router *http.ServeMux - Server *httptest.Server - - AccessTokenToUserID map[string]userAndDeviceID - DeviceInbox map[id.UserID]map[id.DeviceID][]event.Event - AccountData map[id.UserID]map[event.Type]json.RawMessage - DeviceKeys map[id.UserID]map[id.DeviceID]mautrix.DeviceKeys - OneTimeKeys map[id.UserID]map[id.DeviceID]map[id.KeyID]mautrix.OneTimeKey - MasterKeys map[id.UserID]mautrix.CrossSigningKeys - SelfSigningKeys map[id.UserID]mautrix.CrossSigningKeys - UserSigningKeys map[id.UserID]mautrix.CrossSigningKeys - - PopOTKs bool - MemoryStore bool -} - -func Create(t testing.TB) *MockServer { - t.Helper() - - server := MockServer{ - AccessTokenToUserID: map[string]userAndDeviceID{}, - DeviceInbox: map[id.UserID]map[id.DeviceID][]event.Event{}, - AccountData: map[id.UserID]map[event.Type]json.RawMessage{}, - DeviceKeys: map[id.UserID]map[id.DeviceID]mautrix.DeviceKeys{}, - OneTimeKeys: map[id.UserID]map[id.DeviceID]map[id.KeyID]mautrix.OneTimeKey{}, - MasterKeys: map[id.UserID]mautrix.CrossSigningKeys{}, - SelfSigningKeys: map[id.UserID]mautrix.CrossSigningKeys{}, - UserSigningKeys: map[id.UserID]mautrix.CrossSigningKeys{}, - PopOTKs: true, - MemoryStore: true, - } - - router := http.NewServeMux() - router.HandleFunc("POST /_matrix/client/v3/login", server.postLogin) - router.HandleFunc("POST /_matrix/client/v3/keys/query", server.postKeysQuery) - router.HandleFunc("POST /_matrix/client/v3/keys/claim", server.postKeysClaim) - router.HandleFunc("PUT /_matrix/client/v3/sendToDevice/{type}/{txn}", server.putSendToDevice) - router.HandleFunc("PUT /_matrix/client/v3/user/{userID}/account_data/{type}", server.putAccountData) - router.HandleFunc("POST /_matrix/client/v3/keys/device_signing/upload", server.postDeviceSigningUpload) - router.HandleFunc("POST /_matrix/client/v3/keys/signatures/upload", server.emptyResp) - router.HandleFunc("POST /_matrix/client/v3/keys/upload", server.postKeysUpload) - server.Router = router - server.Server = httptest.NewServer(router) - t.Cleanup(server.Server.Close) - return &server -} - -func (ms *MockServer) getUserID(r *http.Request) userAndDeviceID { - authHeader := r.Header.Get("Authorization") - authHeader = strings.TrimPrefix(authHeader, "Bearer ") - userID, ok := ms.AccessTokenToUserID[authHeader] - if !ok { - panic("no user ID found for access token " + authHeader) - } - return userID -} - -func (ms *MockServer) emptyResp(w http.ResponseWriter, _ *http.Request) { - exhttp.WriteEmptyJSONResponse(w, http.StatusOK) -} - -func (ms *MockServer) postLogin(w http.ResponseWriter, r *http.Request) { - var loginReq mautrix.ReqLogin - mustDecode(r, &loginReq) - - deviceID := loginReq.DeviceID - if deviceID == "" { - deviceID = id.DeviceID(random.String(10)) - } - - accessToken := random.String(30) - userID := id.UserID(loginReq.Identifier.User) - ms.AccessTokenToUserID[accessToken] = userAndDeviceID{ - UserID: userID, - DeviceID: deviceID, - } - - exhttp.WriteJSONResponse(w, http.StatusOK, &mautrix.RespLogin{ - AccessToken: accessToken, - DeviceID: deviceID, - UserID: userID, - }) -} - -func (ms *MockServer) putSendToDevice(w http.ResponseWriter, r *http.Request) { - var req mautrix.ReqSendToDevice - mustDecode(r, &req) - evtType := event.Type{Type: r.PathValue("type"), Class: event.ToDeviceEventType} - - for user, devices := range req.Messages { - for device, content := range devices { - if _, ok := ms.DeviceInbox[user]; !ok { - ms.DeviceInbox[user] = map[id.DeviceID][]event.Event{} - } - content.ParseRaw(evtType) - ms.DeviceInbox[user][device] = append(ms.DeviceInbox[user][device], event.Event{ - Sender: ms.getUserID(r).UserID, - Type: evtType, - Content: *content, - }) - } - } - ms.emptyResp(w, r) -} - -func (ms *MockServer) putAccountData(w http.ResponseWriter, r *http.Request) { - userID := id.UserID(r.PathValue("userID")) - eventType := event.Type{Type: r.PathValue("type"), Class: event.AccountDataEventType} - - jsonData, _ := io.ReadAll(r.Body) - if _, ok := ms.AccountData[userID]; !ok { - ms.AccountData[userID] = map[event.Type]json.RawMessage{} - } - ms.AccountData[userID][eventType] = json.RawMessage(jsonData) - ms.emptyResp(w, r) -} - -func (ms *MockServer) postKeysQuery(w http.ResponseWriter, r *http.Request) { - var req mautrix.ReqQueryKeys - mustDecode(r, &req) - resp := mautrix.RespQueryKeys{ - MasterKeys: map[id.UserID]mautrix.CrossSigningKeys{}, - UserSigningKeys: map[id.UserID]mautrix.CrossSigningKeys{}, - SelfSigningKeys: map[id.UserID]mautrix.CrossSigningKeys{}, - DeviceKeys: map[id.UserID]map[id.DeviceID]mautrix.DeviceKeys{}, - } - for user := range req.DeviceKeys { - resp.MasterKeys[user] = ms.MasterKeys[user] - resp.UserSigningKeys[user] = ms.UserSigningKeys[user] - resp.SelfSigningKeys[user] = ms.SelfSigningKeys[user] - resp.DeviceKeys[user] = ms.DeviceKeys[user] - } - exhttp.WriteJSONResponse(w, http.StatusOK, &resp) -} - -func (ms *MockServer) postKeysClaim(w http.ResponseWriter, r *http.Request) { - var req mautrix.ReqClaimKeys - mustDecode(r, &req) - resp := mautrix.RespClaimKeys{ - OneTimeKeys: map[id.UserID]map[id.DeviceID]map[id.KeyID]mautrix.OneTimeKey{}, - } - for user, devices := range req.OneTimeKeys { - resp.OneTimeKeys[user] = map[id.DeviceID]map[id.KeyID]mautrix.OneTimeKey{} - for device := range devices { - keys := ms.OneTimeKeys[user][device] - for keyID, key := range keys { - if ms.PopOTKs { - delete(keys, keyID) - } - resp.OneTimeKeys[user][device] = map[id.KeyID]mautrix.OneTimeKey{ - keyID: key, - } - break - } - } - } - exhttp.WriteJSONResponse(w, http.StatusOK, &resp) -} - -func (ms *MockServer) postKeysUpload(w http.ResponseWriter, r *http.Request) { - var req mautrix.ReqUploadKeys - mustDecode(r, &req) - - uid := ms.getUserID(r) - userID := uid.UserID - if _, ok := ms.DeviceKeys[userID]; !ok { - ms.DeviceKeys[userID] = map[id.DeviceID]mautrix.DeviceKeys{} - } - if _, ok := ms.OneTimeKeys[userID]; !ok { - ms.OneTimeKeys[userID] = map[id.DeviceID]map[id.KeyID]mautrix.OneTimeKey{} - } - - if req.DeviceKeys != nil { - ms.DeviceKeys[userID][uid.DeviceID] = *req.DeviceKeys - } - otks, ok := ms.OneTimeKeys[userID][uid.DeviceID] - if !ok { - otks = map[id.KeyID]mautrix.OneTimeKey{} - ms.OneTimeKeys[userID][uid.DeviceID] = otks - } - if req.OneTimeKeys != nil { - maps.Copy(otks, req.OneTimeKeys) - } - - exhttp.WriteJSONResponse(w, http.StatusOK, &mautrix.RespUploadKeys{ - OneTimeKeyCounts: mautrix.OTKCount{SignedCurve25519: len(otks)}, - }) -} - -func (ms *MockServer) postDeviceSigningUpload(w http.ResponseWriter, r *http.Request) { - var req mautrix.UploadCrossSigningKeysReq[any] - mustDecode(r, &req) - - userID := ms.getUserID(r).UserID - ms.MasterKeys[userID] = req.Master - ms.SelfSigningKeys[userID] = req.SelfSigning - ms.UserSigningKeys[userID] = req.UserSigning - - ms.emptyResp(w, r) -} - -func (ms *MockServer) Login(t testing.TB, ctx context.Context, userID id.UserID, deviceID id.DeviceID) (*mautrix.Client, crypto.Store) { - t.Helper() - if ctx == nil { - ctx = context.TODO() - } - client, err := mautrix.NewClient(ms.Server.URL, "", "") - require.NoError(t, err) - client.Client = ms.Server.Client() - - _, err = client.Login(ctx, &mautrix.ReqLogin{ - Type: mautrix.AuthTypePassword, - Identifier: mautrix.UserIdentifier{ - Type: mautrix.IdentifierTypeUser, - User: userID.String(), - }, - DeviceID: deviceID, - Password: "password", - StoreCredentials: true, - }) - require.NoError(t, err) - - var store any - if ms.MemoryStore { - store = crypto.NewMemoryStore(nil) - client.StateStore = mautrix.NewMemoryStateStore() - } else { - store, err = dbutil.NewFromConfig("", dbutil.Config{ - PoolConfig: dbutil.PoolConfig{ - Type: "sqlite3-fk-wal", - URI: fmt.Sprintf("file:%s?mode=memory&cache=shared&_txlock=immediate", random.String(10)), - MaxOpenConns: 5, - MaxIdleConns: 1, - }, - }, nil) - require.NoError(t, err) - } - cryptoHelper, err := cryptohelper.NewCryptoHelper(client, []byte("test"), store) - require.NoError(t, err) - client.Crypto = cryptoHelper - - err = cryptoHelper.Init(ctx) - require.NoError(t, err) - - machineLog := globallog.Logger.With(). - Stringer("my_user_id", userID). - Stringer("my_device_id", deviceID). - Logger() - cryptoHelper.Machine().Log = &machineLog - - err = cryptoHelper.Machine().ShareKeys(ctx, 50) - require.NoError(t, err) - - return client, cryptoHelper.Machine().CryptoStore -} - -func (ms *MockServer) DispatchToDevice(t testing.TB, ctx context.Context, client *mautrix.Client) { - t.Helper() - - for _, evt := range ms.DeviceInbox[client.UserID][client.DeviceID] { - client.Syncer.(*mautrix.DefaultSyncer).Dispatch(ctx, &evt) - ms.DeviceInbox[client.UserID][client.DeviceID] = ms.DeviceInbox[client.UserID][client.DeviceID][1:] - } -} diff --git a/pushrules/action.go b/pushrules/action.go index b5a884b2..9838e88b 100644 --- a/pushrules/action.go +++ b/pushrules/action.go @@ -105,7 +105,7 @@ func (action *PushAction) UnmarshalJSON(raw []byte) error { if ok { action.Action = ActionSetTweak action.Tweak = PushActionTweak(tweak) - action.Value = val["value"] + action.Value, _ = val["value"] } } return nil diff --git a/pushrules/action_test.go b/pushrules/action_test.go index 3c0aa168..a8f68415 100644 --- a/pushrules/action_test.go +++ b/pushrules/action_test.go @@ -139,9 +139,9 @@ func TestPushAction_UnmarshalJSON_InvalidTypeDoesNothing(t *testing.T) { } err := pa.UnmarshalJSON([]byte(`{"foo": "bar"}`)) - assert.NoError(t, err) + assert.Nil(t, err) err = pa.UnmarshalJSON([]byte(`9001`)) - assert.NoError(t, err) + assert.Nil(t, err) assert.Equal(t, pushrules.PushActionType("unchanged"), pa.Action) assert.Equal(t, pushrules.PushActionTweak("unchanged"), pa.Tweak) @@ -156,7 +156,7 @@ func TestPushAction_UnmarshalJSON_StringChangesActionType(t *testing.T) { } err := pa.UnmarshalJSON([]byte(`"foo"`)) - assert.NoError(t, err) + assert.Nil(t, err) assert.Equal(t, pushrules.PushActionType("foo"), pa.Action) assert.Equal(t, pushrules.PushActionTweak("unchanged"), pa.Tweak) @@ -171,7 +171,7 @@ func TestPushAction_UnmarshalJSON_SetTweakChangesTweak(t *testing.T) { } err := pa.UnmarshalJSON([]byte(`{"set_tweak": "foo", "value": 123.0}`)) - assert.NoError(t, err) + assert.Nil(t, err) assert.Equal(t, pushrules.ActionSetTweak, pa.Action) assert.Equal(t, pushrules.PushActionTweak("foo"), pa.Tweak) @@ -185,7 +185,7 @@ func TestPushAction_MarshalJSON_TweakOutputWorks(t *testing.T) { Value: "bar", } data, err := pa.MarshalJSON() - assert.NoError(t, err) + assert.Nil(t, err) assert.Equal(t, []byte(`{"set_tweak":"foo","value":"bar"}`), data) } @@ -196,6 +196,6 @@ func TestPushAction_MarshalJSON_OtherOutputWorks(t *testing.T) { Value: "bar", } data, err := pa.MarshalJSON() - assert.NoError(t, err) + assert.Nil(t, err) assert.Equal(t, []byte(`"something else"`), data) } diff --git a/pushrules/condition_test.go b/pushrules/condition_test.go index 37af3e34..0d3eaf7a 100644 --- a/pushrules/condition_test.go +++ b/pushrules/condition_test.go @@ -102,6 +102,14 @@ func newEventPropertyIsPushCondition(key string, value any) *pushrules.PushCondi } } +func newEventPropertyContainsPushCondition(key string, value any) *pushrules.PushCondition { + return &pushrules.PushCondition{ + Kind: pushrules.KindEventPropertyContains, + Key: key, + Value: value, + } +} + func TestPushCondition_Match_InvalidKind(t *testing.T) { condition := &pushrules.PushCondition{ Kind: pushrules.PushCondKind("invalid"), diff --git a/pushrules/pushrules_test.go b/pushrules/pushrules_test.go index a5a0f5e7..a531ca28 100644 --- a/pushrules/pushrules_test.go +++ b/pushrules/pushrules_test.go @@ -25,7 +25,7 @@ func TestEventToPushRules(t *testing.T) { }, } pushRuleset, err := pushrules.EventToPushRules(evt) - assert.NoError(t, err) + assert.Nil(t, err) assert.NotNil(t, pushRuleset) assert.IsType(t, pushRuleset.Override, pushrules.PushRuleArray{}) diff --git a/requests.go b/requests.go index cc8b7266..8f31e52f 100644 --- a/requests.go +++ b/requests.go @@ -66,14 +66,14 @@ const ( ) // ReqRegister is the JSON request for https://spec.matrix.org/v1.2/client-server-api/#post_matrixclientv3register -type ReqRegister[UIAType any] struct { +type ReqRegister struct { Username string `json:"username,omitempty"` Password string `json:"password,omitempty"` DeviceID id.DeviceID `json:"device_id,omitempty"` InitialDeviceDisplayName string `json:"initial_device_display_name,omitempty"` InhibitLogin bool `json:"inhibit_login,omitempty"` RefreshToken bool `json:"refresh_token,omitempty"` - Auth UIAType `json:"auth,omitempty"` + Auth interface{} `json:"auth,omitempty"` // Type for registration, only used for appservice user registrations // https://spec.matrix.org/v1.2/application-service-api/#server-admin-style-permissions @@ -183,11 +183,6 @@ type ReqKnockRoom struct { Reason string `json:"reason,omitempty"` } -type ReqSearchUserDirectory struct { - SearchTerm string `json:"search_term"` - Limit int `json:"limit,omitempty"` -} - type ReqMutualRooms struct { From string `json:"-"` } @@ -320,11 +315,11 @@ func (csk *CrossSigningKeys) FirstKey() id.Ed25519 { return "" } -type UploadCrossSigningKeysReq[UIAType any] struct { +type UploadCrossSigningKeysReq struct { Master CrossSigningKeys `json:"master_key"` SelfSigning CrossSigningKeys `json:"self_signing_key"` UserSigning CrossSigningKeys `json:"user_signing_key"` - Auth UIAType `json:"auth,omitempty"` + Auth interface{} `json:"auth,omitempty"` } type KeyMap map[id.DeviceKeyID]string @@ -367,23 +362,18 @@ type ReqSendToDevice struct { } type ReqSendEvent struct { - Timestamp int64 - TransactionID string - UnstableDelay time.Duration - UnstableStickyDuration time.Duration - DontEncrypt bool - MeowEventID id.EventID -} + Timestamp int64 + TransactionID string + UnstableDelay time.Duration -type ReqDelayedEvents struct { - DelayID id.DelayID `json:"-"` - Status event.DelayStatus `json:"-"` - NextBatch string `json:"-"` + DontEncrypt bool + + MeowEventID id.EventID } type ReqUpdateDelayedEvent struct { - DelayID id.DelayID `json:"-"` - Action event.DelayAction `json:"action"` + DelayID string `json:"-"` + Action string `json:"action"` // TODO use enum } // ReqDeviceInfo is the JSON request for https://spec.matrix.org/v1.2/client-server-api/#put_matrixclientv3devicesdeviceid @@ -392,14 +382,14 @@ type ReqDeviceInfo struct { } // ReqDeleteDevice is the JSON request for https://spec.matrix.org/v1.2/client-server-api/#delete_matrixclientv3devicesdeviceid -type ReqDeleteDevice[UIAType any] struct { - Auth UIAType `json:"auth,omitempty"` +type ReqDeleteDevice struct { + Auth interface{} `json:"auth,omitempty"` } // ReqDeleteDevices is the JSON request for https://spec.matrix.org/v1.2/client-server-api/#post_matrixclientv3delete_devices -type ReqDeleteDevices[UIAType any] struct { +type ReqDeleteDevices struct { Devices []id.DeviceID `json:"devices"` - Auth UIAType `json:"auth,omitempty"` + Auth interface{} `json:"auth,omitempty"` } type ReqPutPushRule struct { @@ -411,6 +401,18 @@ type ReqPutPushRule struct { Pattern string `json:"pattern"` } +// Deprecated: MSC2716 was abandoned +type ReqBatchSend struct { + PrevEventID id.EventID `json:"-"` + BatchID id.BatchID `json:"-"` + + BeeperNewMessages bool `json:"-"` + BeeperMarkReadBy id.UserID `json:"-"` + + StateEventsAtStart []*event.Event `json:"state_events_at_start"` + Events []*event.Event `json:"events"` +} + type ReqBeeperBatchSend struct { // ForwardIfNoMessages should be set to true if the batch should be forward // backfilled if there are no messages currently in the room. @@ -606,13 +608,3 @@ func (rgr *ReqGetRelations) Query() map[string]string { } return query } - -// ReqSuspend is the request body for https://github.com/matrix-org/matrix-spec-proposals/pull/4323 -type ReqSuspend struct { - Suspended bool `json:"suspended"` -} - -// ReqLocked is the request body for https://github.com/matrix-org/matrix-spec-proposals/pull/4323 -type ReqLocked struct { - Locked bool `json:"locked"` -} diff --git a/responses.go b/responses.go index 4fbe1fbc..27d96ffe 100644 --- a/responses.go +++ b/responses.go @@ -6,14 +6,12 @@ import ( "fmt" "maps" "reflect" - "slices" "strconv" "strings" "github.com/tidwall/gjson" "github.com/tidwall/sjson" "go.mau.fi/util/jsontime" - "go.mau.fi/util/ptr" "maunium.net/go/mautrix/event" "maunium.net/go/mautrix/id" @@ -106,22 +104,11 @@ type RespContext struct { type RespSendEvent struct { EventID id.EventID `json:"event_id"` - UnstableDelayID id.DelayID `json:"delay_id,omitempty"` + UnstableDelayID string `json:"delay_id,omitempty"` } type RespUpdateDelayedEvent struct{} -type RespDelayedEvents struct { - Scheduled []*event.ScheduledDelayedEvent `json:"scheduled,omitempty"` - Finalised []*event.FinalisedDelayedEvent `json:"finalised,omitempty"` - NextBatch string `json:"next_batch,omitempty"` - - // Deprecated: Synapse implementation still returns this - DelayedEvents []*event.ScheduledDelayedEvent `json:"delayed_events,omitempty"` - // Deprecated: Synapse implementation still returns this - FinalisedEvents []*event.FinalisedDelayedEvent `json:"finalised_events,omitempty"` -} - type RespRedactUserEvents struct { IsMoreEvents bool `json:"is_more_events"` RedactedEvents struct { @@ -223,48 +210,21 @@ func (r *RespUserProfile) MarshalJSON() ([]byte, error) { } else { delete(marshalMap, "avatar_url") } - return json.Marshal(marshalMap) -} - -type RespSearchUserDirectory struct { - Limited bool `json:"limited"` - Results []*UserDirectoryEntry `json:"results"` -} - -type UserDirectoryEntry struct { - RespUserProfile - UserID id.UserID `json:"user_id"` -} - -func (r *UserDirectoryEntry) UnmarshalJSON(data []byte) error { - err := r.RespUserProfile.UnmarshalJSON(data) - if err != nil { - return err - } - userIDStr, _ := r.Extra["user_id"].(string) - r.UserID = id.UserID(userIDStr) - delete(r.Extra, "user_id") - return nil -} - -func (r *UserDirectoryEntry) MarshalJSON() ([]byte, error) { - if r.Extra == nil { - r.Extra = make(map[string]any) - } - r.Extra["user_id"] = r.UserID.String() - return r.RespUserProfile.MarshalJSON() + return json.Marshal(r.Extra) } type RespMutualRooms struct { Joined []id.RoomID `json:"joined"` NextBatch string `json:"next_batch,omitempty"` - Count int `json:"count,omitempty"` } type RespRoomSummary struct { PublicRoomInfo - Membership event.Membership `json:"membership,omitempty"` + Membership event.Membership `json:"membership,omitempty"` + RoomVersion id.RoomVersion `json:"room_version,omitempty"` + Encryption id.Algorithm `json:"encryption,omitempty"` + AllowedRoomIDs []id.RoomID `json:"allowed_room_ids,omitempty"` UnstableRoomVersion id.RoomVersion `json:"im.nheko.summary.room_version,omitempty"` UnstableRoomVersionOld id.RoomVersion `json:"im.nheko.summary.version,omitempty"` @@ -342,24 +302,6 @@ type LazyLoadSummary struct { InvitedMemberCount *int `json:"m.invited_member_count,omitempty"` } -func (lls *LazyLoadSummary) MemberCount() int { - if lls == nil { - return 0 - } - return ptr.Val(lls.JoinedMemberCount) + ptr.Val(lls.InvitedMemberCount) -} - -func (lls *LazyLoadSummary) Equal(other *LazyLoadSummary) bool { - if lls == other { - return true - } else if lls == nil || other == nil { - return false - } - return ptr.Val(lls.JoinedMemberCount) == ptr.Val(other.JoinedMemberCount) && - ptr.Val(lls.InvitedMemberCount) == ptr.Val(other.InvitedMemberCount) && - slices.Equal(lls.Heroes, other.Heroes) -} - type SyncEventsList struct { Events []*event.Event `json:"events,omitempty"` } @@ -455,7 +397,7 @@ type BeeperInboxPreviewEvent struct { type SyncJoinedRoom struct { Summary LazyLoadSummary `json:"summary"` State SyncEventsList `json:"state"` - StateAfter *SyncEventsList `json:"state_after,omitempty"` + StateAfter *SyncEventsList `json:"org.matrix.msc4222.state_after,omitempty"` Timeline SyncTimeline `json:"timeline"` Ephemeral SyncEventsList `json:"ephemeral"` AccountData SyncEventsList `json:"account_data"` @@ -546,19 +488,30 @@ type RespDeviceInfo struct { LastSeenTS int64 `json:"last_seen_ts"` } +// Deprecated: MSC2716 was abandoned +type RespBatchSend struct { + StateEventIDs []id.EventID `json:"state_event_ids"` + EventIDs []id.EventID `json:"event_ids"` + + InsertionEventID id.EventID `json:"insertion_event_id"` + BatchEventID id.EventID `json:"batch_event_id"` + BaseInsertionEventID id.EventID `json:"base_insertion_event_id"` + + NextBatchID id.BatchID `json:"next_batch_id"` +} + type RespBeeperBatchSend struct { EventIDs []id.EventID `json:"event_ids"` } // RespCapabilities is the JSON response for https://spec.matrix.org/v1.3/client-server-api/#get_matrixclientv3capabilities type RespCapabilities struct { - RoomVersions *CapRoomVersions `json:"m.room_versions,omitempty"` - ChangePassword *CapBooleanTrue `json:"m.change_password,omitempty"` - SetDisplayname *CapBooleanTrue `json:"m.set_displayname,omitempty"` - SetAvatarURL *CapBooleanTrue `json:"m.set_avatar_url,omitempty"` - ThreePIDChanges *CapBooleanTrue `json:"m.3pid_changes,omitempty"` - GetLoginToken *CapBooleanTrue `json:"m.get_login_token,omitempty"` - UnstableAccountModeration *CapUnstableAccountModeration `json:"uk.timedout.msc4323,omitempty"` + RoomVersions *CapRoomVersions `json:"m.room_versions,omitempty"` + ChangePassword *CapBooleanTrue `json:"m.change_password,omitempty"` + SetDisplayname *CapBooleanTrue `json:"m.set_displayname,omitempty"` + SetAvatarURL *CapBooleanTrue `json:"m.set_avatar_url,omitempty"` + ThreePIDChanges *CapBooleanTrue `json:"m.3pid_changes,omitempty"` + GetLoginToken *CapBooleanTrue `json:"m.get_login_token,omitempty"` Custom map[string]interface{} `json:"-"` } @@ -667,11 +620,6 @@ func (vers *CapRoomVersions) IsAvailable(version string) bool { return available } -type CapUnstableAccountModeration struct { - Suspend bool `json:"suspend"` - Lock bool `json:"lock"` -} - type RespPublicRooms struct { Chunk []*PublicRoomInfo `json:"chunk"` NextBatch string `json:"next_batch,omitempty"` @@ -690,10 +638,6 @@ type PublicRoomInfo struct { RoomType event.RoomType `json:"room_type"` Topic string `json:"topic,omitempty"` WorldReadable bool `json:"world_readable"` - - RoomVersion id.RoomVersion `json:"room_version,omitempty"` - Encryption id.Algorithm `json:"encryption,omitempty"` - AllowedRoomIDs []id.RoomID `json:"allowed_room_ids,omitempty"` } // RespHierarchy is the JSON response for https://spec.matrix.org/v1.4/client-server-api/#get_matrixclientv1roomsroomidhierarchy @@ -767,33 +711,3 @@ type RespGetRelations struct { PrevBatch string `json:"prev_batch,omitempty"` RecursionDepth int `json:"recursion_depth,omitempty"` } - -// RespSuspended is the response body for https://github.com/matrix-org/matrix-spec-proposals/pull/4323 -type RespSuspended struct { - Suspended bool `json:"suspended"` -} - -// RespLocked is the response body for https://github.com/matrix-org/matrix-spec-proposals/pull/4323 -type RespLocked struct { - Locked bool `json:"locked"` -} - -type ConnectionInfo struct { - IP string `json:"ip,omitempty"` - LastSeen jsontime.UnixMilli `json:"last_seen,omitempty"` - UserAgent string `json:"user_agent,omitempty"` -} - -type SessionInfo struct { - Connections []ConnectionInfo `json:"connections,omitempty"` -} - -type DeviceInfo struct { - Sessions []SessionInfo `json:"sessions,omitempty"` -} - -// RespWhoIs is the response body for https://spec.matrix.org/v1.15/client-server-api/#get_matrixclientv3adminwhoisuserid -type RespWhoIs struct { - UserID id.UserID `json:"user_id,omitempty"` - Devices map[id.DeviceID]DeviceInfo `json:"devices,omitempty"` -} diff --git a/responses_test.go b/responses_test.go index 73d82635..b23d85ad 100644 --- a/responses_test.go +++ b/responses_test.go @@ -8,6 +8,7 @@ package mautrix_test import ( "encoding/json" + "fmt" "testing" "github.com/stretchr/testify/assert" @@ -85,6 +86,7 @@ func TestRespCapabilities_UnmarshalJSON(t *testing.T) { var caps mautrix.RespCapabilities err := json.Unmarshal([]byte(sampleData), &caps) require.NoError(t, err) + fmt.Println(caps) require.NotNil(t, caps.RoomVersions) assert.Equal(t, "9", caps.RoomVersions.Default) diff --git a/room.go b/room.go index 4292bff5..c3ddb7e6 100644 --- a/room.go +++ b/room.go @@ -5,6 +5,8 @@ import ( "maunium.net/go/mautrix/id" ) +type RoomStateMap = map[event.Type]map[string]*event.Event + // Room represents a single Matrix room. type Room struct { ID id.RoomID @@ -23,8 +25,8 @@ func (room Room) UpdateState(evt *event.Event) { // GetStateEvent returns the state event for the given type/state_key combo, or nil. func (room Room) GetStateEvent(eventType event.Type, stateKey string) *event.Event { - stateEventMap := room.State[eventType] - evt := stateEventMap[stateKey] + stateEventMap, _ := room.State[eventType] + evt, _ := stateEventMap[stateKey] return evt } diff --git a/sqlstatestore/statestore.go b/sqlstatestore/statestore.go index 11957dfa..0ed4b698 100644 --- a/sqlstatestore/statestore.go +++ b/sqlstatestore/statestore.go @@ -370,7 +370,7 @@ func (store *SQLStateStore) SetEncryptionEvent(ctx context.Context, roomID id.Ro func (store *SQLStateStore) GetEncryptionEvent(ctx context.Context, roomID id.RoomID) (*event.EncryptionEventContent, error) { var data []byte err := store. - QueryRow(ctx, "SELECT encryption FROM mx_room_state WHERE room_id=$1 AND encryption IS NOT NULL", roomID). + QueryRow(ctx, "SELECT encryption FROM mx_room_state WHERE room_id=$1", roomID). Scan(&data) if errors.Is(err, sql.ErrNoRows) { return nil, nil @@ -406,7 +406,7 @@ func (store *SQLStateStore) SetPowerLevels(ctx context.Context, roomID id.RoomID func (store *SQLStateStore) GetPowerLevels(ctx context.Context, roomID id.RoomID) (levels *event.PowerLevelsEventContent, err error) { levels = &event.PowerLevelsEventContent{} err = store. - QueryRow(ctx, "SELECT power_levels, create_event FROM mx_room_state WHERE room_id=$1 AND power_levels IS NOT NULL", roomID). + QueryRow(ctx, "SELECT power_levels, create_event FROM mx_room_state WHERE room_id=$1", roomID). Scan(&dbutil.JSON{Data: &levels}, &dbutil.JSON{Data: &levels.CreateEvent}) if errors.Is(err, sql.ErrNoRows) { return nil, nil @@ -458,7 +458,7 @@ func (store *SQLStateStore) SetCreate(ctx context.Context, evt *event.Event) err func (store *SQLStateStore) GetCreate(ctx context.Context, roomID id.RoomID) (evt *event.Event, err error) { err = store. - QueryRow(ctx, "SELECT create_event FROM mx_room_state WHERE room_id=$1 AND create_event IS NOT NULL", roomID). + QueryRow(ctx, "SELECT create_event FROM mx_room_state WHERE room_id=$1", roomID). Scan(&dbutil.JSON{Data: &evt}) if errors.Is(err, sql.ErrNoRows) { return nil, nil @@ -470,26 +470,3 @@ func (store *SQLStateStore) GetCreate(ctx context.Context, roomID id.RoomID) (ev } return } - -func (store *SQLStateStore) SetJoinRules(ctx context.Context, roomID id.RoomID, rules *event.JoinRulesEventContent) error { - if roomID == "" { - return fmt.Errorf("room ID is empty") - } - _, err := store.Exec(ctx, ` - INSERT INTO mx_room_state (room_id, join_rules) VALUES ($1, $2) - ON CONFLICT (room_id) DO UPDATE SET join_rules=excluded.join_rules - `, roomID, dbutil.JSON{Data: rules}) - return err -} - -func (store *SQLStateStore) GetJoinRules(ctx context.Context, roomID id.RoomID) (levels *event.JoinRulesEventContent, err error) { - levels = &event.JoinRulesEventContent{} - err = store. - QueryRow(ctx, "SELECT join_rules FROM mx_room_state WHERE room_id=$1 AND join_rules IS NOT NULL", roomID). - Scan(&dbutil.JSON{Data: &levels}) - if errors.Is(err, sql.ErrNoRows) { - levels = nil - err = nil - } - return -} diff --git a/sqlstatestore/v00-latest-revision.sql b/sqlstatestore/v00-latest-revision.sql index 4679f1c6..b5a858ec 100644 --- a/sqlstatestore/v00-latest-revision.sql +++ b/sqlstatestore/v00-latest-revision.sql @@ -1,4 +1,4 @@ --- v0 -> v10 (compatible with v3+): Latest revision +-- v0 -> v9 (compatible with v3+): Latest revision CREATE TABLE mx_registrations ( user_id TEXT PRIMARY KEY @@ -27,6 +27,5 @@ CREATE TABLE mx_room_state ( power_levels jsonb, encryption jsonb, create_event jsonb, - join_rules jsonb, members_fetched BOOLEAN NOT NULL DEFAULT false ); diff --git a/sqlstatestore/v10-join-rules.sql b/sqlstatestore/v10-join-rules.sql deleted file mode 100644 index 3074c46a..00000000 --- a/sqlstatestore/v10-join-rules.sql +++ /dev/null @@ -1,2 +0,0 @@ --- v10 (compatible with v3+): Add join rules to room state table -ALTER TABLE mx_room_state ADD COLUMN join_rules jsonb; diff --git a/statestore.go b/statestore.go index 2bd498dd..1933ab95 100644 --- a/statestore.go +++ b/statestore.go @@ -37,9 +37,6 @@ type StateStore interface { SetCreate(ctx context.Context, evt *event.Event) error GetCreate(ctx context.Context, roomID id.RoomID) (*event.Event, error) - GetJoinRules(ctx context.Context, roomID id.RoomID) (*event.JoinRulesEventContent, error) - SetJoinRules(ctx context.Context, roomID id.RoomID, content *event.JoinRulesEventContent) error - HasFetchedMembers(ctx context.Context, roomID id.RoomID) (bool, error) MarkMembersFetched(ctx context.Context, roomID id.RoomID) error GetAllMembers(ctx context.Context, roomID id.RoomID) (map[id.UserID]*event.MemberEventContent, error) @@ -76,8 +73,6 @@ func UpdateStateStore(ctx context.Context, store StateStore, evt *event.Event) { err = store.SetEncryptionEvent(ctx, evt.RoomID, content) case *event.CreateEventContent: err = store.SetCreate(ctx, evt) - case *event.JoinRulesEventContent: - err = store.SetJoinRules(ctx, evt.RoomID, content) default: switch evt.Type { case event.StateMember, event.StatePowerLevels, event.StateEncryption, event.StateCreate: @@ -112,13 +107,11 @@ type MemoryStateStore struct { PowerLevels map[id.RoomID]*event.PowerLevelsEventContent `json:"power_levels"` Encryption map[id.RoomID]*event.EncryptionEventContent `json:"encryption"` Create map[id.RoomID]*event.Event `json:"create"` - JoinRules map[id.RoomID]*event.JoinRulesEventContent `json:"join_rules"` registrationsLock sync.RWMutex membersLock sync.RWMutex powerLevelsLock sync.RWMutex encryptionLock sync.RWMutex - joinRulesLock sync.RWMutex } func NewMemoryStateStore() StateStore { @@ -129,7 +122,6 @@ func NewMemoryStateStore() StateStore { PowerLevels: make(map[id.RoomID]*event.PowerLevelsEventContent), Encryption: make(map[id.RoomID]*event.EncryptionEventContent), Create: make(map[id.RoomID]*event.Event), - JoinRules: make(map[id.RoomID]*event.JoinRulesEventContent), } } @@ -362,19 +354,6 @@ func (store *MemoryStateStore) GetEncryptionEvent(_ context.Context, roomID id.R return store.Encryption[roomID], nil } -func (store *MemoryStateStore) SetJoinRules(ctx context.Context, roomID id.RoomID, content *event.JoinRulesEventContent) error { - store.joinRulesLock.Lock() - store.JoinRules[roomID] = content - store.joinRulesLock.Unlock() - return nil -} - -func (store *MemoryStateStore) GetJoinRules(ctx context.Context, roomID id.RoomID) (*event.JoinRulesEventContent, error) { - store.joinRulesLock.RLock() - defer store.joinRulesLock.RUnlock() - return store.JoinRules[roomID], nil -} - func (store *MemoryStateStore) IsEncrypted(ctx context.Context, roomID id.RoomID) (bool, error) { cfg, err := store.GetEncryptionEvent(ctx, roomID) return cfg != nil && cfg.Algorithm == id.AlgorithmMegolmV1, err diff --git a/synapseadmin/roomapi.go b/synapseadmin/roomapi.go index 0925b748..a09ba174 100644 --- a/synapseadmin/roomapi.go +++ b/synapseadmin/roomapi.go @@ -75,7 +75,8 @@ type RespListRooms struct { // https://matrix-org.github.io/synapse/latest/admin_api/rooms.html#list-room-api func (cli *Client) ListRooms(ctx context.Context, req ReqListRoom) (RespListRooms, error) { var resp RespListRooms - reqURL := cli.Client.BuildURLWithQuery(mautrix.SynapseAdminURLPath{"v1", "rooms"}, req.BuildQuery()) + var reqURL string + reqURL = cli.Client.BuildURLWithQuery(mautrix.SynapseAdminURLPath{"v1", "rooms"}, req.BuildQuery()) _, err := cli.Client.MakeRequest(ctx, http.MethodGet, reqURL, nil, &resp) return resp, err } @@ -116,7 +117,6 @@ func (cli *Client) RoomMessages(ctx context.Context, roomID id.RoomID, from, to type ReqDeleteRoom struct { Purge bool `json:"purge,omitempty"` - ForcePurge bool `json:"force_purge,omitempty"` Block bool `json:"block,omitempty"` Message string `json:"message,omitempty"` RoomName string `json:"room_name,omitempty"` diff --git a/sync.go b/sync.go index 598df8e0..c52bd2f9 100644 --- a/sync.go +++ b/sync.go @@ -90,7 +90,6 @@ func (s *DefaultSyncer) ProcessResponse(ctx context.Context, res *RespSync, sinc err = fmt.Errorf("ProcessResponse panicked! since=%s panic=%s\n%s", since, r, debug.Stack()) } }() - ctx = context.WithValue(ctx, SyncTokenContextKey, since) for _, listener := range s.syncListeners { if !listener(ctx, res, since) { diff --git a/url.go b/url.go index 91b3d49d..d888956a 100644 --- a/url.go +++ b/url.go @@ -98,8 +98,10 @@ func (saup SynapseAdminURLPath) FullPath() []any { // and appservice user ID set already. func (cli *Client) BuildURLWithQuery(urlPath PrefixableURLPath, urlQuery map[string]string) string { return cli.BuildURLWithFullQuery(urlPath, func(q url.Values) { - for k, v := range urlQuery { - q.Set(k, v) + if urlQuery != nil { + for k, v := range urlQuery { + q.Set(k, v) + } } }) } diff --git a/version.go b/version.go index f00bbf39..fd0d0a8d 100644 --- a/version.go +++ b/version.go @@ -4,11 +4,10 @@ import ( "fmt" "regexp" "runtime" - "runtime/debug" "strings" ) -const Version = "v0.26.3" +const Version = "v0.25.0" var GoModVersion = "" var Commit = "" @@ -16,20 +15,11 @@ var VersionWithCommit = Version var DefaultUserAgent = "mautrix-go/" + Version + " go/" + strings.TrimPrefix(runtime.Version(), "go") +var goModVersionRegex = regexp.MustCompile(`v.+\d{14}-([0-9a-f]{12})`) + func init() { - if GoModVersion == "" { - info, _ := debug.ReadBuildInfo() - if info != nil { - for _, mod := range info.Deps { - if mod.Path == "maunium.net/go/mautrix" { - GoModVersion = mod.Version - break - } - } - } - } if GoModVersion != "" { - match := regexp.MustCompile(`v.+\d{14}-([0-9a-f]{12})`).FindStringSubmatch(GoModVersion) + match := goModVersionRegex.FindStringSubmatch(GoModVersion) if match != nil { Commit = match[1] } diff --git a/versions.go b/versions.go index 61b2e4ea..f87bddda 100644 --- a/versions.go +++ b/versions.go @@ -60,28 +60,20 @@ type UnstableFeature struct { } var ( - FeatureAsyncUploads = UnstableFeature{UnstableFlag: "fi.mau.msc2246.stable", SpecVersion: SpecV17} - FeatureAppservicePing = UnstableFeature{UnstableFlag: "fi.mau.msc2659.stable", SpecVersion: SpecV17} - FeatureAuthenticatedMedia = UnstableFeature{UnstableFlag: "org.matrix.msc3916.stable", SpecVersion: SpecV111} - FeatureUnstableMutualRooms = UnstableFeature{UnstableFlag: "uk.half-shot.msc2666.query_mutual_rooms"} - FeatureStableMutualRooms = UnstableFeature{UnstableFlag: "uk.half-shot.msc2666.query_mutual_rooms.stable" /*, SpecVersion: SpecV118*/} - FeatureUserRedaction = UnstableFeature{UnstableFlag: "org.matrix.msc4194"} - FeatureViewRedactedContent = UnstableFeature{UnstableFlag: "fi.mau.msc2815"} - FeatureUnstableAccountModeration = UnstableFeature{UnstableFlag: "uk.timedout.msc4323"} - FeatureStableAccountModeration = UnstableFeature{UnstableFlag: "uk.timedout.msc4323.stable" /*, SpecVersion: SpecV118*/} - FeatureUnstableProfileFields = UnstableFeature{UnstableFlag: "uk.tcpip.msc4133"} - FeatureArbitraryProfileFields = UnstableFeature{UnstableFlag: "uk.tcpip.msc4133.stable", SpecVersion: SpecV116} - FeatureRedactSendAsEvent = UnstableFeature{UnstableFlag: "com.beeper.msc4169"} + FeatureAsyncUploads = UnstableFeature{UnstableFlag: "fi.mau.msc2246.stable", SpecVersion: SpecV17} + FeatureAppservicePing = UnstableFeature{UnstableFlag: "fi.mau.msc2659.stable", SpecVersion: SpecV17} + FeatureAuthenticatedMedia = UnstableFeature{UnstableFlag: "org.matrix.msc3916.stable", SpecVersion: SpecV111} + FeatureMutualRooms = UnstableFeature{UnstableFlag: "uk.half-shot.msc2666.query_mutual_rooms"} + FeatureUserRedaction = UnstableFeature{UnstableFlag: "org.matrix.msc4194"} + FeatureViewRedactedContent = UnstableFeature{UnstableFlag: "fi.mau.msc2815"} - BeeperFeatureHungry = UnstableFeature{UnstableFlag: "com.beeper.hungry"} - BeeperFeatureBatchSending = UnstableFeature{UnstableFlag: "com.beeper.batch_sending"} - BeeperFeatureRoomYeeting = UnstableFeature{UnstableFlag: "com.beeper.room_yeeting"} - BeeperFeatureAutojoinInvites = UnstableFeature{UnstableFlag: "com.beeper.room_create_autojoin_invites"} - BeeperFeatureArbitraryProfileMeta = UnstableFeature{UnstableFlag: "com.beeper.arbitrary_profile_meta"} - BeeperFeatureAccountDataMute = UnstableFeature{UnstableFlag: "com.beeper.account_data_mute"} - BeeperFeatureInboxState = UnstableFeature{UnstableFlag: "com.beeper.inbox_state"} - BeeperFeatureArbitraryMemberChange = UnstableFeature{UnstableFlag: "com.beeper.arbitrary_member_change"} - BeeperFeatureEphemeralEvents = UnstableFeature{UnstableFlag: "com.beeper.ephemeral"} + BeeperFeatureHungry = UnstableFeature{UnstableFlag: "com.beeper.hungry"} + BeeperFeatureBatchSending = UnstableFeature{UnstableFlag: "com.beeper.batch_sending"} + BeeperFeatureRoomYeeting = UnstableFeature{UnstableFlag: "com.beeper.room_yeeting"} + BeeperFeatureAutojoinInvites = UnstableFeature{UnstableFlag: "com.beeper.room_create_autojoin_invites"} + BeeperFeatureArbitraryProfileMeta = UnstableFeature{UnstableFlag: "com.beeper.arbitrary_profile_meta"} + BeeperFeatureAccountDataMute = UnstableFeature{UnstableFlag: "com.beeper.account_data_mute"} + BeeperFeatureInboxState = UnstableFeature{UnstableFlag: "com.beeper.inbox_state"} ) func (versions *RespVersions) Supports(feature UnstableFeature) bool { @@ -125,8 +117,6 @@ var ( SpecV113 = MustParseSpecVersion("v1.13") SpecV114 = MustParseSpecVersion("v1.14") SpecV115 = MustParseSpecVersion("v1.15") - SpecV116 = MustParseSpecVersion("v1.16") - SpecV117 = MustParseSpecVersion("v1.17") ) func (svf SpecVersionFormat) String() string {