diff --git a/.github/workflows/go.yml b/.github/workflows/go.yml index 10025368..c0add220 100644 --- a/.github/workflows/go.yml +++ b/.github/workflows/go.yml @@ -10,12 +10,12 @@ jobs: runs-on: ubuntu-latest name: Lint (latest) steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@v6 - name: Set up Go - uses: actions/setup-go@v5 + uses: actions/setup-go@v6 with: - go-version: "1.23" + go-version: "1.26" cache: true - name: Install libolm @@ -24,6 +24,7 @@ jobs: - name: Install goimports run: | go install golang.org/x/tools/cmd/goimports@latest + go install honnef.co/go/tools/cmd/staticcheck@latest export PATH="$HOME/go/bin:$PATH" - name: Run pre-commit @@ -34,14 +35,14 @@ jobs: strategy: fail-fast: false matrix: - go-version: ["1.22", "1.23"] - name: Build (${{ matrix.go-version == '1.23' && 'latest' || 'old' }}, libolm) + go-version: ["1.25", "1.26"] + name: Build (${{ matrix.go-version == '1.26' && 'latest' || 'old' }}, libolm) steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@v6 - name: Set up Go ${{ matrix.go-version }} - uses: actions/setup-go@v5 + uses: actions/setup-go@v6 with: go-version: ${{ matrix.go-version }} cache: true @@ -60,28 +61,28 @@ jobs: - name: Test run: go test -json -v ./... 2>&1 | gotestfmt + - name: Test (jsonv2) + env: + GOEXPERIMENT: jsonv2 + run: go test -json -v ./... 2>&1 | gotestfmt + build-goolm: runs-on: ubuntu-latest strategy: fail-fast: false matrix: - go-version: ["1.22", "1.23"] - name: Build (${{ matrix.go-version == '1.23' && 'latest' || 'old' }}, goolm) + go-version: ["1.25", "1.26"] + name: Build (${{ matrix.go-version == '1.26' && 'latest' || 'old' }}, goolm) steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@v6 - name: Set up Go ${{ matrix.go-version }} - uses: actions/setup-go@v5 + uses: actions/setup-go@v6 with: go-version: ${{ matrix.go-version }} cache: true - - name: Set up gotestfmt - uses: GoTestTools/gotestfmt-action@v2 - with: - token: ${{ secrets.GITHUB_TOKEN }} - - name: Build run: | rm -rf crypto/libolm diff --git a/.github/workflows/stale.yml b/.github/workflows/stale.yml index 578349c9..9a9e7375 100644 --- a/.github/workflows/stale.yml +++ b/.github/workflows/stale.yml @@ -17,7 +17,7 @@ jobs: lock-stale: runs-on: ubuntu-latest steps: - - uses: dessant/lock-threads@v5 + - uses: dessant/lock-threads@v6 id: lock with: issue-inactive-days: 90 diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 81701203..616fccb2 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,6 +1,6 @@ repos: - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v5.0.0 + rev: v6.0.0 hooks: - id: trailing-whitespace exclude_types: [markdown] @@ -9,7 +9,7 @@ repos: - id: check-added-large-files - repo: https://github.com/tekwizely/pre-commit-golang - rev: v1.0.0-rc.1 + rev: v1.0.0-rc.4 hooks: - id: go-imports-repo args: @@ -18,8 +18,7 @@ repos: - "-w" - id: go-vet-repo-mod - id: go-mod-tidy - # TODO enable this - #- id: go-staticcheck-repo-mod + - id: go-staticcheck-repo-mod - repo: https://github.com/beeper/pre-commit-go rev: v0.4.2 @@ -27,3 +26,4 @@ repos: - id: prevent-literal-http-methods - id: zerolog-ban-global-log - id: zerolog-ban-msgf + - id: zerolog-use-stringer diff --git a/CHANGELOG.md b/CHANGELOG.md index 0cc96b60..f2829199 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,420 @@ +## v0.26.3 (2026-02-16) + +* Bumped minimum Go version to 1.25. +* *(client)* Added fields for sending [MSC4354] sticky events. +* *(bridgev2)* Added automatic message request accepting when sending message. +* *(mediaproxy)* Added support for federation thumbnail endpoint. +* *(crypto/ssss)* Improved support for recovery keys with slightly broken + metadata. +* *(crypto)* Changed key import to call session received callback even for + sessions that already exist in the database. +* *(appservice)* Fixed building websocket URL accidentally using file path + separators instead of always `/`. +* *(crypto)* Fixed key exports not including the `sender_claimed_keys` field. +* *(client)* Fixed incorrect context usage in async uploads. +* *(crypto)* Fixed panic when passing invalid input to megolm message index + parser used for debugging. +* *(bridgev2/provisioning)* Fixed completed or failed logins not being cleaned + up properly. + +[MSC4354]: https://github.com/matrix-org/matrix-spec-proposals/pull/4354 + +## v0.26.2 (2026-01-16) + +* *(bridgev2)* Added chunked portal deletion to avoid database locks when + deleting large portals. +* *(crypto,bridgev2)* Added option to encrypt reaction and reply metadata + as per [MSC4392]. +* *(bridgev2/login)* Added `default_value` for user input fields. +* *(bridgev2)* Added interfaces to let the Matrix connector provide suggested + HTTP client settings and to reset active connections of the network connector. +* *(bridgev2)* Added interface to let network connectors get the provisioning + API HTTP router and add new endpoints. +* *(event)* Added blurhash field to Beeper link preview objects. +* *(event)* Added [MSC4391] support for bot commands. +* *(event)* Dropped [MSC4332] support for bot commands. +* *(client)* Changed media download methods to return an error if the provided + MXC URI is empty. +* *(client)* Stabilized support for [MSC4323]. +* *(bridgev2/matrix)* Fixed `GetEvent` panicking when trying to decrypt events. +* *(bridgev2)* Fixed some deadlocks when room creation happens in parallel with + a portal re-ID call. + +[MSC4391]: https://github.com/matrix-org/matrix-spec-proposals/pull/4391 +[MSC4392]: https://github.com/matrix-org/matrix-spec-proposals/pull/4392 + +## v0.26.1 (2025-12-16) + +* **Breaking change *(mediaproxy)*** Changed `GetMediaResponseFile` to return + the mime type from the callback rather than in the return get media return + value. The callback can now also redirect the caller to a different file. +* *(federation)* Added join/knock/leave functions + (thanks to [@nexy7574] in [#422]). +* *(federation/eventauth)* Fixed various incorrect checks. +* *(client)* Added backoff for retrying media uploads to external URLs + (with MSC3870). +* *(bridgev2/config)* Added support for overriding config fields using + environment variables. +* *(bridgev2/commands)* Added command to mute chat on remote network. +* *(bridgev2)* Added interface for network connectors to redirect to a different + user ID when handling an invite from Matrix. +* *(bridgev2)* Added interface for signaling message request status of portals. +* *(bridgev2)* Changed portal creation to not backfill unless `CanBackfill` flag + is set in chat info. +* *(bridgev2)* Changed Matrix reaction handling to only delete old reaction if + bridging the new one is successful. +* *(bridgev2/mxmain)* Improved error message when trying to run bridge with + pre-megabridge database when no database migration exists. +* *(bridgev2)* Improved reliability of database migration when enabling split + portals. +* *(bridgev2)* Improved detection of orphaned DM rooms when starting new chats. +* *(bridgev2)* Stopped sending redundant invites when joining ghosts to public + portal rooms. +* *(bridgev2)* Stopped hardcoding room versions in favor of checking + server capabilities to determine appropriate `/createRoom` parameters. + +[#422]: https://github.com/mautrix/go/pull/422 + +## v0.26.0 (2025-11-16) + +* *(client,appservice)* Deprecated `SendMassagedStateEvent` as `SendStateEvent` + has been able to do the same for a while now. +* *(client,federation)* Added size limits for responses to make it safer to send + requests to untrusted servers. +* *(client)* Added wrapper for `/admin/whois` client API + (thanks to [@nexy7574] in [#411]). +* *(synapseadmin)* Added `force_purge` option to DeleteRoom + (thanks to [@nexy7574] in [#420]). +* *(statestore)* Added saving join rules for rooms. +* *(bridgev2)* Added optional automatic rollback of room state if bridging the + change to the remote network fails. +* *(bridgev2)* Added management room notices if transient disconnect state + doesn't resolve within 3 minutes. +* *(bridgev2)* Added interface to signal that certain participants couldn't be + invited when creating a group. +* *(bridgev2)* Added `select` type for user input fields in login. +* *(bridgev2)* Added interface to let network connector customize personal + filtering space. +* *(bridgev2/matrix)* Added checks to avoid sending error messages in reply to + other bots. +* *(bridgev2/matrix)* Switched to using [MSC4169] to send redactions whenever + possible. +* *(bridgev2/publicmedia)* Added support for custom path prefixes, file names, + and encrypted files. +* *(bridgev2/commands)* Added command to resync a single portal. +* *(bridgev2/commands)* Added create group command. +* *(bridgev2/config)* Added option to limit maximum number of logins. +* *(bridgev2)* Changed ghost joining to skip unnecessary invite if portal room + is public. +* *(bridgev2/disappear)* Changed read receipt handling to only start + disappearing timers for messages up to the read message (note: may not work in + all cases if the read receipt points at an unknown event). +* *(event/reply)* Changed plaintext reply fallback removal to only happen when + an HTML reply fallback is removed successfully. +* *(bridgev2/matrix)* Fixed unnecessary sleep after registering bot on first run. +* *(crypto/goolm)* Fixed panic when processing certain malformed Olm messages. +* *(federation)* Fixed HTTP method for sending transactions + (thanks to [@nexy7574] in [#426]). +* *(federation)* Fixed response body being closed even when using `DontReadBody` + parameter. +* *(federation)* Fixed validating auth for requests with query params. +* *(federation/eventauth)* Fixed typo causing restricted joins to not work. + +[MSC4169]: https://github.com/matrix-org/matrix-spec-proposals/pull/4169 +[#411]: github.com/mautrix/go/pull/411 +[#420]: github.com/mautrix/go/pull/420 +[#426]: github.com/mautrix/go/pull/426 + +## v0.25.2 (2025-10-16) + +* **Breaking change *(id)*** Split `UserID.ParseAndValidate` into + `ParseAndValidateRelaxed` and `ParseAndValidateStrict`. Strict is the old + behavior, but most users likely want the relaxed version, as there are real + users whose user IDs aren't valid under the strict rules. +* *(crypto)* Added helper methods for generating and verifying with recovery + keys. +* *(bridgev2/matrix)* Added config option to automatically generate a recovery + key for the bridge bot and self-sign the bridge's device. +* *(bridgev2/matrix)* Added initial support for using appservice/MSC3202 mode + for encryption with standard servers like Synapse. +* *(bridgev2)* Added optional support for implicit read receipts. +* *(bridgev2)* Added interface for deleting chats on remote network. +* *(bridgev2)* Added local enforcement of media duration and size limits. +* *(bridgev2)* Extended event duration logging to log any event taking too long. +* *(bridgev2)* Improved validation in group creation provisioning API. +* *(event)* Added event type constant for poll end events. +* *(client)* Added wrapper for searching user directory. +* *(client)* Improved support for managing [MSC4140] delayed events. +* *(crypto/helper)* Changed default sync handling to not block on waiting for + decryption keys. On initial sync, keys won't be requested at all by default. +* *(crypto)* Fixed olm unwedging not working (regressed in v0.25.1). +* *(bridgev2)* Fixed various bugs with migrating to split portals. +* *(event)* Fixed poll start events having incorrect null `m.relates_to`. +* *(client)* Fixed `RespUserProfile` losing standard fields when re-marshaling. +* *(federation)* Fixed various bugs in event auth. + +## v0.25.1 (2025-09-16) + +* *(client)* Fixed HTTP method of delete devices API call + (thanks to [@fmseals] in [#393]). +* *(client)* Added wrappers for [MSC4323]: User suspension & locking endpoints + (thanks to [@nexy7574] in [#407]). +* *(client)* Stabilized support for extensible profiles. +* *(client)* Stabilized support for `state_after` in sync. +* *(client)* Removed deprecated MSC2716 requests. +* *(crypto)* Added fallback to ensure `m.relates_to` is always copied even if + the content struct doesn't implement `Relatable`. +* *(crypto)* Changed olm unwedging to ignore newly created sessions if they + haven't been used successfully in either direction. +* *(federation)* Added utilities for generating, parsing, validating and + authorizing PDUs. + * Note: the new PDU code depends on `GOEXPERIMENT=jsonv2` +* *(event)* Added `is_animated` flag from [MSC4230] to file info. +* *(event)* Added types for [MSC4332]: In-room bot commands. +* *(event)* Added missing poll end event type for [MSC3381]. +* *(appservice)* Fixed URLs not being escaped properly when using unix socket + for homeserver connections. +* *(format)* Added more helpers for forming markdown links. +* *(event,bridgev2)* Added support for Beeper's disappearing message state event. +* *(bridgev2)* Redesigned group creation interface and added support in commands + and provisioning API. +* *(bridgev2)* Added GetEvent to Matrix interface to allow network connectors to + get an old event. The method is best effort only, as some configurations don't + allow fetching old events. +* *(bridgev2)* Added shared logic for provisioning that can be reused by the + API, commands and other sources. +* *(bridgev2)* Fixed mentions and URL previews not being copied over when + caption and media are merged. +* *(bridgev2)* Removed config option to change provisioning API prefix, which + had already broken in the previous release. + +[@fmseals]: https://github.com/fmseals +[#393]: https://github.com/mautrix/go/pull/393 +[#407]: https://github.com/mautrix/go/pull/407 +[MSC3381]: https://github.com/matrix-org/matrix-spec-proposals/pull/3381 +[MSC4230]: https://github.com/matrix-org/matrix-spec-proposals/pull/4230 +[MSC4323]: https://github.com/matrix-org/matrix-spec-proposals/pull/4323 +[MSC4332]: https://github.com/matrix-org/matrix-spec-proposals/pull/4332 + +## v0.25.0 (2025-08-16) + +* Bumped minimum Go version to 1.24. +* **Breaking change *(appservice,bridgev2,federation)*** Replaced gorilla/mux + with standard library ServeMux. +* *(client,bridgev2)* Added support for creator power in room v12. +* *(client)* Added option to not set `User-Agent` header for improved Wasm + compatibility. +* *(bridgev2)* Added support for following tombstones. +* *(bridgev2)* Added interface for getting arbitrary state event from Matrix. +* *(bridgev2)* Added batching to disappearing message queue to ensure it doesn't + use too many resources even if there are a large number of messages. +* *(bridgev2/commands)* Added support for canceling QR login with `cancel` + command. +* *(client)* Added option to override HTTP client used for .well-known + resolution. +* *(crypto/backup)* Added method for encrypting key backup session without + private keys. +* *(event->id)* Moved room version type and constants to id package. +* *(bridgev2)* Bots in DM portals will now be added to the functional members + state event to hide them from the room name calculation. +* *(bridgev2)* Changed message delete handling to ignore "delete for me" events + if there are multiple Matrix users in the room. +* *(format/htmlparser)* Changed text processing to collapse multiple spaces into + one when outside `pre`/`code` tags. +* *(format/htmlparser)* Removed link suffix in plaintext output when link text + is only missing protocol part of href. + * e.g. `example.com` will turn into + `example.com` rather than `example.com (https://example.com)` +* *(appservice)* Switched appservice websockets from gorilla/websocket to + coder/websocket. +* *(bridgev2/matrix)* Fixed encryption key sharing not ignoring ghosts properly. +* *(crypto/attachments)* Fixed hash check when decrypting file streams. +* *(crypto)* Removed unnecessary `AlreadyShared` error in `ShareGroupSession`. + The function will now act as if it was successful instead. + +## v0.24.2 (2025-07-16) + +* *(bridgev2)* Added support for return values from portal event handlers. Note + that the return value will always be "queued" unless the event buffer is + disabled. +* *(bridgev2)* Added support for [MSC4144] per-message profile passthrough in + relay mode. +* *(bridgev2)* Added option to auto-reconnect logins after a certain period if + they hit an `UNKNOWN_ERROR` state. +* *(bridgev2)* Added analytics for event handler panics. +* *(bridgev2)* Changed new room creation to hardcode room v11 to avoid v12 rooms + being created before proper support for them can be added. +* *(bridgev2)* Changed queuing events to block instead of dropping events if the + buffer is full. +* *(bridgev2)* Fixed assumption that replies to unknown messages are cross-room. +* *(id)* Fixed server name validation not including ports correctly + (thanks to [@krombel] in [#392]). +* *(federation)* Fixed base64 algorithm in signature generation. +* *(event)* Fixed [MSC4144] fallbacks not being removed from edits. + +[@krombel]: https://github.com/krombel +[#392]: https://github.com/mautrix/go/pull/392 + +## v0.24.1 (2025-06-16) + +* *(commands)* Added framework for using reactions as buttons that execute + command handlers. +* *(client)* Added wrapper for `/relations` endpoints. +* *(client)* Added support for stable version of room summary endpoint. +* *(client)* Fixed parsing URL preview responses where width/height are strings. +* *(federation)* Fixed bugs in server auth. +* *(id)* Added utilities for validating server names. +* *(event)* Fixed incorrect empty `entity` field when sending hashed moderation + policy events. +* *(event)* Added [MSC4293] redact events field to member events. +* *(event)* Added support for fallbacks in [MSC4144] per-message profiles. +* *(format)* Added `MarkdownLink` and `MarkdownMention` utility functions for + generating properly escaped markdown. +* *(synapseadmin)* Added support for synchronous (v1) room delete endpoint. +* *(synapseadmin)* Changed `Client` struct to not embed the `mautrix.Client`. + This is a breaking change if you were relying on accessing non-admin functions + from the admin client. +* *(bridgev2/provisioning)* Fixed `/display_and_wait` not passing through errors + from the network connector properly. +* *(bridgev2/crypto)* Fixed encryption not working if the user's ID had the same + prefix as the bridge ghosts (e.g. `@whatsappbridgeuser:example.com` with a + `@whatsapp_` prefix). +* *(bridgev2)* Fixed portals not being saved after creating a DM portal from a + Matrix DM invite. +* *(bridgev2)* Added config option to determine whether cross-room replies + should be bridged. +* *(appservice)* Fixed `EnsureRegistered` not being called when sending a custom + member event for the controlled user. + +[MSC4293]: https://github.com/matrix-org/matrix-spec-proposals/pull/4293 + +## v0.24.0 (2025-05-16) + +* *(commands)* Added generic framework for implementing bot commands. +* *(client)* Added support for specifying maximum number of HTTP retries using + a context value instead of having to call `MakeFullRequest` manually. +* *(client,federation)* Added methods for fetching room directories. +* *(federation)* Added support for server side of request authentication. +* *(synapseadmin)* Added wrapper for the account suspension endpoint. +* *(format)* Added method for safely wrapping a string in markdown inline code. +* *(crypto)* Added method to import key backup without persisting to database, + to allow the client more control over the process. +* *(bridgev2)* Added viewing chat interface to signal when the user is viewing + a given chat. +* *(bridgev2)* Added option to pass through transaction ID from client when + sending messages to remote network. +* *(crypto)* Fixed unnecessary error log when decrypting dummy events used for + unwedging Olm sessions. +* *(crypto)* Fixed `forwarding_curve25519_key_chain` not being set consistently + when backing up keys. +* *(event)* Fixed marshaling legacy VoIP events with no version field. +* *(bridgev2)* Fixed disappearing message references not being deleted when the + portal is deleted. +* *(bridgev2)* Fixed read receipt bridging not ignoring fake message entries + and causing unnecessary error logs. + +## v0.23.3 (2025-04-16) + +* *(commands)* Added generic command processing framework for bots. +* *(client)* Added `allowed_room_ids` field to room summary responses + (thanks to [@nexy7574] in [#367]). +* *(bridgev2)* Added support for custom timeouts on outgoing messages which have + to wait for a remote echo. +* *(bridgev2)* Added automatic typing stop event if the ghost user had sent a + typing event before a message. +* *(bridgev2)* The saved management room is now cleared if the user leaves the + room, allowing the next DM to be automatically marked as a management room. +* *(bridge)* Removed deprecated fallback package for bridge statuses. + The status package is now only available under bridgev2. + +[#367]: https://github.com/mautrix/go/pull/367 + +## v0.23.2 (2025-03-16) + +* **Breaking change *(bridge)*** Removed legacy bridge module. +* **Breaking change *(event)*** Changed `m.federate` field in room create event + content to a pointer to allow detecting omitted values. +* *(bridgev2/commands)* Added `set-management-room` command to set a new + management room. +* *(bridgev2/portal)* Changed edit bridging to ignore remote edits if the + original sender on Matrix can't be puppeted. +* *(bridgv2)* Added config option to disable bridging `m.notice` messages. +* *(appservice/http)* Switched access token validation to use constant time + comparisons. +* *(event)* Added support for [MSC3765] rich text topics. +* *(event)* Added fields to policy list event contents for [MSC4204] and + [MSC4205]. +* *(client)* Added method for getting the content of a redacted event using + [MSC2815]. +* *(client)* Added methods for sending and updating [MSC4140] delayed events. +* *(client)* Added support for [MSC4222] in sync payloads. +* *(crypto/cryptohelper)* Switched to using `sqlite3-fk-wal` instead of plain + `sqlite3` by default. +* *(crypto/encryptolm)* Added generic method for encrypting to-device events. +* *(crypto/ssss)* Fixed panic if server-side key metadata is corrupted. +* *(crypto/sqlstore)* Fixed error when marking over 32 thousand device lists + as outdated on SQLite. + +[MSC2815]: https://github.com/matrix-org/matrix-spec-proposals/pull/2815 +[MSC3765]: https://github.com/matrix-org/matrix-spec-proposals/pull/3765 +[MSC4140]: https://github.com/matrix-org/matrix-spec-proposals/pull/4140 +[MSC4204]: https://github.com/matrix-org/matrix-spec-proposals/pull/4204 +[MSC4205]: https://github.com/matrix-org/matrix-spec-proposals/pull/4205 +[MSC4222]: https://github.com/matrix-org/matrix-spec-proposals/pull/4222 + +## v0.23.1 (2025-02-16) + +* *(client)* Added `FullStateEvent` method to get a state event including + metadata (using the `?format=event` query parameter). +* *(client)* Added wrapper method for [MSC4194]'s redact endpoint. +* *(pushrules)* Fixed content rules not considering word boundaries and being + case-sensitive. +* *(crypto)* Fixed bugs that would cause key exports to fail for no reason. +* *(crypto)* Deprecated `ResolveTrust` in favor of `ResolveTrustContext`. +* *(crypto)* Stopped accepting secret shares from unverified devices. +* **Breaking change *(crypto)*** Changed `GetAndVerifyLatestKeyBackupVersion` + to take an optional private key parameter. The method will now trust the + public key if it matches the provided private key even if there are no valid + signatures. +* **Breaking change *(crypto)*** Added context parameter to `IsDeviceTrusted`. + +[MSC4194]: https://github.com/matrix-org/matrix-spec-proposals/pull/4194 + +## v0.23.0 (2025-01-16) + +* **Breaking change *(client)*** Changed `JoinRoom` parameters to allow multiple + `via`s. +* **Breaking change *(bridgev2)*** Updated capability system. + * The return type of `NetworkAPI.GetCapabilities` is now different. + * Media type capabilities are enforced automatically by bridgev2. + * Capabilities are now sent to Matrix rooms using the + `com.beeper.room_features` state event. +* *(client)* Added `GetRoomSummary` to implement [MSC3266]. +* *(client)* Added support for arbitrary profile fields to implement [MSC4133] + (thanks to [@nexy7574] in [#337]). +* *(crypto)* Started storing olm message hashes to prevent decryption errors + if messages are repeated (e.g. if the app crashes right after decrypting). +* *(crypto)* Improved olm session unwedging to check when the last session was + created instead of only relying on an in-memory map. +* *(crypto/verificationhelper)* Fixed emoji verification not doing cross-signing + properly after a successful verification. +* *(bridgev2/config)* Moved MSC4190 flag from `appservice` to `encryption`. +* *(bridgev2/space)* Fixed failing to add rooms to spaces if the room create + call was made with a temporary context. +* *(bridgev2/commands)* Changed `help` command to hide commands which require + interfaces that aren't implemented by the network connector. +* *(bridgev2/matrixinterface)* Moved deterministic room ID generation to Matrix + connector. +* *(bridgev2)* Fixed service member state event not being set correctly when + creating a DM by inviting a ghost user. +* *(bridgev2)* Fixed `RemoteReactionSync` events replacing all reactions every + time instead of only changed ones. + +[MSC3266]: https://github.com/matrix-org/matrix-spec-proposals/pull/3266 +[MSC4133]: https://github.com/matrix-org/matrix-spec-proposals/pull/4133 +[@nexy7574]: https://github.com/nexy7574 +[#337]: https://github.com/mautrix/go/pull/337 + ## v0.22.1 (2024-12-16) * *(crypto)* Added automatic cleanup when there are too many olm sessions with @@ -20,6 +437,7 @@ [MSC4156]: https://github.com/matrix-org/matrix-spec-proposals/pull/4156 [MSC4190]: https://github.com/matrix-org/matrix-spec-proposals/pull/4190 [#288]: https://github.com/mautrix/go/pull/288 +[@onestacked]: https://github.com/onestacked ## v0.22.0 (2024-11-16) diff --git a/README.md b/README.md index ac41ca78..b1a2edf8 100644 --- a/README.md +++ b/README.md @@ -1,8 +1,9 @@ # mautrix-go [![GoDoc](https://pkg.go.dev/badge/maunium.net/go/mautrix)](https://pkg.go.dev/maunium.net/go/mautrix) -A Golang Matrix framework. Used by [gomuks](https://matrix.org/docs/projects/client/gomuks), -[go-neb](https://github.com/matrix-org/go-neb), [mautrix-whatsapp](https://github.com/mautrix/whatsapp) +A Golang Matrix framework. Used by [gomuks](https://gomuks.app), +[go-neb](https://github.com/matrix-org/go-neb), +[mautrix-whatsapp](https://github.com/mautrix/whatsapp) and others. Matrix room: [`#go:maunium.net`](https://matrix.to/#/#go:maunium.net) @@ -13,9 +14,10 @@ The original project is licensed under [Apache 2.0](https://github.com/matrix-or In addition to the basic client API features the original project has, this framework also has: * Appservice support (Intent API like mautrix-python, room state storage, etc) -* End-to-end encryption support (incl. interactive SAS verification) +* End-to-end encryption support (incl. key backup, cross-signing, interactive verification, etc) * High-level module for building puppeting bridges -* High-level module for building chat clients +* Partial federation module (making requests, PDU processing and event authorization) +* A media proxy server which can be used to expose anything as a Matrix media repo * Wrapper functions for the Synapse admin API * Structs for parsing event content * Helpers for parsing and generating Matrix HTML diff --git a/appservice/appservice.go b/appservice/appservice.go index 518e1073..d7037ef6 100644 --- a/appservice/appservice.go +++ b/appservice/appservice.go @@ -1,4 +1,4 @@ -// Copyright (c) 2023 Tulir Asokan +// Copyright (c) 2025 Tulir Asokan // // This Source Code Form is subject to the terms of the Mozilla Public // License, v. 2.0. If a copy of the MPL was not distributed with this @@ -19,8 +19,7 @@ import ( "syscall" "time" - "github.com/gorilla/mux" - "github.com/gorilla/websocket" + "github.com/coder/websocket" "github.com/rs/zerolog" "golang.org/x/net/publicsuffix" "gopkg.in/yaml.v3" @@ -43,7 +42,7 @@ func Create() *AppService { intents: make(map[id.UserID]*IntentAPI), HTTPClient: &http.Client{Timeout: 180 * time.Second, Jar: jar}, StateStore: mautrix.NewMemoryStateStore().(StateStore), - Router: mux.NewRouter(), + Router: http.NewServeMux(), UserAgent: mautrix.DefaultUserAgent, txnIDC: NewTransactionIDCache(128), Live: true, @@ -61,12 +60,12 @@ func Create() *AppService { DefaultHTTPRetries: 4, } - as.Router.HandleFunc("/_matrix/app/v1/transactions/{txnID}", as.PutTransaction).Methods(http.MethodPut) - as.Router.HandleFunc("/_matrix/app/v1/rooms/{roomAlias}", as.GetRoom).Methods(http.MethodGet) - as.Router.HandleFunc("/_matrix/app/v1/users/{userID}", as.GetUser).Methods(http.MethodGet) - as.Router.HandleFunc("/_matrix/app/v1/ping", as.PostPing).Methods(http.MethodPost) - as.Router.HandleFunc("/_matrix/mau/live", as.GetLive).Methods(http.MethodGet) - as.Router.HandleFunc("/_matrix/mau/ready", as.GetReady).Methods(http.MethodGet) + as.Router.HandleFunc("PUT /_matrix/app/v1/transactions/{txnID}", as.PutTransaction) + as.Router.HandleFunc("GET /_matrix/app/v1/rooms/{roomAlias}", as.GetRoom) + as.Router.HandleFunc("GET /_matrix/app/v1/users/{userID}", as.GetUser) + as.Router.HandleFunc("POST /_matrix/app/v1/ping", as.PostPing) + as.Router.HandleFunc("GET /_matrix/mau/live", as.GetLive) + as.Router.HandleFunc("GET /_matrix/mau/ready", as.GetReady) return as } @@ -114,13 +113,13 @@ var _ StateStore = (*mautrix.MemoryStateStore)(nil) // QueryHandler handles room alias and user ID queries from the homeserver. type QueryHandler interface { - QueryAlias(alias string) bool + QueryAlias(alias id.RoomAlias) bool QueryUser(userID id.UserID) bool } type QueryHandlerStub struct{} -func (qh *QueryHandlerStub) QueryAlias(alias string) bool { +func (qh *QueryHandlerStub) QueryAlias(alias id.RoomAlias) bool { return false } @@ -128,7 +127,7 @@ func (qh *QueryHandlerStub) QueryUser(userID id.UserID) bool { return false } -type WebsocketHandler func(WebsocketCommand) (ok bool, data interface{}) +type WebsocketHandler func(WebsocketCommand) (ok bool, data any) type StateStore interface { mautrix.StateStore @@ -160,7 +159,7 @@ type AppService struct { QueryHandler QueryHandler StateStore StateStore - Router *mux.Router + Router *http.ServeMux UserAgent string server *http.Server HTTPClient *http.Client @@ -179,7 +178,6 @@ type AppService struct { intentsLock sync.RWMutex ws *websocket.Conn - wsWriteLock sync.Mutex StopWebsocket func(error) websocketHandlers map[string]WebsocketHandler websocketHandlersLock sync.RWMutex @@ -336,7 +334,7 @@ func (as *AppService) SetHomeserverURL(homeserverURL string) error { } else if as.hsURLForClient.Scheme == "" { as.hsURLForClient.Scheme = "https" } - as.hsURLForClient.RawPath = parsedURL.EscapedPath() + as.hsURLForClient.RawPath = as.hsURLForClient.EscapedPath() jar, _ := cookiejar.New(&cookiejar.Options{PublicSuffixList: publicsuffix.List}) as.HTTPClient = &http.Client{Timeout: 180 * time.Second, Jar: jar} @@ -362,7 +360,7 @@ func (as *AppService) NewMautrixClient(userID id.UserID) *mautrix.Client { AccessToken: as.Registration.AppToken, UserAgent: as.UserAgent, StateStore: as.StateStore, - Log: as.Log.With().Str("as_user_id", userID.String()).Logger(), + Log: as.Log.With().Stringer("as_user_id", userID).Logger(), Client: as.HTTPClient, DefaultHTTPRetries: as.DefaultHTTPRetries, SpecVersions: as.SpecVersions, diff --git a/appservice/http.go b/appservice/http.go index 661513b4..27ce6288 100644 --- a/appservice/http.go +++ b/appservice/http.go @@ -1,4 +1,4 @@ -// Copyright (c) 2023 Tulir Asokan +// Copyright (c) 2025 Tulir Asokan // // This Source Code Form is subject to the terms of the Mozilla Public // License, v. 2.0. If a copy of the MPL was not distributed with this @@ -17,8 +17,9 @@ import ( "syscall" "time" - "github.com/gorilla/mux" "github.com/rs/zerolog" + "go.mau.fi/util/exhttp" + "go.mau.fi/util/exstrings" "maunium.net/go/mautrix" "maunium.net/go/mautrix/event" @@ -78,17 +79,9 @@ func (as *AppService) Stop() { func (as *AppService) CheckServerToken(w http.ResponseWriter, r *http.Request) (isValid bool) { authHeader := r.Header.Get("Authorization") if !strings.HasPrefix(authHeader, "Bearer ") { - Error{ - ErrorCode: ErrUnknownToken, - HTTPStatus: http.StatusForbidden, - Message: "Missing access token", - }.Write(w) - } else if authHeader[len("Bearer "):] != as.Registration.ServerToken { - Error{ - ErrorCode: ErrUnknownToken, - HTTPStatus: http.StatusForbidden, - Message: "Incorrect access token", - }.Write(w) + mautrix.MMissingToken.WithMessage("Missing access token").Write(w) + } else if !exstrings.ConstantTimeEqual(authHeader[len("Bearer "):], as.Registration.ServerToken) { + mautrix.MUnknownToken.WithMessage("Invalid access token").Write(w) } else { isValid = true } @@ -101,24 +94,15 @@ func (as *AppService) PutTransaction(w http.ResponseWriter, r *http.Request) { return } - vars := mux.Vars(r) - txnID := vars["txnID"] + txnID := r.PathValue("txnID") if len(txnID) == 0 { - Error{ - ErrorCode: ErrNoTransactionID, - HTTPStatus: http.StatusBadRequest, - Message: "Missing transaction ID", - }.Write(w) + mautrix.MInvalidParam.WithMessage("Missing transaction ID").Write(w) return } defer r.Body.Close() body, err := io.ReadAll(r.Body) if err != nil || len(body) == 0 { - Error{ - ErrorCode: ErrNotJSON, - HTTPStatus: http.StatusBadRequest, - Message: "Missing request body", - }.Write(w) + mautrix.MNotJSON.WithMessage("Failed to read response body").Write(w) return } log := as.Log.With().Str("transaction_id", txnID).Logger() @@ -127,7 +111,7 @@ func (as *AppService) PutTransaction(w http.ResponseWriter, r *http.Request) { ctx = log.WithContext(ctx) if as.txnIDC.IsProcessed(txnID) { // Duplicate transaction ID: no-op - WriteBlankOK(w) + exhttp.WriteEmptyJSONResponse(w, http.StatusOK) log.Debug().Msg("Ignoring duplicate transaction") return } @@ -136,14 +120,10 @@ func (as *AppService) PutTransaction(w http.ResponseWriter, r *http.Request) { err = json.Unmarshal(body, &txn) if err != nil { log.Error().Err(err).Msg("Failed to parse transaction content") - Error{ - ErrorCode: ErrBadJSON, - HTTPStatus: http.StatusBadRequest, - Message: "Failed to parse body JSON", - }.Write(w) + mautrix.MBadJSON.WithMessage("Failed to parse transaction content").Write(w) } else { as.handleTransaction(ctx, txnID, &txn) - WriteBlankOK(w) + exhttp.WriteEmptyJSONResponse(w, http.StatusOK) } } @@ -221,7 +201,7 @@ func (as *AppService) handleEvents(ctx context.Context, evts []*event.Event, def } err := evt.Content.ParseRaw(evt.Type) if errors.Is(err, event.ErrUnsupportedContentType) { - log.Debug().Str("event_id", evt.ID.String()).Msg("Not parsing content of unsupported event") + log.Debug().Stringer("event_id", evt.ID).Msg("Not parsing content of unsupported event") } else if err != nil { log.Warn().Err(err). Str("event_id", evt.ID.String()). @@ -258,16 +238,12 @@ func (as *AppService) GetRoom(w http.ResponseWriter, r *http.Request) { return } - vars := mux.Vars(r) - roomAlias := vars["roomAlias"] + roomAlias := id.RoomAlias(r.PathValue("roomAlias")) ok := as.QueryHandler.QueryAlias(roomAlias) if ok { - WriteBlankOK(w) + exhttp.WriteEmptyJSONResponse(w, http.StatusOK) } else { - Error{ - ErrorCode: ErrUnknown, - HTTPStatus: http.StatusNotFound, - }.Write(w) + mautrix.MNotFound.WithMessage("Alias not found").Write(w) } } @@ -277,16 +253,12 @@ func (as *AppService) GetUser(w http.ResponseWriter, r *http.Request) { return } - vars := mux.Vars(r) - userID := id.UserID(vars["userID"]) + userID := id.UserID(r.PathValue("userID")) ok := as.QueryHandler.QueryUser(userID) if ok { - WriteBlankOK(w) + exhttp.WriteEmptyJSONResponse(w, http.StatusOK) } else { - Error{ - ErrorCode: ErrUnknown, - HTTPStatus: http.StatusNotFound, - }.Write(w) + mautrix.MNotFound.WithMessage("User not found").Write(w) } } @@ -296,11 +268,7 @@ func (as *AppService) PostPing(w http.ResponseWriter, r *http.Request) { } body, err := io.ReadAll(r.Body) if err != nil || len(body) == 0 || !json.Valid(body) { - Error{ - ErrorCode: ErrNotJSON, - HTTPStatus: http.StatusBadRequest, - Message: "Missing request body", - }.Write(w) + mautrix.MNotJSON.WithMessage("Invalid or missing request body").Write(w) return } @@ -308,27 +276,21 @@ func (as *AppService) PostPing(w http.ResponseWriter, r *http.Request) { _ = json.Unmarshal(body, &txn) as.Log.Debug().Str("txn_id", txn.TxnID).Msg("Received ping from homeserver") - w.Header().Add("Content-Type", "application/json") - w.WriteHeader(http.StatusOK) - w.Write([]byte("{}")) + exhttp.WriteEmptyJSONResponse(w, http.StatusOK) } func (as *AppService) GetLive(w http.ResponseWriter, r *http.Request) { - w.Header().Add("Content-Type", "application/json") if as.Live { - w.WriteHeader(http.StatusOK) + exhttp.WriteEmptyJSONResponse(w, http.StatusOK) } else { - w.WriteHeader(http.StatusInternalServerError) + exhttp.WriteEmptyJSONResponse(w, http.StatusInternalServerError) } - w.Write([]byte("{}")) } func (as *AppService) GetReady(w http.ResponseWriter, r *http.Request) { - w.Header().Add("Content-Type", "application/json") if as.Ready { - w.WriteHeader(http.StatusOK) + exhttp.WriteEmptyJSONResponse(w, http.StatusOK) } else { - w.WriteHeader(http.StatusInternalServerError) + exhttp.WriteEmptyJSONResponse(w, http.StatusInternalServerError) } - w.Write([]byte("{}")) } diff --git a/appservice/intent.go b/appservice/intent.go index 6848f28c..5d43f190 100644 --- a/appservice/intent.go +++ b/appservice/intent.go @@ -51,7 +51,7 @@ func (as *AppService) NewIntentAPI(localpart string) *IntentAPI { } func (intent *IntentAPI) Register(ctx context.Context) error { - _, err := intent.Client.MakeRequest(ctx, http.MethodPost, intent.BuildClientURL("v3", "register"), &mautrix.ReqRegister{ + _, err := intent.Client.MakeRequest(ctx, http.MethodPost, intent.BuildClientURL("v3", "register"), &mautrix.ReqRegister[any]{ Username: intent.Localpart, Type: mautrix.AuthTypeAppservice, InhibitLogin: true, @@ -86,6 +86,7 @@ func (intent *IntentAPI) EnsureRegistered(ctx context.Context) error { type EnsureJoinedParams struct { IgnoreCache bool BotOverride *mautrix.Client + Via []string } func (intent *IntentAPI) EnsureJoined(ctx context.Context, roomID id.RoomID, extra ...EnsureJoinedParams) error { @@ -99,11 +100,17 @@ func (intent *IntentAPI) EnsureJoined(ctx context.Context, roomID id.RoomID, ext return nil } - if err := intent.EnsureRegistered(ctx); err != nil { + err := intent.EnsureRegistered(ctx) + if err != nil { return fmt.Errorf("failed to ensure joined: %w", err) } - resp, err := intent.JoinRoomByID(ctx, roomID) + var resp *mautrix.RespJoinRoom + if len(params.Via) > 0 { + resp, err = intent.JoinRoom(ctx, roomID.String(), &mautrix.ReqJoinRoom{Via: params.Via}) + } else { + resp, err = intent.JoinRoomByID(ctx, roomID) + } if err != nil { bot := intent.bot if params.BotOverride != nil { @@ -142,12 +149,16 @@ func (intent *IntentAPI) EnsureJoined(ctx context.Context, roomID id.RoomID, ext return nil } +func (intent *IntentAPI) IsDoublePuppet() bool { + return intent.IsCustomPuppet && intent.as.DoublePuppetValue != "" +} + func (intent *IntentAPI) AddDoublePuppetValue(into any) any { return intent.AddDoublePuppetValueWithTS(into, 0) } func (intent *IntentAPI) AddDoublePuppetValueWithTS(into any, ts int64) any { - if !intent.IsCustomPuppet || intent.as.DoublePuppetValue == "" { + if !intent.IsDoublePuppet() { return into } // Only use ts deduplication feature with appservice double puppeting @@ -203,38 +214,45 @@ func (intent *IntentAPI) AddDoublePuppetValueWithTS(into any, ts int64) any { } } -func (intent *IntentAPI) SendMessageEvent(ctx context.Context, roomID id.RoomID, eventType event.Type, contentJSON interface{}) (*mautrix.RespSendEvent, error) { +func (intent *IntentAPI) SendMessageEvent(ctx context.Context, roomID id.RoomID, eventType event.Type, contentJSON any, extra ...mautrix.ReqSendEvent) (*mautrix.RespSendEvent, error) { if err := intent.EnsureJoined(ctx, roomID); err != nil { return nil, err } contentJSON = intent.AddDoublePuppetValue(contentJSON) - return intent.Client.SendMessageEvent(ctx, roomID, eventType, contentJSON) + return intent.Client.SendMessageEvent(ctx, roomID, eventType, contentJSON, extra...) } -func (intent *IntentAPI) SendMassagedMessageEvent(ctx context.Context, roomID id.RoomID, eventType event.Type, contentJSON interface{}, ts int64) (*mautrix.RespSendEvent, error) { +func (intent *IntentAPI) BeeperSendEphemeralEvent(ctx context.Context, roomID id.RoomID, eventType event.Type, contentJSON any, extra ...mautrix.ReqSendEvent) (*mautrix.RespSendEvent, error) { if err := intent.EnsureJoined(ctx, roomID); err != nil { return nil, err } - contentJSON = intent.AddDoublePuppetValueWithTS(contentJSON, ts) - return intent.Client.SendMessageEvent(ctx, roomID, eventType, contentJSON, mautrix.ReqSendEvent{Timestamp: ts}) + if !intent.SpecVersions.Supports(mautrix.BeeperFeatureEphemeralEvents) { + return nil, mautrix.MUnrecognized.WithMessage("Homeserver does not advertise com.beeper.ephemeral support") + } + contentJSON = intent.AddDoublePuppetValue(contentJSON) + return intent.Client.BeeperSendEphemeralEvent(ctx, roomID, eventType, contentJSON, extra...) } -func (intent *IntentAPI) SendStateEvent(ctx context.Context, roomID id.RoomID, eventType event.Type, stateKey string, contentJSON interface{}) (*mautrix.RespSendEvent, error) { +// Deprecated: use SendMessageEvent with mautrix.ReqSendEvent.Timestamp instead +func (intent *IntentAPI) SendMassagedMessageEvent(ctx context.Context, roomID id.RoomID, eventType event.Type, contentJSON interface{}, ts int64) (*mautrix.RespSendEvent, error) { + return intent.SendMessageEvent(ctx, roomID, eventType, contentJSON, mautrix.ReqSendEvent{Timestamp: ts}) +} + +func (intent *IntentAPI) SendStateEvent(ctx context.Context, roomID id.RoomID, eventType event.Type, stateKey string, contentJSON any, extra ...mautrix.ReqSendEvent) (*mautrix.RespSendEvent, error) { if eventType != event.StateMember || stateKey != string(intent.UserID) { if err := intent.EnsureJoined(ctx, roomID); err != nil { return nil, err } - } - contentJSON = intent.AddDoublePuppetValue(contentJSON) - return intent.Client.SendStateEvent(ctx, roomID, eventType, stateKey, contentJSON) -} - -func (intent *IntentAPI) SendMassagedStateEvent(ctx context.Context, roomID id.RoomID, eventType event.Type, stateKey string, contentJSON interface{}, ts int64) (*mautrix.RespSendEvent, error) { - if err := intent.EnsureJoined(ctx, roomID); err != nil { + } else if err := intent.EnsureRegistered(ctx); err != nil { return nil, err } - contentJSON = intent.AddDoublePuppetValueWithTS(contentJSON, ts) - return intent.Client.SendMassagedStateEvent(ctx, roomID, eventType, stateKey, contentJSON, ts) + contentJSON = intent.AddDoublePuppetValue(contentJSON) + return intent.Client.SendStateEvent(ctx, roomID, eventType, stateKey, contentJSON, extra...) +} + +// Deprecated: use SendStateEvent with mautrix.ReqSendEvent.Timestamp instead +func (intent *IntentAPI) SendMassagedStateEvent(ctx context.Context, roomID id.RoomID, eventType event.Type, stateKey string, contentJSON interface{}, ts int64) (*mautrix.RespSendEvent, error) { + return intent.SendStateEvent(ctx, roomID, eventType, stateKey, contentJSON, mautrix.ReqSendEvent{Timestamp: ts}) } func (intent *IntentAPI) StateEvent(ctx context.Context, roomID id.RoomID, eventType event.Type, stateKey string, outContent interface{}) error { @@ -293,7 +311,7 @@ func (intent *IntentAPI) SendCustomMembershipEvent(ctx context.Context, roomID i func (intent *IntentAPI) JoinRoomByID(ctx context.Context, roomID id.RoomID, extraContent ...map[string]interface{}) (resp *mautrix.RespJoinRoom, err error) { if intent.IsCustomPuppet || len(extraContent) > 0 { _, err = intent.SendCustomMembershipEvent(ctx, roomID, intent.UserID, event.MembershipJoin, "", extraContent...) - return &mautrix.RespJoinRoom{}, err + return &mautrix.RespJoinRoom{RoomID: roomID}, err } return intent.Client.JoinRoomByID(ctx, roomID) } @@ -362,6 +380,24 @@ func (intent *IntentAPI) Member(ctx context.Context, roomID id.RoomID, userID id return member } +func (intent *IntentAPI) FillPowerLevelCreateEvent(ctx context.Context, roomID id.RoomID, pl *event.PowerLevelsEventContent) error { + if pl.CreateEvent != nil { + return nil + } + var err error + pl.CreateEvent, err = intent.StateStore.GetCreate(ctx, roomID) + if err != nil { + return fmt.Errorf("failed to get create event from cache: %w", err) + } else if pl.CreateEvent != nil { + return nil + } + pl.CreateEvent, err = intent.FullStateEvent(ctx, roomID, event.StateCreate, "") + if err != nil { + return fmt.Errorf("failed to get create event from server: %w", err) + } + return nil +} + func (intent *IntentAPI) PowerLevels(ctx context.Context, roomID id.RoomID) (pl *event.PowerLevelsEventContent, err error) { pl, err = intent.as.StateStore.GetPowerLevels(ctx, roomID) if err != nil { @@ -371,6 +407,12 @@ func (intent *IntentAPI) PowerLevels(ctx context.Context, roomID id.RoomID) (pl if pl == nil { pl = &event.PowerLevelsEventContent{} err = intent.StateEvent(ctx, roomID, event.StatePowerLevels, "", pl) + if err != nil { + return + } + } + if pl.CreateEvent == nil { + pl.CreateEvent, err = intent.FullStateEvent(ctx, roomID, event.StateCreate, "") } return } @@ -385,8 +427,7 @@ func (intent *IntentAPI) SetPowerLevel(ctx context.Context, roomID id.RoomID, us return nil, err } - if pl.GetUserLevel(userID) != level { - pl.SetUserLevel(userID, level) + if pl.EnsureUserLevelAs(intent.UserID, userID, level) { return intent.SendStateEvent(ctx, roomID, event.StatePowerLevels, "", &pl) } return nil, nil @@ -475,7 +516,7 @@ func (intent *IntentAPI) SetAvatarURL(ctx context.Context, avatarURL id.ContentU // No need to update return nil } - if !avatarURL.IsEmpty() { + if !avatarURL.IsEmpty() && !intent.SpecVersions.Supports(mautrix.BeeperFeatureHungry) { // Some homeservers require the avatar to be downloaded before setting it resp, _ := intent.Download(ctx, avatarURL) if resp != nil { diff --git a/appservice/ping.go b/appservice/ping.go new file mode 100644 index 00000000..774ec423 --- /dev/null +++ b/appservice/ping.go @@ -0,0 +1,68 @@ +// Copyright (c) 2025 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package appservice + +import ( + "context" + "encoding/json" + "errors" + "os" + "strings" + "time" + + "github.com/rs/zerolog" + + "maunium.net/go/mautrix" +) + +func (intent *IntentAPI) EnsureAppserviceConnection(ctx context.Context) { + var pingResp *mautrix.RespAppservicePing + var txnID string + var retryCount int + var err error + const maxRetries = 6 + for { + txnID = intent.TxnID() + pingResp, err = intent.AppservicePing(ctx, intent.as.Registration.ID, txnID) + if err == nil { + break + } + var httpErr mautrix.HTTPError + var pingErrBody string + if errors.As(err, &httpErr) && httpErr.RespError != nil { + if val, ok := httpErr.RespError.ExtraData["body"].(string); ok { + pingErrBody = strings.TrimSpace(val) + } + } + outOfRetries := retryCount >= maxRetries + level := zerolog.ErrorLevel + if outOfRetries { + level = zerolog.FatalLevel + } + evt := zerolog.Ctx(ctx).WithLevel(level).Err(err).Str("txn_id", txnID) + if pingErrBody != "" { + bodyBytes := []byte(pingErrBody) + if json.Valid(bodyBytes) { + evt.RawJSON("body", bodyBytes) + } else { + evt.Str("body", pingErrBody) + } + } + if outOfRetries { + evt.Msg("Homeserver -> appservice connection is not working") + zerolog.Ctx(ctx).Info().Msg("See https://docs.mau.fi/faq/as-ping for more info") + os.Exit(13) + } + evt.Msg("Homeserver -> appservice connection is not working, retrying in 5 seconds...") + time.Sleep(5 * time.Second) + retryCount++ + } + zerolog.Ctx(ctx).Debug(). + Str("txn_id", txnID). + Int64("duration_ms", pingResp.DurationMS). + Msg("Homeserver -> appservice connection works") +} diff --git a/appservice/protocol.go b/appservice/protocol.go index 7a9891ef..7c493bcb 100644 --- a/appservice/protocol.go +++ b/appservice/protocol.go @@ -1,4 +1,4 @@ -// Copyright (c) 2023 Tulir Asokan +// Copyright (c) 2025 Tulir Asokan // // This Source Code Form is subject to the terms of the Mozilla Public // License, v. 2.0. If a copy of the MPL was not distributed with this @@ -7,9 +7,7 @@ package appservice import ( - "encoding/json" "fmt" - "net/http" "strings" "github.com/rs/zerolog" @@ -103,50 +101,3 @@ func (txn *Transaction) ContentString() string { // EventListener is a function that receives events. type EventListener func(evt *event.Event) - -// WriteBlankOK writes a blank OK message as a reply to a HTTP request. -func WriteBlankOK(w http.ResponseWriter) { - w.Header().Add("Content-Type", "application/json") - w.WriteHeader(http.StatusOK) - _, _ = w.Write([]byte("{}")) -} - -// Respond responds to a HTTP request with a JSON object. -func Respond(w http.ResponseWriter, data interface{}) error { - w.Header().Add("Content-Type", "application/json") - dataStr, err := json.Marshal(data) - if err != nil { - return err - } - _, err = w.Write(dataStr) - return err -} - -// Error represents a Matrix protocol error. -type Error struct { - HTTPStatus int `json:"-"` - ErrorCode ErrorCode `json:"errcode"` - Message string `json:"error"` -} - -func (err Error) Write(w http.ResponseWriter) { - w.Header().Add("Content-Type", "application/json") - w.WriteHeader(err.HTTPStatus) - _ = Respond(w, &err) -} - -// ErrorCode is the machine-readable code in an Error. -type ErrorCode string - -// Native ErrorCodes -const ( - ErrUnknownToken ErrorCode = "M_UNKNOWN_TOKEN" - ErrBadJSON ErrorCode = "M_BAD_JSON" - ErrNotJSON ErrorCode = "M_NOT_JSON" - ErrUnknown ErrorCode = "M_UNKNOWN" -) - -// Custom ErrorCodes -const ( - ErrNoTransactionID ErrorCode = "NET.MAUNIUM.NO_TRANSACTION_ID" -) diff --git a/appservice/websocket.go b/appservice/websocket.go index 598d70d1..ef65e65a 100644 --- a/appservice/websocket.go +++ b/appservice/websocket.go @@ -1,4 +1,4 @@ -// Copyright (c) 2023 Tulir Asokan +// Copyright (c) 2025 Tulir Asokan // // This Source Code Form is subject to the terms of the Mozilla Public // License, v. 2.0. If a copy of the MPL was not distributed with this @@ -11,26 +11,26 @@ import ( "encoding/json" "errors" "fmt" + "io" "net/http" "net/url" - "path/filepath" + "path" "strings" "sync" "sync/atomic" - "time" - "github.com/gorilla/websocket" + "github.com/coder/websocket" "github.com/rs/zerolog" "github.com/tidwall/gjson" "github.com/tidwall/sjson" + + "maunium.net/go/mautrix" ) type WebsocketRequest struct { - ReqID int `json:"id,omitempty"` - Command string `json:"command"` - Data interface{} `json:"data"` - - Deadline time.Duration `json:"-"` + ReqID int `json:"id,omitempty"` + Command string `json:"command"` + Data any `json:"data"` } type WebsocketCommand struct { @@ -41,7 +41,7 @@ type WebsocketCommand struct { Ctx context.Context `json:"-"` } -func (wsc *WebsocketCommand) MakeResponse(ok bool, data interface{}) *WebsocketRequest { +func (wsc *WebsocketCommand) MakeResponse(ok bool, data any) *WebsocketRequest { if wsc.ReqID == 0 || wsc.Command == "response" || wsc.Command == "error" { return nil } @@ -56,7 +56,7 @@ func (wsc *WebsocketCommand) MakeResponse(ok bool, data interface{}) *WebsocketR var prefixMessage string for unwrappedErr != nil { errorData, jsonErr = json.Marshal(unwrappedErr) - if errorData != nil && len(errorData) > 2 && jsonErr == nil { + if len(errorData) > 2 && jsonErr == nil { prefixMessage = strings.Replace(err.Error(), unwrappedErr.Error(), "", 1) prefixMessage = strings.TrimRight(prefixMessage, ": ") break @@ -98,8 +98,8 @@ type WebsocketMessage struct { } const ( - WebsocketCloseConnReplaced = 4001 - WebsocketCloseTxnNotAcknowledged = 4002 + WebsocketCloseConnReplaced websocket.StatusCode = 4001 + WebsocketCloseTxnNotAcknowledged websocket.StatusCode = 4002 ) type MeowWebsocketCloseCode string @@ -133,7 +133,7 @@ func (mwcc MeowWebsocketCloseCode) String() string { } type CloseCommand struct { - Code int `json:"-"` + Code websocket.StatusCode `json:"-"` Command string `json:"command"` Status MeowWebsocketCloseCode `json:"status"` } @@ -143,15 +143,15 @@ func (cc CloseCommand) Error() string { } func parseCloseError(err error) error { - closeError := &websocket.CloseError{} + var closeError websocket.CloseError if !errors.As(err, &closeError) { return err } var closeCommand CloseCommand closeCommand.Code = closeError.Code closeCommand.Command = "disconnect" - if len(closeError.Text) > 0 { - jsonErr := json.Unmarshal([]byte(closeError.Text), &closeCommand) + if len(closeError.Reason) > 0 { + jsonErr := json.Unmarshal([]byte(closeError.Reason), &closeCommand) if jsonErr != nil { return err } @@ -159,7 +159,7 @@ func parseCloseError(err error) error { if len(closeCommand.Status) == 0 { if closeCommand.Code == WebsocketCloseConnReplaced { closeCommand.Status = MeowConnectionReplaced - } else if closeCommand.Code == websocket.CloseServiceRestart { + } else if closeCommand.Code == websocket.StatusServiceRestart { closeCommand.Status = MeowServerShuttingDown } } @@ -170,20 +170,23 @@ func (as *AppService) HasWebsocket() bool { return as.ws != nil } -func (as *AppService) SendWebsocket(cmd *WebsocketRequest) error { +func (as *AppService) SendWebsocket(ctx context.Context, cmd *WebsocketRequest) error { ws := as.ws if cmd == nil { return nil } else if ws == nil { return ErrWebsocketNotConnected } - as.wsWriteLock.Lock() - defer as.wsWriteLock.Unlock() - if cmd.Deadline == 0 { - cmd.Deadline = 3 * time.Minute + wr, err := ws.Writer(ctx, websocket.MessageText) + if err != nil { + return err } - _ = ws.SetWriteDeadline(time.Now().Add(cmd.Deadline)) - return ws.WriteJSON(cmd) + err = json.NewEncoder(wr).Encode(cmd) + if err != nil { + _ = wr.Close() + return err + } + return wr.Close() } func (as *AppService) clearWebsocketResponseWaiters() { @@ -220,12 +223,12 @@ func (er *ErrorResponse) Error() string { return fmt.Sprintf("%s: %s", er.Code, er.Message) } -func (as *AppService) RequestWebsocket(ctx context.Context, cmd *WebsocketRequest, response interface{}) error { +func (as *AppService) RequestWebsocket(ctx context.Context, cmd *WebsocketRequest, response any) error { cmd.ReqID = int(atomic.AddInt32(&as.websocketRequestID, 1)) respChan := make(chan *WebsocketCommand, 1) as.addWebsocketResponseWaiter(cmd.ReqID, respChan) defer as.removeWebsocketResponseWaiter(cmd.ReqID, respChan) - err := as.SendWebsocket(cmd) + err := as.SendWebsocket(ctx, cmd) if err != nil { return err } @@ -254,7 +257,7 @@ func (as *AppService) RequestWebsocket(ctx context.Context, cmd *WebsocketReques } } -func (as *AppService) unknownCommandHandler(cmd WebsocketCommand) (bool, interface{}) { +func (as *AppService) unknownCommandHandler(cmd WebsocketCommand) (bool, any) { zerolog.Ctx(cmd.Ctx).Warn().Msg("No handler for websocket command") return false, fmt.Errorf("unknown request type") } @@ -278,14 +281,28 @@ func (as *AppService) defaultHandleWebsocketTransaction(ctx context.Context, msg return true, &WebsocketTransactionResponse{TxnID: msg.TxnID} } -func (as *AppService) consumeWebsocket(stopFunc func(error), ws *websocket.Conn) { +func (as *AppService) consumeWebsocket(ctx context.Context, stopFunc func(error), ws *websocket.Conn) { defer stopFunc(ErrWebsocketUnknownError) - ctx := context.Background() for { - var msg WebsocketMessage - err := ws.ReadJSON(&msg) + msgType, reader, err := ws.Reader(ctx) if err != nil { - as.Log.Debug().Err(err).Msg("Error reading from websocket") + as.Log.Debug().Err(err).Msg("Error getting reader from websocket") + stopFunc(parseCloseError(err)) + return + } else if msgType != websocket.MessageText { + as.Log.Debug().Msg("Ignoring non-text message from websocket") + continue + } + data, err := io.ReadAll(reader) + if err != nil { + as.Log.Debug().Err(err).Msg("Error reading data from websocket") + stopFunc(parseCloseError(err)) + return + } + var msg WebsocketMessage + err = json.Unmarshal(data, &msg) + if err != nil { + as.Log.Debug().Err(err).Msg("Error parsing JSON received from websocket") stopFunc(parseCloseError(err)) return } @@ -296,11 +313,11 @@ func (as *AppService) consumeWebsocket(stopFunc func(error), ws *websocket.Conn) with = with.Str("transaction_id", msg.TxnID) } log := with.Logger() - ctx = log.WithContext(ctx) + ctx := log.WithContext(ctx) if msg.Command == "" || msg.Command == "transaction" { ok, resp := as.WebsocketTransactionHandler(ctx, msg) go func() { - err := as.SendWebsocket(msg.MakeResponse(ok, resp)) + err := as.SendWebsocket(ctx, msg.MakeResponse(ok, resp)) if err != nil { log.Warn().Err(err).Msg("Failed to send response to websocket transaction") } else { @@ -332,7 +349,7 @@ func (as *AppService) consumeWebsocket(stopFunc func(error), ws *websocket.Conn) } go func() { okResp, data := handler(msg.WebsocketCommand) - err := as.SendWebsocket(msg.MakeResponse(okResp, data)) + err := as.SendWebsocket(ctx, msg.MakeResponse(okResp, data)) if err != nil { log.Error().Err(err).Msg("Failed to send response to websocket command") } else if okResp { @@ -345,7 +362,7 @@ func (as *AppService) consumeWebsocket(stopFunc func(error), ws *websocket.Conn) } } -func (as *AppService) StartWebsocket(baseURL string, onConnect func()) error { +func (as *AppService) StartWebsocket(ctx context.Context, baseURL string, onConnect func()) error { var parsed *url.URL if baseURL != "" { var err error @@ -357,26 +374,29 @@ func (as *AppService) StartWebsocket(baseURL string, onConnect func()) error { copiedURL := *as.hsURLForClient parsed = &copiedURL } - parsed.Path = filepath.Join(parsed.Path, "_matrix/client/unstable/fi.mau.as_sync") + parsed.Path = path.Join(parsed.Path, "_matrix/client/unstable/fi.mau.as_sync") if parsed.Scheme == "http" { parsed.Scheme = "ws" } else if parsed.Scheme == "https" { parsed.Scheme = "wss" } - ws, resp, err := websocket.DefaultDialer.Dial(parsed.String(), http.Header{ - "Authorization": []string{fmt.Sprintf("Bearer %s", as.Registration.AppToken)}, - "User-Agent": []string{as.BotClient().UserAgent}, + ws, resp, err := websocket.Dial(ctx, parsed.String(), &websocket.DialOptions{ + HTTPClient: as.HTTPClient, + HTTPHeader: http.Header{ + "Authorization": []string{fmt.Sprintf("Bearer %s", as.Registration.AppToken)}, + "User-Agent": []string{as.BotClient().UserAgent}, - "X-Mautrix-Process-ID": []string{as.ProcessID}, - "X-Mautrix-Websocket-Version": []string{"3"}, + "X-Mautrix-Process-ID": []string{as.ProcessID}, + "X-Mautrix-Websocket-Version": []string{"3"}, + }, }) if resp != nil && resp.StatusCode >= 400 { - var errResp Error + var errResp mautrix.RespError err = json.NewDecoder(resp.Body).Decode(&errResp) if err != nil { return fmt.Errorf("websocket request returned HTTP %d with non-JSON body", resp.StatusCode) } else { - return fmt.Errorf("websocket request returned %s (HTTP %d): %s", errResp.ErrorCode, resp.StatusCode, errResp.Message) + return fmt.Errorf("websocket request returned %s (HTTP %d): %s", errResp.ErrCode, resp.StatusCode, errResp.Err) } } else if err != nil { return fmt.Errorf("failed to open websocket: %w", err) @@ -399,12 +419,13 @@ func (as *AppService) StartWebsocket(baseURL string, onConnect func()) error { } }) } + ws.SetReadLimit(50 * 1024 * 1024) as.ws = ws as.StopWebsocket = stopFunc as.PrepareWebsocket() as.Log.Debug().Msg("Appservice transaction websocket opened") - go as.consumeWebsocket(stopFunc, ws) + go as.consumeWebsocket(ctx, stopFunc, ws) var onConnectDone atomic.Bool if onConnect != nil { @@ -426,12 +447,7 @@ func (as *AppService) StartWebsocket(baseURL string, onConnect func()) error { as.ws = nil } - _ = ws.SetWriteDeadline(time.Now().Add(3 * time.Second)) - err = ws.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseGoingAway, "")) - if err != nil && !errors.Is(err, websocket.ErrCloseSent) { - as.Log.Warn().Err(err).Msg("Error writing close message to websocket") - } - err = ws.Close() + err = ws.Close(websocket.StatusGoingAway, "") if err != nil { as.Log.Warn().Err(err).Msg("Error closing websocket") } diff --git a/bridge/bridge.go b/bridge/bridge.go deleted file mode 100644 index 17a4a30c..00000000 --- a/bridge/bridge.go +++ /dev/null @@ -1,936 +0,0 @@ -// Copyright (c) 2023 Tulir Asokan -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. - -package bridge - -import ( - "context" - "encoding/json" - "errors" - "fmt" - "net/url" - "os" - "os/signal" - "runtime" - "strings" - "sync" - "syscall" - "time" - - "github.com/lib/pq" - "github.com/mattn/go-sqlite3" - "github.com/rs/zerolog" - "go.mau.fi/util/configupgrade" - "go.mau.fi/util/dbutil" - _ "go.mau.fi/util/dbutil/litestream" - "go.mau.fi/util/exzerolog" - "gopkg.in/yaml.v3" - flag "maunium.net/go/mauflag" - - "maunium.net/go/mautrix" - "maunium.net/go/mautrix/appservice" - "maunium.net/go/mautrix/bridge/bridgeconfig" - "maunium.net/go/mautrix/bridge/status" - "maunium.net/go/mautrix/event" - "maunium.net/go/mautrix/id" - "maunium.net/go/mautrix/sqlstatestore" -) - -var configPath = flag.MakeFull("c", "config", "The path to your config file.", "config.yaml").String() -var dontSaveConfig = flag.MakeFull("n", "no-update", "Don't save updated config to disk.", "false").Bool() -var registrationPath = flag.MakeFull("r", "registration", "The path where to save the appservice registration.", "registration.yaml").String() -var generateRegistration = flag.MakeFull("g", "generate-registration", "Generate registration and quit.", "false").Bool() -var version = flag.MakeFull("v", "version", "View bridge version and quit.", "false").Bool() -var versionJSON = flag.Make().LongKey("version-json").Usage("Print a JSON object representing the bridge version and quit.").Default("false").Bool() -var ignoreUnsupportedDatabase = flag.Make().LongKey("ignore-unsupported-database").Usage("Run even if the database schema is too new").Default("false").Bool() -var ignoreForeignTables = flag.Make().LongKey("ignore-foreign-tables").Usage("Run even if the database contains tables from other programs (like Synapse)").Default("false").Bool() -var ignoreUnsupportedServer = flag.Make().LongKey("ignore-unsupported-server").Usage("Run even if the Matrix homeserver is outdated").Default("false").Bool() -var wantHelp, _ = flag.MakeHelpFlag() - -var _ appservice.StateStore = (*sqlstatestore.SQLStateStore)(nil) - -type Portal interface { - IsEncrypted() bool - IsPrivateChat() bool - MarkEncrypted() - MainIntent() *appservice.IntentAPI - - ReceiveMatrixEvent(user User, evt *event.Event) - UpdateBridgeInfo(ctx context.Context) -} - -type MembershipHandlingPortal interface { - Portal - HandleMatrixLeave(sender User, evt *event.Event) - HandleMatrixKick(sender User, ghost Ghost, evt *event.Event) - HandleMatrixInvite(sender User, ghost Ghost, evt *event.Event) -} - -type ReadReceiptHandlingPortal interface { - Portal - HandleMatrixReadReceipt(sender User, eventID id.EventID, receipt event.ReadReceipt) -} - -type TypingPortal interface { - Portal - HandleMatrixTyping(userIDs []id.UserID) -} - -type MetaHandlingPortal interface { - Portal - HandleMatrixMeta(sender User, evt *event.Event) -} - -type DisappearingPortal interface { - Portal - ScheduleDisappearing() -} - -type PowerLevelHandlingPortal interface { - Portal - HandleMatrixPowerLevels(sender User, evt *event.Event) -} - -type JoinRuleHandlingPortal interface { - Portal - HandleMatrixJoinRule(sender User, evt *event.Event) -} - -type BanHandlingPortal interface { - Portal - HandleMatrixBan(sender User, ghost Ghost, evt *event.Event) - HandleMatrixUnban(sender User, ghost Ghost, evt *event.Event) -} - -type KnockHandlingPortal interface { - Portal - HandleMatrixKnock(sender User, evt *event.Event) - HandleMatrixRetractKnock(sender User, evt *event.Event) - HandleMatrixAcceptKnock(sender User, ghost Ghost, evt *event.Event) - HandleMatrixRejectKnock(sender User, ghost Ghost, evt *event.Event) -} - -type InviteHandlingPortal interface { - Portal - HandleMatrixAcceptInvite(sender User, evt *event.Event) - HandleMatrixRejectInvite(sender User, evt *event.Event) - HandleMatrixRetractInvite(sender User, ghost Ghost, evt *event.Event) -} - -type User interface { - GetPermissionLevel() bridgeconfig.PermissionLevel - IsLoggedIn() bool - GetManagementRoomID() id.RoomID - SetManagementRoom(id.RoomID) - GetMXID() id.UserID - GetIDoublePuppet() DoublePuppet - GetIGhost() Ghost -} - -type DoublePuppet interface { - CustomIntent() *appservice.IntentAPI - SwitchCustomMXID(accessToken string, userID id.UserID) error - ClearCustomMXID() -} - -type Ghost interface { - DoublePuppet - DefaultIntent() *appservice.IntentAPI - GetMXID() id.UserID -} - -type GhostWithProfile interface { - Ghost - GetDisplayname() string - GetAvatarURL() id.ContentURI -} - -type ChildOverride interface { - GetExampleConfig() string - GetConfigPtr() interface{} - - Init() - Start() - Stop() - - GetIPortal(id.RoomID) Portal - GetAllIPortals() []Portal - GetIUser(id id.UserID, create bool) User - IsGhost(id.UserID) bool - GetIGhost(id.UserID) Ghost - CreatePrivatePortal(id.RoomID, User, Ghost) -} - -type ConfigValidatingBridge interface { - ChildOverride - ValidateConfig() error -} - -type FlagHandlingBridge interface { - ChildOverride - HandleFlags() bool -} - -type PreInitableBridge interface { - ChildOverride - PreInit() -} - -type WebsocketStartingBridge interface { - ChildOverride - OnWebsocketConnect() -} - -type CSFeatureRequirer interface { - CheckFeatures(versions *mautrix.RespVersions) (string, bool) -} - -type Bridge struct { - Name string - URL string - Description string - Version string - ProtocolName string - BeeperServiceName string - BeeperNetworkName string - - AdditionalShortFlags string - AdditionalLongFlags string - - VersionDesc string - LinkifiedVersion string - BuildTime string - commit string - baseVersion string - - PublicHSAddress *url.URL - - DoublePuppet *doublePuppetUtil - - AS *appservice.AppService - EventProcessor *appservice.EventProcessor - CommandProcessor CommandProcessor - MatrixHandler *MatrixHandler - Bot *appservice.IntentAPI - Config bridgeconfig.BaseConfig - ConfigPath string - RegistrationPath string - SaveConfig bool - ConfigUpgrader configupgrade.BaseUpgrader - DB *dbutil.Database - StateStore *sqlstatestore.SQLStateStore - Crypto Crypto - CryptoPickleKey string - - ZLog *zerolog.Logger - - MediaConfig mautrix.RespMediaConfig - SpecVersions mautrix.RespVersions - - Child ChildOverride - - manualStop chan int - Stopping bool - - latestState *status.BridgeState - - Websocket bool - wsStopPinger chan struct{} - wsStarted chan struct{} - wsStopped chan struct{} - wsShortCircuitReconnectBackoff chan struct{} - wsStartupWait *sync.WaitGroup -} - -type Crypto interface { - HandleMemberEvent(context.Context, *event.Event) - Decrypt(context.Context, *event.Event) (*event.Event, error) - Encrypt(context.Context, id.RoomID, event.Type, *event.Content) error - WaitForSession(context.Context, id.RoomID, id.SenderKey, id.SessionID, time.Duration) bool - RequestSession(context.Context, id.RoomID, id.SenderKey, id.SessionID, id.UserID, id.DeviceID) - ResetSession(context.Context, id.RoomID) - Init(ctx context.Context) error - Start() - Stop() - Reset(ctx context.Context, startAfterReset bool) - Client() *mautrix.Client - ShareKeys(context.Context) error -} - -func (br *Bridge) GenerateRegistration() { - if !br.SaveConfig { - // We need to save the generated as_token and hs_token in the config - _, _ = fmt.Fprintln(os.Stderr, "--no-update is not compatible with --generate-registration") - os.Exit(5) - } else if br.Config.Homeserver.Domain == "example.com" { - _, _ = fmt.Fprintln(os.Stderr, "Homeserver domain is not set") - os.Exit(20) - } - reg := br.Config.GenerateRegistration() - err := reg.Save(br.RegistrationPath) - if err != nil { - _, _ = fmt.Fprintln(os.Stderr, "Failed to save registration:", err) - os.Exit(21) - } - - updateTokens := func(helper configupgrade.Helper) { - helper.Set(configupgrade.Str, reg.AppToken, "appservice", "as_token") - helper.Set(configupgrade.Str, reg.ServerToken, "appservice", "hs_token") - } - _, _, err = configupgrade.Do(br.ConfigPath, true, br.ConfigUpgrader, configupgrade.SimpleUpgrader(updateTokens)) - if err != nil { - _, _ = fmt.Fprintln(os.Stderr, "Failed to save config:", err) - os.Exit(22) - } - fmt.Println("Registration generated. See https://docs.mau.fi/bridges/general/registering-appservices.html for instructions on installing the registration.") - os.Exit(0) -} - -func (br *Bridge) InitVersion(tag, commit, buildTime string) { - br.baseVersion = br.Version - if len(tag) > 0 && tag[0] == 'v' { - tag = tag[1:] - } - if tag != br.Version { - suffix := "" - if !strings.HasSuffix(br.Version, "+dev") { - suffix = "+dev" - } - if len(commit) > 8 { - br.Version = fmt.Sprintf("%s%s.%s", br.Version, suffix, commit[:8]) - } else { - br.Version = fmt.Sprintf("%s%s.unknown", br.Version, suffix) - } - } - - br.LinkifiedVersion = fmt.Sprintf("v%s", br.Version) - if tag == br.Version { - br.LinkifiedVersion = fmt.Sprintf("[v%s](%s/releases/v%s)", br.Version, br.URL, tag) - } else if len(commit) > 8 { - br.LinkifiedVersion = strings.Replace(br.LinkifiedVersion, commit[:8], fmt.Sprintf("[%s](%s/commit/%s)", commit[:8], br.URL, commit), 1) - } - mautrix.DefaultUserAgent = fmt.Sprintf("%s/%s %s", br.Name, br.Version, mautrix.DefaultUserAgent) - br.VersionDesc = fmt.Sprintf("%s %s (%s with %s)", br.Name, br.Version, buildTime, runtime.Version()) - br.commit = commit - br.BuildTime = buildTime -} - -var MinSpecVersion = mautrix.SpecV14 - -func (br *Bridge) logInitialRequestError(err error, defaultMessage string) { - if errors.Is(err, mautrix.MUnknownToken) { - br.ZLog.WithLevel(zerolog.FatalLevel).Msg("The as_token was not accepted. Is the registration file installed in your homeserver correctly?") - br.ZLog.Info().Msg("See https://docs.mau.fi/faq/as-token for more info") - } else if errors.Is(err, mautrix.MExclusive) { - br.ZLog.WithLevel(zerolog.FatalLevel).Msg("The as_token was accepted, but the /register request was not. Are the homeserver domain, bot username and username template in the config correct, and do they match the values in the registration?") - br.ZLog.Info().Msg("See https://docs.mau.fi/faq/as-register for more info") - } else { - br.ZLog.WithLevel(zerolog.FatalLevel).Err(err).Msg(defaultMessage) - } -} - -func (br *Bridge) ensureConnection(ctx context.Context) { - for { - versions, err := br.Bot.Versions(ctx) - if err != nil { - if errors.Is(err, mautrix.MForbidden) { - br.ZLog.Debug().Msg("M_FORBIDDEN in /versions, trying to register before retrying") - err = br.Bot.EnsureRegistered(ctx) - if err != nil { - br.logInitialRequestError(err, "Failed to register after /versions failed with M_FORBIDDEN") - os.Exit(16) - } - } else { - br.ZLog.Err(err).Msg("Failed to connect to homeserver, retrying in 10 seconds...") - time.Sleep(10 * time.Second) - } - } else { - br.SpecVersions = *versions - *br.AS.SpecVersions = *versions - break - } - } - - unsupportedServerLogLevel := zerolog.FatalLevel - if *ignoreUnsupportedServer { - unsupportedServerLogLevel = zerolog.ErrorLevel - } - if br.Config.Homeserver.Software == bridgeconfig.SoftwareHungry && !br.SpecVersions.Supports(mautrix.BeeperFeatureHungry) { - br.ZLog.WithLevel(zerolog.FatalLevel).Msg("The config claims the homeserver is hungryserv, but the /versions response didn't confirm it") - os.Exit(18) - } else if !br.SpecVersions.ContainsGreaterOrEqual(MinSpecVersion) { - br.ZLog.WithLevel(unsupportedServerLogLevel). - Stringer("server_supports", br.SpecVersions.GetLatest()). - Stringer("bridge_requires", MinSpecVersion). - Msg("The homeserver is outdated (supported spec versions are below minimum required by bridge)") - if !*ignoreUnsupportedServer { - os.Exit(18) - } - } else if fr, ok := br.Child.(CSFeatureRequirer); ok { - if msg, hasFeatures := fr.CheckFeatures(&br.SpecVersions); !hasFeatures { - br.ZLog.WithLevel(unsupportedServerLogLevel).Msg(msg) - if !*ignoreUnsupportedServer { - os.Exit(18) - } - } - } - - resp, err := br.Bot.Whoami(ctx) - if err != nil { - br.logInitialRequestError(err, "/whoami request failed with unknown error") - os.Exit(16) - } else if resp.UserID != br.Bot.UserID { - br.ZLog.WithLevel(zerolog.FatalLevel). - Stringer("got_user_id", resp.UserID). - Stringer("expected_user_id", br.Bot.UserID). - Msg("Unexpected user ID in whoami call") - os.Exit(17) - } - - if br.Websocket { - br.ZLog.Debug().Msg("Websocket mode: no need to check status of homeserver -> bridge connection") - return - } else if !br.SpecVersions.Supports(mautrix.FeatureAppservicePing) { - br.ZLog.Debug().Msg("Homeserver does not support checking status of homeserver -> bridge connection") - return - } - var pingResp *mautrix.RespAppservicePing - var txnID string - var retryCount int - const maxRetries = 6 - for { - txnID = br.Bot.TxnID() - pingResp, err = br.Bot.AppservicePing(ctx, br.Config.AppService.ID, txnID) - if err == nil { - break - } - var httpErr mautrix.HTTPError - var pingErrBody string - if errors.As(err, &httpErr) && httpErr.RespError != nil { - if val, ok := httpErr.RespError.ExtraData["body"].(string); ok { - pingErrBody = strings.TrimSpace(val) - } - } - outOfRetries := retryCount >= maxRetries - level := zerolog.ErrorLevel - if outOfRetries { - level = zerolog.FatalLevel - } - evt := br.ZLog.WithLevel(level).Err(err).Str("txn_id", txnID) - if pingErrBody != "" { - bodyBytes := []byte(pingErrBody) - if json.Valid(bodyBytes) { - evt.RawJSON("body", bodyBytes) - } else { - evt.Str("body", pingErrBody) - } - } - if outOfRetries { - evt.Msg("Homeserver -> bridge connection is not working") - br.ZLog.Info().Msg("See https://docs.mau.fi/faq/as-ping for more info") - os.Exit(13) - } - evt.Msg("Homeserver -> bridge connection is not working, retrying in 5 seconds...") - time.Sleep(5 * time.Second) - retryCount++ - } - br.ZLog.Debug(). - Str("txn_id", txnID). - Int64("duration_ms", pingResp.DurationMS). - Msg("Homeserver -> bridge connection works") -} - -func (br *Bridge) fetchMediaConfig(ctx context.Context) { - cfg, err := br.Bot.GetMediaConfig(ctx) - if err != nil { - br.ZLog.Warn().Err(err).Msg("Failed to fetch media config") - } else { - if cfg.UploadSize == 0 { - cfg.UploadSize = 50 * 1024 * 1024 - } - br.MediaConfig = *cfg - } -} - -func (br *Bridge) UpdateBotProfile(ctx context.Context) { - br.ZLog.Debug().Msg("Updating bot profile") - botConfig := &br.Config.AppService.Bot - - var err error - var mxc id.ContentURI - if botConfig.Avatar == "remove" { - err = br.Bot.SetAvatarURL(ctx, mxc) - } else if !botConfig.ParsedAvatar.IsEmpty() { - err = br.Bot.SetAvatarURL(ctx, botConfig.ParsedAvatar) - } - if err != nil { - br.ZLog.Warn().Err(err).Msg("Failed to update bot avatar") - } - - if botConfig.Displayname == "remove" { - err = br.Bot.SetDisplayName(ctx, "") - } else if len(botConfig.Displayname) > 0 { - err = br.Bot.SetDisplayName(ctx, botConfig.Displayname) - } - if err != nil { - br.ZLog.Warn().Err(err).Msg("Failed to update bot displayname") - } - - if br.SpecVersions.Supports(mautrix.BeeperFeatureArbitraryProfileMeta) && br.BeeperNetworkName != "" { - br.ZLog.Debug().Msg("Setting contact info on the appservice bot") - br.Bot.BeeperUpdateProfile(ctx, map[string]any{ - "com.beeper.bridge.service": br.BeeperServiceName, - "com.beeper.bridge.network": br.BeeperNetworkName, - "com.beeper.bridge.is_bridge_bot": true, - }) - } -} - -func (br *Bridge) loadConfig() { - configData, upgraded, err := configupgrade.Do(br.ConfigPath, br.SaveConfig, br.ConfigUpgrader) - if err != nil { - _, _ = fmt.Fprintln(os.Stderr, "Error updating config:", err) - if configData == nil { - os.Exit(10) - } - } - - target := br.Child.GetConfigPtr() - if !upgraded { - // Fallback: if config upgrading failed, load example config for base values - err = yaml.Unmarshal([]byte(br.Child.GetExampleConfig()), &target) - if err != nil { - _, _ = fmt.Fprintln(os.Stderr, "Failed to unmarshal example config:", err) - os.Exit(10) - } - } - err = yaml.Unmarshal(configData, target) - if err != nil { - _, _ = fmt.Fprintln(os.Stderr, "Failed to parse config:", err) - os.Exit(10) - } -} - -func (br *Bridge) validateConfig() error { - switch { - case br.Config.Homeserver.Address == "https://matrix.example.com": - return errors.New("homeserver.address not configured") - case br.Config.Homeserver.Domain == "example.com": - return errors.New("homeserver.domain not configured") - case !bridgeconfig.AllowedHomeserverSoftware[br.Config.Homeserver.Software]: - return errors.New("invalid value for homeserver.software (use `standard` if you don't know what the field is for)") - case br.Config.AppService.ASToken == "This value is generated when generating the registration": - return errors.New("appservice.as_token not configured. Did you forget to generate the registration? ") - case br.Config.AppService.HSToken == "This value is generated when generating the registration": - return errors.New("appservice.hs_token not configured. Did you forget to generate the registration? ") - case br.Config.AppService.Database.URI == "postgres://user:password@host/database?sslmode=disable": - return errors.New("appservice.database not configured") - default: - err := br.Config.Bridge.Validate() - if err != nil { - return err - } - validator, ok := br.Child.(ConfigValidatingBridge) - if ok { - return validator.ValidateConfig() - } - return nil - } -} - -func (br *Bridge) getProfile(userID id.UserID, roomID id.RoomID) *event.MemberEventContent { - ghost := br.Child.GetIGhost(userID) - if ghost == nil { - return nil - } - profilefulGhost, ok := ghost.(GhostWithProfile) - if ok { - return &event.MemberEventContent{ - Displayname: profilefulGhost.GetDisplayname(), - AvatarURL: profilefulGhost.GetAvatarURL().CUString(), - } - } - return nil -} - -func (br *Bridge) init() { - pib, ok := br.Child.(PreInitableBridge) - if ok { - pib.PreInit() - } - - var err error - - br.MediaConfig.UploadSize = 50 * 1024 * 1024 - - br.ZLog, err = br.Config.Logging.Compile() - if err != nil { - _, _ = fmt.Fprintln(os.Stderr, "Failed to initialize logger:", err) - os.Exit(12) - } - exzerolog.SetupDefaults(br.ZLog) - - br.DoublePuppet = &doublePuppetUtil{br: br, log: br.ZLog.With().Str("component", "double puppet").Logger()} - - err = br.validateConfig() - if err != nil { - br.ZLog.WithLevel(zerolog.FatalLevel).Err(err).Msg("Configuration error") - br.ZLog.Info().Msg("See https://docs.mau.fi/faq/field-unconfigured for more info") - os.Exit(11) - } - - br.ZLog.Info(). - Str("name", br.Name). - Str("version", br.Version). - Str("built_at", br.BuildTime). - Str("go_version", runtime.Version()). - Msg("Initializing bridge") - - br.ZLog.Debug().Msg("Initializing database connection") - dbConfig := br.Config.AppService.Database - if (dbConfig.Type == "sqlite3-fk-wal" || dbConfig.Type == "litestream") && dbConfig.MaxOpenConns != 1 && !strings.Contains(dbConfig.URI, "_txlock=immediate") { - var fixedExampleURI string - if !strings.HasPrefix(dbConfig.URI, "file:") { - fixedExampleURI = fmt.Sprintf("file:%s?_txlock=immediate", dbConfig.URI) - } else if !strings.ContainsRune(dbConfig.URI, '?') { - fixedExampleURI = fmt.Sprintf("%s?_txlock=immediate", dbConfig.URI) - } else { - fixedExampleURI = fmt.Sprintf("%s&_txlock=immediate", dbConfig.URI) - } - br.ZLog.Warn(). - Str("fixed_uri_example", fixedExampleURI). - Msg("Using SQLite without _txlock=immediate is not recommended") - } - br.DB, err = dbutil.NewFromConfig(br.Name, dbConfig, dbutil.ZeroLogger(br.ZLog.With().Str("db_section", "main").Logger())) - if err != nil { - br.ZLog.WithLevel(zerolog.FatalLevel).Err(err).Msg("Failed to initialize database connection") - if sqlError := (&sqlite3.Error{}); errors.As(err, sqlError) && sqlError.Code == sqlite3.ErrCorrupt { - os.Exit(18) - } - os.Exit(14) - } - br.DB.IgnoreUnsupportedDatabase = *ignoreUnsupportedDatabase - br.DB.IgnoreForeignTables = *ignoreForeignTables - - br.ZLog.Debug().Msg("Initializing state store") - br.StateStore = sqlstatestore.NewSQLStateStore(br.DB, dbutil.ZeroLogger(br.ZLog.With().Str("db_section", "matrix_state").Logger()), true) - - br.AS, err = appservice.CreateFull(appservice.CreateOpts{ - Registration: br.Config.AppService.GetRegistration(), - HomeserverDomain: br.Config.Homeserver.Domain, - HomeserverURL: br.Config.Homeserver.Address, - HostConfig: appservice.HostConfig{ - Hostname: br.Config.AppService.Hostname, - Port: br.Config.AppService.Port, - }, - StateStore: br.StateStore, - }) - if err != nil { - br.ZLog.WithLevel(zerolog.FatalLevel).Err(err). - Msg("Failed to initialize appservice") - os.Exit(15) - } - br.AS.Log = *br.ZLog - br.AS.DoublePuppetValue = br.Name - br.AS.GetProfile = br.getProfile - br.Bot = br.AS.BotIntent() - - br.ZLog.Debug().Msg("Initializing Matrix event processor") - br.EventProcessor = appservice.NewEventProcessor(br.AS) - if !br.Config.AppService.AsyncTransactions { - br.EventProcessor.ExecMode = appservice.Sync - } - br.ZLog.Debug().Msg("Initializing Matrix event handler") - br.MatrixHandler = NewMatrixHandler(br) - - br.Crypto = NewCryptoHelper(br) - - hsURL := br.Config.Homeserver.Address - if br.Config.Homeserver.PublicAddress != "" { - hsURL = br.Config.Homeserver.PublicAddress - } - br.PublicHSAddress, err = url.Parse(hsURL) - if err != nil { - br.ZLog.WithLevel(zerolog.FatalLevel).Err(err). - Str("input", hsURL). - Msg("Failed to parse public homeserver URL") - os.Exit(15) - } - - br.Child.Init() -} - -type zerologPQError pq.Error - -func (zpe *zerologPQError) MarshalZerologObject(evt *zerolog.Event) { - maybeStr := func(field, value string) { - if value != "" { - evt.Str(field, value) - } - } - maybeStr("severity", zpe.Severity) - if name := zpe.Code.Name(); name != "" { - evt.Str("code", name) - } else if zpe.Code != "" { - evt.Str("code", string(zpe.Code)) - } - //maybeStr("message", zpe.Message) - maybeStr("detail", zpe.Detail) - maybeStr("hint", zpe.Hint) - maybeStr("position", zpe.Position) - maybeStr("internal_position", zpe.InternalPosition) - maybeStr("internal_query", zpe.InternalQuery) - maybeStr("where", zpe.Where) - maybeStr("schema", zpe.Schema) - maybeStr("table", zpe.Table) - maybeStr("column", zpe.Column) - maybeStr("data_type_name", zpe.DataTypeName) - maybeStr("constraint", zpe.Constraint) - maybeStr("file", zpe.File) - maybeStr("line", zpe.Line) - maybeStr("routine", zpe.Routine) -} - -func (br *Bridge) LogDBUpgradeErrorAndExit(name string, err error) { - logEvt := br.ZLog.WithLevel(zerolog.FatalLevel). - Err(err). - Str("db_section", name) - var errWithLine *dbutil.PQErrorWithLine - if errors.As(err, &errWithLine) { - logEvt.Str("sql_line", errWithLine.Line) - } - var pqe *pq.Error - if errors.As(err, &pqe) { - logEvt.Object("pq_error", (*zerologPQError)(pqe)) - } - logEvt.Msg("Failed to initialize database") - if sqlError := (&sqlite3.Error{}); errors.As(err, sqlError) && sqlError.Code == sqlite3.ErrCorrupt { - os.Exit(18) - } else if errors.Is(err, dbutil.ErrForeignTables) { - br.ZLog.Info().Msg("See https://docs.mau.fi/faq/foreign-tables for more info") - } else if errors.Is(err, dbutil.ErrNotOwned) { - br.ZLog.Info().Msg("Sharing the same database with different programs is not supported") - } else if errors.Is(err, dbutil.ErrUnsupportedDatabaseVersion) { - br.ZLog.Info().Msg("Downgrading the bridge is not supported") - } - os.Exit(15) -} - -func (br *Bridge) WaitWebsocketConnected() { - if br.wsStartupWait != nil { - br.wsStartupWait.Wait() - } -} - -func (br *Bridge) start() { - br.ZLog.Debug().Msg("Running database upgrades") - err := br.DB.Upgrade(br.ZLog.With().Str("db_section", "main").Logger().WithContext(context.TODO())) - if err != nil { - br.LogDBUpgradeErrorAndExit("main", err) - } else if err = br.StateStore.Upgrade(br.ZLog.With().Str("db_section", "matrix_state").Logger().WithContext(context.TODO())); err != nil { - br.LogDBUpgradeErrorAndExit("matrix_state", err) - } - - if br.Config.Homeserver.Websocket || len(br.Config.Homeserver.WSProxy) > 0 { - br.Websocket = true - br.ZLog.Debug().Msg("Starting application service websocket") - var wg sync.WaitGroup - wg.Add(1) - br.wsStartupWait = &wg - br.wsShortCircuitReconnectBackoff = make(chan struct{}) - go br.startWebsocket(&wg) - } else if br.AS.Host.IsConfigured() { - br.ZLog.Debug().Msg("Starting application service HTTP server") - go br.AS.Start() - } else { - br.ZLog.WithLevel(zerolog.FatalLevel).Msg("Neither appservice HTTP listener nor websocket is enabled") - os.Exit(23) - } - br.ZLog.Debug().Msg("Checking connection to homeserver") - - ctx := br.ZLog.WithContext(context.Background()) - br.ensureConnection(ctx) - go br.fetchMediaConfig(ctx) - - if br.Crypto != nil { - err = br.Crypto.Init(ctx) - if err != nil { - br.ZLog.WithLevel(zerolog.FatalLevel).Err(err).Msg("Error initializing end-to-bridge encryption") - os.Exit(19) - } - } - - br.ZLog.Debug().Msg("Starting event processor") - br.EventProcessor.Start(ctx) - - go br.UpdateBotProfile(ctx) - if br.Crypto != nil { - go br.Crypto.Start() - } - - br.Child.Start() - br.WaitWebsocketConnected() - br.AS.Ready = true - - if br.Config.Bridge.GetResendBridgeInfo() { - go br.ResendBridgeInfo() - } - if br.Websocket && br.Config.Homeserver.WSPingInterval > 0 { - br.wsStopPinger = make(chan struct{}, 1) - go br.websocketServerPinger() - } -} - -func (br *Bridge) ResendBridgeInfo() { - if !br.SaveConfig { - br.ZLog.Warn().Msg("Not setting resend_bridge_info to false in config due to --no-update flag") - } else { - _, _, err := configupgrade.Do(br.ConfigPath, true, br.ConfigUpgrader, configupgrade.SimpleUpgrader(func(helper configupgrade.Helper) { - helper.Set(configupgrade.Bool, "false", "bridge", "resend_bridge_info") - })) - if err != nil { - br.ZLog.Err(err).Msg("Failed to save config after setting resend_bridge_info to false") - } - } - br.ZLog.Info().Msg("Re-sending bridge info state event to all portals") - for _, portal := range br.Child.GetAllIPortals() { - portal.UpdateBridgeInfo(context.TODO()) - } - br.ZLog.Info().Msg("Finished re-sending bridge info state events") -} - -func sendStopSignal(ch chan struct{}) { - if ch != nil { - select { - case ch <- struct{}{}: - default: - } - } -} - -func (br *Bridge) stop() { - br.Stopping = true - if br.Crypto != nil { - br.Crypto.Stop() - } - waitForWS := false - if br.AS.StopWebsocket != nil { - br.ZLog.Debug().Msg("Stopping application service websocket") - br.AS.StopWebsocket(appservice.ErrWebsocketManualStop) - waitForWS = true - } - br.AS.Stop() - sendStopSignal(br.wsStopPinger) - sendStopSignal(br.wsShortCircuitReconnectBackoff) - br.EventProcessor.Stop() - br.Child.Stop() - err := br.DB.Close() - if err != nil { - br.ZLog.Warn().Err(err).Msg("Error closing database") - } - if waitForWS { - select { - case <-br.wsStopped: - case <-time.After(4 * time.Second): - br.ZLog.Warn().Msg("Timed out waiting for websocket to close") - } - } -} - -func (br *Bridge) ManualStop(exitCode int) { - if br.manualStop != nil { - br.manualStop <- exitCode - } else { - os.Exit(exitCode) - } -} - -type VersionJSONOutput struct { - Name string - URL string - - Version string - IsRelease bool - Commit string - FormattedVersion string - BuildTime string - - OS string - Arch string - - Mautrix struct { - Version string - Commit string - } -} - -func (br *Bridge) Main() { - flag.SetHelpTitles( - fmt.Sprintf("%s - %s", br.Name, br.Description), - fmt.Sprintf("%s [-hgvn%s] [-c ] [-r ]%s", br.Name, br.AdditionalShortFlags, br.AdditionalLongFlags)) - err := flag.Parse() - br.ConfigPath = *configPath - br.RegistrationPath = *registrationPath - br.SaveConfig = !*dontSaveConfig - if err != nil { - _, _ = fmt.Fprintln(os.Stderr, err) - flag.PrintHelp() - os.Exit(1) - } else if *wantHelp { - flag.PrintHelp() - os.Exit(0) - } else if *version { - fmt.Println(br.VersionDesc) - return - } else if *versionJSON { - output := VersionJSONOutput{ - URL: br.URL, - Name: br.Name, - - Version: br.baseVersion, - IsRelease: br.Version == br.baseVersion, - Commit: br.commit, - FormattedVersion: br.Version, - BuildTime: br.BuildTime, - - OS: runtime.GOOS, - Arch: runtime.GOARCH, - } - output.Mautrix.Commit = mautrix.Commit - output.Mautrix.Version = mautrix.Version - _ = json.NewEncoder(os.Stdout).Encode(output) - return - } else if flagHandler, ok := br.Child.(FlagHandlingBridge); ok && flagHandler.HandleFlags() { - return - } - - br.loadConfig() - - if *generateRegistration { - br.GenerateRegistration() - return - } - - br.manualStop = make(chan int, 1) - br.init() - br.ZLog.Info().Msg("Bridge initialization complete, starting...") - br.start() - br.ZLog.Info().Msg("Bridge started!") - - c := make(chan os.Signal, 1) - signal.Notify(c, os.Interrupt, syscall.SIGTERM) - var exitCode int - select { - case <-c: - br.ZLog.Info().Msg("Interrupt received, stopping...") - case exitCode = <-br.manualStop: - br.ZLog.Info().Int("exit_code", exitCode).Msg("Manual stop requested") - } - - br.stop() - br.ZLog.Info().Msg("Bridge stopped.") - os.Exit(exitCode) -} diff --git a/bridge/bridgeconfig/config.go b/bridge/bridgeconfig/config.go deleted file mode 100644 index dfb6b7e5..00000000 --- a/bridge/bridgeconfig/config.go +++ /dev/null @@ -1,337 +0,0 @@ -// Copyright (c) 2023 Tulir Asokan -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. - -package bridgeconfig - -import ( - "fmt" - "os" - "path/filepath" - "regexp" - "strings" - - "github.com/rs/zerolog" - up "go.mau.fi/util/configupgrade" - "go.mau.fi/util/dbutil" - "go.mau.fi/util/random" - "go.mau.fi/zeroconfig" - "gopkg.in/yaml.v3" - - "maunium.net/go/mautrix/appservice" - "maunium.net/go/mautrix/id" -) - -type HomeserverSoftware string - -const ( - SoftwareStandard HomeserverSoftware = "standard" - SoftwareAsmux HomeserverSoftware = "asmux" - SoftwareHungry HomeserverSoftware = "hungry" -) - -var AllowedHomeserverSoftware = map[HomeserverSoftware]bool{ - SoftwareStandard: true, - SoftwareAsmux: true, - SoftwareHungry: true, -} - -type HomeserverConfig struct { - Address string `yaml:"address"` - Domain string `yaml:"domain"` - AsyncMedia bool `yaml:"async_media"` - - PublicAddress string `yaml:"public_address,omitempty"` - - Software HomeserverSoftware `yaml:"software"` - - StatusEndpoint string `yaml:"status_endpoint"` - MessageSendCheckpointEndpoint string `yaml:"message_send_checkpoint_endpoint"` - - Websocket bool `yaml:"websocket"` - WSProxy string `yaml:"websocket_proxy"` - WSPingInterval int `yaml:"ping_interval_seconds"` -} - -type AppserviceConfig struct { - Address string `yaml:"address"` - Hostname string `yaml:"hostname"` - Port uint16 `yaml:"port"` - - Database dbutil.Config `yaml:"database"` - - ID string `yaml:"id"` - Bot BotUserConfig `yaml:"bot"` - - ASToken string `yaml:"as_token"` - HSToken string `yaml:"hs_token"` - - EphemeralEvents bool `yaml:"ephemeral_events"` - AsyncTransactions bool `yaml:"async_transactions"` -} - -func (config *BaseConfig) MakeUserIDRegex(matcher string) *regexp.Regexp { - usernamePlaceholder := strings.ToLower(random.String(16)) - usernameTemplate := fmt.Sprintf("@%s:%s", - config.Bridge.FormatUsername(usernamePlaceholder), - config.Homeserver.Domain) - usernameTemplate = regexp.QuoteMeta(usernameTemplate) - usernameTemplate = strings.Replace(usernameTemplate, usernamePlaceholder, matcher, 1) - usernameTemplate = fmt.Sprintf("^%s$", usernameTemplate) - return regexp.MustCompile(usernameTemplate) -} - -// GenerateRegistration generates a registration file for the homeserver. -func (config *BaseConfig) GenerateRegistration() *appservice.Registration { - registration := appservice.CreateRegistration() - config.AppService.HSToken = registration.ServerToken - config.AppService.ASToken = registration.AppToken - config.AppService.copyToRegistration(registration) - - registration.SenderLocalpart = random.String(32) - botRegex := regexp.MustCompile(fmt.Sprintf("^@%s:%s$", - regexp.QuoteMeta(config.AppService.Bot.Username), - regexp.QuoteMeta(config.Homeserver.Domain))) - registration.Namespaces.UserIDs.Register(botRegex, true) - registration.Namespaces.UserIDs.Register(config.MakeUserIDRegex(".*"), true) - - return registration -} - -func (config *BaseConfig) MakeAppService() *appservice.AppService { - as := appservice.Create() - as.HomeserverDomain = config.Homeserver.Domain - _ = as.SetHomeserverURL(config.Homeserver.Address) - as.Host.Hostname = config.AppService.Hostname - as.Host.Port = config.AppService.Port - as.Registration = config.AppService.GetRegistration() - return as -} - -// GetRegistration copies the data from the bridge config into an *appservice.Registration struct. -// This can't be used with the homeserver, see GenerateRegistration for generating files for the homeserver. -func (asc *AppserviceConfig) GetRegistration() *appservice.Registration { - reg := &appservice.Registration{} - asc.copyToRegistration(reg) - reg.SenderLocalpart = asc.Bot.Username - reg.ServerToken = asc.HSToken - reg.AppToken = asc.ASToken - return reg -} - -func (asc *AppserviceConfig) copyToRegistration(registration *appservice.Registration) { - registration.ID = asc.ID - registration.URL = asc.Address - falseVal := false - registration.RateLimited = &falseVal - registration.EphemeralEvents = asc.EphemeralEvents - registration.SoruEphemeralEvents = asc.EphemeralEvents -} - -type BotUserConfig struct { - Username string `yaml:"username"` - Displayname string `yaml:"displayname"` - Avatar string `yaml:"avatar"` - - ParsedAvatar id.ContentURI `yaml:"-"` -} - -type serializableBUC BotUserConfig - -func (buc *BotUserConfig) UnmarshalYAML(unmarshal func(interface{}) error) error { - var sbuc serializableBUC - err := unmarshal(&sbuc) - if err != nil { - return err - } - *buc = (BotUserConfig)(sbuc) - if buc.Avatar != "" && buc.Avatar != "remove" { - buc.ParsedAvatar, err = id.ParseContentURI(buc.Avatar) - if err != nil { - return fmt.Errorf("%w in bot avatar", err) - } - } - return nil -} - -type BridgeConfig interface { - FormatUsername(username string) string - GetEncryptionConfig() EncryptionConfig - GetCommandPrefix() string - GetManagementRoomTexts() ManagementRoomTexts - GetDoublePuppetConfig() DoublePuppetConfig - GetResendBridgeInfo() bool - EnableMessageStatusEvents() bool - EnableMessageErrorNotices() bool - Validate() error -} - -type DoublePuppetConfig struct { - ServerMap map[string]string `yaml:"double_puppet_server_map"` - AllowDiscovery bool `yaml:"double_puppet_allow_discovery"` - SharedSecretMap map[string]string `yaml:"login_shared_secret_map"` -} - -type EncryptionConfig struct { - Allow bool `yaml:"allow"` - Default bool `yaml:"default"` - Require bool `yaml:"require"` - Appservice bool `yaml:"appservice"` - - PlaintextMentions bool `yaml:"plaintext_mentions"` - - DeleteKeys struct { - DeleteOutboundOnAck bool `yaml:"delete_outbound_on_ack"` - DontStoreOutbound bool `yaml:"dont_store_outbound"` - RatchetOnDecrypt bool `yaml:"ratchet_on_decrypt"` - DeleteFullyUsedOnDecrypt bool `yaml:"delete_fully_used_on_decrypt"` - DeletePrevOnNewSession bool `yaml:"delete_prev_on_new_session"` - DeleteOnDeviceDelete bool `yaml:"delete_on_device_delete"` - PeriodicallyDeleteExpired bool `yaml:"periodically_delete_expired"` - DeleteOutdatedInbound bool `yaml:"delete_outdated_inbound"` - } `yaml:"delete_keys"` - - VerificationLevels struct { - Receive id.TrustState `yaml:"receive"` - Send id.TrustState `yaml:"send"` - Share id.TrustState `yaml:"share"` - } `yaml:"verification_levels"` - AllowKeySharing bool `yaml:"allow_key_sharing"` - - Rotation struct { - EnableCustom bool `yaml:"enable_custom"` - Milliseconds int64 `yaml:"milliseconds"` - Messages int `yaml:"messages"` - - DisableDeviceChangeKeyRotation bool `yaml:"disable_device_change_key_rotation"` - } `yaml:"rotation"` -} - -type ManagementRoomTexts struct { - Welcome string `yaml:"welcome"` - WelcomeConnected string `yaml:"welcome_connected"` - WelcomeUnconnected string `yaml:"welcome_unconnected"` - AdditionalHelp string `yaml:"additional_help"` -} - -type BaseConfig struct { - Homeserver HomeserverConfig `yaml:"homeserver"` - AppService AppserviceConfig `yaml:"appservice"` - Bridge BridgeConfig `yaml:"-"` - Logging zeroconfig.Config `yaml:"logging"` -} - -func doUpgrade(helper up.Helper) { - helper.Copy(up.Str, "homeserver", "address") - helper.Copy(up.Str, "homeserver", "domain") - if legacyAsmuxFlag, ok := helper.Get(up.Bool, "homeserver", "asmux"); ok && legacyAsmuxFlag == "true" { - helper.Set(up.Str, string(SoftwareAsmux), "homeserver", "software") - } else { - helper.Copy(up.Str, "homeserver", "software") - } - helper.Copy(up.Str|up.Null, "homeserver", "status_endpoint") - helper.Copy(up.Str|up.Null, "homeserver", "message_send_checkpoint_endpoint") - helper.Copy(up.Bool, "homeserver", "async_media") - helper.Copy(up.Str|up.Null, "homeserver", "websocket_proxy") - helper.Copy(up.Bool, "homeserver", "websocket") - helper.Copy(up.Int, "homeserver", "ping_interval_seconds") - - helper.Copy(up.Str|up.Null, "appservice", "address") - helper.Copy(up.Str|up.Null, "appservice", "hostname") - helper.Copy(up.Int|up.Null, "appservice", "port") - if dbType, ok := helper.Get(up.Str, "appservice", "database", "type"); ok && dbType == "sqlite3" { - helper.Set(up.Str, "sqlite3-fk-wal", "appservice", "database", "type") - } else { - helper.Copy(up.Str, "appservice", "database", "type") - } - helper.Copy(up.Str, "appservice", "database", "uri") - helper.Copy(up.Int, "appservice", "database", "max_open_conns") - helper.Copy(up.Int, "appservice", "database", "max_idle_conns") - helper.Copy(up.Str|up.Null, "appservice", "database", "max_conn_idle_time") - helper.Copy(up.Str|up.Null, "appservice", "database", "max_conn_lifetime") - helper.Copy(up.Str, "appservice", "id") - helper.Copy(up.Str, "appservice", "bot", "username") - helper.Copy(up.Str, "appservice", "bot", "displayname") - helper.Copy(up.Str, "appservice", "bot", "avatar") - helper.Copy(up.Bool, "appservice", "ephemeral_events") - helper.Copy(up.Bool, "appservice", "async_transactions") - helper.Copy(up.Str, "appservice", "as_token") - helper.Copy(up.Str, "appservice", "hs_token") - - if helper.GetNode("logging", "writers") == nil && (helper.GetNode("logging", "print_level") != nil || helper.GetNode("logging", "file_name_format") != nil) { - _, _ = fmt.Fprintln(os.Stderr, "Migrating legacy log config") - migrateLegacyLogConfig(helper) - } else if helper.GetNode("logging", "writers") == nil && (helper.GetNode("logging", "handlers") != nil) { - _, _ = fmt.Fprintln(os.Stderr, "Migrating Python log config is not currently supported") - // TODO implement? - //migratePythonLogConfig(helper) - } else { - helper.Copy(up.Map, "logging") - } -} - -type legacyLogConfig struct { - Directory string `yaml:"directory"` - FileNameFormat string `yaml:"file_name_format"` - FileDateFormat string `yaml:"file_date_format"` - FileMode uint32 `yaml:"file_mode"` - TimestampFormat string `yaml:"timestamp_format"` - RawPrintLevel string `yaml:"print_level"` - JSONStdout bool `yaml:"print_json"` - JSONFile bool `yaml:"file_json"` -} - -func migrateLegacyLogConfig(helper up.Helper) { - var llc legacyLogConfig - var newConfig zeroconfig.Config - err := helper.GetBaseNode("logging").Decode(&newConfig) - if err != nil { - _, _ = fmt.Fprintln(os.Stderr, "Base config is corrupted: failed to decode example log config:", err) - return - } else if len(newConfig.Writers) != 2 || newConfig.Writers[0].Type != "stdout" || newConfig.Writers[1].Type != "file" { - _, _ = fmt.Fprintln(os.Stderr, "Base log config is not in expected format") - return - } - err = helper.GetNode("logging").Decode(&llc) - if err != nil { - _, _ = fmt.Fprintln(os.Stderr, "Failed to decode legacy log config:", err) - return - } - if llc.RawPrintLevel != "" { - level, err := zerolog.ParseLevel(llc.RawPrintLevel) - if err != nil { - _, _ = fmt.Fprintln(os.Stderr, "Failed to parse minimum stdout log level:", err) - } else { - newConfig.Writers[0].MinLevel = &level - } - } - if llc.Directory != "" && llc.FileNameFormat != "" { - if llc.FileNameFormat == "{{.Date}}-{{.Index}}.log" { - llc.FileNameFormat = "bridge.log" - } else { - llc.FileNameFormat = strings.ReplaceAll(llc.FileNameFormat, "{{.Date}}", "") - llc.FileNameFormat = strings.ReplaceAll(llc.FileNameFormat, "{{.Index}}", "") - } - newConfig.Writers[1].Filename = filepath.Join(llc.Directory, llc.FileNameFormat) - } else if llc.FileNameFormat == "" { - newConfig.Writers = newConfig.Writers[0:1] - } - if llc.JSONStdout { - newConfig.Writers[0].TimeFormat = "" - newConfig.Writers[0].Format = "json" - } else if llc.TimestampFormat != "" { - newConfig.Writers[0].TimeFormat = llc.TimestampFormat - } - var updatedConfig yaml.Node - err = updatedConfig.Encode(&newConfig) - if err != nil { - _, _ = fmt.Fprintln(os.Stderr, "Failed to encode migrated log config:", err) - return - } - *helper.GetBaseNode("logging").Node = updatedConfig -} - -// Upgrader is a config upgrader that copies the default fields in the homeserver, appservice and logging blocks. -var Upgrader = up.SimpleUpgrader(doUpgrade) diff --git a/bridge/bridgeconfig/permissions.go b/bridge/bridgeconfig/permissions.go deleted file mode 100644 index 198e140e..00000000 --- a/bridge/bridgeconfig/permissions.go +++ /dev/null @@ -1,71 +0,0 @@ -// Copyright (c) 2023 Tulir Asokan -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. - -package bridgeconfig - -import ( - "strconv" - "strings" - - "maunium.net/go/mautrix/id" -) - -type PermissionConfig map[string]PermissionLevel - -type PermissionLevel int - -const ( - PermissionLevelBlock PermissionLevel = 0 - PermissionLevelRelay PermissionLevel = 5 - PermissionLevelUser PermissionLevel = 10 - PermissionLevelAdmin PermissionLevel = 100 -) - -var namesToLevels = map[string]PermissionLevel{ - "block": PermissionLevelBlock, - "relay": PermissionLevelRelay, - "user": PermissionLevelUser, - "admin": PermissionLevelAdmin, -} - -func RegisterPermissionLevel(name string, level PermissionLevel) { - namesToLevels[name] = level -} - -func (pc *PermissionConfig) UnmarshalYAML(unmarshal func(interface{}) error) error { - rawPC := make(map[string]string) - err := unmarshal(&rawPC) - if err != nil { - return err - } - - if *pc == nil { - *pc = make(map[string]PermissionLevel) - } - for key, value := range rawPC { - level, ok := namesToLevels[strings.ToLower(value)] - if ok { - (*pc)[key] = level - } else if val, err := strconv.Atoi(value); err == nil { - (*pc)[key] = PermissionLevel(val) - } else { - (*pc)[key] = PermissionLevelBlock - } - } - return nil -} - -func (pc PermissionConfig) Get(userID id.UserID) PermissionLevel { - if level, ok := pc[string(userID)]; ok { - return level - } else if level, ok = pc[userID.Homeserver()]; len(userID.Homeserver()) > 0 && ok { - return level - } else if level, ok = pc["*"]; ok { - return level - } else { - return PermissionLevelBlock - } -} diff --git a/bridge/bridgestate.go b/bridge/bridgestate.go deleted file mode 100644 index f9c3a3c6..00000000 --- a/bridge/bridgestate.go +++ /dev/null @@ -1,156 +0,0 @@ -// Copyright (c) 2023 Tulir Asokan -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. - -package bridge - -import ( - "context" - "runtime/debug" - "time" - - "github.com/rs/zerolog" - - "maunium.net/go/mautrix/appservice" - "maunium.net/go/mautrix/bridge/status" -) - -func (br *Bridge) SendBridgeState(ctx context.Context, state *status.BridgeState) error { - if br.Websocket { - // FIXME this doesn't account for multiple users - br.latestState = state - - return br.AS.SendWebsocket(&appservice.WebsocketRequest{ - Command: "bridge_status", - Data: state, - }) - } else if br.Config.Homeserver.StatusEndpoint != "" { - return state.SendHTTP(ctx, br.Config.Homeserver.StatusEndpoint, br.Config.AppService.ASToken) - } else { - return nil - } -} - -func (br *Bridge) SendGlobalBridgeState(state status.BridgeState) { - if len(br.Config.Homeserver.StatusEndpoint) == 0 && !br.Websocket { - return - } - - for { - ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) - if err := br.SendBridgeState(ctx, &state); err != nil { - br.ZLog.Warn().Err(err).Msg("Failed to update global bridge state") - cancel() - time.Sleep(5 * time.Second) - continue - } else { - br.ZLog.Debug().Interface("bridge_state", state).Msg("Sent new global bridge state") - cancel() - break - } - } -} - -type BridgeStateQueue struct { - prev *status.BridgeState - ch chan status.BridgeState - bridge *Bridge - user status.BridgeStateFiller -} - -func (br *Bridge) NewBridgeStateQueue(user status.BridgeStateFiller) *BridgeStateQueue { - if len(br.Config.Homeserver.StatusEndpoint) == 0 && !br.Websocket { - return nil - } - bsq := &BridgeStateQueue{ - ch: make(chan status.BridgeState, 10), - bridge: br, - user: user, - } - go bsq.loop() - return bsq -} - -func (bsq *BridgeStateQueue) loop() { - defer func() { - err := recover() - if err != nil { - bsq.bridge.ZLog.Error(). - Str(zerolog.ErrorStackFieldName, string(debug.Stack())). - Interface(zerolog.ErrorFieldName, err). - Msg("Panic in bridge state loop") - } - }() - for state := range bsq.ch { - bsq.immediateSendBridgeState(state) - } -} - -func (bsq *BridgeStateQueue) immediateSendBridgeState(state status.BridgeState) { - retryIn := 2 - for { - if bsq.prev != nil && bsq.prev.ShouldDeduplicate(&state) { - bsq.bridge.ZLog.Debug(). - Str("state_event", string(state.StateEvent)). - Msg("Not sending bridge state as it's a duplicate") - return - } - - ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) - err := bsq.bridge.SendBridgeState(ctx, &state) - cancel() - - if err != nil { - bsq.bridge.ZLog.Warn().Err(err). - Int("retry_in_seconds", retryIn). - Msg("Failed to update bridge state") - time.Sleep(time.Duration(retryIn) * time.Second) - retryIn *= 2 - if retryIn > 64 { - retryIn = 64 - } - } else { - bsq.prev = &state - bsq.bridge.ZLog.Debug(). - Interface("bridge_state", state). - Msg("Sent new bridge state") - return - } - } -} - -func (bsq *BridgeStateQueue) Send(state status.BridgeState) { - if bsq == nil { - return - } - - state = state.Fill(bsq.user) - - if len(bsq.ch) >= 8 { - bsq.bridge.ZLog.Warn().Msg("Bridge state queue is nearly full, discarding an item") - select { - case <-bsq.ch: - default: - } - } - select { - case bsq.ch <- state: - default: - bsq.bridge.ZLog.Error().Msg("Bridge state queue is full, dropped new state") - } -} - -func (bsq *BridgeStateQueue) GetPrev() status.BridgeState { - if bsq != nil && bsq.prev != nil { - return *bsq.prev - } - return status.BridgeState{} -} - -func (bsq *BridgeStateQueue) SetPrev(prev status.BridgeState) { - if bsq != nil { - bsq.prev = &prev - } -} diff --git a/bridge/commands/admin.go b/bridge/commands/admin.go deleted file mode 100644 index ff3340e3..00000000 --- a/bridge/commands/admin.go +++ /dev/null @@ -1,77 +0,0 @@ -// Copyright (c) 2022 Tulir Asokan -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. - -package commands - -import ( - "strconv" - - "maunium.net/go/mautrix/id" -) - -var CommandDiscardMegolmSession = &FullHandler{ - Func: func(ce *Event) { - if ce.Bridge.Crypto == nil { - ce.Reply("This bridge instance doesn't have end-to-bridge encryption enabled") - } else { - ce.Bridge.Crypto.ResetSession(ce.Ctx, ce.RoomID) - ce.Reply("Successfully reset Megolm session in this room. New decryption keys will be shared the next time a message is sent from the remote network.") - } - }, - Name: "discard-megolm-session", - Aliases: []string{"discard-session"}, - Help: HelpMeta{ - Section: HelpSectionAdmin, - Description: "Discard the Megolm session in the room", - }, - RequiresAdmin: true, -} - -func fnSetPowerLevel(ce *Event) { - var level int - var userID id.UserID - var err error - if len(ce.Args) == 1 { - level, err = strconv.Atoi(ce.Args[0]) - if err != nil { - ce.Reply("Invalid power level \"%s\"", ce.Args[0]) - return - } - userID = ce.User.GetMXID() - } else if len(ce.Args) == 2 { - userID = id.UserID(ce.Args[0]) - _, _, err := userID.Parse() - if err != nil { - ce.Reply("Invalid user ID \"%s\"", ce.Args[0]) - return - } - level, err = strconv.Atoi(ce.Args[1]) - if err != nil { - ce.Reply("Invalid power level \"%s\"", ce.Args[1]) - return - } - } else { - ce.Reply("**Usage:** `set-pl [user] `") - return - } - _, err = ce.Portal.MainIntent().SetPowerLevel(ce.Ctx, ce.RoomID, userID, level) - if err != nil { - ce.Reply("Failed to set power levels: %v", err) - } -} - -var CommandSetPowerLevel = &FullHandler{ - Func: fnSetPowerLevel, - Name: "set-pl", - Aliases: []string{"set-power-level"}, - Help: HelpMeta{ - Section: HelpSectionAdmin, - Description: "Change the power level in a portal room.", - Args: "[_user ID_] <_power level_>", - }, - RequiresAdmin: true, - RequiresPortal: true, -} diff --git a/bridge/commands/doublepuppet.go b/bridge/commands/doublepuppet.go deleted file mode 100644 index 3f074951..00000000 --- a/bridge/commands/doublepuppet.go +++ /dev/null @@ -1,83 +0,0 @@ -// Copyright (c) 2022 Tulir Asokan -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. - -package commands - -var CommandLoginMatrix = &FullHandler{ - Func: fnLoginMatrix, - Name: "login-matrix", - Help: HelpMeta{ - Section: HelpSectionAuth, - Description: "Enable double puppeting.", - Args: "<_access token_>", - }, - RequiresLogin: true, -} - -func fnLoginMatrix(ce *Event) { - if len(ce.Args) == 0 { - ce.Reply("**Usage:** `login-matrix `") - return - } - puppet := ce.User.GetIDoublePuppet() - if puppet == nil { - puppet = ce.User.GetIGhost() - if puppet == nil { - ce.Reply("Didn't get a ghost :(") - return - } - } - err := puppet.SwitchCustomMXID(ce.Args[0], ce.User.GetMXID()) - if err != nil { - ce.Reply("Failed to enable double puppeting: %v", err) - } else { - ce.Reply("Successfully switched puppet") - } -} - -var CommandPingMatrix = &FullHandler{ - Func: fnPingMatrix, - Name: "ping-matrix", - Help: HelpMeta{ - Section: HelpSectionAuth, - Description: "Ping the Matrix server with the double puppet.", - }, - RequiresLogin: true, -} - -func fnPingMatrix(ce *Event) { - puppet := ce.User.GetIDoublePuppet() - if puppet == nil || puppet.CustomIntent() == nil { - ce.Reply("You are not logged in with your Matrix account.") - return - } - resp, err := puppet.CustomIntent().Whoami(ce.Ctx) - if err != nil { - ce.Reply("Failed to validate Matrix login: %v", err) - } else { - ce.Reply("Confirmed valid access token for %s / %s", resp.UserID, resp.DeviceID) - } -} - -var CommandLogoutMatrix = &FullHandler{ - Func: fnLogoutMatrix, - Name: "logout-matrix", - Help: HelpMeta{ - Section: HelpSectionAuth, - Description: "Disable double puppeting.", - }, - RequiresLogin: true, -} - -func fnLogoutMatrix(ce *Event) { - puppet := ce.User.GetIDoublePuppet() - if puppet == nil || puppet.CustomIntent() == nil { - ce.Reply("You don't have double puppeting enabled.") - return - } - puppet.ClearCustomMXID() - ce.Reply("Successfully disabled double puppeting.") -} diff --git a/bridge/commands/event.go b/bridge/commands/event.go deleted file mode 100644 index 49a8b277..00000000 --- a/bridge/commands/event.go +++ /dev/null @@ -1,95 +0,0 @@ -// Copyright (c) 2023 Tulir Asokan -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. - -package commands - -import ( - "context" - "fmt" - "strings" - - "github.com/rs/zerolog" - - "maunium.net/go/mautrix" - "maunium.net/go/mautrix/appservice" - "maunium.net/go/mautrix/bridge" - "maunium.net/go/mautrix/event" - "maunium.net/go/mautrix/format" - "maunium.net/go/mautrix/id" -) - -// Event stores all data which might be used to handle commands -type Event struct { - Bot *appservice.IntentAPI - Bridge *bridge.Bridge - Portal bridge.Portal - Processor *Processor - Handler MinimalHandler - RoomID id.RoomID - EventID id.EventID - User bridge.User - Command string - Args []string - RawArgs string - ReplyTo id.EventID - Ctx context.Context - ZLog *zerolog.Logger -} - -// MainIntent returns the intent to use when replying to the command. -// -// It prefers the bridge bot, but falls back to the other user in DMs if the bridge bot is not present. -func (ce *Event) MainIntent() *appservice.IntentAPI { - intent := ce.Bot - if ce.Portal != nil && ce.Portal.IsPrivateChat() && !ce.Portal.IsEncrypted() { - intent = ce.Portal.MainIntent() - } - return intent -} - -// Reply sends a reply to command as notice, with optional string formatting and automatic $cmdprefix replacement. -func (ce *Event) Reply(msg string, args ...interface{}) { - msg = strings.ReplaceAll(msg, "$cmdprefix ", ce.Bridge.Config.Bridge.GetCommandPrefix()+" ") - if len(args) > 0 { - msg = fmt.Sprintf(msg, args...) - } - ce.ReplyAdvanced(msg, true, false) -} - -// ReplyAdvanced sends a reply to command as notice. It allows using HTML and disabling markdown, -// but doesn't have built-in string formatting. -func (ce *Event) ReplyAdvanced(msg string, allowMarkdown, allowHTML bool) { - content := format.RenderMarkdown(msg, allowMarkdown, allowHTML) - content.MsgType = event.MsgNotice - _, err := ce.MainIntent().SendMessageEvent(ce.Ctx, ce.RoomID, event.EventMessage, content) - if err != nil { - ce.ZLog.Error().Err(err).Msg("Failed to reply to command") - } -} - -// React sends a reaction to the command. -func (ce *Event) React(key string) { - _, err := ce.MainIntent().SendReaction(ce.Ctx, ce.RoomID, ce.EventID, key) - if err != nil { - ce.ZLog.Error().Err(err).Msg("Failed to react to command") - } -} - -// Redact redacts the command. -func (ce *Event) Redact(req ...mautrix.ReqRedact) { - _, err := ce.MainIntent().RedactEvent(ce.Ctx, ce.RoomID, ce.EventID, req...) - if err != nil { - ce.ZLog.Error().Err(err).Msg("Failed to redact command") - } -} - -// MarkRead marks the command event as read. -func (ce *Event) MarkRead() { - err := ce.MainIntent().SendReceipt(ce.Ctx, ce.RoomID, ce.EventID, event.ReceiptTypeRead, nil) - if err != nil { - ce.ZLog.Error().Err(err).Msg("Failed to mark command as read") - } -} diff --git a/bridge/commands/handler.go b/bridge/commands/handler.go deleted file mode 100644 index ab6899c0..00000000 --- a/bridge/commands/handler.go +++ /dev/null @@ -1,100 +0,0 @@ -// Copyright (c) 2022 Tulir Asokan -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. - -package commands - -import ( - "maunium.net/go/mautrix/bridge" - "maunium.net/go/mautrix/bridge/bridgeconfig" - "maunium.net/go/mautrix/event" -) - -type MinimalHandler interface { - Run(*Event) -} - -type MinimalHandlerFunc func(*Event) - -func (mhf MinimalHandlerFunc) Run(ce *Event) { - mhf(ce) -} - -type CommandState struct { - Next MinimalHandler - Action string - Meta interface{} -} - -type CommandingUser interface { - bridge.User - GetCommandState() *CommandState - SetCommandState(*CommandState) -} - -type Handler interface { - MinimalHandler - GetName() string -} - -type AliasedHandler interface { - Handler - GetAliases() []string -} - -type FullHandler struct { - Func func(*Event) - - Name string - Aliases []string - Help HelpMeta - - RequiresAdmin bool - RequiresPortal bool - RequiresLogin bool - - RequiresEventLevel event.Type -} - -func (fh *FullHandler) GetHelp() HelpMeta { - fh.Help.Command = fh.Name - return fh.Help -} - -func (fh *FullHandler) GetName() string { - return fh.Name -} - -func (fh *FullHandler) GetAliases() []string { - return fh.Aliases -} - -func (fh *FullHandler) ShowInHelp(ce *Event) bool { - return !fh.RequiresAdmin || ce.User.GetPermissionLevel() >= bridgeconfig.PermissionLevelAdmin -} - -func (fh *FullHandler) userHasRoomPermission(ce *Event) bool { - levels, err := ce.MainIntent().PowerLevels(ce.Ctx, ce.RoomID) - if err != nil { - ce.ZLog.Warn().Err(err).Msg("Failed to check room power levels") - ce.Reply("Failed to get room power levels to see if you're allowed to use that command") - return false - } - return levels.GetUserLevel(ce.User.GetMXID()) >= levels.GetEventLevel(fh.RequiresEventLevel) -} - -func (fh *FullHandler) Run(ce *Event) { - if fh.RequiresAdmin && ce.User.GetPermissionLevel() < bridgeconfig.PermissionLevelAdmin { - ce.Reply("That command is limited to bridge administrators.") - } else if fh.RequiresEventLevel.Type != "" && ce.User.GetPermissionLevel() < bridgeconfig.PermissionLevelAdmin && !fh.userHasRoomPermission(ce) { - ce.Reply("That command requires room admin rights.") - } else if fh.RequiresPortal && ce.Portal == nil { - ce.Reply("That command can only be ran in portal rooms.") - } else if fh.RequiresLogin && !ce.User.IsLoggedIn() { - ce.Reply("That command requires you to be logged in.") - } else { - fh.Func(ce) - } -} diff --git a/bridge/commands/help.go b/bridge/commands/help.go deleted file mode 100644 index f4891555..00000000 --- a/bridge/commands/help.go +++ /dev/null @@ -1,129 +0,0 @@ -// Copyright (c) 2022 Tulir Asokan -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. - -package commands - -import ( - "fmt" - "sort" - "strings" -) - -type HelpfulHandler interface { - Handler - GetHelp() HelpMeta - ShowInHelp(*Event) bool -} - -type HelpSection struct { - Name string - Order int -} - -var ( - // Deprecated: this should be used as a placeholder that needs to be fixed - HelpSectionUnclassified = HelpSection{"Unclassified", -1} - - HelpSectionGeneral = HelpSection{"General", 0} - HelpSectionAuth = HelpSection{"Authentication", 10} - HelpSectionAdmin = HelpSection{"Administration", 50} -) - -type HelpMeta struct { - Command string - Section HelpSection - Description string - Args string -} - -func (hm *HelpMeta) String() string { - if len(hm.Args) == 0 { - return fmt.Sprintf("**%s** - %s", hm.Command, hm.Description) - } - return fmt.Sprintf("**%s** %s - %s", hm.Command, hm.Args, hm.Description) -} - -type helpSectionList []HelpSection - -func (h helpSectionList) Len() int { - return len(h) -} - -func (h helpSectionList) Less(i, j int) bool { - return h[i].Order < h[j].Order -} - -func (h helpSectionList) Swap(i, j int) { - h[i], h[j] = h[j], h[i] -} - -type helpMetaList []HelpMeta - -func (h helpMetaList) Len() int { - return len(h) -} - -func (h helpMetaList) Less(i, j int) bool { - return h[i].Command < h[j].Command -} - -func (h helpMetaList) Swap(i, j int) { - h[i], h[j] = h[j], h[i] -} - -var _ sort.Interface = (helpSectionList)(nil) -var _ sort.Interface = (helpMetaList)(nil) - -func FormatHelp(ce *Event) string { - sections := make(map[HelpSection]helpMetaList) - for _, handler := range ce.Processor.handlers { - helpfulHandler, ok := handler.(HelpfulHandler) - if !ok || !helpfulHandler.ShowInHelp(ce) { - continue - } - help := helpfulHandler.GetHelp() - if help.Description == "" { - continue - } - sections[help.Section] = append(sections[help.Section], help) - } - - sortedSections := make(helpSectionList, 0, len(sections)) - for section := range sections { - sortedSections = append(sortedSections, section) - } - sort.Sort(sortedSections) - - var output strings.Builder - output.Grow(10240) - - var prefixMsg string - if ce.RoomID == ce.User.GetManagementRoomID() { - prefixMsg = "This is your management room: prefixing commands with `%s` is not required." - } else if ce.Portal != nil { - prefixMsg = "**This is a portal room**: you must always prefix commands with `%s`. Management commands will not be bridged." - } else { - prefixMsg = "This is not your management room: prefixing commands with `%s` is required." - } - _, _ = fmt.Fprintf(&output, prefixMsg, ce.Bridge.Config.Bridge.GetCommandPrefix()) - output.WriteByte('\n') - output.WriteString("Parameters in [square brackets] are optional, while parameters in are required.") - output.WriteByte('\n') - output.WriteByte('\n') - - for _, section := range sortedSections { - output.WriteString("#### ") - output.WriteString(section.Name) - output.WriteByte('\n') - sort.Sort(sections[section]) - for _, command := range sections[section] { - output.WriteString(command.String()) - output.WriteByte('\n') - } - output.WriteByte('\n') - } - return output.String() -} diff --git a/bridge/commands/meta.go b/bridge/commands/meta.go deleted file mode 100644 index 615f6a34..00000000 --- a/bridge/commands/meta.go +++ /dev/null @@ -1,56 +0,0 @@ -// Copyright (c) 2022 Tulir Asokan -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. - -package commands - -var CommandHelp = &FullHandler{ - Func: func(ce *Event) { - ce.Reply(FormatHelp(ce)) - }, - Name: "help", - Help: HelpMeta{ - Section: HelpSectionGeneral, - Description: "Show this help message.", - }, -} - -var CommandVersion = &FullHandler{ - Func: func(ce *Event) { - ce.Reply("[%s](%s) %s (%s)", ce.Bridge.Name, ce.Bridge.URL, ce.Bridge.LinkifiedVersion, ce.Bridge.BuildTime) - }, - Name: "version", - Help: HelpMeta{ - Section: HelpSectionGeneral, - Description: "Get the bridge version.", - }, -} - -var CommandCancel = &FullHandler{ - Func: func(ce *Event) { - commandingUser, ok := ce.User.(CommandingUser) - if !ok { - ce.Reply("This bridge does not implement cancelable commands") - return - } - state := commandingUser.GetCommandState() - - if state != nil { - action := state.Action - if action == "" { - action = "Unknown action" - } - commandingUser.SetCommandState(nil) - ce.Reply("%s cancelled.", action) - } else { - ce.Reply("No ongoing command.") - } - }, - Name: "cancel", - Help: HelpMeta{ - Section: HelpSectionGeneral, - Description: "Cancel an ongoing action.", - }, -} diff --git a/bridge/commands/processor.go b/bridge/commands/processor.go deleted file mode 100644 index 6158a7cd..00000000 --- a/bridge/commands/processor.go +++ /dev/null @@ -1,122 +0,0 @@ -// Copyright (c) 2023 Tulir Asokan -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. - -package commands - -import ( - "context" - "runtime/debug" - "strings" - - "github.com/rs/zerolog" - - "maunium.net/go/mautrix/bridge" - "maunium.net/go/mautrix/id" -) - -type Processor struct { - bridge *bridge.Bridge - log *zerolog.Logger - - handlers map[string]Handler - aliases map[string]string -} - -// NewProcessor creates a Processor -func NewProcessor(bridge *bridge.Bridge) *Processor { - proc := &Processor{ - bridge: bridge, - log: bridge.ZLog, - - handlers: make(map[string]Handler), - aliases: make(map[string]string), - } - proc.AddHandlers( - CommandHelp, CommandVersion, CommandCancel, - CommandLoginMatrix, CommandLogoutMatrix, CommandPingMatrix, - CommandDiscardMegolmSession, CommandSetPowerLevel) - return proc -} - -func (proc *Processor) AddHandlers(handlers ...Handler) { - for _, handler := range handlers { - proc.AddHandler(handler) - } -} - -func (proc *Processor) AddHandler(handler Handler) { - proc.handlers[handler.GetName()] = handler - aliased, ok := handler.(AliasedHandler) - if ok { - for _, alias := range aliased.GetAliases() { - proc.aliases[alias] = handler.GetName() - } - } -} - -// Handle handles messages to the bridge -func (proc *Processor) Handle(ctx context.Context, roomID id.RoomID, eventID id.EventID, user bridge.User, message string, replyTo id.EventID) { - defer func() { - err := recover() - if err != nil { - zerolog.Ctx(ctx).Error(). - Str(zerolog.ErrorStackFieldName, string(debug.Stack())). - Interface(zerolog.ErrorFieldName, err). - Msg("Panic in Matrix command handler") - } - }() - args := strings.Fields(message) - if len(args) == 0 { - args = []string{"unknown-command"} - } - command := strings.ToLower(args[0]) - rawArgs := strings.TrimLeft(strings.TrimPrefix(message, command), " ") - log := zerolog.Ctx(ctx).With().Str("mx_command", command).Logger() - ctx = log.WithContext(ctx) - ce := &Event{ - Bot: proc.bridge.Bot, - Bridge: proc.bridge, - Portal: proc.bridge.Child.GetIPortal(roomID), - Processor: proc, - RoomID: roomID, - EventID: eventID, - User: user, - Command: command, - Args: args[1:], - RawArgs: rawArgs, - ReplyTo: replyTo, - Ctx: ctx, - ZLog: &log, - } - log.Debug().Msg("Received command") - - realCommand, ok := proc.aliases[ce.Command] - if !ok { - realCommand = ce.Command - } - commandingUser, ok := ce.User.(CommandingUser) - - var handler MinimalHandler - handler, ok = proc.handlers[realCommand] - if !ok { - var state *CommandState - if commandingUser != nil { - state = commandingUser.GetCommandState() - } - if state != nil && state.Next != nil { - ce.Command = "" - ce.RawArgs = message - ce.Args = args - ce.Handler = state.Next - state.Next.Run(ce) - } else { - ce.Reply("Unknown command, use the `help` command for help.") - } - } else { - ce.Handler = handler - handler.Run(ce) - } -} diff --git a/bridge/crypto.go b/bridge/crypto.go deleted file mode 100644 index f0b90056..00000000 --- a/bridge/crypto.go +++ /dev/null @@ -1,511 +0,0 @@ -// Copyright (c) 2024 Tulir Asokan -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. - -//go:build cgo && !nocrypto - -package bridge - -import ( - "context" - "errors" - "fmt" - "os" - "runtime/debug" - "sync" - "time" - - "github.com/rs/zerolog" - "go.mau.fi/util/dbutil" - - "maunium.net/go/mautrix" - "maunium.net/go/mautrix/bridge/bridgeconfig" - "maunium.net/go/mautrix/crypto" - "maunium.net/go/mautrix/crypto/olm" - "maunium.net/go/mautrix/event" - "maunium.net/go/mautrix/id" - "maunium.net/go/mautrix/sqlstatestore" -) - -var _ crypto.StateStore = (*sqlstatestore.SQLStateStore)(nil) - -var NoSessionFound = crypto.NoSessionFound -var DuplicateMessageIndex = crypto.DuplicateMessageIndex -var UnknownMessageIndex = olm.UnknownMessageIndex - -type CryptoHelper struct { - bridge *Bridge - client *mautrix.Client - mach *crypto.OlmMachine - store *SQLCryptoStore - log *zerolog.Logger - - lock sync.RWMutex - syncDone sync.WaitGroup - cancelSync func() - - cancelPeriodicDeleteLoop func() -} - -func NewCryptoHelper(bridge *Bridge) Crypto { - if !bridge.Config.Bridge.GetEncryptionConfig().Allow { - bridge.ZLog.Debug().Msg("Bridge built with end-to-bridge encryption, but disabled in config") - return nil - } - log := bridge.ZLog.With().Str("component", "crypto").Logger() - return &CryptoHelper{ - bridge: bridge, - log: &log, - } -} - -func (helper *CryptoHelper) Init(ctx context.Context) error { - if len(helper.bridge.CryptoPickleKey) == 0 { - panic("CryptoPickleKey not set") - } - helper.log.Debug().Msg("Initializing end-to-bridge encryption...") - - helper.store = NewSQLCryptoStore( - helper.bridge.DB, - dbutil.ZeroLogger(helper.bridge.ZLog.With().Str("db_section", "crypto").Logger()), - helper.bridge.AS.BotMXID(), - fmt.Sprintf("@%s:%s", helper.bridge.Config.Bridge.FormatUsername("%"), helper.bridge.AS.HomeserverDomain), - helper.bridge.CryptoPickleKey, - ) - - err := helper.store.DB.Upgrade(ctx) - if err != nil { - helper.bridge.LogDBUpgradeErrorAndExit("crypto", err) - } - - var isExistingDevice bool - helper.client, isExistingDevice, err = helper.loginBot(ctx) - if err != nil { - return err - } - - helper.log.Debug(). - Str("device_id", helper.client.DeviceID.String()). - Msg("Logged in as bridge bot") - stateStore := &cryptoStateStore{helper.bridge} - helper.mach = crypto.NewOlmMachine(helper.client, helper.log, helper.store, stateStore) - helper.mach.AllowKeyShare = helper.allowKeyShare - - encryptionConfig := helper.bridge.Config.Bridge.GetEncryptionConfig() - helper.mach.SendKeysMinTrust = encryptionConfig.VerificationLevels.Receive - helper.mach.PlaintextMentions = encryptionConfig.PlaintextMentions - - helper.mach.DeleteOutboundKeysOnAck = encryptionConfig.DeleteKeys.DeleteOutboundOnAck - helper.mach.DontStoreOutboundKeys = encryptionConfig.DeleteKeys.DontStoreOutbound - helper.mach.RatchetKeysOnDecrypt = encryptionConfig.DeleteKeys.RatchetOnDecrypt - helper.mach.DeleteFullyUsedKeysOnDecrypt = encryptionConfig.DeleteKeys.DeleteFullyUsedOnDecrypt - helper.mach.DeletePreviousKeysOnReceive = encryptionConfig.DeleteKeys.DeletePrevOnNewSession - helper.mach.DeleteKeysOnDeviceDelete = encryptionConfig.DeleteKeys.DeleteOnDeviceDelete - helper.mach.DisableDeviceChangeKeyRotation = encryptionConfig.Rotation.DisableDeviceChangeKeyRotation - if encryptionConfig.DeleteKeys.PeriodicallyDeleteExpired { - ctx, cancel := context.WithCancel(context.Background()) - helper.cancelPeriodicDeleteLoop = cancel - go helper.mach.ExpiredKeyDeleteLoop(ctx) - } - - if encryptionConfig.DeleteKeys.DeleteOutdatedInbound { - deleted, err := helper.store.RedactOutdatedGroupSessions(ctx) - if err != nil { - return err - } - if len(deleted) > 0 { - helper.log.Debug().Int("deleted", len(deleted)).Msg("Deleted inbound keys which lacked expiration metadata") - } - } - - helper.client.Syncer = &cryptoSyncer{helper.mach} - helper.client.Store = helper.store - - err = helper.mach.Load(ctx) - if err != nil { - return err - } - if isExistingDevice { - helper.verifyKeysAreOnServer(ctx) - } - - go helper.resyncEncryptionInfo(context.TODO()) - - return nil -} - -func (helper *CryptoHelper) resyncEncryptionInfo(ctx context.Context) { - log := helper.log.With().Str("action", "resync encryption event").Logger() - rows, err := helper.bridge.DB.Query(ctx, `SELECT room_id FROM mx_room_state WHERE encryption='{"resync":true}'`) - if err != nil { - log.Err(err).Msg("Failed to query rooms for resync") - return - } - roomIDs, err := dbutil.NewRowIter(rows, dbutil.ScanSingleColumn[id.RoomID]).AsList() - if err != nil { - log.Err(err).Msg("Failed to scan rooms for resync") - return - } - if len(roomIDs) > 0 { - log.Debug().Interface("room_ids", roomIDs).Msg("Resyncing rooms") - for _, roomID := range roomIDs { - var evt event.EncryptionEventContent - err = helper.client.StateEvent(ctx, roomID, event.StateEncryption, "", &evt) - if err != nil { - log.Err(err).Str("room_id", roomID.String()).Msg("Failed to get encryption event") - _, err = helper.bridge.DB.Exec(ctx, ` - UPDATE mx_room_state SET encryption=NULL WHERE room_id=$1 AND encryption='{"resync":true}' - `, roomID) - if err != nil { - log.Err(err).Str("room_id", roomID.String()).Msg("Failed to unmark room for resync after failed sync") - } - } else { - maxAge := evt.RotationPeriodMillis - if maxAge <= 0 { - maxAge = (7 * 24 * time.Hour).Milliseconds() - } - maxMessages := evt.RotationPeriodMessages - if maxMessages <= 0 { - maxMessages = 100 - } - log.Debug(). - Str("room_id", roomID.String()). - Int64("max_age_ms", maxAge). - Int("max_messages", maxMessages). - Interface("content", &evt). - Msg("Resynced encryption event") - _, err = helper.bridge.DB.Exec(ctx, ` - UPDATE crypto_megolm_inbound_session - SET max_age=$1, max_messages=$2 - WHERE room_id=$3 AND max_age IS NULL AND max_messages IS NULL - `, maxAge, maxMessages, roomID) - if err != nil { - log.Err(err).Str("room_id", roomID.String()).Msg("Failed to update megolm session table") - } else { - log.Debug().Str("room_id", roomID.String()).Msg("Updated megolm session table") - } - } - } - } -} - -func (helper *CryptoHelper) allowKeyShare(ctx context.Context, device *id.Device, info event.RequestedKeyInfo) *crypto.KeyShareRejection { - cfg := helper.bridge.Config.Bridge.GetEncryptionConfig() - if !cfg.AllowKeySharing { - return &crypto.KeyShareRejectNoResponse - } else if device.Trust == id.TrustStateBlacklisted { - return &crypto.KeyShareRejectBlacklisted - } else if trustState := helper.mach.ResolveTrust(device); trustState >= cfg.VerificationLevels.Share { - portal := helper.bridge.Child.GetIPortal(info.RoomID) - if portal == nil { - zerolog.Ctx(ctx).Debug().Msg("Rejecting key request: room is not a portal") - return &crypto.KeyShareRejection{Code: event.RoomKeyWithheldUnavailable, Reason: "Requested room is not a portal room"} - } - user := helper.bridge.Child.GetIUser(device.UserID, true) - // FIXME reimplement IsInPortal - if user.GetPermissionLevel() < bridgeconfig.PermissionLevelAdmin /*&& !user.IsInPortal(portal.Key)*/ { - zerolog.Ctx(ctx).Debug().Msg("Rejecting key request: user is not in portal") - return &crypto.KeyShareRejection{Code: event.RoomKeyWithheldUnauthorized, Reason: "You're not in that portal"} - } - zerolog.Ctx(ctx).Debug().Msg("Accepting key request") - return nil - } else { - return &crypto.KeyShareRejectUnverified - } -} - -func (helper *CryptoHelper) loginBot(ctx context.Context) (*mautrix.Client, bool, error) { - deviceID, err := helper.store.FindDeviceID(ctx) - if err != nil { - return nil, false, fmt.Errorf("failed to find existing device ID: %w", err) - } else if len(deviceID) > 0 { - helper.log.Debug().Str("device_id", deviceID.String()).Msg("Found existing device ID for bot in database") - } - // Create a new client instance with the default AS settings (including as_token), - // the Login call will then override the access token in the client. - client := helper.bridge.AS.NewMautrixClient(helper.bridge.AS.BotMXID()) - flows, err := client.GetLoginFlows(ctx) - if err != nil { - return nil, deviceID != "", fmt.Errorf("failed to get supported login flows: %w", err) - } else if !flows.HasFlow(mautrix.AuthTypeAppservice) { - return nil, deviceID != "", fmt.Errorf("homeserver does not support appservice login") - } - resp, err := client.Login(ctx, &mautrix.ReqLogin{ - Type: mautrix.AuthTypeAppservice, - Identifier: mautrix.UserIdentifier{ - Type: mautrix.IdentifierTypeUser, - User: string(helper.bridge.AS.BotMXID()), - }, - DeviceID: deviceID, - StoreCredentials: true, - - InitialDeviceDisplayName: fmt.Sprintf("%s bridge", helper.bridge.ProtocolName), - }) - if err != nil { - return nil, deviceID != "", fmt.Errorf("failed to log in as bridge bot: %w", err) - } - helper.store.DeviceID = resp.DeviceID - return client, deviceID != "", nil -} - -func (helper *CryptoHelper) verifyKeysAreOnServer(ctx context.Context) { - helper.log.Debug().Msg("Making sure keys are still on server") - resp, err := helper.client.QueryKeys(ctx, &mautrix.ReqQueryKeys{ - DeviceKeys: map[id.UserID]mautrix.DeviceIDList{ - helper.client.UserID: {helper.client.DeviceID}, - }, - }) - if err != nil { - helper.log.WithLevel(zerolog.FatalLevel).Err(err).Msg("Failed to query own keys to make sure device still exists") - os.Exit(33) - } - device, ok := resp.DeviceKeys[helper.client.UserID][helper.client.DeviceID] - if ok && len(device.Keys) > 0 { - return - } - helper.log.Warn().Msg("Existing device doesn't have keys on server, resetting crypto") - helper.Reset(ctx, false) -} - -func (helper *CryptoHelper) Start() { - if helper.bridge.Config.Bridge.GetEncryptionConfig().Appservice { - helper.log.Debug().Msg("End-to-bridge encryption is in appservice mode, registering event listeners and not starting syncer") - helper.bridge.AS.Registration.EphemeralEvents = true - helper.mach.AddAppserviceListener(helper.bridge.EventProcessor) - return - } - helper.syncDone.Add(1) - defer helper.syncDone.Done() - helper.log.Debug().Msg("Starting syncer for receiving to-device messages") - var ctx context.Context - ctx, helper.cancelSync = context.WithCancel(context.Background()) - err := helper.client.SyncWithContext(ctx) - if err != nil && !errors.Is(err, context.Canceled) { - helper.log.WithLevel(zerolog.FatalLevel).Err(err).Msg("Fatal error syncing") - os.Exit(51) - } else { - helper.log.Info().Msg("Bridge bot to-device syncer stopped without error") - } -} - -func (helper *CryptoHelper) Stop() { - helper.log.Debug().Msg("CryptoHelper.Stop() called, stopping bridge bot sync") - helper.client.StopSync() - if helper.cancelSync != nil { - helper.cancelSync() - } - if helper.cancelPeriodicDeleteLoop != nil { - helper.cancelPeriodicDeleteLoop() - } - helper.syncDone.Wait() -} - -func (helper *CryptoHelper) clearDatabase(ctx context.Context) { - _, err := helper.store.DB.Exec(ctx, "DELETE FROM crypto_account") - if err != nil { - helper.log.Warn().Err(err).Msg("Failed to clear crypto_account table") - } - _, err = helper.store.DB.Exec(ctx, "DELETE FROM crypto_olm_session") - if err != nil { - helper.log.Warn().Err(err).Msg("Failed to clear crypto_olm_session table") - } - _, err = helper.store.DB.Exec(ctx, "DELETE FROM crypto_megolm_outbound_session") - if err != nil { - helper.log.Warn().Err(err).Msg("Failed to clear crypto_megolm_outbound_session table") - } - //_, _ = helper.store.DB.Exec("DELETE FROM crypto_device") - //_, _ = helper.store.DB.Exec("DELETE FROM crypto_tracked_user") - //_, _ = helper.store.DB.Exec("DELETE FROM crypto_cross_signing_keys") - //_, _ = helper.store.DB.Exec("DELETE FROM crypto_cross_signing_signatures") -} - -func (helper *CryptoHelper) Reset(ctx context.Context, startAfterReset bool) { - helper.lock.Lock() - defer helper.lock.Unlock() - helper.log.Info().Msg("Resetting end-to-bridge encryption device") - helper.Stop() - helper.log.Debug().Msg("Crypto syncer stopped, clearing database") - helper.clearDatabase(ctx) - helper.log.Debug().Msg("Crypto database cleared, logging out of all sessions") - _, err := helper.client.LogoutAll(ctx) - if err != nil { - helper.log.Warn().Err(err).Msg("Failed to log out all devices") - } - helper.client = nil - helper.store = nil - helper.mach = nil - err = helper.Init(ctx) - if err != nil { - helper.log.WithLevel(zerolog.FatalLevel).Err(err).Msg("Error reinitializing end-to-bridge encryption") - os.Exit(50) - } - helper.log.Info().Msg("End-to-bridge encryption successfully reset") - if startAfterReset { - go helper.Start() - } -} - -func (helper *CryptoHelper) Client() *mautrix.Client { - return helper.client -} - -func (helper *CryptoHelper) Decrypt(ctx context.Context, evt *event.Event) (*event.Event, error) { - return helper.mach.DecryptMegolmEvent(ctx, evt) -} - -func (helper *CryptoHelper) Encrypt(ctx context.Context, roomID id.RoomID, evtType event.Type, content *event.Content) (err error) { - helper.lock.RLock() - defer helper.lock.RUnlock() - var encrypted *event.EncryptedEventContent - encrypted, err = helper.mach.EncryptMegolmEvent(ctx, roomID, evtType, content) - if err != nil { - if !errors.Is(err, crypto.SessionExpired) && !errors.Is(err, crypto.SessionNotShared) && !errors.Is(err, crypto.NoGroupSession) { - return - } - helper.log.Debug().Err(err). - Str("room_id", roomID.String()). - Msg("Got error while encrypting event for room, sharing group session and trying again...") - var users []id.UserID - users, err = helper.store.GetRoomJoinedOrInvitedMembers(ctx, roomID) - if err != nil { - err = fmt.Errorf("failed to get room member list: %w", err) - } else if err = helper.mach.ShareGroupSession(ctx, roomID, users); err != nil { - err = fmt.Errorf("failed to share group session: %w", err) - } else if encrypted, err = helper.mach.EncryptMegolmEvent(ctx, roomID, evtType, content); err != nil { - err = fmt.Errorf("failed to encrypt event after re-sharing group session: %w", err) - } - } - if encrypted != nil { - content.Parsed = encrypted - content.Raw = nil - } - return -} - -func (helper *CryptoHelper) WaitForSession(ctx context.Context, roomID id.RoomID, senderKey id.SenderKey, sessionID id.SessionID, timeout time.Duration) bool { - helper.lock.RLock() - defer helper.lock.RUnlock() - return helper.mach.WaitForSession(ctx, roomID, senderKey, sessionID, timeout) -} - -func (helper *CryptoHelper) RequestSession(ctx context.Context, roomID id.RoomID, senderKey id.SenderKey, sessionID id.SessionID, userID id.UserID, deviceID id.DeviceID) { - helper.lock.RLock() - defer helper.lock.RUnlock() - if deviceID == "" { - deviceID = "*" - } - err := helper.mach.SendRoomKeyRequest(ctx, roomID, senderKey, sessionID, "", map[id.UserID][]id.DeviceID{userID: {deviceID}}) - if err != nil { - helper.log.Warn().Err(err). - Str("user_id", userID.String()). - Str("device_id", deviceID.String()). - Str("session_id", sessionID.String()). - Str("room_id", roomID.String()). - Msg("Failed to send key request") - } else { - helper.log.Debug(). - Str("user_id", userID.String()). - Str("device_id", deviceID.String()). - Str("session_id", sessionID.String()). - Str("room_id", roomID.String()). - Msg("Sent key request") - } -} - -func (helper *CryptoHelper) ResetSession(ctx context.Context, roomID id.RoomID) { - helper.lock.RLock() - defer helper.lock.RUnlock() - err := helper.mach.CryptoStore.RemoveOutboundGroupSession(ctx, roomID) - if err != nil { - helper.log.Debug().Err(err). - Str("room_id", roomID.String()). - Msg("Error manually removing outbound group session in room") - } -} - -func (helper *CryptoHelper) HandleMemberEvent(ctx context.Context, evt *event.Event) { - helper.lock.RLock() - defer helper.lock.RUnlock() - helper.mach.HandleMemberEvent(ctx, evt) -} - -// ShareKeys uploads the given number of one-time-keys to the server. -func (helper *CryptoHelper) ShareKeys(ctx context.Context) error { - return helper.mach.ShareKeys(ctx, -1) -} - -type cryptoSyncer struct { - *crypto.OlmMachine -} - -func (syncer *cryptoSyncer) ProcessResponse(ctx context.Context, resp *mautrix.RespSync, since string) error { - done := make(chan struct{}) - go func() { - defer func() { - if err := recover(); err != nil { - syncer.Log.Error(). - Str("since", since). - Interface("error", err). - Str("stack", string(debug.Stack())). - Msg("Processing sync response panicked") - } - done <- struct{}{} - }() - syncer.Log.Trace().Str("since", since).Msg("Starting sync response handling") - syncer.ProcessSyncResponse(ctx, resp, since) - syncer.Log.Trace().Str("since", since).Msg("Successfully handled sync response") - }() - select { - case <-done: - case <-time.After(30 * time.Second): - syncer.Log.Warn().Str("since", since).Msg("Handling sync response is taking unusually long") - } - return nil -} - -func (syncer *cryptoSyncer) OnFailedSync(_ *mautrix.RespSync, err error) (time.Duration, error) { - if errors.Is(err, mautrix.MUnknownToken) { - return 0, err - } - syncer.Log.Error().Err(err).Msg("Error /syncing, waiting 10 seconds") - return 10 * time.Second, nil -} - -func (syncer *cryptoSyncer) GetFilterJSON(_ id.UserID) *mautrix.Filter { - everything := []event.Type{{Type: "*"}} - return &mautrix.Filter{ - Presence: mautrix.FilterPart{NotTypes: everything}, - AccountData: mautrix.FilterPart{NotTypes: everything}, - Room: mautrix.RoomFilter{ - IncludeLeave: false, - Ephemeral: mautrix.FilterPart{NotTypes: everything}, - AccountData: mautrix.FilterPart{NotTypes: everything}, - State: mautrix.FilterPart{NotTypes: everything}, - Timeline: mautrix.FilterPart{NotTypes: everything}, - }, - } -} - -type cryptoStateStore struct { - bridge *Bridge -} - -var _ crypto.StateStore = (*cryptoStateStore)(nil) - -func (c *cryptoStateStore) IsEncrypted(ctx context.Context, id id.RoomID) (bool, error) { - portal := c.bridge.Child.GetIPortal(id) - if portal != nil { - return portal.IsEncrypted(), nil - } - return c.bridge.StateStore.IsEncrypted(ctx, id) -} - -func (c *cryptoStateStore) FindSharedRooms(ctx context.Context, id id.UserID) ([]id.RoomID, error) { - return c.bridge.StateStore.FindSharedRooms(ctx, id) -} - -func (c *cryptoStateStore) GetEncryptionEvent(ctx context.Context, id id.RoomID) (*event.EncryptionEventContent, error) { - return c.bridge.StateStore.GetEncryptionEvent(ctx, id) -} diff --git a/bridge/cryptostore.go b/bridge/cryptostore.go deleted file mode 100644 index dde48a25..00000000 --- a/bridge/cryptostore.go +++ /dev/null @@ -1,63 +0,0 @@ -// Copyright (c) 2022 Tulir Asokan -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. - -//go:build cgo && !nocrypto - -package bridge - -import ( - "context" - - "github.com/lib/pq" - "go.mau.fi/util/dbutil" - - "maunium.net/go/mautrix/crypto" - "maunium.net/go/mautrix/id" -) - -func init() { - crypto.PostgresArrayWrapper = pq.Array -} - -type SQLCryptoStore struct { - *crypto.SQLCryptoStore - UserID id.UserID - GhostIDFormat string -} - -var _ crypto.Store = (*SQLCryptoStore)(nil) - -func NewSQLCryptoStore(db *dbutil.Database, log dbutil.DatabaseLogger, userID id.UserID, ghostIDFormat, pickleKey string) *SQLCryptoStore { - return &SQLCryptoStore{ - SQLCryptoStore: crypto.NewSQLCryptoStore(db, log, "", "", []byte(pickleKey)), - UserID: userID, - GhostIDFormat: ghostIDFormat, - } -} - -func (store *SQLCryptoStore) GetRoomJoinedOrInvitedMembers(ctx context.Context, roomID id.RoomID) (members []id.UserID, err error) { - var rows dbutil.Rows - rows, err = store.DB.Query(ctx, ` - SELECT user_id FROM mx_user_profile - WHERE room_id=$1 - AND (membership='join' OR membership='invite') - AND user_id<>$2 - AND user_id NOT LIKE $3 - `, roomID, store.UserID, store.GhostIDFormat) - if err != nil { - return - } - for rows.Next() { - var userID id.UserID - err = rows.Scan(&userID) - if err != nil { - return members, err - } else { - members = append(members, userID) - } - } - return -} diff --git a/bridge/doublepuppet.go b/bridge/doublepuppet.go deleted file mode 100644 index 265d3d5c..00000000 --- a/bridge/doublepuppet.go +++ /dev/null @@ -1,173 +0,0 @@ -// Copyright (c) 2023 Tulir Asokan -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. - -package bridge - -import ( - "context" - "crypto/hmac" - "crypto/sha512" - "encoding/hex" - "errors" - "fmt" - "strings" - - "github.com/rs/zerolog" - - "maunium.net/go/mautrix" - "maunium.net/go/mautrix/appservice" - "maunium.net/go/mautrix/id" -) - -type doublePuppetUtil struct { - br *Bridge - log zerolog.Logger -} - -func (dp *doublePuppetUtil) newClient(ctx context.Context, mxid id.UserID, accessToken string) (*mautrix.Client, error) { - _, homeserver, err := mxid.Parse() - if err != nil { - return nil, err - } - homeserverURL, found := dp.br.Config.Bridge.GetDoublePuppetConfig().ServerMap[homeserver] - if !found { - if homeserver == dp.br.AS.HomeserverDomain { - homeserverURL = "" - } else if dp.br.Config.Bridge.GetDoublePuppetConfig().AllowDiscovery { - resp, err := mautrix.DiscoverClientAPI(ctx, homeserver) - if err != nil { - return nil, fmt.Errorf("failed to find homeserver URL for %s: %v", homeserver, err) - } - homeserverURL = resp.Homeserver.BaseURL - dp.log.Debug(). - Str("homeserver", homeserver). - Str("url", homeserverURL). - Str("user_id", mxid.String()). - Msg("Discovered URL to enable double puppeting for user") - } else { - return nil, fmt.Errorf("double puppeting from %s is not allowed", homeserver) - } - } - return dp.br.AS.NewExternalMautrixClient(mxid, accessToken, homeserverURL) -} - -func (dp *doublePuppetUtil) newIntent(ctx context.Context, mxid id.UserID, accessToken string) (*appservice.IntentAPI, error) { - client, err := dp.newClient(ctx, mxid, accessToken) - if err != nil { - return nil, err - } - - ia := dp.br.AS.NewIntentAPI("custom") - ia.Client = client - ia.Localpart, _, _ = mxid.Parse() - ia.UserID = mxid - ia.IsCustomPuppet = true - return ia, nil -} - -func (dp *doublePuppetUtil) autoLogin(ctx context.Context, mxid id.UserID, loginSecret string) (string, error) { - dp.log.Debug().Str("user_id", mxid.String()).Msg("Logging into user account with shared secret") - client, err := dp.newClient(ctx, mxid, "") - if err != nil { - return "", fmt.Errorf("failed to create mautrix client to log in: %v", err) - } - bridgeName := fmt.Sprintf("%s Bridge", dp.br.ProtocolName) - req := mautrix.ReqLogin{ - Identifier: mautrix.UserIdentifier{Type: mautrix.IdentifierTypeUser, User: string(mxid)}, - DeviceID: id.DeviceID(bridgeName), - InitialDeviceDisplayName: bridgeName, - } - if loginSecret == "appservice" { - client.AccessToken = dp.br.AS.Registration.AppToken - req.Type = mautrix.AuthTypeAppservice - } else { - loginFlows, err := client.GetLoginFlows(ctx) - if err != nil { - return "", fmt.Errorf("failed to get supported login flows: %w", err) - } - mac := hmac.New(sha512.New, []byte(loginSecret)) - mac.Write([]byte(mxid)) - token := hex.EncodeToString(mac.Sum(nil)) - switch { - case loginFlows.HasFlow(mautrix.AuthTypeDevtureSharedSecret): - req.Type = mautrix.AuthTypeDevtureSharedSecret - req.Token = token - case loginFlows.HasFlow(mautrix.AuthTypePassword): - req.Type = mautrix.AuthTypePassword - req.Password = token - default: - return "", fmt.Errorf("no supported auth types for shared secret auth found") - } - } - resp, err := client.Login(ctx, &req) - if err != nil { - return "", err - } - return resp.AccessToken, nil -} - -var ( - ErrMismatchingMXID = errors.New("whoami result does not match custom mxid") - ErrNoAccessToken = errors.New("no access token provided") - ErrNoMXID = errors.New("no mxid provided") -) - -const useConfigASToken = "appservice-config" -const asTokenModePrefix = "as_token:" - -func (dp *doublePuppetUtil) Setup(ctx context.Context, mxid id.UserID, savedAccessToken string, reloginOnFail bool) (intent *appservice.IntentAPI, newAccessToken string, err error) { - if len(mxid) == 0 { - err = ErrNoMXID - return - } - _, homeserver, _ := mxid.Parse() - loginSecret, hasSecret := dp.br.Config.Bridge.GetDoublePuppetConfig().SharedSecretMap[homeserver] - // Special case appservice: prefix to not login and use it as an as_token directly. - if hasSecret && strings.HasPrefix(loginSecret, asTokenModePrefix) { - intent, err = dp.newIntent(ctx, mxid, strings.TrimPrefix(loginSecret, asTokenModePrefix)) - if err != nil { - return - } - intent.SetAppServiceUserID = true - if savedAccessToken != useConfigASToken { - var resp *mautrix.RespWhoami - resp, err = intent.Whoami(ctx) - if err == nil && resp.UserID != mxid { - err = ErrMismatchingMXID - } - } - return intent, useConfigASToken, err - } - if savedAccessToken == "" || savedAccessToken == useConfigASToken { - if reloginOnFail && hasSecret { - savedAccessToken, err = dp.autoLogin(ctx, mxid, loginSecret) - } else { - err = ErrNoAccessToken - } - if err != nil { - return - } - } - intent, err = dp.newIntent(ctx, mxid, savedAccessToken) - if err != nil { - return - } - var resp *mautrix.RespWhoami - resp, err = intent.Whoami(ctx) - if err != nil { - if reloginOnFail && hasSecret && errors.Is(err, mautrix.MUnknownToken) { - intent.AccessToken, err = dp.autoLogin(ctx, mxid, loginSecret) - if err == nil { - newAccessToken = intent.AccessToken - } - } - } else if resp.UserID != mxid { - err = ErrMismatchingMXID - } else { - newAccessToken = savedAccessToken - } - return -} diff --git a/bridge/matrix.go b/bridge/matrix.go deleted file mode 100644 index 446a0b0a..00000000 --- a/bridge/matrix.go +++ /dev/null @@ -1,755 +0,0 @@ -// Copyright (c) 2023 Tulir Asokan -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. - -package bridge - -import ( - "context" - "errors" - "fmt" - "strings" - "time" - - "github.com/rs/zerolog" - - "maunium.net/go/mautrix" - "maunium.net/go/mautrix/appservice" - "maunium.net/go/mautrix/bridge/bridgeconfig" - "maunium.net/go/mautrix/bridge/status" - "maunium.net/go/mautrix/event" - "maunium.net/go/mautrix/format" - "maunium.net/go/mautrix/id" -) - -type CommandProcessor interface { - Handle(ctx context.Context, roomID id.RoomID, eventID id.EventID, user User, message string, replyTo id.EventID) -} - -type MatrixHandler struct { - bridge *Bridge - as *appservice.AppService - log *zerolog.Logger - - TrackEventDuration func(event.Type) func() -} - -func noop() {} - -func noopTrack(_ event.Type) func() { - return noop -} - -func NewMatrixHandler(br *Bridge) *MatrixHandler { - handler := &MatrixHandler{ - bridge: br, - as: br.AS, - log: br.ZLog, - - TrackEventDuration: noopTrack, - } - for evtType := range status.CheckpointTypes { - br.EventProcessor.On(evtType, handler.sendBridgeCheckpoint) - } - br.EventProcessor.On(event.EventMessage, handler.HandleMessage) - br.EventProcessor.On(event.EventEncrypted, handler.HandleEncrypted) - br.EventProcessor.On(event.EventSticker, handler.HandleMessage) - br.EventProcessor.On(event.EventReaction, handler.HandleReaction) - br.EventProcessor.On(event.EventRedaction, handler.HandleRedaction) - br.EventProcessor.On(event.StateMember, handler.HandleMembership) - br.EventProcessor.On(event.StateRoomName, handler.HandleRoomMetadata) - br.EventProcessor.On(event.StateRoomAvatar, handler.HandleRoomMetadata) - br.EventProcessor.On(event.StateTopic, handler.HandleRoomMetadata) - br.EventProcessor.On(event.StateEncryption, handler.HandleEncryption) - br.EventProcessor.On(event.EphemeralEventReceipt, handler.HandleReceipt) - br.EventProcessor.On(event.EphemeralEventTyping, handler.HandleTyping) - br.EventProcessor.On(event.StatePowerLevels, handler.HandlePowerLevels) - br.EventProcessor.On(event.StateJoinRules, handler.HandleJoinRule) - return handler -} - -func (mx *MatrixHandler) sendBridgeCheckpoint(_ context.Context, evt *event.Event) { - if !evt.Mautrix.CheckpointSent { - go mx.bridge.SendMessageSuccessCheckpoint(evt, status.MsgStepBridge, 0) - } -} - -func (mx *MatrixHandler) HandleEncryption(ctx context.Context, evt *event.Event) { - defer mx.TrackEventDuration(evt.Type)() - if evt.Content.AsEncryption().Algorithm != id.AlgorithmMegolmV1 { - return - } - portal := mx.bridge.Child.GetIPortal(evt.RoomID) - if portal != nil && !portal.IsEncrypted() { - mx.log.Debug(). - Str("user_id", evt.Sender.String()). - Str("room_id", evt.RoomID.String()). - Msg("Encryption was enabled in room") - portal.MarkEncrypted() - if portal.IsPrivateChat() { - err := mx.as.BotIntent().EnsureJoined(ctx, evt.RoomID, appservice.EnsureJoinedParams{BotOverride: portal.MainIntent().Client}) - if err != nil { - mx.log.Err(err). - Str("room_id", evt.RoomID.String()). - Msg("Failed to join bot to room after encryption was enabled") - } - } - } -} - -func (mx *MatrixHandler) joinAndCheckMembers(ctx context.Context, evt *event.Event, intent *appservice.IntentAPI) *mautrix.RespJoinedMembers { - log := zerolog.Ctx(ctx) - resp, err := intent.JoinRoomByID(ctx, evt.RoomID) - if err != nil { - log.Warn().Err(err).Msg("Failed to join room with invite") - return nil - } - - members, err := intent.JoinedMembers(ctx, resp.RoomID) - if err != nil { - log.Warn().Err(err).Msg("Failed to get members in room after accepting invite, leaving room") - _, _ = intent.LeaveRoom(ctx, resp.RoomID) - return nil - } - - if len(members.Joined) < 2 { - log.Debug().Msg("Leaving empty room after accepting invite") - _, _ = intent.LeaveRoom(ctx, resp.RoomID) - return nil - } - return members -} - -func (mx *MatrixHandler) sendNoticeWithMarkdown(ctx context.Context, roomID id.RoomID, message string) (*mautrix.RespSendEvent, error) { - intent := mx.as.BotIntent() - content := format.RenderMarkdown(message, true, false) - content.MsgType = event.MsgNotice - return intent.SendMessageEvent(ctx, roomID, event.EventMessage, content) -} - -func (mx *MatrixHandler) HandleBotInvite(ctx context.Context, evt *event.Event) { - intent := mx.as.BotIntent() - - user := mx.bridge.Child.GetIUser(evt.Sender, true) - if user == nil { - return - } - - members := mx.joinAndCheckMembers(ctx, evt, intent) - if members == nil { - return - } - - if user.GetPermissionLevel() < bridgeconfig.PermissionLevelUser { - _, _ = intent.SendNotice(ctx, evt.RoomID, "You are not whitelisted to use this bridge.\n"+ - "If you're the owner of this bridge, see the bridge.permissions section in your config file.") - _, _ = intent.LeaveRoom(ctx, evt.RoomID) - return - } - - texts := mx.bridge.Config.Bridge.GetManagementRoomTexts() - _, _ = mx.sendNoticeWithMarkdown(ctx, evt.RoomID, texts.Welcome) - - if len(members.Joined) == 2 && (len(user.GetManagementRoomID()) == 0 || evt.Content.AsMember().IsDirect) { - user.SetManagementRoom(evt.RoomID) - _, _ = intent.SendNotice(ctx, user.GetManagementRoomID(), "This room has been registered as your bridge management/status room.") - zerolog.Ctx(ctx).Debug().Msg("Registered room as management room with inviter") - } - - if evt.RoomID == user.GetManagementRoomID() { - if user.IsLoggedIn() { - _, _ = mx.sendNoticeWithMarkdown(ctx, evt.RoomID, texts.WelcomeConnected) - } else { - _, _ = mx.sendNoticeWithMarkdown(ctx, evt.RoomID, texts.WelcomeUnconnected) - } - - additionalHelp := texts.AdditionalHelp - if len(additionalHelp) > 0 { - _, _ = mx.sendNoticeWithMarkdown(ctx, evt.RoomID, additionalHelp) - } - } -} - -func (mx *MatrixHandler) HandleGhostInvite(ctx context.Context, evt *event.Event, inviter User, ghost Ghost) { - log := zerolog.Ctx(ctx) - intent := ghost.DefaultIntent() - - if inviter.GetPermissionLevel() < bridgeconfig.PermissionLevelUser { - log.Debug().Msg("Rejecting invite: inviter is not whitelisted") - _, err := intent.LeaveRoom(ctx, evt.RoomID, &mautrix.ReqLeave{ - Reason: "You're not whitelisted to use this bridge", - }) - if err != nil { - log.Error().Err(err).Msg("Failed to reject invite") - } - return - } else if !inviter.IsLoggedIn() { - log.Debug().Msg("Rejecting invite: inviter is not logged in") - _, err := intent.LeaveRoom(ctx, evt.RoomID, &mautrix.ReqLeave{ - Reason: "You're not logged into this bridge", - }) - if err != nil { - log.Error().Err(err).Msg("Failed to reject invite") - } - return - } - - members := mx.joinAndCheckMembers(ctx, evt, intent) - if members == nil { - return - } - var createEvent event.CreateEventContent - if err := intent.StateEvent(ctx, evt.RoomID, event.StateCreate, "", &createEvent); err != nil { - log.Warn().Err(err).Msg("Failed to check m.room.create event in room") - } else if createEvent.Type != "" { - log.Warn().Str("room_type", string(createEvent.Type)).Msg("Non-standard room type, leaving room") - _, err = intent.LeaveRoom(ctx, evt.RoomID, &mautrix.ReqLeave{ - Reason: "Unsupported room type", - }) - if err != nil { - log.Error().Err(err).Msg("Failed to leave room") - } - return - } - var hasBridgeBot, hasOtherUsers bool - for mxid, _ := range members.Joined { - if mxid == intent.UserID || mxid == inviter.GetMXID() { - continue - } else if mxid == mx.bridge.Bot.UserID { - hasBridgeBot = true - } else { - hasOtherUsers = true - } - } - if !hasBridgeBot && !hasOtherUsers && evt.Content.AsMember().IsDirect { - mx.bridge.Child.CreatePrivatePortal(evt.RoomID, inviter, ghost) - } else if !hasBridgeBot { - log.Debug().Msg("Leaving multi-user room after accepting invite") - _, _ = intent.SendNotice(ctx, evt.RoomID, "Please invite the bridge bot first if you want to bridge to a remote chat.") - _, _ = intent.LeaveRoom(ctx, evt.RoomID) - } else { - _, _ = intent.SendNotice(ctx, evt.RoomID, "This puppet will remain inactive until this room is bridged to a remote chat.") - } -} - -func (mx *MatrixHandler) HandleMembership(ctx context.Context, evt *event.Event) { - if evt.Sender == mx.bridge.Bot.UserID || mx.bridge.Child.IsGhost(evt.Sender) { - return - } - defer mx.TrackEventDuration(evt.Type)() - - if mx.bridge.Crypto != nil { - mx.bridge.Crypto.HandleMemberEvent(ctx, evt) - } - - log := mx.log.With(). - Str("sender", evt.Sender.String()). - Str("target", evt.GetStateKey()). - Str("room_id", evt.RoomID.String()). - Logger() - ctx = log.WithContext(ctx) - - content := evt.Content.AsMember() - if content.Membership == event.MembershipInvite && id.UserID(evt.GetStateKey()) == mx.as.BotMXID() { - mx.HandleBotInvite(ctx, evt) - return - } - - if mx.shouldIgnoreEvent(evt) { - return - } - - user := mx.bridge.Child.GetIUser(evt.Sender, true) - if user == nil { - return - } - isSelf := id.UserID(evt.GetStateKey()) == evt.Sender - ghost := mx.bridge.Child.GetIGhost(id.UserID(evt.GetStateKey())) - portal := mx.bridge.Child.GetIPortal(evt.RoomID) - if portal == nil { - if ghost != nil && content.Membership == event.MembershipInvite { - mx.HandleGhostInvite(ctx, evt, user, ghost) - } - return - } else if user.GetPermissionLevel() < bridgeconfig.PermissionLevelUser || !user.IsLoggedIn() { - return - } - bhp, bhpOk := portal.(BanHandlingPortal) - mhp, mhpOk := portal.(MembershipHandlingPortal) - khp, khpOk := portal.(KnockHandlingPortal) - ihp, ihpOk := portal.(InviteHandlingPortal) - if !(mhpOk || bhpOk || khpOk) { - return - } - prevContent := &event.MemberEventContent{Membership: event.MembershipLeave} - if evt.Unsigned.PrevContent != nil { - _ = evt.Unsigned.PrevContent.ParseRaw(evt.Type) - prevContent, _ = evt.Unsigned.PrevContent.Parsed.(*event.MemberEventContent) - } - if ihpOk && prevContent.Membership == event.MembershipInvite && content.Membership != event.MembershipBan { - if content.Membership == event.MembershipJoin { - ihp.HandleMatrixAcceptInvite(user, evt) - } - if content.Membership == event.MembershipLeave { - if isSelf { - ihp.HandleMatrixRejectInvite(user, evt) - } else if ghost != nil { - ihp.HandleMatrixRetractInvite(user, ghost, evt) - } - } - } - if bhpOk && ghost != nil { - if content.Membership == event.MembershipBan { - bhp.HandleMatrixBan(user, ghost, evt) - } else if content.Membership == event.MembershipLeave && prevContent.Membership == event.MembershipBan { - bhp.HandleMatrixUnban(user, ghost, evt) - } - } - if khpOk { - if content.Membership == event.MembershipKnock { - khp.HandleMatrixKnock(user, evt) - } else if prevContent.Membership == event.MembershipKnock { - if content.Membership == event.MembershipInvite && ghost != nil { - khp.HandleMatrixAcceptKnock(user, ghost, evt) - } else if content.Membership == event.MembershipLeave { - if isSelf { - khp.HandleMatrixRetractKnock(user, evt) - } else if ghost != nil { - khp.HandleMatrixRejectKnock(user, ghost, evt) - } - } - } - } - if mhpOk { - if content.Membership == event.MembershipLeave && prevContent.Membership == event.MembershipJoin { - if isSelf { - mhp.HandleMatrixLeave(user, evt) - } else if ghost != nil { - mhp.HandleMatrixKick(user, ghost, evt) - } - } else if content.Membership == event.MembershipInvite && !isSelf && ghost != nil { - mhp.HandleMatrixInvite(user, ghost, evt) - } - } - // TODO kicking/inviting non-ghost users users -} - -func (mx *MatrixHandler) HandleRoomMetadata(ctx context.Context, evt *event.Event) { - defer mx.TrackEventDuration(evt.Type)() - if mx.shouldIgnoreEvent(evt) { - return - } - - user := mx.bridge.Child.GetIUser(evt.Sender, true) - if user == nil { - return - } - - portal := mx.bridge.Child.GetIPortal(evt.RoomID) - if portal == nil || portal.IsPrivateChat() { - return - } - - metaPortal, ok := portal.(MetaHandlingPortal) - if !ok { - return - } - - metaPortal.HandleMatrixMeta(user, evt) -} - -func (mx *MatrixHandler) shouldIgnoreEvent(evt *event.Event) bool { - if evt.Sender == mx.bridge.Bot.UserID || mx.bridge.Child.IsGhost(evt.Sender) { - return true - } - user := mx.bridge.Child.GetIUser(evt.Sender, true) - if user == nil || user.GetPermissionLevel() <= 0 { - return true - } else if val, ok := evt.Content.Raw[appservice.DoublePuppetKey]; ok && val == mx.bridge.Name && user.GetIDoublePuppet() != nil { - return true - } - return false -} - -const initialSessionWaitTimeout = 3 * time.Second -const extendedSessionWaitTimeout = 22 * time.Second - -func (mx *MatrixHandler) sendCryptoStatusError(ctx context.Context, evt *event.Event, editEvent id.EventID, err error, retryCount int, isFinal bool) id.EventID { - mx.bridge.SendMessageErrorCheckpoint(evt, status.MsgStepDecrypted, err, isFinal, retryCount) - - if mx.bridge.Config.Bridge.EnableMessageStatusEvents() { - statusEvent := &event.BeeperMessageStatusEventContent{ - // TODO: network - RelatesTo: event.RelatesTo{ - Type: event.RelReference, - EventID: evt.ID, - }, - Status: event.MessageStatusRetriable, - Reason: event.MessageStatusUndecryptable, - Error: err.Error(), - Message: errorToHumanMessage(err), - } - if !isFinal { - statusEvent.Status = event.MessageStatusPending - } - _, sendErr := mx.bridge.Bot.SendMessageEvent(ctx, evt.RoomID, event.BeeperMessageStatus, statusEvent) - if sendErr != nil { - zerolog.Ctx(ctx).Error().Err(err).Msg("Failed to send message status event") - } - } - if mx.bridge.Config.Bridge.EnableMessageErrorNotices() { - update := event.MessageEventContent{ - MsgType: event.MsgNotice, - Body: fmt.Sprintf("\u26a0 Your message was not bridged: %v.", err), - } - if errors.Is(err, errNoCrypto) { - update.Body = "🔒 This bridge has not been configured to support encryption" - } - relatable, ok := evt.Content.Parsed.(event.Relatable) - if editEvent != "" { - update.SetEdit(editEvent) - } else if ok && relatable.OptionalGetRelatesTo().GetThreadParent() != "" { - update.GetRelatesTo().SetThread(relatable.OptionalGetRelatesTo().GetThreadParent(), evt.ID) - } - resp, sendErr := mx.bridge.Bot.SendMessageEvent(ctx, evt.RoomID, event.EventMessage, &update) - if sendErr != nil { - zerolog.Ctx(ctx).Error().Err(sendErr).Msg("Failed to send decryption error notice") - } else if resp != nil { - return resp.EventID - } - } - return "" -} - -var ( - errDeviceNotTrusted = errors.New("your device is not trusted") - errMessageNotEncrypted = errors.New("unencrypted message") - errNoDecryptionKeys = errors.New("the bridge hasn't received the decryption keys") - errNoCrypto = errors.New("this bridge has not been configured to support encryption") -) - -func errorToHumanMessage(err error) string { - var withheld *event.RoomKeyWithheldEventContent - switch { - case errors.Is(err, errDeviceNotTrusted), errors.Is(err, errNoDecryptionKeys): - return err.Error() - case errors.Is(err, UnknownMessageIndex): - return "the keys received by the bridge can't decrypt the message" - case errors.Is(err, DuplicateMessageIndex): - return "your client encrypted multiple messages with the same key" - case errors.As(err, &withheld): - if withheld.Code == event.RoomKeyWithheldBeeperRedacted { - return "your client used an outdated encryption session" - } - return "your client refused to share decryption keys with the bridge" - case errors.Is(err, errMessageNotEncrypted): - return "the message is not encrypted" - default: - return "the bridge failed to decrypt the message" - } -} - -func deviceUnverifiedErrorWithExplanation(trust id.TrustState) error { - var explanation string - switch trust { - case id.TrustStateBlacklisted: - explanation = "device is blacklisted" - case id.TrustStateUnset: - explanation = "unverified" - case id.TrustStateUnknownDevice: - explanation = "device info not found" - case id.TrustStateForwarded: - explanation = "keys were forwarded from an unknown device" - case id.TrustStateCrossSignedUntrusted: - explanation = "cross-signing keys changed after setting up the bridge" - default: - return errDeviceNotTrusted - } - return fmt.Errorf("%w (%s)", errDeviceNotTrusted, explanation) -} - -func copySomeKeys(original, decrypted *event.Event) { - isScheduled, _ := original.Content.Raw["com.beeper.scheduled"].(bool) - _, alreadyExists := decrypted.Content.Raw["com.beeper.scheduled"] - if isScheduled && !alreadyExists { - decrypted.Content.Raw["com.beeper.scheduled"] = true - } -} - -func (mx *MatrixHandler) postDecrypt(ctx context.Context, original, decrypted *event.Event, retryCount int, errorEventID id.EventID, duration time.Duration) { - log := zerolog.Ctx(ctx) - minLevel := mx.bridge.Config.Bridge.GetEncryptionConfig().VerificationLevels.Send - if decrypted.Mautrix.TrustState < minLevel { - logEvt := log.Warn(). - Str("user_id", decrypted.Sender.String()). - Bool("forwarded_keys", decrypted.Mautrix.ForwardedKeys). - Stringer("device_trust", decrypted.Mautrix.TrustState). - Stringer("min_trust", minLevel) - if decrypted.Mautrix.TrustSource != nil { - dev := decrypted.Mautrix.TrustSource - logEvt. - Str("device_id", dev.DeviceID.String()). - Str("device_signing_key", dev.SigningKey.String()) - } else { - logEvt.Str("device_id", "unknown") - } - logEvt.Msg("Dropping event due to insufficient verification level") - err := deviceUnverifiedErrorWithExplanation(decrypted.Mautrix.TrustState) - go mx.sendCryptoStatusError(ctx, decrypted, errorEventID, err, retryCount, true) - return - } - copySomeKeys(original, decrypted) - - mx.bridge.SendMessageSuccessCheckpoint(decrypted, status.MsgStepDecrypted, retryCount) - decrypted.Mautrix.CheckpointSent = true - decrypted.Mautrix.DecryptionDuration = duration - decrypted.Mautrix.EventSource |= event.SourceDecrypted - mx.bridge.EventProcessor.Dispatch(ctx, decrypted) - if errorEventID != "" { - _, _ = mx.bridge.Bot.RedactEvent(ctx, decrypted.RoomID, errorEventID) - } -} - -func (mx *MatrixHandler) HandleEncrypted(ctx context.Context, evt *event.Event) { - defer mx.TrackEventDuration(evt.Type)() - if mx.shouldIgnoreEvent(evt) { - return - } - content := evt.Content.AsEncrypted() - log := zerolog.Ctx(ctx).With(). - Str("event_id", evt.ID.String()). - Str("session_id", content.SessionID.String()). - Logger() - ctx = log.WithContext(ctx) - if mx.bridge.Crypto == nil { - go mx.sendCryptoStatusError(ctx, evt, "", errNoCrypto, 0, true) - return - } - log.Debug().Msg("Decrypting received event") - - decryptionStart := time.Now() - decrypted, err := mx.bridge.Crypto.Decrypt(ctx, evt) - decryptionRetryCount := 0 - if errors.Is(err, NoSessionFound) { - decryptionRetryCount = 1 - log.Debug(). - Int("wait_seconds", int(initialSessionWaitTimeout.Seconds())). - Msg("Couldn't find session, waiting for keys to arrive...") - mx.bridge.SendMessageErrorCheckpoint(evt, status.MsgStepDecrypted, err, false, 0) - if mx.bridge.Crypto.WaitForSession(ctx, evt.RoomID, content.SenderKey, content.SessionID, initialSessionWaitTimeout) { - log.Debug().Msg("Got keys after waiting, trying to decrypt event again") - decrypted, err = mx.bridge.Crypto.Decrypt(ctx, evt) - } else { - go mx.waitLongerForSession(ctx, evt, decryptionStart) - return - } - } - if err != nil { - mx.bridge.SendMessageErrorCheckpoint(evt, status.MsgStepDecrypted, err, true, decryptionRetryCount) - log.Warn().Err(err).Msg("Failed to decrypt event") - go mx.sendCryptoStatusError(ctx, evt, "", err, decryptionRetryCount, true) - return - } - mx.postDecrypt(ctx, evt, decrypted, decryptionRetryCount, "", time.Since(decryptionStart)) -} - -func (mx *MatrixHandler) waitLongerForSession(ctx context.Context, evt *event.Event, decryptionStart time.Time) { - log := zerolog.Ctx(ctx) - content := evt.Content.AsEncrypted() - log.Debug(). - Int("wait_seconds", int(extendedSessionWaitTimeout.Seconds())). - Msg("Couldn't find session, requesting keys and waiting longer...") - - go mx.bridge.Crypto.RequestSession(ctx, evt.RoomID, content.SenderKey, content.SessionID, evt.Sender, content.DeviceID) - errorEventID := mx.sendCryptoStatusError(ctx, evt, "", fmt.Errorf("%w. The bridge will retry for %d seconds", errNoDecryptionKeys, int(extendedSessionWaitTimeout.Seconds())), 1, false) - - if !mx.bridge.Crypto.WaitForSession(ctx, evt.RoomID, content.SenderKey, content.SessionID, extendedSessionWaitTimeout) { - log.Debug().Msg("Didn't get session, giving up trying to decrypt event") - mx.sendCryptoStatusError(ctx, evt, errorEventID, errNoDecryptionKeys, 2, true) - return - } - - log.Debug().Msg("Got keys after waiting longer, trying to decrypt event again") - decrypted, err := mx.bridge.Crypto.Decrypt(ctx, evt) - if err != nil { - log.Error().Err(err).Msg("Failed to decrypt event") - mx.sendCryptoStatusError(ctx, evt, errorEventID, err, 2, true) - return - } - - mx.postDecrypt(ctx, evt, decrypted, 2, errorEventID, time.Since(decryptionStart)) -} - -func (mx *MatrixHandler) HandleMessage(ctx context.Context, evt *event.Event) { - defer mx.TrackEventDuration(evt.Type)() - log := zerolog.Ctx(ctx).With(). - Str("event_id", evt.ID.String()). - Str("room_id", evt.RoomID.String()). - Str("sender", evt.Sender.String()). - Logger() - ctx = log.WithContext(ctx) - if mx.shouldIgnoreEvent(evt) { - return - } else if !evt.Mautrix.WasEncrypted && mx.bridge.Config.Bridge.GetEncryptionConfig().Require { - log.Warn().Msg("Dropping unencrypted event") - mx.sendCryptoStatusError(ctx, evt, "", errMessageNotEncrypted, 0, true) - return - } - - user := mx.bridge.Child.GetIUser(evt.Sender, true) - if user == nil { - return - } - - content := evt.Content.AsMessage() - content.RemoveReplyFallback() - if user.GetPermissionLevel() >= bridgeconfig.PermissionLevelUser && content.MsgType == event.MsgText { - commandPrefix := mx.bridge.Config.Bridge.GetCommandPrefix() - hasCommandPrefix := strings.HasPrefix(content.Body, commandPrefix) - if hasCommandPrefix { - content.Body = strings.TrimLeft(strings.TrimPrefix(content.Body, commandPrefix), " ") - } - if hasCommandPrefix || evt.RoomID == user.GetManagementRoomID() { - go mx.bridge.CommandProcessor.Handle(ctx, evt.RoomID, evt.ID, user, content.Body, content.RelatesTo.GetReplyTo()) - go mx.bridge.SendMessageSuccessCheckpoint(evt, status.MsgStepCommand, 0) - if mx.bridge.Config.Bridge.EnableMessageStatusEvents() { - statusEvent := &event.BeeperMessageStatusEventContent{ - // TODO: network - RelatesTo: event.RelatesTo{ - Type: event.RelReference, - EventID: evt.ID, - }, - Status: event.MessageStatusSuccess, - } - _, sendErr := mx.bridge.Bot.SendMessageEvent(ctx, evt.RoomID, event.BeeperMessageStatus, statusEvent) - if sendErr != nil { - log.Warn().Err(sendErr).Msg("Failed to send message status event for command") - } - } - return - } - } - - portal := mx.bridge.Child.GetIPortal(evt.RoomID) - if portal != nil { - portal.ReceiveMatrixEvent(user, evt) - } else { - mx.bridge.SendMessageErrorCheckpoint(evt, status.MsgStepRemote, fmt.Errorf("unknown room"), true, 0) - } -} - -func (mx *MatrixHandler) HandleReaction(_ context.Context, evt *event.Event) { - defer mx.TrackEventDuration(evt.Type)() - if mx.shouldIgnoreEvent(evt) { - return - } - - user := mx.bridge.Child.GetIUser(evt.Sender, true) - if user == nil || user.GetPermissionLevel() < bridgeconfig.PermissionLevelUser || !user.IsLoggedIn() { - return - } - - portal := mx.bridge.Child.GetIPortal(evt.RoomID) - if portal != nil { - portal.ReceiveMatrixEvent(user, evt) - } else { - mx.bridge.SendMessageErrorCheckpoint(evt, status.MsgStepRemote, fmt.Errorf("unknown room"), true, 0) - } -} - -func (mx *MatrixHandler) HandleRedaction(_ context.Context, evt *event.Event) { - defer mx.TrackEventDuration(evt.Type)() - if mx.shouldIgnoreEvent(evt) { - return - } - - user := mx.bridge.Child.GetIUser(evt.Sender, true) - if user == nil { - return - } - - portal := mx.bridge.Child.GetIPortal(evt.RoomID) - if portal != nil { - portal.ReceiveMatrixEvent(user, evt) - } else { - mx.bridge.SendMessageErrorCheckpoint(evt, status.MsgStepRemote, fmt.Errorf("unknown room"), true, 0) - } -} - -func (mx *MatrixHandler) HandleReceipt(_ context.Context, evt *event.Event) { - portal := mx.bridge.Child.GetIPortal(evt.RoomID) - if portal == nil { - return - } - - rrPortal, ok := portal.(ReadReceiptHandlingPortal) - if !ok { - return - } - - for eventID, receipts := range *evt.Content.AsReceipt() { - for userID, receipt := range receipts[event.ReceiptTypeRead] { - user := mx.bridge.Child.GetIUser(userID, false) - if user == nil { - // Not a bridge user - continue - } - customPuppet := user.GetIDoublePuppet() - if val, ok := receipt.Extra[appservice.DoublePuppetKey].(string); ok && customPuppet != nil && val == mx.bridge.Name { - // Ignore double puppeted read receipts. - mx.log.Debug().Interface("content", evt.Content.Raw).Msg("Ignoring double-puppeted read receipt") - // But do start disappearing messages, because the user read the chat - dp, ok := portal.(DisappearingPortal) - if ok { - dp.ScheduleDisappearing() - } - } else { - rrPortal.HandleMatrixReadReceipt(user, eventID, receipt) - } - } - } -} - -func (mx *MatrixHandler) HandleTyping(_ context.Context, evt *event.Event) { - portal := mx.bridge.Child.GetIPortal(evt.RoomID) - if portal == nil { - return - } - typingPortal, ok := portal.(TypingPortal) - if !ok { - return - } - typingPortal.HandleMatrixTyping(evt.Content.AsTyping().UserIDs) -} - -func (mx *MatrixHandler) HandlePowerLevels(_ context.Context, evt *event.Event) { - if mx.shouldIgnoreEvent(evt) { - return - } - portal := mx.bridge.Child.GetIPortal(evt.RoomID) - if portal == nil { - return - } - powerLevelPortal, ok := portal.(PowerLevelHandlingPortal) - if ok { - user := mx.bridge.Child.GetIUser(evt.Sender, true) - powerLevelPortal.HandleMatrixPowerLevels(user, evt) - } -} - -func (mx *MatrixHandler) HandleJoinRule(_ context.Context, evt *event.Event) { - if mx.shouldIgnoreEvent(evt) { - return - } - portal := mx.bridge.Child.GetIPortal(evt.RoomID) - if portal == nil { - return - } - joinRulePortal, ok := portal.(JoinRuleHandlingPortal) - if ok { - user := mx.bridge.Child.GetIUser(evt.Sender, true) - joinRulePortal.HandleMatrixJoinRule(user, evt) - } -} diff --git a/bridge/messagecheckpoint.go b/bridge/messagecheckpoint.go deleted file mode 100644 index a95d2160..00000000 --- a/bridge/messagecheckpoint.go +++ /dev/null @@ -1,61 +0,0 @@ -// Copyright (c) 2021 Sumner Evans -// Copyright (c) 2023 Tulir Asokan -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. - -package bridge - -import ( - "maunium.net/go/mautrix/appservice" - "maunium.net/go/mautrix/bridge/status" - "maunium.net/go/mautrix/event" -) - -func (br *Bridge) SendMessageSuccessCheckpoint(evt *event.Event, step status.MessageCheckpointStep, retryNum int) { - br.SendMessageCheckpoint(evt, step, nil, status.MsgStatusSuccess, retryNum) -} - -func (br *Bridge) SendMessageErrorCheckpoint(evt *event.Event, step status.MessageCheckpointStep, err error, permanent bool, retryNum int) { - s := status.MsgStatusWillRetry - if permanent { - s = status.MsgStatusPermFailure - } - br.SendMessageCheckpoint(evt, step, err, s, retryNum) -} - -func (br *Bridge) SendMessageCheckpoint(evt *event.Event, step status.MessageCheckpointStep, err error, s status.MessageCheckpointStatus, retryNum int) { - checkpoint := status.NewMessageCheckpoint(evt, step, s, retryNum) - if err != nil { - checkpoint.Info = err.Error() - } - go br.SendRawMessageCheckpoint(checkpoint) -} - -func (br *Bridge) SendRawMessageCheckpoint(cp *status.MessageCheckpoint) { - err := br.SendMessageCheckpoints([]*status.MessageCheckpoint{cp}) - if err != nil { - br.ZLog.Warn().Err(err).Interface("message_checkpoint", cp).Msg("Error sending message checkpoint") - } else { - br.ZLog.Debug().Interface("message_checkpoint", cp).Msg("Sent message checkpoint") - } -} - -func (br *Bridge) SendMessageCheckpoints(checkpoints []*status.MessageCheckpoint) error { - checkpointsJSON := status.CheckpointsJSON{Checkpoints: checkpoints} - - if br.Websocket { - return br.AS.SendWebsocket(&appservice.WebsocketRequest{ - Command: "message_checkpoint", - Data: checkpointsJSON, - }) - } - - endpoint := br.Config.Homeserver.MessageSendCheckpointEndpoint - if endpoint == "" { - return nil - } - - return checkpointsJSON.SendHTTP(endpoint, br.AS.Registration.AppToken) -} diff --git a/bridge/no-crypto.go b/bridge/no-crypto.go deleted file mode 100644 index 019ab7c1..00000000 --- a/bridge/no-crypto.go +++ /dev/null @@ -1,26 +0,0 @@ -// Copyright (c) 2023 Tulir Asokan -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. - -//go:build !cgo || nocrypto - -package bridge - -import ( - "errors" -) - -func NewCryptoHelper(bridge *Bridge) Crypto { - if bridge.Config.Bridge.GetEncryptionConfig().Allow { - bridge.ZLog.Warn().Msg("Bridge built without end-to-bridge encryption, but encryption is enabled in config") - } else { - bridge.ZLog.Debug().Msg("Bridge built without end-to-bridge encryption") - } - return nil -} - -var NoSessionFound = errors.New("nil") -var UnknownMessageIndex = NoSessionFound -var DuplicateMessageIndex = NoSessionFound diff --git a/bridge/websocket.go b/bridge/websocket.go deleted file mode 100644 index 44a3d8d8..00000000 --- a/bridge/websocket.go +++ /dev/null @@ -1,163 +0,0 @@ -package bridge - -import ( - "context" - "errors" - "fmt" - "sync" - "time" - - "go.mau.fi/util/jsontime" - - "maunium.net/go/mautrix/appservice" -) - -const defaultReconnectBackoff = 2 * time.Second -const maxReconnectBackoff = 2 * time.Minute -const reconnectBackoffReset = 5 * time.Minute - -func (br *Bridge) startWebsocket(wg *sync.WaitGroup) { - log := br.ZLog.With().Str("action", "appservice websocket").Logger() - var wgOnce sync.Once - onConnect := func() { - wssBr, ok := br.Child.(WebsocketStartingBridge) - if ok { - wssBr.OnWebsocketConnect() - } - if br.latestState != nil { - go func() { - ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) - defer cancel() - br.latestState.Timestamp = jsontime.UnixNow() - err := br.SendBridgeState(ctx, br.latestState) - if err != nil { - log.Err(err).Msg("Failed to resend latest bridge state after websocket reconnect") - } else { - log.Debug().Any("bridge_state", br.latestState).Msg("Resent bridge state after websocket reconnect") - } - }() - } - wgOnce.Do(wg.Done) - select { - case br.wsStarted <- struct{}{}: - default: - } - } - reconnectBackoff := defaultReconnectBackoff - lastDisconnect := time.Now().UnixNano() - br.wsStopped = make(chan struct{}) - defer func() { - log.Debug().Msg("Appservice websocket loop finished") - close(br.wsStopped) - }() - addr := br.Config.Homeserver.WSProxy - if addr == "" { - addr = br.Config.Homeserver.Address - } - for { - err := br.AS.StartWebsocket(addr, onConnect) - if errors.Is(err, appservice.ErrWebsocketManualStop) { - return - } else if closeCommand := (&appservice.CloseCommand{}); errors.As(err, &closeCommand) && closeCommand.Status == appservice.MeowConnectionReplaced { - log.Info().Msg("Appservice websocket closed by another instance of the bridge, shutting down...") - br.ManualStop(0) - return - } else if err != nil { - log.Err(err).Msg("Error in appservice websocket") - } - if br.Stopping { - return - } - now := time.Now().UnixNano() - if lastDisconnect+reconnectBackoffReset.Nanoseconds() < now { - reconnectBackoff = defaultReconnectBackoff - } else { - reconnectBackoff *= 2 - if reconnectBackoff > maxReconnectBackoff { - reconnectBackoff = maxReconnectBackoff - } - } - lastDisconnect = now - log.Info(). - Int("backoff_seconds", int(reconnectBackoff.Seconds())). - Msg("Websocket disconnected, reconnecting...") - select { - case <-br.wsShortCircuitReconnectBackoff: - log.Debug().Msg("Reconnect backoff was short-circuited") - case <-time.After(reconnectBackoff): - } - if br.Stopping { - return - } - } -} - -type wsPingData struct { - Timestamp int64 `json:"timestamp"` -} - -func (br *Bridge) PingServer() (start, serverTs, end time.Time) { - if !br.Websocket { - panic(fmt.Errorf("PingServer called without websocket enabled")) - } - if !br.AS.HasWebsocket() { - br.ZLog.Debug().Msg("Received server ping request, but no websocket connected. Trying to short-circuit backoff sleep") - select { - case br.wsShortCircuitReconnectBackoff <- struct{}{}: - default: - br.ZLog.Warn().Msg("Failed to ping websocket: not connected and no backoff?") - return - } - select { - case <-br.wsStarted: - case <-time.After(15 * time.Second): - if !br.AS.HasWebsocket() { - br.ZLog.Warn().Msg("Failed to ping websocket: didn't connect after 15 seconds of waiting") - return - } - } - } - start = time.Now() - var resp wsPingData - br.ZLog.Debug().Msg("Pinging appservice websocket") - ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) - defer cancel() - err := br.AS.RequestWebsocket(ctx, &appservice.WebsocketRequest{ - Command: "ping", - Data: &wsPingData{Timestamp: start.UnixMilli()}, - }, &resp) - end = time.Now() - if err != nil { - br.ZLog.Warn().Err(err).Dur("duration", end.Sub(start)).Msg("Websocket ping returned error") - br.AS.StopWebsocket(fmt.Errorf("websocket ping returned error in %s: %w", end.Sub(start), err)) - } else { - serverTs = time.Unix(0, resp.Timestamp*int64(time.Millisecond)) - br.ZLog.Debug(). - Dur("duration", end.Sub(start)). - Dur("req_duration", serverTs.Sub(start)). - Dur("resp_duration", end.Sub(serverTs)). - Msg("Websocket ping returned success") - } - return -} - -func (br *Bridge) websocketServerPinger() { - interval := time.Duration(br.Config.Homeserver.WSPingInterval) * time.Second - clock := time.NewTicker(interval) - defer func() { - br.ZLog.Info().Msg("Stopping websocket pinger") - clock.Stop() - }() - br.ZLog.Info().Dur("interval_duration", interval).Msg("Starting websocket pinger") - for { - select { - case <-clock.C: - br.PingServer() - case <-br.wsStopPinger: - return - } - if br.Stopping { - return - } - } -} diff --git a/bridgev2/backfillqueue.go b/bridgev2/backfillqueue.go index fce4a1b0..61318d94 100644 --- a/bridgev2/backfillqueue.go +++ b/bridgev2/backfillqueue.go @@ -38,8 +38,10 @@ func (br *Bridge) RunBackfillQueue() { return } ctx, cancel := context.WithCancel(log.WithContext(context.Background())) + br.stopBackfillQueue.Clear() + stopChan := br.stopBackfillQueue.GetChan() go func() { - <-br.stopBackfillQueue + <-stopChan cancel() }() batchDelay := time.Duration(br.Config.Backfill.Queue.BatchDelay) * time.Second @@ -61,7 +63,7 @@ func (br *Bridge) RunBackfillQueue() { } } noTasksFoundCount = 0 - case <-br.stopBackfillQueue: + case <-stopChan: if !timer.Stop() { select { case <-timer.C: @@ -78,13 +80,13 @@ func (br *Bridge) RunBackfillQueue() { time.Sleep(BackfillQueueErrorBackoff) continue } else if backfillTask != nil { - br.doBackfillTask(ctx, backfillTask) + br.DoBackfillTask(ctx, backfillTask) noTasksFoundCount = 0 } } } -func (br *Bridge) doBackfillTask(ctx context.Context, task *database.BackfillTask) { +func (br *Bridge) DoBackfillTask(ctx context.Context, task *database.BackfillTask) { log := zerolog.Ctx(ctx).With(). Object("portal_key", task.PortalKey). Str("login_id", string(task.UserLoginID)). diff --git a/bridgev2/bridge.go b/bridgev2/bridge.go index b2151ee6..226adc90 100644 --- a/bridgev2/bridge.go +++ b/bridgev2/bridge.go @@ -9,15 +9,20 @@ package bridgev2 import ( "context" "fmt" + "os" "sync" + "sync/atomic" + "time" "github.com/rs/zerolog" "go.mau.fi/util/dbutil" + "go.mau.fi/util/exhttp" + "go.mau.fi/util/exsync" - "maunium.net/go/mautrix/bridge/status" "maunium.net/go/mautrix/bridgev2/bridgeconfig" "maunium.net/go/mautrix/bridgev2/database" "maunium.net/go/mautrix/bridgev2/networkid" + "maunium.net/go/mautrix/bridgev2/status" "maunium.net/go/mautrix/id" ) @@ -45,8 +50,17 @@ type Bridge struct { ghostsByID map[networkid.UserID]*Ghost cacheLock sync.Mutex + didSplitPortals bool + + Background bool + ExternallyManagedDB bool + stopping atomic.Bool + wakeupBackfillQueue chan struct{} - stopBackfillQueue chan struct{} + stopBackfillQueue *exsync.Event + + BackgroundCtx context.Context + cancelBackgroundCtx context.CancelFunc } func NewBridge( @@ -74,7 +88,7 @@ func NewBridge( ghostsByID: make(map[networkid.UserID]*Ghost), wakeupBackfillQueue: make(chan struct{}), - stopBackfillQueue: make(chan struct{}), + stopBackfillQueue: exsync.NewEvent(), } if br.Config == nil { br.Config = &bridgeconfig.BridgeConfig{CommandPrefix: "!bridge"} @@ -100,29 +114,89 @@ func (e DBUpgradeError) Unwrap() error { return e.Err } -func (br *Bridge) Start() error { - err := br.StartConnectors() +func (br *Bridge) Start(ctx context.Context) error { + ctx = br.Log.WithContext(ctx) + err := br.StartConnectors(ctx) if err != nil { return err } - err = br.StartLogins() + err = br.StartLogins(ctx) if err != nil { return err } + go br.PostStart(ctx) return nil } -func (br *Bridge) StartConnectors() error { - br.Log.Info().Msg("Starting bridge") - ctx := br.Log.WithContext(context.Background()) - - err := br.DB.Upgrade(ctx) +func (br *Bridge) RunOnce(ctx context.Context, loginID networkid.UserLoginID, params *ConnectBackgroundParams) error { + br.Background = true + br.stopping.Store(false) + err := br.StartConnectors(ctx) if err != nil { - return DBUpgradeError{Err: err, Section: "main"} + return err + } + + if loginID == "" { + br.Log.Info().Msg("No login ID provided to RunOnce, running all logins for 20 seconds") + err = br.StartLogins(ctx) + if err != nil { + return err + } + defer br.StopWithTimeout(5 * time.Second) + select { + case <-time.After(20 * time.Second): + case <-ctx.Done(): + } + return nil + } + + defer br.stop(true, 5*time.Second) + login, err := br.GetExistingUserLoginByID(ctx, loginID) + if err != nil { + return fmt.Errorf("failed to get user login: %w", err) + } else if login == nil { + return ErrNotLoggedIn + } + syncClient, ok := login.Client.(BackgroundSyncingNetworkAPI) + if !ok { + br.Log.Warn().Msg("Network connector doesn't implement background mode, using fallback mechanism for RunOnce") + login.Client.Connect(ctx) + defer login.DisconnectWithTimeout(5 * time.Second) + select { + case <-time.After(20 * time.Second): + case <-ctx.Done(): + } + br.stopping.Store(true) + return nil + } else { + br.Log.Info().Str("user_login_id", string(login.ID)).Msg("Starting individual user login in background mode") + return syncClient.ConnectBackground(login.Log.WithContext(ctx), params) + } +} + +func (br *Bridge) StartConnectors(ctx context.Context) error { + br.Log.Info().Msg("Starting bridge") + br.stopping.Store(false) + if br.BackgroundCtx == nil || br.BackgroundCtx.Err() != nil { + br.BackgroundCtx, br.cancelBackgroundCtx = context.WithCancel(context.Background()) + br.BackgroundCtx = br.Log.WithContext(br.BackgroundCtx) + } + + if !br.ExternallyManagedDB { + err := br.DB.Upgrade(ctx) + if err != nil { + return DBUpgradeError{Err: err, Section: "main"} + } + } + if !br.Background { + var postMigrate func() + br.didSplitPortals, postMigrate = br.MigrateToSplitPortals(ctx) + if postMigrate != nil { + defer postMigrate() + } } - didSplitPortals := br.MigrateToSplitPortals(ctx) br.Log.Info().Msg("Starting Matrix connector") - err = br.Matrix.Start(ctx) + err := br.Matrix.Start(ctx) if err != nil { return fmt.Errorf("failed to start Matrix connector: %w", err) } @@ -131,16 +205,40 @@ func (br *Bridge) StartConnectors() error { if err != nil { return fmt.Errorf("failed to start network connector: %w", err) } - if br.Network.GetCapabilities().DisappearingMessages { + if br.Network.GetCapabilities().DisappearingMessages && !br.Background { go br.DisappearLoop.Start() } - if didSplitPortals || br.Config.ResendBridgeInfo { - br.ResendBridgeInfo(ctx) - } return nil } -func (br *Bridge) ResendBridgeInfo(ctx context.Context) { +func (br *Bridge) PostStart(ctx context.Context) { + if br.Background { + return + } + rawBridgeInfoVer := br.DB.KV.Get(ctx, database.KeyBridgeInfoVersion) + bridgeInfoVer, capVer, err := parseBridgeInfoVersion(rawBridgeInfoVer) + if err != nil { + br.Log.Err(err).Str("db_bridge_info_version", rawBridgeInfoVer).Msg("Failed to parse bridge info version") + return + } + expectedBridgeInfoVer, expectedCapVer := br.Network.GetBridgeInfoVersion() + doResendBridgeInfo := bridgeInfoVer != expectedBridgeInfoVer || br.didSplitPortals || br.Config.ResendBridgeInfo + doResendCapabilities := capVer != expectedCapVer || br.didSplitPortals + if doResendBridgeInfo || doResendCapabilities { + br.ResendBridgeInfo(ctx, doResendBridgeInfo, doResendCapabilities) + } + br.DB.KV.Set(ctx, database.KeyBridgeInfoVersion, fmt.Sprintf("%d,%d", expectedBridgeInfoVer, expectedCapVer)) +} + +func parseBridgeInfoVersion(version string) (info, capabilities int, err error) { + _, err = fmt.Sscanf(version, "%d,%d", &info, &capabilities) + if version == "" { + err = nil + } + return +} + +func (br *Bridge) ResendBridgeInfo(ctx context.Context, resendInfo, resendCaps bool) { log := zerolog.Ctx(ctx).With().Str("action", "resend bridge info").Logger() portals, err := br.GetAllPortalsWithMXID(ctx) if err != nil { @@ -148,30 +246,103 @@ func (br *Bridge) ResendBridgeInfo(ctx context.Context) { return } for _, portal := range portals { - portal.UpdateBridgeInfo(ctx) + if resendInfo { + portal.UpdateBridgeInfo(ctx) + } + if resendCaps { + logins, err := br.GetUserLoginsInPortal(ctx, portal.PortalKey) + if err != nil { + log.Err(err). + Stringer("room_id", portal.MXID). + Object("portal_key", portal.PortalKey). + Msg("Failed to get user logins in portal") + } else { + found := false + for _, login := range logins { + if portal.CapState.ID == "" || login.ID == portal.CapState.Source { + portal.UpdateCapabilities(ctx, login, true) + found = true + } + } + if !found && len(logins) > 0 { + portal.CapState.Source = "" + portal.UpdateCapabilities(ctx, logins[0], true) + } else if !found { + log.Warn(). + Stringer("room_id", portal.MXID). + Object("portal_key", portal.PortalKey). + Msg("No user login found to update capabilities") + } + } + } } - log.Info().Msg("Resent bridge info to all portals") + log.Info(). + Bool("capabilities", resendCaps). + Bool("info", resendInfo). + Msg("Resent bridge info to all portals") } -func (br *Bridge) MigrateToSplitPortals(ctx context.Context) bool { +func (br *Bridge) MigrateToSplitPortals(ctx context.Context) (bool, func()) { log := zerolog.Ctx(ctx).With().Str("action", "migrate to split portals").Logger() ctx = log.WithContext(ctx) if !br.Config.SplitPortals || br.DB.KV.Get(ctx, database.KeySplitPortalsEnabled) == "true" { - return false + return false, nil } affected, err := br.DB.Portal.MigrateToSplitPortals(ctx) if err != nil { - log.Err(err).Msg("Failed to migrate portals") - return false + log.WithLevel(zerolog.FatalLevel).Err(err).Msg("Failed to migrate portals") + os.Exit(31) + return false, nil } log.Info().Int64("rows_affected", affected).Msg("Migrated to split portals") + affected2, err := br.DB.Portal.FixParentsAfterSplitPortalMigration(ctx) + if err != nil { + log.Err(err).Msg("Failed to fix parent portals after split portal migration") + os.Exit(31) + return false, nil + } + log.Info().Int64("rows_affected", affected2).Msg("Updated parent receivers after split portal migration") + withoutReceiver, err := br.DB.Portal.GetAllWithoutReceiver(ctx) + if err != nil { + log.WithLevel(zerolog.FatalLevel).Err(err).Msg("Failed to get portals that failed to migrate") + os.Exit(31) + return false, nil + } + var roomsToDelete []id.RoomID + log.Info().Int("remaining_portals", len(withoutReceiver)).Msg("Deleting remaining portals without receiver") + for _, portal := range withoutReceiver { + if err = br.DB.Portal.Delete(ctx, portal.PortalKey); err != nil { + log.Err(err). + Str("portal_id", string(portal.ID)). + Stringer("mxid", portal.MXID). + Msg("Failed to delete portal database row that failed to migrate") + } else if portal.MXID != "" { + log.Debug(). + Str("portal_id", string(portal.ID)). + Stringer("mxid", portal.MXID). + Msg("Marked portal room for deletion from homeserver") + roomsToDelete = append(roomsToDelete, portal.MXID) + } else { + log.Debug(). + Str("portal_id", string(portal.ID)). + Msg("Deleted portal row with no Matrix room") + } + } br.DB.KV.Set(ctx, database.KeySplitPortalsEnabled, "true") - return affected > 0 + log.Info().Msg("Finished split portal migration successfully") + return affected > 0, func() { + for _, roomID := range roomsToDelete { + if err = br.Bot.DeleteRoom(ctx, roomID, true); err != nil { + log.Err(err). + Stringer("mxid", roomID). + Msg("Failed to delete portal room that failed to migrate") + } + } + log.Info().Int("room_count", len(roomsToDelete)).Msg("Finished deleting rooms that failed to migrate") + } } -func (br *Bridge) StartLogins() error { - ctx := br.Log.WithContext(context.Background()) - +func (br *Bridge) StartLogins(ctx context.Context) error { userIDs, err := br.DB.UserLogin.GetAllUserIDsWithLogins(ctx) if err != nil { return fmt.Errorf("failed to get users with logins: %w", err) @@ -195,30 +366,93 @@ func (br *Bridge) StartLogins() error { br.Log.Info().Msg("No user logins found") br.SendGlobalBridgeState(status.BridgeState{StateEvent: status.StateUnconfigured}) } - go br.RunBackfillQueue() + if !br.Background { + go br.RunBackfillQueue() + } br.Log.Info().Msg("Bridge started") return nil } -func (br *Bridge) Stop() { - br.Log.Info().Msg("Shutting down bridge") - close(br.stopBackfillQueue) - br.Matrix.Stop() - br.cacheLock.Lock() - var wg sync.WaitGroup - wg.Add(len(br.userLoginsByID)) - for _, login := range br.userLoginsByID { - go login.Disconnect(wg.Done) +func (br *Bridge) ResetNetworkConnections() { + nrn, ok := br.Network.(NetworkResettingNetwork) + if ok { + br.Log.Info().Msg("Resetting network connections with NetworkConnector.ResetNetworkConnections") + nrn.ResetNetworkConnections() + return + } + + br.Log.Info().Msg("Network connector doesn't support ResetNetworkConnections, recreating clients manually") + for _, login := range br.GetAllCachedUserLogins() { + login.Log.Debug().Msg("Disconnecting and recreating client for network reset") + ctx := login.Log.WithContext(br.BackgroundCtx) + login.Client.Disconnect() + err := login.recreateClient(ctx) + if err != nil { + login.Log.Err(err).Msg("Failed to recreate client during network reset") + login.BridgeState.Send(status.BridgeState{ + StateEvent: status.StateUnknownError, + Error: "bridgev2-network-reset-fail", + Info: map[string]any{"go_error": err.Error()}, + }) + } else { + login.Client.Connect(ctx) + } + } + br.Log.Info().Msg("Finished resetting all user logins") +} + +func (br *Bridge) GetHTTPClientSettings() exhttp.ClientSettings { + mchs, ok := br.Matrix.(MatrixConnectorWithHTTPSettings) + if ok { + return mchs.GetHTTPClientSettings() + } + return exhttp.SensibleClientSettings +} + +func (br *Bridge) IsStopping() bool { + return br.stopping.Load() +} + +func (br *Bridge) Stop() { + br.stop(false, 0) +} + +func (br *Bridge) StopWithTimeout(timeout time.Duration) { + br.stop(false, timeout) +} + +func (br *Bridge) stop(isRunOnce bool, timeout time.Duration) { + br.Log.Info().Msg("Shutting down bridge") + br.stopping.Store(true) + br.DisappearLoop.Stop() + br.stopBackfillQueue.Set() + br.Matrix.PreStop() + if !isRunOnce { + br.cacheLock.Lock() + var wg sync.WaitGroup + wg.Add(len(br.userLoginsByID)) + for _, login := range br.userLoginsByID { + go func() { + login.DisconnectWithTimeout(timeout) + wg.Done() + }() + } + br.cacheLock.Unlock() + wg.Wait() + } + br.Matrix.Stop() + if br.cancelBackgroundCtx != nil { + br.cancelBackgroundCtx() } - wg.Wait() - br.cacheLock.Unlock() if stopNet, ok := br.Network.(StoppableNetwork); ok { stopNet.Stop() } - err := br.DB.Close() - if err != nil { - br.Log.Warn().Err(err).Msg("Failed to close database") + if !br.ExternallyManagedDB { + err := br.DB.Close() + if err != nil { + br.Log.Warn().Err(err).Msg("Failed to close database") + } } br.Log.Info().Msg("Shutdown complete") } diff --git a/bridgev2/bridgeconfig/appservice.go b/bridgev2/bridgeconfig/appservice.go index 89ce5677..f709c8e0 100644 --- a/bridgev2/bridgeconfig/appservice.go +++ b/bridgev2/bridgeconfig/appservice.go @@ -34,7 +34,6 @@ type AppserviceConfig struct { EphemeralEvents bool `yaml:"ephemeral_events"` AsyncTransactions bool `yaml:"async_transactions"` - MSC4190 bool `yaml:"msc4190"` UsernameTemplate string `yaml:"username_template"` usernameTemplate *template.Template `yaml:"-"` @@ -78,7 +77,11 @@ func (asc *AppserviceConfig) copyToRegistration(registration *appservice.Registr registration.RateLimited = &falseVal registration.EphemeralEvents = asc.EphemeralEvents registration.SoruEphemeralEvents = asc.EphemeralEvents - registration.MSC4190 = asc.MSC4190 +} + +func (ec *EncryptionConfig) applyUnstableFlags(registration *appservice.Registration) { + registration.MSC4190 = ec.MSC4190 + registration.MSC3202 = ec.Appservice } // GenerateRegistration generates a registration file for the homeserver. @@ -87,6 +90,7 @@ func (config *Config) GenerateRegistration() *appservice.Registration { config.AppService.HSToken = registration.ServerToken config.AppService.ASToken = registration.AppToken config.AppService.copyToRegistration(registration) + config.Encryption.applyUnstableFlags(registration) registration.SenderLocalpart = random.String(32) botRegex := regexp.MustCompile(fmt.Sprintf("^@%s:%s$", @@ -105,6 +109,7 @@ func (config *Config) MakeAppService() *appservice.AppService { as.Host.Hostname = config.AppService.Hostname as.Host.Port = config.AppService.Port as.Registration = config.AppService.GetRegistration() + config.Encryption.applyUnstableFlags(as.Registration) return as } diff --git a/bridgev2/bridgeconfig/backfill.go b/bridgev2/bridgeconfig/backfill.go index 44d2d588..eedae1e8 100644 --- a/bridgev2/bridgeconfig/backfill.go +++ b/bridgev2/bridgeconfig/backfill.go @@ -14,6 +14,11 @@ type BackfillConfig struct { Threads BackfillThreadsConfig `yaml:"threads"` Queue BackfillQueueConfig `yaml:"queue"` + + // Flag to indicate that the creator will not run the backfill queue but will still paginate + // backfill by calling DoBackfillTask directly. Note that this is not used anywhere within + // mautrix-go and exists so bridges can use it to decide when to drop backfill data. + WillPaginateManually bool `yaml:"will_paginate_manually"` } type BackfillThreadsConfig struct { @@ -29,10 +34,12 @@ type BackfillQueueConfig struct { MaxBatchesOverride map[string]int `yaml:"max_batches_override"` } -func (bqc *BackfillQueueConfig) GetOverride(name string) int { - override, ok := bqc.MaxBatchesOverride[name] - if !ok { - return bqc.MaxBatches +func (bqc *BackfillQueueConfig) GetOverride(names ...string) int { + for _, name := range names { + override, ok := bqc.MaxBatchesOverride[name] + if ok { + return override + } } - return override + return bqc.MaxBatches } diff --git a/bridgev2/bridgeconfig/config.go b/bridgev2/bridgeconfig/config.go index cf87864f..bd6b9c06 100644 --- a/bridgev2/bridgeconfig/config.go +++ b/bridgev2/bridgeconfig/config.go @@ -7,6 +7,8 @@ package bridgeconfig import ( + "time" + "go.mau.fi/util/dbutil" "go.mau.fi/zeroconfig" "gopkg.in/yaml.v3" @@ -31,6 +33,8 @@ type Config struct { Encryption EncryptionConfig `yaml:"encryption"` Logging zeroconfig.Config `yaml:"logging"` + EnvConfigPrefix string `yaml:"env_config_prefix"` + ManagementRoomTexts ManagementRoomTexts `yaml:"management_room_texts"` } @@ -58,30 +62,40 @@ type CleanupOnLogouts struct { } type BridgeConfig struct { - CommandPrefix string `yaml:"command_prefix"` - PersonalFilteringSpaces bool `yaml:"personal_filtering_spaces"` - PrivateChatPortalMeta bool `yaml:"private_chat_portal_meta"` - AsyncEvents bool `yaml:"async_events"` - SplitPortals bool `yaml:"split_portals"` - ResendBridgeInfo bool `yaml:"resend_bridge_info"` - BridgeMatrixLeave bool `yaml:"bridge_matrix_leave"` - TagOnlyOnCreate bool `yaml:"tag_only_on_create"` - OnlyBridgeTags []event.RoomTag `yaml:"only_bridge_tags"` - MuteOnlyOnCreate bool `yaml:"mute_only_on_create"` - OutgoingMessageReID bool `yaml:"outgoing_message_re_id"` - CleanupOnLogout CleanupOnLogouts `yaml:"cleanup_on_logout"` - Relay RelayConfig `yaml:"relay"` - Permissions PermissionConfig `yaml:"permissions"` - Backfill BackfillConfig `yaml:"backfill"` + CommandPrefix string `yaml:"command_prefix"` + PersonalFilteringSpaces bool `yaml:"personal_filtering_spaces"` + PrivateChatPortalMeta bool `yaml:"private_chat_portal_meta"` + AsyncEvents bool `yaml:"async_events"` + SplitPortals bool `yaml:"split_portals"` + ResendBridgeInfo bool `yaml:"resend_bridge_info"` + NoBridgeInfoStateKey bool `yaml:"no_bridge_info_state_key"` + BridgeStatusNotices string `yaml:"bridge_status_notices"` + UnknownErrorAutoReconnect time.Duration `yaml:"unknown_error_auto_reconnect"` + UnknownErrorMaxAutoReconnects int `yaml:"unknown_error_max_auto_reconnects"` + BridgeMatrixLeave bool `yaml:"bridge_matrix_leave"` + BridgeNotices bool `yaml:"bridge_notices"` + TagOnlyOnCreate bool `yaml:"tag_only_on_create"` + OnlyBridgeTags []event.RoomTag `yaml:"only_bridge_tags"` + MuteOnlyOnCreate bool `yaml:"mute_only_on_create"` + DeduplicateMatrixMessages bool `yaml:"deduplicate_matrix_messages"` + CrossRoomReplies bool `yaml:"cross_room_replies"` + OutgoingMessageReID bool `yaml:"outgoing_message_re_id"` + RevertFailedStateChanges bool `yaml:"revert_failed_state_changes"` + KickMatrixUsers bool `yaml:"kick_matrix_users"` + CleanupOnLogout CleanupOnLogouts `yaml:"cleanup_on_logout"` + Relay RelayConfig `yaml:"relay"` + Permissions PermissionConfig `yaml:"permissions"` + Backfill BackfillConfig `yaml:"backfill"` } type MatrixConfig struct { - MessageStatusEvents bool `yaml:"message_status_events"` - DeliveryReceipts bool `yaml:"delivery_receipts"` - MessageErrorNotices bool `yaml:"message_error_notices"` - SyncDirectChatList bool `yaml:"sync_direct_chat_list"` - FederateRooms bool `yaml:"federate_rooms"` - UploadFileThreshold int64 `yaml:"upload_file_threshold"` + MessageStatusEvents bool `yaml:"message_status_events"` + DeliveryReceipts bool `yaml:"delivery_receipts"` + MessageErrorNotices bool `yaml:"message_error_notices"` + SyncDirectChatList bool `yaml:"sync_direct_chat_list"` + FederateRooms bool `yaml:"federate_rooms"` + UploadFileThreshold int64 `yaml:"upload_file_threshold"` + GhostExtraProfileInfo bool `yaml:"ghost_extra_profile_info"` } type AnalyticsConfig struct { @@ -91,9 +105,9 @@ type AnalyticsConfig struct { } type ProvisioningConfig struct { - Prefix string `yaml:"prefix"` - SharedSecret string `yaml:"shared_secret"` - DebugEndpoints bool `yaml:"debug_endpoints"` + SharedSecret string `yaml:"shared_secret"` + DebugEndpoints bool `yaml:"debug_endpoints"` + EnableSessionTransfers bool `yaml:"enable_session_transfers"` } type DirectMediaConfig struct { @@ -103,10 +117,12 @@ type DirectMediaConfig struct { } type PublicMediaConfig struct { - Enabled bool `yaml:"enabled"` - SigningKey string `yaml:"signing_key"` - HashLength int `yaml:"hash_length"` - Expiry int `yaml:"expiry"` + Enabled bool `yaml:"enabled"` + SigningKey string `yaml:"signing_key"` + Expiry int `yaml:"expiry"` + HashLength int `yaml:"hash_length"` + PathPrefix string `yaml:"path_prefix"` + UseDatabase bool `yaml:"use_database"` } type DoublePuppetConfig struct { diff --git a/bridgev2/bridgeconfig/encryption.go b/bridgev2/bridgeconfig/encryption.go index 93a427d3..934613ca 100644 --- a/bridgev2/bridgeconfig/encryption.go +++ b/bridgev2/bridgeconfig/encryption.go @@ -15,6 +15,9 @@ type EncryptionConfig struct { Default bool `yaml:"default"` Require bool `yaml:"require"` Appservice bool `yaml:"appservice"` + MSC4190 bool `yaml:"msc4190"` + MSC4392 bool `yaml:"msc4392"` + SelfSign bool `yaml:"self_sign"` PlaintextMentions bool `yaml:"plaintext_mentions"` diff --git a/bridgev2/bridgeconfig/legacymigrate.go b/bridgev2/bridgeconfig/legacymigrate.go index fb2a86d6..954a37c3 100644 --- a/bridgev2/bridgeconfig/legacymigrate.go +++ b/bridgev2/bridgeconfig/legacymigrate.go @@ -133,9 +133,7 @@ func doMigrateLegacy(helper up.Helper, python bool) { CopyToOtherLocation(helper, up.Bool, []string{"bridge", "sync_direct_chat_list"}, []string{"matrix", "sync_direct_chat_list"}) CopyToOtherLocation(helper, up.Bool, []string{"bridge", "federate_rooms"}, []string{"matrix", "federate_rooms"}) - CopyToOtherLocation(helper, up.Str, []string{"bridge", "provisioning", "prefix"}, []string{"provisioning", "prefix"}) CopyToOtherLocation(helper, up.Str, []string{"bridge", "provisioning", "shared_secret"}, []string{"provisioning", "shared_secret"}) - CopyToOtherLocation(helper, up.Str, []string{"appservice", "provisioning", "prefix"}, []string{"provisioning", "prefix"}) CopyToOtherLocation(helper, up.Str, []string{"appservice", "provisioning", "shared_secret"}, []string{"provisioning", "shared_secret"}) CopyToOtherLocation(helper, up.Bool, []string{"bridge", "provisioning", "debug_endpoints"}, []string{"provisioning", "debug_endpoints"}) diff --git a/bridgev2/bridgeconfig/permissions.go b/bridgev2/bridgeconfig/permissions.go index 610051e0..9efe068e 100644 --- a/bridgev2/bridgeconfig/permissions.go +++ b/bridgev2/bridgeconfig/permissions.go @@ -24,6 +24,7 @@ type Permissions struct { DoublePuppet bool `yaml:"double_puppet"` Admin bool `yaml:"admin"` ManageRelay bool `yaml:"manage_relay"` + MaxLogins int `yaml:"max_logins"` } type PermissionConfig map[string]*Permissions @@ -40,10 +41,7 @@ func (pc PermissionConfig) IsConfigured() bool { _, hasExampleDomain := pc["example.com"] _, hasExampleUser := pc["@admin:example.com"] exampleLen := boolToInt(hasWildcard) + boolToInt(hasExampleUser) + boolToInt(hasExampleDomain) - if len(pc) <= exampleLen { - return false - } - return true + return len(pc) > exampleLen } func (pc PermissionConfig) Get(userID id.UserID) Permissions { diff --git a/bridgev2/bridgeconfig/upgrade.go b/bridgev2/bridgeconfig/upgrade.go index 776fa44d..92515ea0 100644 --- a/bridgev2/bridgeconfig/upgrade.go +++ b/bridgev2/bridgeconfig/upgrade.go @@ -30,10 +30,19 @@ func doUpgrade(helper up.Helper) { helper.Copy(up.Bool, "bridge", "async_events") helper.Copy(up.Bool, "bridge", "split_portals") helper.Copy(up.Bool, "bridge", "resend_bridge_info") + helper.Copy(up.Bool, "bridge", "no_bridge_info_state_key") + helper.Copy(up.Str|up.Null, "bridge", "bridge_status_notices") + helper.Copy(up.Str|up.Int|up.Null, "bridge", "unknown_error_auto_reconnect") + helper.Copy(up.Int, "bridge", "unknown_error_max_auto_reconnects") helper.Copy(up.Bool, "bridge", "bridge_matrix_leave") + helper.Copy(up.Bool, "bridge", "bridge_notices") helper.Copy(up.Bool, "bridge", "tag_only_on_create") helper.Copy(up.List, "bridge", "only_bridge_tags") helper.Copy(up.Bool, "bridge", "mute_only_on_create") + helper.Copy(up.Bool, "bridge", "deduplicate_matrix_messages") + helper.Copy(up.Bool, "bridge", "cross_room_replies") + helper.Copy(up.Bool, "bridge", "revert_failed_state_changes") + helper.Copy(up.Bool, "bridge", "kick_matrix_users") helper.Copy(up.Bool, "bridge", "cleanup_on_logout", "enabled") helper.Copy(up.Str, "bridge", "cleanup_on_logout", "manual", "private") helper.Copy(up.Str, "bridge", "cleanup_on_logout", "manual", "relayed") @@ -82,7 +91,6 @@ func doUpgrade(helper up.Helper) { helper.Copy(up.Str, "appservice", "bot", "avatar") helper.Copy(up.Bool, "appservice", "ephemeral_events") helper.Copy(up.Bool, "appservice", "async_transactions") - helper.Copy(up.Bool, "appservice", "msc4190") helper.Copy(up.Str, "appservice", "as_token") helper.Copy(up.Str, "appservice", "hs_token") helper.Copy(up.Str, "appservice", "username_template") @@ -93,12 +101,12 @@ func doUpgrade(helper up.Helper) { helper.Copy(up.Bool, "matrix", "sync_direct_chat_list") helper.Copy(up.Bool, "matrix", "federate_rooms") helper.Copy(up.Int, "matrix", "upload_file_threshold") + helper.Copy(up.Bool, "matrix", "ghost_extra_profile_info") helper.Copy(up.Str|up.Null, "analytics", "token") helper.Copy(up.Str|up.Null, "analytics", "url") helper.Copy(up.Str|up.Null, "analytics", "user_id") - helper.Copy(up.Str, "provisioning", "prefix") if secret, ok := helper.Get(up.Str, "provisioning", "shared_secret"); !ok || secret == "generate" { sharedSecret := random.String(64) helper.Set(up.Str, sharedSecret, "provisioning", "shared_secret") @@ -106,6 +114,7 @@ func doUpgrade(helper up.Helper) { helper.Copy(up.Str, "provisioning", "shared_secret") } helper.Copy(up.Bool, "provisioning", "debug_endpoints") + helper.Copy(up.Bool, "provisioning", "enable_session_transfers") helper.Copy(up.Bool, "direct_media", "enabled") helper.Copy(up.Str|up.Null, "direct_media", "media_id_prefix") @@ -127,6 +136,8 @@ func doUpgrade(helper up.Helper) { } helper.Copy(up.Int, "public_media", "expiry") helper.Copy(up.Int, "public_media", "hash_length") + helper.Copy(up.Str|up.Null, "public_media", "path_prefix") + helper.Copy(up.Bool, "public_media", "use_database") helper.Copy(up.Bool, "backfill", "enabled") helper.Copy(up.Int, "backfill", "max_initial_messages") @@ -147,6 +158,13 @@ func doUpgrade(helper up.Helper) { helper.Copy(up.Bool, "encryption", "default") helper.Copy(up.Bool, "encryption", "require") helper.Copy(up.Bool, "encryption", "appservice") + if val, ok := helper.Get(up.Bool, "appservice", "msc4190"); ok { + helper.Set(up.Bool, val, "encryption", "msc4190") + } else { + helper.Copy(up.Bool, "encryption", "msc4190") + } + helper.Copy(up.Bool, "encryption", "msc4392") + helper.Copy(up.Bool, "encryption", "self_sign") helper.Copy(up.Bool, "encryption", "allow_key_sharing") if secret, ok := helper.Get(up.Str, "encryption", "pickle_key"); !ok || secret == "generate" { helper.Set(up.Str, random.String(64), "encryption", "pickle_key") @@ -169,6 +187,8 @@ func doUpgrade(helper up.Helper) { helper.Copy(up.Int, "encryption", "rotation", "messages") helper.Copy(up.Bool, "encryption", "rotation", "disable_device_change_key_rotation") + helper.Copy(up.Str|up.Null, "env_config_prefix") + helper.Copy(up.Map, "logging") } @@ -196,6 +216,7 @@ var SpacedBlocks = [][]string{ {"backfill"}, {"double_puppet"}, {"encryption"}, + {"env_config_prefix"}, {"logging"}, } diff --git a/bridgev2/bridgestate.go b/bridgev2/bridgestate.go index 1cd6b0c5..96d9fd5c 100644 --- a/bridgev2/bridgestate.go +++ b/bridgev2/bridgestate.go @@ -8,20 +8,37 @@ package bridgev2 import ( "context" + "fmt" + "math/rand/v2" "runtime/debug" + "sync/atomic" "time" "github.com/rs/zerolog" + "go.mau.fi/util/exfmt" - "maunium.net/go/mautrix/bridge/status" + "maunium.net/go/mautrix/bridgev2/status" + "maunium.net/go/mautrix/event" + "maunium.net/go/mautrix/format" ) +var CatchBridgeStateQueuePanics = true + type BridgeStateQueue struct { prevUnsent *status.BridgeState prevSent *status.BridgeState + errorSent bool ch chan status.BridgeState bridge *Bridge - user status.StandaloneCustomBridgeStateFiller + login *UserLogin + + firstTransientDisconnect time.Time + cancelScheduledNotice atomic.Pointer[context.CancelFunc] + + stopChan chan struct{} + stopReconnect atomic.Pointer[context.CancelFunc] + + unknownErrorReconnects int } func (br *Bridge) SendGlobalBridgeState(state status.BridgeState) { @@ -41,51 +58,221 @@ func (br *Bridge) SendGlobalBridgeState(state status.BridgeState) { } } -func (br *Bridge) NewBridgeStateQueue(user status.StandaloneCustomBridgeStateFiller) *BridgeStateQueue { +func (br *Bridge) NewBridgeStateQueue(login *UserLogin) *BridgeStateQueue { bsq := &BridgeStateQueue{ - ch: make(chan status.BridgeState, 10), - bridge: br, - user: user, + ch: make(chan status.BridgeState, 10), + stopChan: make(chan struct{}), + bridge: br, + login: login, } go bsq.loop() return bsq } func (bsq *BridgeStateQueue) Destroy() { + close(bsq.stopChan) close(bsq.ch) + bsq.StopUnknownErrorReconnect() +} + +func (bsq *BridgeStateQueue) StopUnknownErrorReconnect() { + if bsq == nil { + return + } + if cancelFn := bsq.stopReconnect.Swap(nil); cancelFn != nil { + (*cancelFn)() + } + if cancelFn := bsq.cancelScheduledNotice.Swap(nil); cancelFn != nil { + (*cancelFn)() + } } func (bsq *BridgeStateQueue) loop() { - defer func() { - err := recover() - if err != nil { - bsq.bridge.Log.Error(). - Bytes(zerolog.ErrorStackFieldName, debug.Stack()). - Any(zerolog.ErrorFieldName, err). - Msg("Panic in bridge state loop") - } - }() + if CatchBridgeStateQueuePanics { + defer func() { + err := recover() + if err != nil { + bsq.login.Log.Error(). + Bytes(zerolog.ErrorStackFieldName, debug.Stack()). + Any(zerolog.ErrorFieldName, err). + Msg("Panic in bridge state loop") + } + }() + } for state := range bsq.ch { bsq.immediateSendBridgeState(state) } } +func (bsq *BridgeStateQueue) scheduleNotice(triggeredBy status.BridgeState) { + log := bsq.login.Log.With().Str("action", "transient disconnect notice").Logger() + ctx := log.WithContext(bsq.bridge.BackgroundCtx) + if !bsq.waitForTransientDisconnectReconnect(ctx) { + return + } + prevUnsent := bsq.GetPrevUnsent() + prev := bsq.GetPrev() + if triggeredBy.Timestamp != prev.Timestamp || len(bsq.ch) > 0 || bsq.errorSent || + prevUnsent.StateEvent != status.StateTransientDisconnect || prev.StateEvent != status.StateTransientDisconnect { + log.Trace().Any("triggered_by", triggeredBy).Msg("Not sending delayed transient disconnect notice") + return + } + log.Debug().Any("triggered_by", triggeredBy).Msg("Sending delayed transient disconnect notice") + bsq.sendNotice(ctx, triggeredBy, true) +} + +func (bsq *BridgeStateQueue) sendNotice(ctx context.Context, state status.BridgeState, isDelayed bool) { + noticeConfig := bsq.bridge.Config.BridgeStatusNotices + isError := state.StateEvent == status.StateBadCredentials || + state.StateEvent == status.StateUnknownError || + state.UserAction == status.UserActionOpenNative || + (isDelayed && state.StateEvent == status.StateTransientDisconnect) + sendNotice := noticeConfig == "all" || (noticeConfig == "errors" && + (isError || (bsq.errorSent && state.StateEvent == status.StateConnected))) + if state.StateEvent != status.StateTransientDisconnect && state.StateEvent != status.StateUnknownError { + bsq.firstTransientDisconnect = time.Time{} + } + if !sendNotice { + if !bsq.errorSent && !isDelayed && noticeConfig == "errors" && state.StateEvent == status.StateTransientDisconnect { + if bsq.firstTransientDisconnect.IsZero() { + bsq.firstTransientDisconnect = time.Now() + } + go bsq.scheduleNotice(state) + } + return + } + managementRoom, err := bsq.login.User.GetManagementRoom(ctx) + if err != nil { + bsq.login.Log.Err(err).Msg("Failed to get management room") + return + } + name := bsq.login.RemoteName + if name == "" { + name = fmt.Sprintf("`%s`", bsq.login.ID) + } + message := fmt.Sprintf("State update for %s: `%s`", name, state.StateEvent) + if state.Error != "" { + message += fmt.Sprintf(" (`%s`)", state.Error) + } + if isDelayed { + message += fmt.Sprintf(" not resolved after waiting %s", exfmt.Duration(TransientDisconnectNoticeDelay)) + } + if state.Message != "" { + message += fmt.Sprintf(": %s", state.Message) + } + content := format.RenderMarkdown(message, true, false) + if !isError { + content.MsgType = event.MsgNotice + } + _, err = bsq.bridge.Bot.SendMessage(ctx, managementRoom, event.EventMessage, &event.Content{ + Parsed: content, + Raw: map[string]any{ + "fi.mau.bridge_state": state, + }, + }, nil) + if err != nil { + bsq.login.Log.Err(err).Msg("Failed to send bridge state notice") + } else { + bsq.errorSent = isError + } +} + +func (bsq *BridgeStateQueue) unknownErrorReconnect(triggeredBy status.BridgeState) { + log := bsq.login.Log.With().Str("action", "unknown error reconnect").Logger() + ctx := log.WithContext(bsq.bridge.BackgroundCtx) + if !bsq.waitForUnknownErrorReconnect(ctx) { + return + } + prevUnsent := bsq.GetPrevUnsent() + prev := bsq.GetPrev() + if triggeredBy.Timestamp != prev.Timestamp { + log.Debug().Msg("Not reconnecting as a new bridge state was sent after the unknown error") + return + } else if len(bsq.ch) > 0 { + log.Warn().Msg("Not reconnecting as there are unsent bridge states") + return + } else if prevUnsent.StateEvent != status.StateUnknownError || prev.StateEvent != status.StateUnknownError { + log.Debug().Msg("Not reconnecting as the previous state was not an unknown error") + return + } else if bsq.unknownErrorReconnects > bsq.bridge.Config.UnknownErrorMaxAutoReconnects { + log.Warn().Msg("Not reconnecting as the maximum number of unknown error reconnects has been reached") + return + } + bsq.unknownErrorReconnects++ + log.Info(). + Int("reconnect_num", bsq.unknownErrorReconnects). + Msg("Disconnecting and reconnecting login due to unknown error") + bsq.login.Disconnect() + log.Debug().Msg("Disconnection finished, recreating client and reconnecting") + err := bsq.login.recreateClient(ctx) + if err != nil { + log.Err(err).Msg("Failed to recreate client after unknown error") + return + } + bsq.login.Client.Connect(ctx) + log.Debug().Msg("Reconnection finished") +} + +func (bsq *BridgeStateQueue) waitForUnknownErrorReconnect(ctx context.Context) bool { + reconnectIn := bsq.bridge.Config.UnknownErrorAutoReconnect + // Don't allow too low values + if reconnectIn < 1*time.Minute { + return false + } + reconnectIn += time.Duration(rand.Int64N(int64(float64(reconnectIn)*0.4)) - int64(float64(reconnectIn)*0.2)) + return bsq.waitForReconnect(ctx, reconnectIn, &bsq.stopReconnect) +} + +const TransientDisconnectNoticeDelay = 3 * time.Minute + +func (bsq *BridgeStateQueue) waitForTransientDisconnectReconnect(ctx context.Context) bool { + timeUntilSchedule := time.Until(bsq.firstTransientDisconnect.Add(TransientDisconnectNoticeDelay)) + zerolog.Ctx(ctx).Trace(). + Stringer("duration", timeUntilSchedule). + Msg("Waiting before sending notice about transient disconnect") + return bsq.waitForReconnect(ctx, timeUntilSchedule, &bsq.cancelScheduledNotice) +} + +func (bsq *BridgeStateQueue) waitForReconnect( + ctx context.Context, reconnectIn time.Duration, ptr *atomic.Pointer[context.CancelFunc], +) bool { + cancelCtx, cancel := context.WithCancel(ctx) + defer cancel() + if oldCancel := ptr.Swap(&cancel); oldCancel != nil { + (*oldCancel)() + } + select { + case <-time.After(reconnectIn): + return ptr.CompareAndSwap(&cancel, nil) + case <-cancelCtx.Done(): + return false + case <-bsq.stopChan: + return false + } +} + func (bsq *BridgeStateQueue) immediateSendBridgeState(state status.BridgeState) { + if bsq.prevSent != nil && bsq.prevSent.ShouldDeduplicate(&state) { + bsq.login.Log.Debug(). + Str("state_event", string(state.StateEvent)). + Msg("Not sending bridge state as it's a duplicate") + return + } + if state.StateEvent == status.StateUnknownError { + go bsq.unknownErrorReconnect(state) + } + + ctx := bsq.login.Log.WithContext(context.Background()) + bsq.sendNotice(ctx, state, false) + retryIn := 2 for { - if bsq.prevSent != nil && bsq.prevSent.ShouldDeduplicate(&state) { - bsq.bridge.Log.Debug(). - Str("state_event", string(state.StateEvent)). - Msg("Not sending bridge state as it's a duplicate") - return - } - - ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + ctx, cancel := context.WithTimeout(ctx, 30*time.Second) err := bsq.bridge.Matrix.SendBridgeStatus(ctx, &state) cancel() if err != nil { - bsq.bridge.Log.Warn().Err(err). + bsq.login.Log.Warn().Err(err). Int("retry_in_seconds", retryIn). Msg("Failed to update bridge state") time.Sleep(time.Duration(retryIn) * time.Second) @@ -95,7 +282,7 @@ func (bsq *BridgeStateQueue) immediateSendBridgeState(state status.BridgeState) } } else { bsq.prevSent = &state - bsq.bridge.Log.Debug(). + bsq.login.Log.Debug(). Any("bridge_state", state). Msg("Sent new bridge state") return @@ -108,11 +295,11 @@ func (bsq *BridgeStateQueue) Send(state status.BridgeState) { return } - state = state.Fill(bsq.user) + state = state.Fill(bsq.login) bsq.prevUnsent = &state if len(bsq.ch) >= 8 { - bsq.bridge.Log.Warn().Msg("Bridge state queue is nearly full, discarding an item") + bsq.login.Log.Warn().Msg("Bridge state queue is nearly full, discarding an item") select { case <-bsq.ch: default: @@ -121,7 +308,7 @@ func (bsq *BridgeStateQueue) Send(state status.BridgeState) { select { case bsq.ch <- state: default: - bsq.bridge.Log.Error().Msg("Bridge state queue is full, dropped new state") + bsq.login.Log.Error().Msg("Bridge state queue is full, dropped new state") } } diff --git a/bridgev2/commands/cleanup.go b/bridgev2/commands/cleanup.go index f8ad1d23..dc21a16e 100644 --- a/bridgev2/commands/cleanup.go +++ b/bridgev2/commands/cleanup.go @@ -55,3 +55,43 @@ var CommandDeleteAllPortals = &FullHandler{ }, RequiresAdmin: true, } + +var CommandSetManagementRoom = &FullHandler{ + Func: func(ce *Event) { + if ce.User.ManagementRoom == ce.RoomID { + ce.Reply("This room is already your management room") + return + } else if ce.Portal != nil { + ce.Reply("This is a portal room: you can't set this as your management room") + return + } + members, err := ce.Bridge.Matrix.GetMembers(ce.Ctx, ce.RoomID) + if err != nil { + ce.Log.Err(err).Msg("Failed to get room members to check if room can be a management room") + ce.Reply("Failed to get room members") + return + } + _, hasBot := members[ce.Bot.GetMXID()] + if !hasBot { + // This reply will probably fail, but whatever + ce.Reply("The bridge bot must be in the room to set it as your management room") + return + } else if len(members) != 2 { + ce.Reply("Your management room must not have any members other than you and the bridge bot") + return + } + ce.User.ManagementRoom = ce.RoomID + err = ce.User.Save(ce.Ctx) + if err != nil { + ce.Log.Err(err).Msg("Failed to save management room") + ce.Reply("Failed to save management room") + } else { + ce.Reply("Management room updated") + } + }, + Name: "set-management-room", + Help: HelpMeta{ + Section: HelpSectionGeneral, + Description: "Mark this room as your management room", + }, +} diff --git a/bridgev2/commands/debug.go b/bridgev2/commands/debug.go index d00697ee..1cae98fe 100644 --- a/bridgev2/commands/debug.go +++ b/bridgev2/commands/debug.go @@ -7,10 +7,13 @@ package commands import ( + "encoding/json" "strings" + "time" "maunium.net/go/mautrix/bridgev2" "maunium.net/go/mautrix/bridgev2/networkid" + "maunium.net/go/mautrix/event" ) var CommandRegisterPush = &FullHandler{ @@ -57,4 +60,66 @@ var CommandRegisterPush = &FullHandler{ }, RequiresAdmin: true, RequiresLogin: true, + NetworkAPI: NetworkAPIImplements[bridgev2.PushableNetworkAPI], +} + +var CommandSendAccountData = &FullHandler{ + Func: func(ce *Event) { + if len(ce.Args) < 2 { + ce.Reply("Usage: `$cmdprefix debug-account-data ") + return + } + var content event.Content + evtType := event.Type{Type: ce.Args[0], Class: event.AccountDataEventType} + ce.RawArgs = strings.TrimSpace(strings.Trim(ce.RawArgs, ce.Args[0])) + err := json.Unmarshal([]byte(ce.RawArgs), &content) + if err != nil { + ce.Reply("Failed to parse JSON: %v", err) + return + } + err = content.ParseRaw(evtType) + if err != nil { + ce.Reply("Failed to deserialize content: %v", err) + return + } + res := ce.Bridge.QueueMatrixEvent(ce.Ctx, &event.Event{ + Sender: ce.User.MXID, + Type: evtType, + Timestamp: time.Now().UnixMilli(), + RoomID: ce.RoomID, + Content: content, + }) + ce.Reply("Result: %+v", res) + }, + Name: "debug-account-data", + Help: HelpMeta{ + Section: HelpSectionAdmin, + Description: "Send a room account data event to the bridge", + Args: "<_type_> <_content_>", + }, + RequiresAdmin: true, + RequiresPortal: true, + RequiresLogin: true, +} + +var CommandResetNetwork = &FullHandler{ + Func: func(ce *Event) { + if strings.Contains(strings.ToLower(ce.RawArgs), "--reset-transport") { + nrn, ok := ce.Bridge.Network.(bridgev2.NetworkResettingNetwork) + if ok { + nrn.ResetHTTPTransport() + } else { + ce.Reply("Network connector does not support resetting HTTP transport") + } + } + ce.Bridge.ResetNetworkConnections() + ce.React("✅️") + }, + Name: "debug-reset-network", + Help: HelpMeta{ + Section: HelpSectionAdmin, + Description: "Reset network connections to the remote network", + Args: "[--reset-transport]", + }, + RequiresAdmin: true, } diff --git a/bridgev2/commands/event.go b/bridgev2/commands/event.go index 78ed94bb..88ba9698 100644 --- a/bridgev2/commands/event.go +++ b/bridgev2/commands/event.go @@ -10,6 +10,7 @@ import ( "context" "fmt" "strings" + "time" "github.com/rs/zerolog" @@ -92,9 +93,8 @@ func (ce *Event) Redact(req ...mautrix.ReqRedact) { // MarkRead marks the command event as read. func (ce *Event) MarkRead() { - // TODO - //err := ce.Bot.SendReceipt(ce.Ctx, ce.RoomID, ce.EventID, event.ReceiptTypeRead, nil) - //if err != nil { - // ce.Log.Err(err).Msg("Failed to mark command as read") - //} + err := ce.Bot.MarkRead(ce.Ctx, ce.RoomID, ce.EventID, time.Now()) + if err != nil { + ce.Log.Err(err).Msg("Failed to mark command as read") + } } diff --git a/bridgev2/commands/handler.go b/bridgev2/commands/handler.go index c1daf1af..672c81dc 100644 --- a/bridgev2/commands/handler.go +++ b/bridgev2/commands/handler.go @@ -7,6 +7,7 @@ package commands import ( + "maunium.net/go/mautrix/bridgev2" "maunium.net/go/mautrix/event" ) @@ -37,6 +38,18 @@ type AliasedCommandHandler interface { GetAliases() []string } +func NetworkAPIImplements[T bridgev2.NetworkAPI](val bridgev2.NetworkAPI) bool { + _, ok := val.(T) + return ok +} + +func NetworkConnectorImplements[T bridgev2.NetworkConnector](val bridgev2.NetworkConnector) bool { + _, ok := val.(T) + return ok +} + +type ImplementationChecker[T any] func(val T) bool + type FullHandler struct { Func func(*Event) @@ -49,6 +62,9 @@ type FullHandler struct { RequiresLogin bool RequiresEventLevel event.Type RequiresLoginPermission bool + + NetworkAPI ImplementationChecker[bridgev2.NetworkAPI] + NetworkConnector ImplementationChecker[bridgev2.NetworkConnector] } func (fh *FullHandler) GetHelp() HelpMeta { @@ -64,9 +80,15 @@ func (fh *FullHandler) GetAliases() []string { return fh.Aliases } +func (fh *FullHandler) ImplementationsFulfilled(ce *Event) bool { + // TODO add dedicated method to get an empty NetworkAPI instead of getting default login + client := ce.User.GetDefaultLogin() + return (fh.NetworkAPI == nil || client == nil || fh.NetworkAPI(client.Client)) && + (fh.NetworkConnector == nil || fh.NetworkConnector(ce.Bridge.Network)) +} + func (fh *FullHandler) ShowInHelp(ce *Event) bool { - return true - //return !fh.RequiresAdmin || ce.User.GetPermissionLevel() >= bridgeconfig.PermissionLevelAdmin + return fh.ImplementationsFulfilled(ce) && (!fh.RequiresAdmin || ce.User.Permissions.Admin) } func (fh *FullHandler) userHasRoomPermission(ce *Event) bool { diff --git a/bridgev2/commands/login.go b/bridgev2/commands/login.go index 5c7ae57d..96d62d3e 100644 --- a/bridgev2/commands/login.go +++ b/bridgev2/commands/login.go @@ -19,9 +19,9 @@ import ( "github.com/skip2/go-qrcode" "go.mau.fi/util/curl" - "maunium.net/go/mautrix/bridge/status" "maunium.net/go/mautrix/bridgev2" "maunium.net/go/mautrix/bridgev2/networkid" + "maunium.net/go/mautrix/bridgev2/status" "maunium.net/go/mautrix/event" "maunium.net/go/mautrix/id" ) @@ -70,6 +70,15 @@ func fnLogin(ce *Event) { } ce.Args = ce.Args[1:] } + if reauth == nil && ce.User.HasTooManyLogins() { + ce.Reply( + "You have reached the maximum number of logins (%d). "+ + "Please logout from an existing login before creating a new one. "+ + "If you want to re-authenticate an existing login, use the `$cmdprefix relogin` command.", + ce.User.Permissions.MaxLogins, + ) + return + } flows := ce.Bridge.Network.GetLoginFlows() var chosenFlowID string if len(ce.Args) > 0 { @@ -112,6 +121,7 @@ func fnLogin(ce *Event) { ce.Reply("Failed to start login: %v", err) return } + ce.Log.Debug().Any("first_step", nextStep).Msg("Created login process") nextStep = checkLoginCommandDirectParams(ce, login, nextStep) if nextStep != nil { @@ -190,11 +200,14 @@ type userInputLoginCommandState struct { func (uilcs *userInputLoginCommandState) promptNext(ce *Event) { field := uilcs.RemainingFields[0] + parts := []string{fmt.Sprintf("Please enter your %s", field.Name)} if field.Description != "" { - ce.Reply("Please enter your %s\n%s", field.Name, field.Description) - } else { - ce.Reply("Please enter your %s", field.Name) + parts = append(parts, field.Description) } + if len(field.Options) > 0 { + parts = append(parts, fmt.Sprintf("Options: `%s`", strings.Join(field.Options, "`, `"))) + } + ce.Reply(strings.Join(parts, "\n")) StoreCommandState(ce.User, &CommandState{ Next: MinimalCommandHandlerFunc(uilcs.submitNext), Action: "Login", @@ -239,14 +252,19 @@ func sendQR(ce *Event, qr string, prevEventID *id.EventID) error { return fmt.Errorf("failed to upload image: %w", err) } content := &event.MessageEventContent{ - MsgType: event.MsgImage, - FileName: "qr.png", - URL: qrMXC, - File: qrFile, - + MsgType: event.MsgImage, + FileName: "qr.png", + URL: qrMXC, + File: qrFile, Body: qr, Format: event.FormatHTML, FormattedBody: fmt.Sprintf("
%s
", html.EscapeString(qr)), + Info: &event.FileInfo{ + MimeType: "image/png", + Width: qrSizePx, + Height: qrSizePx, + Size: len(qrData), + }, } if *prevEventID != "" { content.SetEdit(*prevEventID) @@ -261,6 +279,36 @@ func sendQR(ce *Event, qr string, prevEventID *id.EventID) error { return nil } +func sendUserInputAttachments(ce *Event, atts []*bridgev2.LoginUserInputAttachment) error { + for _, att := range atts { + if att.FileName == "" { + return fmt.Errorf("missing attachment filename") + } + mxc, file, err := ce.Bot.UploadMedia(ce.Ctx, ce.RoomID, att.Content, att.FileName, att.Info.MimeType) + if err != nil { + return fmt.Errorf("failed to upload attachment %q: %w", att.FileName, err) + } + content := &event.MessageEventContent{ + MsgType: att.Type, + FileName: att.FileName, + URL: mxc, + File: file, + Info: &event.FileInfo{ + MimeType: att.Info.MimeType, + Width: att.Info.Width, + Height: att.Info.Height, + Size: att.Info.Size, + }, + Body: att.FileName, + } + _, err = ce.Bot.SendMessage(ce.Ctx, ce.RoomID, event.EventMessage, &event.Content{Parsed: content}, nil) + if err != nil { + return nil + } + } + return nil +} + type contextKey int const ( @@ -273,6 +321,13 @@ func doLoginDisplayAndWait(ce *Event, login bridgev2.LoginProcessDisplayAndWait, prevEvent = new(id.EventID) ce.Ctx = context.WithValue(ce.Ctx, contextKeyPrevEventID, prevEvent) } + cancelCtx, cancelFunc := context.WithCancel(ce.Ctx) + defer cancelFunc() + StoreCommandState(ce.User, &CommandState{ + Action: "Login", + Cancel: cancelFunc, + }) + defer StoreCommandState(ce.User, nil) switch step.DisplayAndWaitParams.Type { case bridgev2.LoginDisplayTypeQR: err := sendQR(ce, step.DisplayAndWaitParams.Data, prevEvent) @@ -292,7 +347,7 @@ func doLoginDisplayAndWait(ce *Event, login bridgev2.LoginProcessDisplayAndWait, login.Cancel() return } - nextStep, err := login.Wait(ce.Ctx) + nextStep, err := login.Wait(cancelCtx) // Redact the QR code, unless the next step is refreshing the code (in which case the event is just edited) if *prevEvent != "" && (nextStep == nil || nextStep.StepID != step.StepID) { _, _ = ce.Bot.SendMessage(ce.Ctx, ce.RoomID, event.EventRedaction, &event.Content{ @@ -445,6 +500,7 @@ func maybeURLDecodeCookie(val string, field *bridgev2.LoginCookieField) string { } func doLoginStep(ce *Event, login bridgev2.LoginProcess, step *bridgev2.LoginStep, override *bridgev2.UserLogin) { + ce.Log.Debug().Any("next_step", step).Msg("Got next login step") if step.Instructions != "" { ce.Reply(step.Instructions) } @@ -459,6 +515,10 @@ func doLoginStep(ce *Event, login bridgev2.LoginProcess, step *bridgev2.LoginSte Override: override, }).prompt(ce) case bridgev2.LoginStepTypeUserInput: + err := sendUserInputAttachments(ce, step.UserInputParams.Attachments) + if err != nil { + ce.Reply("Failed to send attachments: %v", err) + } (&userInputLoginCommandState{ Login: login.(bridgev2.LoginProcessUserInput), RemainingFields: step.UserInputParams.Fields, diff --git a/bridgev2/commands/processor.go b/bridgev2/commands/processor.go index 3343e1ba..391c3685 100644 --- a/bridgev2/commands/processor.go +++ b/bridgev2/commands/processor.go @@ -17,8 +17,7 @@ import ( "github.com/rs/zerolog" "maunium.net/go/mautrix/bridgev2" - - "maunium.net/go/mautrix/bridge/status" + "maunium.net/go/mautrix/bridgev2/status" "maunium.net/go/mautrix/event" "maunium.net/go/mautrix/id" ) @@ -42,10 +41,11 @@ func NewProcessor(bridge *bridgev2.Bridge) bridgev2.CommandProcessor { } proc.AddHandlers( CommandHelp, CommandCancel, - CommandRegisterPush, CommandDeletePortal, CommandDeleteAllPortals, + CommandRegisterPush, CommandSendAccountData, CommandResetNetwork, + CommandDeletePortal, CommandDeleteAllPortals, CommandSetManagementRoom, CommandLogin, CommandRelogin, CommandListLogins, CommandLogout, CommandSetPreferredLogin, CommandSetRelay, CommandUnsetRelay, - CommandResolveIdentifier, CommandStartChat, CommandSearch, + CommandResolveIdentifier, CommandStartChat, CommandCreateGroup, CommandSearch, CommandSyncChat, CommandMute, CommandSudo, CommandDoIn, ) return proc diff --git a/bridgev2/commands/relay.go b/bridgev2/commands/relay.go index af756c87..94c19739 100644 --- a/bridgev2/commands/relay.go +++ b/bridgev2/commands/relay.go @@ -37,7 +37,7 @@ func fnSetRelay(ce *Event) { } onlySetDefaultRelays := !ce.User.Permissions.Admin && ce.Bridge.Config.Relay.AdminOnly var relay *bridgev2.UserLogin - if len(ce.Args) == 0 { + if len(ce.Args) == 0 && ce.Portal.Receiver == "" { relay = ce.User.GetDefaultLogin() isLoggedIn := relay != nil if onlySetDefaultRelays { @@ -73,9 +73,19 @@ func fnSetRelay(ce *Event) { } } } else { - relay = ce.Bridge.GetCachedUserLoginByID(networkid.UserLoginID(ce.Args[0])) + var targetID networkid.UserLoginID + if ce.Portal.Receiver != "" { + targetID = ce.Portal.Receiver + if len(ce.Args) > 0 && ce.Args[0] != string(targetID) { + ce.Reply("In split portals, only the receiver (%s) can be set as relay", targetID) + return + } + } else { + targetID = networkid.UserLoginID(ce.Args[0]) + } + relay = ce.Bridge.GetCachedUserLoginByID(targetID) if relay == nil { - ce.Reply("User login with ID `%s` not found", ce.Args[0]) + ce.Reply("User login with ID `%s` not found", targetID) return } else if slices.Contains(ce.Bridge.Config.Relay.DefaultRelays, relay.ID) { // All good diff --git a/bridgev2/commands/startchat.go b/bridgev2/commands/startchat.go index aa766c0e..c7b05a6e 100644 --- a/bridgev2/commands/startchat.go +++ b/bridgev2/commands/startchat.go @@ -1,4 +1,4 @@ -// Copyright (c) 2024 Tulir Asokan +// Copyright (c) 2025 Tulir Asokan // // This Source Code Form is subject to the terms of the Mozilla Public // License, v. 2.0. If a copy of the MPL was not distributed with this @@ -8,13 +8,21 @@ package commands import ( "context" + "errors" "fmt" "html" + "maps" + "slices" "strings" "time" + "github.com/rs/zerolog" + "maunium.net/go/mautrix/bridgev2" "maunium.net/go/mautrix/bridgev2/networkid" + "maunium.net/go/mautrix/bridgev2/provisionutil" + "maunium.net/go/mautrix/event" + "maunium.net/go/mautrix/format" "maunium.net/go/mautrix/id" ) @@ -27,6 +35,36 @@ var CommandResolveIdentifier = &FullHandler{ Args: "[_login ID_] <_identifier_>", }, RequiresLogin: true, + NetworkAPI: NetworkAPIImplements[bridgev2.IdentifierResolvingNetworkAPI], +} + +var CommandSyncChat = &FullHandler{ + Func: func(ce *Event) { + login, _, err := ce.Portal.FindPreferredLogin(ce.Ctx, ce.User, false) + if err != nil { + ce.Log.Err(err).Msg("Failed to find login for sync") + ce.Reply("Failed to find login: %v", err) + return + } else if login == nil { + ce.Reply("No login found for sync") + return + } + info, err := login.Client.GetChatInfo(ce.Ctx, ce.Portal) + if err != nil { + ce.Log.Err(err).Msg("Failed to get chat info for sync") + ce.Reply("Failed to get chat info: %v", err) + return + } + ce.Portal.UpdateInfo(ce.Ctx, info, login, nil, time.Time{}) + ce.React("✅️") + }, + Name: "sync-portal", + Help: HelpMeta{ + Section: HelpSectionChats, + Description: "Sync the current portal room", + }, + RequiresPortal: true, + RequiresLogin: true, } var CommandStartChat = &FullHandler{ @@ -39,11 +77,18 @@ var CommandStartChat = &FullHandler{ Args: "[_login ID_] <_identifier_>", }, RequiresLogin: true, + NetworkAPI: NetworkAPIImplements[bridgev2.IdentifierResolvingNetworkAPI], } -func getClientForStartingChat[T bridgev2.IdentifierResolvingNetworkAPI](ce *Event, thing string) (*bridgev2.UserLogin, T, []string) { - remainingArgs := ce.Args[1:] - login := ce.Bridge.GetCachedUserLoginByID(networkid.UserLoginID(ce.Args[0])) +func getClientForStartingChat[T bridgev2.NetworkAPI](ce *Event, thing string) (*bridgev2.UserLogin, T, []string) { + var remainingArgs []string + if len(ce.Args) > 1 { + remainingArgs = ce.Args[1:] + } + var login *bridgev2.UserLogin + if len(ce.Args) > 0 { + login = ce.Bridge.GetCachedUserLoginByID(networkid.UserLoginID(ce.Args[0])) + } if login == nil || login.UserMXID != ce.User.MXID { remainingArgs = ce.Args login = ce.User.GetDefaultLogin() @@ -55,24 +100,13 @@ func getClientForStartingChat[T bridgev2.IdentifierResolvingNetworkAPI](ce *Even return login, api, remainingArgs } -func formatResolveIdentifierResult(ctx context.Context, resp *bridgev2.ResolveIdentifierResponse) string { - var targetName string - var targetMXID id.UserID - if resp.Ghost != nil { - if resp.UserInfo != nil { - resp.Ghost.UpdateInfo(ctx, resp.UserInfo) - } - targetName = resp.Ghost.Name - targetMXID = resp.Ghost.Intent.GetMXID() - } else if resp.UserInfo != nil && resp.UserInfo.Name != nil { - targetName = *resp.UserInfo.Name - } - if targetMXID != "" { - return fmt.Sprintf("`%s` / [%s](%s)", resp.UserID, targetName, targetMXID.URI().MatrixToURL()) - } else if targetName != "" { - return fmt.Sprintf("`%s` / %s", resp.UserID, targetName) +func formatResolveIdentifierResult(resp *provisionutil.RespResolveIdentifier) string { + if resp.MXID != "" { + return fmt.Sprintf("`%s` / [%s](%s)", resp.ID, resp.Name, resp.MXID.URI().MatrixToURL()) + } else if resp.Name != "" { + return fmt.Sprintf("`%s` / %s", resp.ID, resp.Name) } else { - return fmt.Sprintf("`%s`", resp.UserID) + return fmt.Sprintf("`%s`", resp.ID) } } @@ -85,65 +119,137 @@ func fnResolveIdentifier(ce *Event) { if api == nil { return } + allLogins := ce.User.GetUserLogins() createChat := ce.Command == "start-chat" || ce.Command == "pm" identifier := strings.Join(identifierParts, " ") - resp, err := api.ResolveIdentifier(ce.Ctx, identifier, createChat) + resp, err := provisionutil.ResolveIdentifier(ce.Ctx, login, identifier, createChat) + for i := 0; i < len(allLogins) && errors.Is(err, bridgev2.ErrResolveIdentifierTryNext); i++ { + resp, err = provisionutil.ResolveIdentifier(ce.Ctx, allLogins[i], identifier, createChat) + } if err != nil { - ce.Log.Err(err).Msg("Failed to resolve identifier") ce.Reply("Failed to resolve identifier: %v", err) return } else if resp == nil { ce.ReplyAdvanced(fmt.Sprintf("Identifier %s not found", html.EscapeString(identifier)), false, true) return } - formattedName := formatResolveIdentifierResult(ce.Ctx, resp) + formattedName := formatResolveIdentifierResult(resp) if createChat { - if resp.Chat == nil { - ce.Reply("Interface error: network connector did not return chat for create chat request") - return + name := resp.Portal.Name + if name == "" { + name = resp.Portal.MXID.String() } - portal := resp.Chat.Portal - if portal == nil { - portal, err = ce.Bridge.GetPortalByKey(ce.Ctx, resp.Chat.PortalKey) - if err != nil { - ce.Log.Err(err).Msg("Failed to get portal") - ce.Reply("Failed to get portal: %v", err) - return - } - } - if resp.Chat.PortalInfo == nil { - resp.Chat.PortalInfo, err = api.GetChatInfo(ce.Ctx, portal) - if err != nil { - ce.Log.Err(err).Msg("Failed to get portal info") - ce.Reply("Failed to get portal info: %v", err) - return - } - } - if portal.MXID != "" { - name := portal.Name - if name == "" { - name = portal.MXID.String() - } - portal.UpdateInfo(ce.Ctx, resp.Chat.PortalInfo, login, nil, time.Time{}) - ce.Reply("You already have a direct chat with %s at [%s](%s)", formattedName, name, portal.MXID.URI().MatrixToURL()) + if !resp.JustCreated { + ce.Reply("You already have a direct chat with %s at [%s](%s)", formattedName, name, resp.Portal.MXID.URI().MatrixToURL()) } else { - err = portal.CreateMatrixRoom(ce.Ctx, login, resp.Chat.PortalInfo) - if err != nil { - ce.Log.Err(err).Msg("Failed to create room") - ce.Reply("Failed to create room: %v", err) - return - } - name := portal.Name - if name == "" { - name = portal.MXID.String() - } - ce.Reply("Created chat with %s: [%s](%s)", formattedName, name, portal.MXID.URI().MatrixToURL()) + ce.Reply("Created chat with %s: [%s](%s)", formattedName, name, resp.Portal.MXID.URI().MatrixToURL()) } } else { ce.Reply("Found %s", formattedName) } } +var CommandCreateGroup = &FullHandler{ + Func: fnCreateGroup, + Name: "create-group", + Aliases: []string{"create"}, + Help: HelpMeta{ + Section: HelpSectionChats, + Description: "Create a new group chat for the current Matrix room", + Args: "[_group type_]", + }, + RequiresLogin: true, + NetworkAPI: NetworkAPIImplements[bridgev2.GroupCreatingNetworkAPI], +} + +func getState[T any](ctx context.Context, roomID id.RoomID, evtType event.Type, provider bridgev2.MatrixConnectorWithArbitraryRoomState) (content T) { + evt, err := provider.GetStateEvent(ctx, roomID, evtType, "") + if err != nil { + zerolog.Ctx(ctx).Err(err).Stringer("event_type", evtType).Msg("Failed to get state event for group creation") + } else if evt != nil { + content, _ = evt.Content.Parsed.(T) + } + return +} + +func fnCreateGroup(ce *Event) { + ce.Bridge.Matrix.GetCapabilities() + login, api, remainingArgs := getClientForStartingChat[bridgev2.GroupCreatingNetworkAPI](ce, "creating group") + if api == nil { + return + } + stateProvider, ok := ce.Bridge.Matrix.(bridgev2.MatrixConnectorWithArbitraryRoomState) + if !ok { + ce.Reply("Matrix connector doesn't support fetching room state") + return + } + members, err := ce.Bridge.Matrix.GetMembers(ce.Ctx, ce.RoomID) + if err != nil { + ce.Log.Err(err).Msg("Failed to get room members for group creation") + ce.Reply("Failed to get room members: %v", err) + return + } + caps := ce.Bridge.Network.GetCapabilities() + params := &bridgev2.GroupCreateParams{ + Username: "", + Participants: make([]networkid.UserID, 0, len(members)-2), + Parent: nil, // TODO check space parent event + Name: getState[*event.RoomNameEventContent](ce.Ctx, ce.RoomID, event.StateRoomName, stateProvider), + Avatar: getState[*event.RoomAvatarEventContent](ce.Ctx, ce.RoomID, event.StateRoomAvatar, stateProvider), + Topic: getState[*event.TopicEventContent](ce.Ctx, ce.RoomID, event.StateTopic, stateProvider), + Disappear: getState[*event.BeeperDisappearingTimer](ce.Ctx, ce.RoomID, event.StateBeeperDisappearingTimer, stateProvider), + RoomID: ce.RoomID, + } + for userID, member := range members { + if userID == ce.User.MXID || userID == ce.Bot.GetMXID() || !member.Membership.IsInviteOrJoin() { + continue + } + if parsedUserID, ok := ce.Bridge.Matrix.ParseGhostMXID(userID); ok { + params.Participants = append(params.Participants, parsedUserID) + } else if !ce.Bridge.Config.SplitPortals { + if user, err := ce.Bridge.GetExistingUserByMXID(ce.Ctx, userID); err != nil { + ce.Log.Err(err).Stringer("user_id", userID).Msg("Failed to get user for room member") + } else if user != nil { + // TODO add user logins to participants + //for _, login := range user.GetUserLogins() { + // params.Participants = append(params.Participants, login.GetUserID()) + //} + } + } + } + + if len(caps.Provisioning.GroupCreation) == 0 { + ce.Reply("No group creation types defined in network capabilities") + return + } else if len(remainingArgs) > 0 { + params.Type = remainingArgs[0] + } else if len(caps.Provisioning.GroupCreation) == 1 { + for params.Type = range caps.Provisioning.GroupCreation { + // The loop assigns the variable we want + } + } else { + types := strings.Join(slices.Collect(maps.Keys(caps.Provisioning.GroupCreation)), "`, `") + ce.Reply("Please specify type of group to create: `%s`", types) + return + } + resp, err := provisionutil.CreateGroup(ce.Ctx, login, params) + if err != nil { + ce.Reply("Failed to create group: %v", err) + return + } + var postfix string + if len(resp.FailedParticipants) > 0 { + failedParticipantsStrings := make([]string, len(resp.FailedParticipants)) + i := 0 + for participantID, meta := range resp.FailedParticipants { + failedParticipantsStrings[i] = fmt.Sprintf("* %s: %s", format.SafeMarkdownCode(participantID), meta.Reason) + i++ + } + postfix += "\n\nFailed to add some participants:\n" + strings.Join(failedParticipantsStrings, "\n") + } + ce.Reply("Successfully created group `%s`%s", resp.ID, postfix) +} + var CommandSearch = &FullHandler{ Func: fnSearch, Name: "search", @@ -153,6 +259,7 @@ var CommandSearch = &FullHandler{ Args: "<_query_>", }, RequiresLogin: true, + NetworkAPI: NetworkAPIImplements[bridgev2.UserSearchingNetworkAPI], } func fnSearch(ce *Event) { @@ -160,35 +267,67 @@ func fnSearch(ce *Event) { ce.Reply("Usage: `$cmdprefix search `") return } - _, api, queryParts := getClientForStartingChat[bridgev2.UserSearchingNetworkAPI](ce, "searching users") + login, api, queryParts := getClientForStartingChat[bridgev2.UserSearchingNetworkAPI](ce, "searching users") if api == nil { return } - results, err := api.SearchUsers(ce.Ctx, strings.Join(queryParts, " ")) + resp, err := provisionutil.SearchUsers(ce.Ctx, login, strings.Join(queryParts, " ")) if err != nil { - ce.Log.Err(err).Msg("Failed to search for users") ce.Reply("Failed to search for users: %v", err) return } - resultsString := make([]string, len(results)) - for i, res := range results { - formattedName := formatResolveIdentifierResult(ce.Ctx, res) + resultsString := make([]string, len(resp.Results)) + for i, res := range resp.Results { + formattedName := formatResolveIdentifierResult(res) resultsString[i] = fmt.Sprintf("* %s", formattedName) - if res.Chat != nil { - if res.Chat.Portal == nil { - res.Chat.Portal, err = ce.Bridge.GetExistingPortalByKey(ce.Ctx, res.Chat.PortalKey) - if err != nil { - ce.Log.Err(err).Object("portal_key", res.Chat.PortalKey).Msg("Failed to get DM portal") - } - } - if res.Chat.Portal != nil && res.Chat.Portal.MXID != "" { - portalName := res.Chat.Portal.Name - if portalName == "" { - portalName = res.Chat.Portal.MXID.String() - } - resultsString[i] = fmt.Sprintf("%s - DM portal: [%s](%s)", resultsString[i], portalName, res.Chat.Portal.MXID.URI().MatrixToURL()) + if res.Portal != nil && res.Portal.MXID != "" { + portalName := res.Portal.Name + if portalName == "" { + portalName = res.Portal.MXID.String() } + resultsString[i] = fmt.Sprintf("%s - DM portal: [%s](%s)", resultsString[i], portalName, res.Portal.MXID.URI().MatrixToURL()) } } ce.Reply("Search results:\n\n%s", strings.Join(resultsString, "\n")) } + +var CommandMute = &FullHandler{ + Func: fnMute, + Name: "mute", + Aliases: []string{"unmute"}, + Help: HelpMeta{ + Section: HelpSectionChats, + Description: "Mute or unmute a chat on the remote network", + Args: "[duration]", + }, + RequiresPortal: true, + RequiresLogin: true, + NetworkAPI: NetworkAPIImplements[bridgev2.MuteHandlingNetworkAPI], +} + +func fnMute(ce *Event) { + _, api, _ := getClientForStartingChat[bridgev2.MuteHandlingNetworkAPI](ce, "muting chats") + var mutedUntil int64 + if ce.Command == "mute" { + mutedUntil = -1 + if len(ce.Args) > 0 { + duration, err := time.ParseDuration(ce.Args[0]) + if err != nil { + ce.Reply("Invalid duration: %v", err) + return + } + mutedUntil = time.Now().Add(duration).UnixMilli() + } + } + err := api.HandleMute(ce.Ctx, &bridgev2.MatrixMute{ + MatrixEventBase: bridgev2.MatrixEventBase[*event.BeeperMuteEventContent]{ + Content: &event.BeeperMuteEventContent{MutedUntil: mutedUntil}, + Portal: ce.Portal, + }, + }) + if err != nil { + ce.Reply("Failed to %s chat: %v", ce.Command, err) + } else { + ce.React("✅️") + } +} diff --git a/bridgev2/database/backfillqueue.go b/bridgev2/database/backfillqueue.go index fed7452d..1f920640 100644 --- a/bridgev2/database/backfillqueue.go +++ b/bridgev2/database/backfillqueue.go @@ -78,6 +78,11 @@ const ( dispatched_at=$9, completed_at=$10, next_dispatch_min_ts=$11 WHERE bridge_id = $1 AND portal_id = $2 AND portal_receiver = $3 ` + markBackfillTaskNotDoneQuery = ` + UPDATE backfill_task + SET is_done = false + WHERE bridge_id = $1 AND portal_id = $2 AND portal_receiver = $3 AND user_login_id = $4 + ` getNextBackfillQuery = ` SELECT bridge_id, portal_id, portal_receiver, user_login_id, batch_count, is_done, @@ -86,6 +91,13 @@ const ( WHERE bridge_id = $1 AND next_dispatch_min_ts < $2 AND is_done = false AND user_login_id <> '' ORDER BY next_dispatch_min_ts LIMIT 1 ` + getNextBackfillQueryForPortal = ` + SELECT + bridge_id, portal_id, portal_receiver, user_login_id, batch_count, is_done, + cursor, oldest_message_id, dispatched_at, completed_at, next_dispatch_min_ts + FROM backfill_task + WHERE bridge_id = $1 AND portal_id = $2 AND portal_receiver = $3 AND is_done = false AND user_login_id <> '' + ` deleteBackfillQueueQuery = ` DELETE FROM backfill_task WHERE bridge_id = $1 AND portal_id = $2 AND portal_receiver = $3 @@ -120,10 +132,18 @@ func (btq *BackfillTaskQuery) Update(ctx context.Context, bq *BackfillTask) erro return btq.Exec(ctx, updateBackfillQueueQuery, bq.sqlVariables()...) } +func (btq *BackfillTaskQuery) MarkNotDone(ctx context.Context, portalKey networkid.PortalKey, userLoginID networkid.UserLoginID) error { + return btq.Exec(ctx, markBackfillTaskNotDoneQuery, btq.BridgeID, portalKey.ID, portalKey.Receiver, userLoginID) +} + func (btq *BackfillTaskQuery) GetNext(ctx context.Context) (*BackfillTask, error) { return btq.QueryOne(ctx, getNextBackfillQuery, btq.BridgeID, time.Now().UnixNano()) } +func (btq *BackfillTaskQuery) GetNextForPortal(ctx context.Context, portalKey networkid.PortalKey) (*BackfillTask, error) { + return btq.QueryOne(ctx, getNextBackfillQueryForPortal, btq.BridgeID, portalKey.ID, portalKey.Receiver) +} + func (btq *BackfillTaskQuery) Delete(ctx context.Context, portalKey networkid.PortalKey) error { return btq.Exec(ctx, deleteBackfillQueueQuery, btq.BridgeID, portalKey.ID, portalKey.Receiver) } diff --git a/bridgev2/database/database.go b/bridgev2/database/database.go index f1789441..05abddf0 100644 --- a/bridgev2/database/database.go +++ b/bridgev2/database/database.go @@ -7,13 +7,7 @@ package database import ( - "encoding/json" - "reflect" - "strings" - "go.mau.fi/util/dbutil" - "golang.org/x/exp/constraints" - "golang.org/x/exp/maps" "maunium.net/go/mautrix/bridgev2/networkid" @@ -34,6 +28,7 @@ type Database struct { UserPortal *UserPortalQuery BackfillTask *BackfillTaskQuery KV *KVQuery + PublicMedia *PublicMediaQuery } type MetaMerger interface { @@ -141,6 +136,12 @@ func New(bridgeID networkid.BridgeID, mt MetaTypes, db *dbutil.Database) *Databa BridgeID: bridgeID, Database: db, }, + PublicMedia: &PublicMediaQuery{ + BridgeID: bridgeID, + QueryHelper: dbutil.MakeQueryHelper(db, func(_ *dbutil.QueryHelper[*PublicMedia]) *PublicMedia { + return &PublicMedia{} + }), + }, } } @@ -151,55 +152,3 @@ func ensureBridgeIDMatches(ptr *networkid.BridgeID, expected networkid.BridgeID) panic("bridge ID mismatch") } } - -func GetNumberFromMap[T constraints.Integer | constraints.Float](m map[string]any, key string) (T, bool) { - if val, found := m[key]; found { - floatVal, ok := val.(float64) - if ok { - return T(floatVal), true - } - tVal, ok := val.(T) - if ok { - return tVal, true - } - } - return 0, false -} - -func unmarshalMerge(input []byte, data any, extra *map[string]any) error { - err := json.Unmarshal(input, data) - if err != nil { - return err - } - err = json.Unmarshal(input, extra) - if err != nil { - return err - } - if *extra == nil { - *extra = make(map[string]any) - } - return nil -} - -func marshalMerge(data any, extra map[string]any) ([]byte, error) { - if extra == nil { - return json.Marshal(data) - } - merged := make(map[string]any) - maps.Copy(merged, extra) - dataRef := reflect.ValueOf(data).Elem() - dataType := dataRef.Type() - for _, field := range reflect.VisibleFields(dataType) { - parts := strings.Split(field.Tag.Get("json"), ",") - if len(parts) == 0 || len(parts[0]) == 0 || parts[0] == "-" { - continue - } - fieldVal := dataRef.FieldByIndex(field.Index) - if fieldVal.IsZero() { - delete(merged, parts[0]) - } else { - merged[parts[0]] = fieldVal.Interface() - } - } - return json.Marshal(merged) -} diff --git a/bridgev2/database/disappear.go b/bridgev2/database/disappear.go index 23db1448..df36b205 100644 --- a/bridgev2/database/disappear.go +++ b/bridgev2/database/disappear.go @@ -12,56 +12,94 @@ import ( "time" "go.mau.fi/util/dbutil" + "go.mau.fi/util/jsontime" "maunium.net/go/mautrix/bridgev2/networkid" + "maunium.net/go/mautrix/event" "maunium.net/go/mautrix/id" ) -// DisappearingType represents the type of a disappearing message timer. -type DisappearingType string +// Deprecated: use [event.DisappearingType] +type DisappearingType = event.DisappearingType +// Deprecated: use constants in event package const ( - DisappearingTypeNone DisappearingType = "" - DisappearingTypeAfterRead DisappearingType = "after_read" - DisappearingTypeAfterSend DisappearingType = "after_send" + DisappearingTypeNone = event.DisappearingTypeNone + DisappearingTypeAfterRead = event.DisappearingTypeAfterRead + DisappearingTypeAfterSend = event.DisappearingTypeAfterSend ) // DisappearingSetting represents a disappearing message timer setting // by combining a type with a timer and an optional start timestamp. type DisappearingSetting struct { - Type DisappearingType + Type event.DisappearingType Timer time.Duration DisappearAt time.Time } +func DisappearingSettingFromEvent(evt *event.BeeperDisappearingTimer) DisappearingSetting { + if evt == nil || evt.Type == event.DisappearingTypeNone { + return DisappearingSetting{} + } + return DisappearingSetting{ + Type: evt.Type, + Timer: evt.Timer.Duration, + } +} + +func (ds DisappearingSetting) Normalize() DisappearingSetting { + if ds.Type == event.DisappearingTypeNone { + ds.Timer = 0 + } else if ds.Timer == 0 { + ds.Type = event.DisappearingTypeNone + } + return ds +} + +func (ds DisappearingSetting) StartingAt(start time.Time) DisappearingSetting { + ds.DisappearAt = start.Add(ds.Timer) + return ds +} + +func (ds DisappearingSetting) ToEventContent() *event.BeeperDisappearingTimer { + if ds.Type == event.DisappearingTypeNone || ds.Timer == 0 { + return &event.BeeperDisappearingTimer{} + } + return &event.BeeperDisappearingTimer{ + Type: ds.Type, + Timer: jsontime.MS(ds.Timer), + } +} + type DisappearingMessageQuery struct { BridgeID networkid.BridgeID *dbutil.QueryHelper[*DisappearingMessage] } type DisappearingMessage struct { - BridgeID networkid.BridgeID - RoomID id.RoomID - EventID id.EventID + BridgeID networkid.BridgeID + RoomID id.RoomID + EventID id.EventID + Timestamp time.Time DisappearingSetting } const ( upsertDisappearingMessageQuery = ` - INSERT INTO disappearing_message (bridge_id, mx_room, mxid, type, timer, disappear_at) - VALUES ($1, $2, $3, $4, $5, $6) + INSERT INTO disappearing_message (bridge_id, mx_room, mxid, timestamp, type, timer, disappear_at) + VALUES ($1, $2, $3, $4, $5, $6, $7) ON CONFLICT (bridge_id, mxid) DO UPDATE SET timer=excluded.timer, disappear_at=excluded.disappear_at ` startDisappearingMessagesQuery = ` UPDATE disappearing_message SET disappear_at=$1 + timer - WHERE bridge_id=$2 AND mx_room=$3 AND disappear_at IS NULL AND type='after_read' - RETURNING bridge_id, mx_room, mxid, type, timer, disappear_at + WHERE bridge_id=$2 AND mx_room=$3 AND disappear_at IS NULL AND type='after_read' AND timestamp<=$4 + RETURNING bridge_id, mx_room, mxid, timestamp, type, timer, disappear_at ` getUpcomingDisappearingMessagesQuery = ` - SELECT bridge_id, mx_room, mxid, type, timer, disappear_at + SELECT bridge_id, mx_room, mxid, timestamp, type, timer, disappear_at FROM disappearing_message WHERE bridge_id = $1 AND disappear_at IS NOT NULL AND disappear_at < $2 - ORDER BY disappear_at + ORDER BY disappear_at LIMIT $3 ` deleteDisappearingMessageQuery = ` DELETE FROM disappearing_message WHERE bridge_id=$1 AND mxid=$2 @@ -73,12 +111,12 @@ func (dmq *DisappearingMessageQuery) Put(ctx context.Context, dm *DisappearingMe return dmq.Exec(ctx, upsertDisappearingMessageQuery, dm.sqlVariables()...) } -func (dmq *DisappearingMessageQuery) StartAll(ctx context.Context, roomID id.RoomID) ([]*DisappearingMessage, error) { - return dmq.QueryMany(ctx, startDisappearingMessagesQuery, time.Now().UnixNano(), dmq.BridgeID, roomID) +func (dmq *DisappearingMessageQuery) StartAllBefore(ctx context.Context, roomID id.RoomID, beforeTS time.Time) ([]*DisappearingMessage, error) { + return dmq.QueryMany(ctx, startDisappearingMessagesQuery, time.Now().UnixNano(), dmq.BridgeID, roomID, beforeTS.UnixNano()) } -func (dmq *DisappearingMessageQuery) GetUpcoming(ctx context.Context, duration time.Duration) ([]*DisappearingMessage, error) { - return dmq.QueryMany(ctx, getUpcomingDisappearingMessagesQuery, dmq.BridgeID, time.Now().Add(duration).UnixNano()) +func (dmq *DisappearingMessageQuery) GetUpcoming(ctx context.Context, duration time.Duration, limit int) ([]*DisappearingMessage, error) { + return dmq.QueryMany(ctx, getUpcomingDisappearingMessagesQuery, dmq.BridgeID, time.Now().Add(duration).UnixNano(), limit) } func (dmq *DisappearingMessageQuery) Delete(ctx context.Context, eventID id.EventID) error { @@ -86,17 +124,19 @@ func (dmq *DisappearingMessageQuery) Delete(ctx context.Context, eventID id.Even } func (d *DisappearingMessage) Scan(row dbutil.Scannable) (*DisappearingMessage, error) { + var timestamp int64 var disappearAt sql.NullInt64 - err := row.Scan(&d.BridgeID, &d.RoomID, &d.EventID, &d.Type, &d.Timer, &disappearAt) + err := row.Scan(&d.BridgeID, &d.RoomID, &d.EventID, ×tamp, &d.Type, &d.Timer, &disappearAt) if err != nil { return nil, err } if disappearAt.Valid { d.DisappearAt = time.Unix(0, disappearAt.Int64) } + d.Timestamp = time.Unix(0, timestamp) return d, nil } func (d *DisappearingMessage) sqlVariables() []any { - return []any{d.BridgeID, d.RoomID, d.EventID, d.Type, d.Timer, dbutil.ConvertedPtr(d.DisappearAt, time.Time.UnixNano)} + return []any{d.BridgeID, d.RoomID, d.EventID, d.Timestamp.UnixNano(), d.Type, d.Timer, dbutil.ConvertedPtr(d.DisappearAt, time.Time.UnixNano)} } diff --git a/bridgev2/database/ghost.go b/bridgev2/database/ghost.go index c32929ad..16af35ca 100644 --- a/bridgev2/database/ghost.go +++ b/bridgev2/database/ghost.go @@ -7,12 +7,17 @@ package database import ( + "bytes" "context" "encoding/hex" + "encoding/json" + "fmt" "go.mau.fi/util/dbutil" + "go.mau.fi/util/exerrors" "maunium.net/go/mautrix/bridgev2/networkid" + "maunium.net/go/mautrix/crypto/canonicaljson" "maunium.net/go/mautrix/id" ) @@ -22,6 +27,55 @@ type GhostQuery struct { *dbutil.QueryHelper[*Ghost] } +type ExtraProfile map[string]json.RawMessage + +func (ep *ExtraProfile) Set(key string, value any) error { + if key == "displayname" || key == "avatar_url" { + return fmt.Errorf("cannot set reserved profile key %q", key) + } + marshaled, err := json.Marshal(value) + if err != nil { + return err + } + if *ep == nil { + *ep = make(ExtraProfile) + } + (*ep)[key] = canonicaljson.CanonicalJSONAssumeValid(marshaled) + return nil +} + +func (ep *ExtraProfile) With(key string, value any) *ExtraProfile { + exerrors.PanicIfNotNil(ep.Set(key, value)) + return ep +} + +func canonicalizeIfObject(data json.RawMessage) json.RawMessage { + if len(data) > 0 && (data[0] == '{' || data[0] == '[') { + return canonicaljson.CanonicalJSONAssumeValid(data) + } + return data +} + +func (ep *ExtraProfile) CopyTo(dest *ExtraProfile) (changed bool) { + if len(*ep) == 0 { + return + } + if *dest == nil { + *dest = make(ExtraProfile) + } + for key, val := range *ep { + if key == "displayname" || key == "avatar_url" { + continue + } + existing, exists := (*dest)[key] + if !exists || !bytes.Equal(canonicalizeIfObject(existing), val) { + (*dest)[key] = val + changed = true + } + } + return +} + type Ghost struct { BridgeID networkid.BridgeID ID networkid.UserID @@ -35,13 +89,14 @@ type Ghost struct { ContactInfoSet bool IsBot bool Identifiers []string + ExtraProfile ExtraProfile Metadata any } const ( getGhostBaseQuery = ` SELECT bridge_id, id, name, avatar_id, avatar_hash, avatar_mxc, - name_set, avatar_set, contact_info_set, is_bot, identifiers, metadata + name_set, avatar_set, contact_info_set, is_bot, identifiers, extra_profile, metadata FROM ghost ` getGhostByIDQuery = getGhostBaseQuery + `WHERE bridge_id=$1 AND id=$2` @@ -49,13 +104,14 @@ const ( insertGhostQuery = ` INSERT INTO ghost ( bridge_id, id, name, avatar_id, avatar_hash, avatar_mxc, - name_set, avatar_set, contact_info_set, is_bot, identifiers, metadata + name_set, avatar_set, contact_info_set, is_bot, identifiers, extra_profile, metadata ) - VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12) + VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13) ` updateGhostQuery = ` UPDATE ghost SET name=$3, avatar_id=$4, avatar_hash=$5, avatar_mxc=$6, - name_set=$7, avatar_set=$8, contact_info_set=$9, is_bot=$10, identifiers=$11, metadata=$12 + name_set=$7, avatar_set=$8, contact_info_set=$9, is_bot=$10, + identifiers=$11, extra_profile=$12, metadata=$13 WHERE bridge_id=$1 AND id=$2 ` ) @@ -86,7 +142,7 @@ func (g *Ghost) Scan(row dbutil.Scannable) (*Ghost, error) { &g.BridgeID, &g.ID, &g.Name, &g.AvatarID, &avatarHash, &g.AvatarMXC, &g.NameSet, &g.AvatarSet, &g.ContactInfoSet, &g.IsBot, - dbutil.JSON{Data: &g.Identifiers}, dbutil.JSON{Data: g.Metadata}, + dbutil.JSON{Data: &g.Identifiers}, dbutil.JSON{Data: &g.ExtraProfile}, dbutil.JSON{Data: g.Metadata}, ) if err != nil { return nil, err @@ -116,6 +172,6 @@ func (g *Ghost) sqlVariables() []any { g.BridgeID, g.ID, g.Name, g.AvatarID, avatarHash, g.AvatarMXC, g.NameSet, g.AvatarSet, g.ContactInfoSet, g.IsBot, - dbutil.JSON{Data: &g.Identifiers}, dbutil.JSON{Data: g.Metadata}, + dbutil.JSON{Data: &g.Identifiers}, dbutil.JSON{Data: g.ExtraProfile}, dbutil.JSON{Data: g.Metadata}, } } diff --git a/bridgev2/database/kvstore.go b/bridgev2/database/kvstore.go index 3fc54f2c..bca26ed5 100644 --- a/bridgev2/database/kvstore.go +++ b/bridgev2/database/kvstore.go @@ -20,7 +20,10 @@ import ( type Key string const ( - KeySplitPortalsEnabled Key = "split_portals_enabled" + KeySplitPortalsEnabled Key = "split_portals_enabled" + KeyBridgeInfoVersion Key = "bridge_info_version" + KeyEncryptionStateResynced Key = "encryption_state_resynced" + KeyRecoveryKey Key = "recovery_key" ) type KVQuery struct { diff --git a/bridgev2/database/message.go b/bridgev2/database/message.go index 8daf7407..4fd599a8 100644 --- a/bridgev2/database/message.go +++ b/bridgev2/database/message.go @@ -11,9 +11,12 @@ import ( "crypto/sha256" "database/sql" "encoding/base64" + "fmt" "strings" + "sync" "time" + "github.com/rs/zerolog" "go.mau.fi/util/dbutil" "maunium.net/go/mautrix/bridgev2/networkid" @@ -24,6 +27,7 @@ type MessageQuery struct { BridgeID networkid.BridgeID MetaType MetaTypeCreator *dbutil.QueryHelper[*Message] + chunkDeleteLock sync.Mutex } type Message struct { @@ -33,37 +37,43 @@ type Message struct { PartID networkid.PartID MXID id.EventID - Room networkid.PortalKey - SenderID networkid.UserID - SenderMXID id.UserID - Timestamp time.Time - EditCount int + Room networkid.PortalKey + SenderID networkid.UserID + SenderMXID id.UserID + Timestamp time.Time + EditCount int + IsDoublePuppeted bool ThreadRoot networkid.MessageID ReplyTo networkid.MessageOptionalPartID + SendTxnID networkid.RawTransactionID + Metadata any } const ( getMessageBaseQuery = ` SELECT rowid, bridge_id, id, part_id, mxid, room_id, room_receiver, sender_id, sender_mxid, - timestamp, edit_count, thread_root_id, reply_to_id, reply_to_part_id, metadata + timestamp, edit_count, double_puppeted, thread_root_id, reply_to_id, reply_to_part_id, + send_txn_id, metadata FROM message ` getAllMessagePartsByIDQuery = getMessageBaseQuery + `WHERE bridge_id=$1 AND (room_receiver=$2 OR room_receiver='') AND id=$3` getMessagePartByIDQuery = getMessageBaseQuery + `WHERE bridge_id=$1 AND (room_receiver=$2 OR room_receiver='') AND id=$3 AND part_id=$4` getMessagePartByRowIDQuery = getMessageBaseQuery + `WHERE bridge_id=$1 AND rowid=$2` getMessageByMXIDQuery = getMessageBaseQuery + `WHERE bridge_id=$1 AND mxid=$2` + getMessageByTxnIDQuery = getMessageBaseQuery + `WHERE bridge_id=$1 AND (room_receiver=$2 OR room_receiver='') AND (mxid=$3 OR send_txn_id=$4)` getLastMessagePartByIDQuery = getMessageBaseQuery + `WHERE bridge_id=$1 AND (room_receiver=$2 OR room_receiver='') AND id=$3 ORDER BY part_id DESC LIMIT 1` getFirstMessagePartByIDQuery = getMessageBaseQuery + `WHERE bridge_id=$1 AND (room_receiver=$2 OR room_receiver='') AND id=$3 ORDER BY part_id ASC LIMIT 1` getMessagesBetweenTimeQuery = getMessageBaseQuery + `WHERE bridge_id=$1 AND room_id=$2 AND room_receiver=$3 AND timestamp>$4 AND timestamp<=$5` getOldestMessageInPortal = getMessageBaseQuery + `WHERE bridge_id=$1 AND room_id=$2 AND room_receiver=$3 ORDER BY timestamp ASC, part_id ASC LIMIT 1` - getFirstMessageInThread = getMessageBaseQuery + `WHERE bridge_id=$1 AND room_id=$2 AND room_receiver=$3 AND (id=$4 OR thread_root_id=$4) ORDER BY timestamp ASC, part_id ASC LIMIT 1` - getLastMessageInThread = getMessageBaseQuery + `WHERE bridge_id=$1 AND room_id=$2 AND room_receiver=$3 AND (id=$4 OR thread_root_id=$4) ORDER BY timestamp DESC, part_id DESC LIMIT 1` + getFirstMessageInThread = getMessageBaseQuery + `WHERE bridge_id=$1 AND room_id=$2 AND room_receiver=$3 AND (id=$4 OR thread_root_id=$4) ORDER BY thread_root_id NULLS FIRST, timestamp ASC, part_id ASC LIMIT 1` + getLastMessageInThread = getMessageBaseQuery + `WHERE bridge_id=$1 AND room_id=$2 AND room_receiver=$3 AND (id=$4 OR thread_root_id=$4) ORDER BY thread_root_id NULLS LAST, timestamp DESC, part_id DESC LIMIT 1` getLastNInPortal = getMessageBaseQuery + `WHERE bridge_id=$1 AND room_id=$2 AND room_receiver=$3 ORDER BY timestamp DESC, part_id DESC LIMIT $4` - getLastMessagePartAtOrBeforeTimeQuery = getMessageBaseQuery + `WHERE bridge_id = $1 AND room_id=$2 AND room_receiver=$3 AND timestamp<=$4 ORDER BY timestamp DESC, part_id DESC LIMIT 1` + getLastMessagePartAtOrBeforeTimeQuery = getMessageBaseQuery + `WHERE bridge_id = $1 AND room_id=$2 AND room_receiver=$3 AND timestamp<=$4 ORDER BY timestamp DESC, part_id DESC LIMIT 1` + getLastNonFakeMessagePartAtOrBeforeTimeQuery = getMessageBaseQuery + `WHERE bridge_id = $1 AND room_id=$2 AND room_receiver=$3 AND timestamp<=$4 AND mxid NOT LIKE '~fake:%' ORDER BY timestamp DESC, part_id DESC LIMIT 1` countMessagesInPortalQuery = ` SELECT COUNT(*) FROM message WHERE bridge_id=$1 AND room_id=$2 AND room_receiver=$3 @@ -72,15 +82,17 @@ const ( insertMessageQuery = ` INSERT INTO message ( bridge_id, id, part_id, mxid, room_id, room_receiver, sender_id, sender_mxid, - timestamp, edit_count, thread_root_id, reply_to_id, reply_to_part_id, metadata + timestamp, edit_count, double_puppeted, thread_root_id, reply_to_id, reply_to_part_id, + send_txn_id, metadata ) - VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14) + VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16) RETURNING rowid ` updateMessageQuery = ` UPDATE message SET id=$2, part_id=$3, mxid=$4, room_id=$5, room_receiver=$6, sender_id=$7, sender_mxid=$8, - timestamp=$9, edit_count=$10, thread_root_id=$11, reply_to_id=$12, reply_to_part_id=$13, metadata=$14 - WHERE bridge_id=$1 AND rowid=$15 + timestamp=$9, edit_count=$10, double_puppeted=$11, thread_root_id=$12, reply_to_id=$13, + reply_to_part_id=$14, send_txn_id=$15, metadata=$16 + WHERE bridge_id=$1 AND rowid=$17 ` deleteAllMessagePartsByIDQuery = ` DELETE FROM message WHERE bridge_id=$1 AND (room_receiver=$2 OR room_receiver='') AND id=$3 @@ -88,6 +100,10 @@ const ( deleteMessagePartByRowIDQuery = ` DELETE FROM message WHERE bridge_id=$1 AND rowid=$2 ` + deleteMessageChunkQuery = ` + DELETE FROM message WHERE bridge_id=$1 AND room_id=$2 AND room_receiver=$3 AND rowid > $4 AND rowid <= $5 + ` + getMaxMessageRowIDQuery = `SELECT MAX(rowid) FROM message WHERE bridge_id=$1` ) func (mq *MessageQuery) GetAllPartsByID(ctx context.Context, receiver networkid.UserLoginID, id networkid.MessageID) ([]*Message, error) { @@ -102,6 +118,10 @@ func (mq *MessageQuery) GetPartByMXID(ctx context.Context, mxid id.EventID) (*Me return mq.QueryOne(ctx, getMessageByMXIDQuery, mq.BridgeID, mxid) } +func (mq *MessageQuery) GetPartByTxnID(ctx context.Context, receiver networkid.UserLoginID, mxid id.EventID, txnID networkid.RawTransactionID) (*Message, error) { + return mq.QueryOne(ctx, getMessageByTxnIDQuery, mq.BridgeID, receiver, mxid, txnID) +} + func (mq *MessageQuery) GetLastPartByID(ctx context.Context, receiver networkid.UserLoginID, id networkid.MessageID) (*Message, error) { return mq.QueryOne(ctx, getLastMessagePartByIDQuery, mq.BridgeID, receiver, id) } @@ -126,6 +146,10 @@ func (mq *MessageQuery) GetLastPartAtOrBeforeTime(ctx context.Context, portal ne return mq.QueryOne(ctx, getLastMessagePartAtOrBeforeTimeQuery, mq.BridgeID, portal.ID, portal.Receiver, maxTS.UnixNano()) } +func (mq *MessageQuery) GetLastNonFakePartAtOrBeforeTime(ctx context.Context, portal networkid.PortalKey, maxTS time.Time) (*Message, error) { + return mq.QueryOne(ctx, getLastNonFakeMessagePartAtOrBeforeTimeQuery, mq.BridgeID, portal.ID, portal.Receiver, maxTS.UnixNano()) +} + func (mq *MessageQuery) GetMessagesBetweenTimeQuery(ctx context.Context, portal networkid.PortalKey, start, end time.Time) ([]*Message, error) { return mq.QueryMany(ctx, getMessagesBetweenTimeQuery, mq.BridgeID, portal.ID, portal.Receiver, start.UnixNano(), end.UnixNano()) } @@ -164,6 +188,85 @@ func (mq *MessageQuery) Delete(ctx context.Context, rowID int64) error { return mq.Exec(ctx, deleteMessagePartByRowIDQuery, mq.BridgeID, rowID) } +func (mq *MessageQuery) deleteChunk(ctx context.Context, portal networkid.PortalKey, minRowID, maxRowID int64) (int64, error) { + res, err := mq.GetDB().Exec(ctx, deleteMessageChunkQuery, mq.BridgeID, portal.ID, portal.Receiver, minRowID, maxRowID) + if err != nil { + return 0, err + } + return res.RowsAffected() +} + +func (mq *MessageQuery) getMaxRowID(ctx context.Context) (maxRowID int64, err error) { + err = mq.GetDB().QueryRow(ctx, getMaxMessageRowIDQuery, mq.BridgeID).Scan(&maxRowID) + return +} + +const deleteChunkSize = 100_000 + +func (mq *MessageQuery) DeleteInChunks(ctx context.Context, portal networkid.PortalKey) error { + if mq.GetDB().Dialect != dbutil.SQLite { + return nil + } + log := zerolog.Ctx(ctx).With(). + Str("action", "delete messages in chunks"). + Stringer("portal_key", portal). + Logger() + if !mq.chunkDeleteLock.TryLock() { + log.Warn().Msg("Portal deletion lock is being held, waiting...") + mq.chunkDeleteLock.Lock() + log.Debug().Msg("Acquired portal deletion lock after waiting") + } + defer mq.chunkDeleteLock.Unlock() + total, err := mq.CountMessagesInPortal(ctx, portal) + if err != nil { + return fmt.Errorf("failed to count messages in portal: %w", err) + } else if total < deleteChunkSize/3 { + return nil + } + globalMaxRowID, err := mq.getMaxRowID(ctx) + if err != nil { + return fmt.Errorf("failed to get max row ID: %w", err) + } + log.Debug(). + Int("total_count", total). + Int64("global_max_row_id", globalMaxRowID). + Msg("Portal has lots of messages, deleting in chunks to avoid database locks") + maxRowID := int64(deleteChunkSize) + globalMaxRowID += deleteChunkSize * 1.2 + var dbTimeUsed time.Duration + globalStart := time.Now() + for total > 500 && maxRowID < globalMaxRowID { + start := time.Now() + count, err := mq.deleteChunk(ctx, portal, maxRowID-deleteChunkSize, maxRowID) + duration := time.Since(start) + dbTimeUsed += duration + if err != nil { + return fmt.Errorf("failed to delete chunk of messages before %d: %w", maxRowID, err) + } + total -= int(count) + maxRowID += deleteChunkSize + sleepTime := max(10*time.Millisecond, min(250*time.Millisecond, time.Duration(count/100)*time.Millisecond)) + log.Debug(). + Int64("max_row_id", maxRowID). + Int64("deleted_count", count). + Int("remaining_count", total). + Dur("duration", duration). + Dur("sleep_time", sleepTime). + Msg("Deleted chunk of messages") + select { + case <-time.After(sleepTime): + case <-ctx.Done(): + return ctx.Err() + } + } + log.Debug(). + Int("remaining_count", total). + Dur("db_time_used", dbTimeUsed). + Dur("total_duration", time.Since(globalStart)). + Msg("Finished chunked delete of messages in portal") + return nil +} + func (mq *MessageQuery) CountMessagesInPortal(ctx context.Context, key networkid.PortalKey) (count int, err error) { err = mq.GetDB().QueryRow(ctx, countMessagesInPortalQuery, mq.BridgeID, key.ID, key.Receiver).Scan(&count) return @@ -171,22 +274,28 @@ func (mq *MessageQuery) CountMessagesInPortal(ctx context.Context, key networkid func (m *Message) Scan(row dbutil.Scannable) (*Message, error) { var timestamp int64 - var threadRootID, replyToID, replyToPartID sql.NullString + var threadRootID, replyToID, replyToPartID, sendTxnID sql.NullString + var doublePuppeted sql.NullBool err := row.Scan( &m.RowID, &m.BridgeID, &m.ID, &m.PartID, &m.MXID, &m.Room.ID, &m.Room.Receiver, &m.SenderID, &m.SenderMXID, - ×tamp, &m.EditCount, &threadRootID, &replyToID, &replyToPartID, dbutil.JSON{Data: m.Metadata}, + ×tamp, &m.EditCount, &doublePuppeted, &threadRootID, &replyToID, &replyToPartID, &sendTxnID, + dbutil.JSON{Data: m.Metadata}, ) if err != nil { return nil, err } m.Timestamp = time.Unix(0, timestamp) m.ThreadRoot = networkid.MessageID(threadRootID.String) + m.IsDoublePuppeted = doublePuppeted.Valid if replyToID.Valid { m.ReplyTo.MessageID = networkid.MessageID(replyToID.String) if replyToPartID.Valid { m.ReplyTo.PartID = (*networkid.PartID)(&replyToPartID.String) } } + if sendTxnID.Valid { + m.SendTxnID = networkid.RawTransactionID(sendTxnID.String) + } return m, nil } @@ -200,7 +309,8 @@ func (m *Message) ensureHasMetadata(metaType MetaTypeCreator) *Message { func (m *Message) sqlVariables() []any { return []any{ m.BridgeID, m.ID, m.PartID, m.MXID, m.Room.ID, m.Room.Receiver, m.SenderID, m.SenderMXID, - m.Timestamp.UnixNano(), m.EditCount, dbutil.StrPtr(m.ThreadRoot), dbutil.StrPtr(m.ReplyTo.MessageID), m.ReplyTo.PartID, + m.Timestamp.UnixNano(), m.EditCount, m.IsDoublePuppeted, dbutil.StrPtr(m.ThreadRoot), + dbutil.StrPtr(m.ReplyTo.MessageID), m.ReplyTo.PartID, dbutil.StrPtr(m.SendTxnID), dbutil.JSON{Data: m.Metadata}, } } @@ -210,6 +320,9 @@ func (m *Message) updateSQLVariables() []any { } const FakeMXIDPrefix = "~fake:" +const TxnMXIDPrefix = "~txn:" +const NetworkTxnMXIDPrefix = TxnMXIDPrefix + "network:" +const RandomTxnMXIDPrefix = TxnMXIDPrefix + "random:" func (m *Message) SetFakeMXID() { hash := sha256.Sum256([]byte(m.ID)) diff --git a/bridgev2/database/portal.go b/bridgev2/database/portal.go index 72e31454..0e6be286 100644 --- a/bridgev2/database/portal.go +++ b/bridgev2/database/portal.go @@ -16,6 +16,7 @@ import ( "go.mau.fi/util/dbutil" "maunium.net/go/mautrix/bridgev2/networkid" + "maunium.net/go/mautrix/event" "maunium.net/go/mautrix/id" ) @@ -34,35 +35,53 @@ type PortalQuery struct { *dbutil.QueryHelper[*Portal] } +type CapStateFlags uint32 + +func (csf CapStateFlags) Has(flag CapStateFlags) bool { + return csf&flag != 0 +} + +const ( + CapStateFlagDisappearingTimerSet CapStateFlags = 1 << iota +) + +type CapabilityState struct { + Source networkid.UserLoginID `json:"source"` + ID string `json:"id"` + Flags CapStateFlags `json:"flags"` +} + type Portal struct { BridgeID networkid.BridgeID networkid.PortalKey MXID id.RoomID - ParentKey networkid.PortalKey - RelayLoginID networkid.UserLoginID - OtherUserID networkid.UserID - Name string - Topic string - AvatarID networkid.AvatarID - AvatarHash [32]byte - AvatarMXC id.ContentURIString - NameSet bool - TopicSet bool - AvatarSet bool - NameIsCustom bool - InSpace bool - RoomType RoomType - Disappear DisappearingSetting - Metadata any + ParentKey networkid.PortalKey + RelayLoginID networkid.UserLoginID + OtherUserID networkid.UserID + Name string + Topic string + AvatarID networkid.AvatarID + AvatarHash [32]byte + AvatarMXC id.ContentURIString + NameSet bool + TopicSet bool + AvatarSet bool + NameIsCustom bool + InSpace bool + MessageRequest bool + RoomType RoomType + Disappear DisappearingSetting + CapState CapabilityState + Metadata any } const ( getPortalBaseQuery = ` SELECT bridge_id, id, receiver, mxid, parent_id, parent_receiver, relay_login_id, other_user_id, name, topic, avatar_id, avatar_hash, avatar_mxc, - name_set, topic_set, avatar_set, name_is_custom, in_space, - room_type, disappear_type, disappear_timer, + name_set, topic_set, avatar_set, name_is_custom, in_space, message_request, + room_type, disappear_type, disappear_timer, cap_state, metadata FROM portal ` @@ -70,7 +89,9 @@ const ( getPortalByIDWithUncertainReceiverQuery = getPortalBaseQuery + `WHERE bridge_id=$1 AND id=$2 AND (receiver=$3 OR receiver='')` getPortalByMXIDQuery = getPortalBaseQuery + `WHERE bridge_id=$1 AND mxid=$2` getAllPortalsWithMXIDQuery = getPortalBaseQuery + `WHERE bridge_id=$1 AND mxid IS NOT NULL` + getAllPortalsWithoutReceiver = getPortalBaseQuery + `WHERE bridge_id=$1 AND (receiver='' OR (parent_id<>'' AND parent_receiver='')) ORDER BY parent_id DESC` getAllDMPortalsQuery = getPortalBaseQuery + `WHERE bridge_id=$1 AND room_type='dm' AND other_user_id=$2` + getDMPortalQuery = getPortalBaseQuery + `WHERE bridge_id=$1 AND room_type='dm' AND receiver=$2 AND other_user_id=$3` getAllPortalsQuery = getPortalBaseQuery + `WHERE bridge_id=$1` getChildPortalsQuery = getPortalBaseQuery + `WHERE bridge_id=$1 AND parent_id=$2 AND parent_receiver=$3` @@ -81,11 +102,11 @@ const ( bridge_id, id, receiver, mxid, parent_id, parent_receiver, relay_login_id, other_user_id, name, topic, avatar_id, avatar_hash, avatar_mxc, - name_set, avatar_set, topic_set, name_is_custom, in_space, - room_type, disappear_type, disappear_timer, + name_set, avatar_set, topic_set, name_is_custom, in_space, message_request, + room_type, disappear_type, disappear_timer, cap_state, metadata, relay_bridge_id ) VALUES ( - $1, $2, $3, $4, $5, $6, cast($7 AS TEXT), $8, $9, $10, $11, $12, $13, $14, $15, $16, $17, $18, $19, $20, $21, $22, + $1, $2, $3, $4, $5, $6, cast($7 AS TEXT), $8, $9, $10, $11, $12, $13, $14, $15, $16, $17, $18, $19, $20, $21, $22, $23, $24, CASE WHEN cast($7 AS TEXT) IS NULL THEN NULL ELSE $1 END ) ` @@ -94,8 +115,8 @@ const ( SET mxid=$4, parent_id=$5, parent_receiver=$6, relay_login_id=cast($7 AS TEXT), relay_bridge_id=CASE WHEN cast($7 AS TEXT) IS NULL THEN NULL ELSE bridge_id END, other_user_id=$8, name=$9, topic=$10, avatar_id=$11, avatar_hash=$12, avatar_mxc=$13, - name_set=$14, avatar_set=$15, topic_set=$16, name_is_custom=$17, in_space=$18, - room_type=$19, disappear_type=$20, disappear_timer=$21, metadata=$22 + name_set=$14, avatar_set=$15, topic_set=$16, name_is_custom=$17, in_space=$18, message_request=$19, + room_type=$20, disappear_type=$21, disappear_timer=$22, cap_state=$23, metadata=$24 WHERE bridge_id=$1 AND id=$2 AND receiver=$3 ` deletePortalQuery = ` @@ -105,15 +126,33 @@ const ( reIDPortalQuery = `UPDATE portal SET id=$4, receiver=$5 WHERE bridge_id=$1 AND id=$2 AND receiver=$3` migrateToSplitPortalsQuery = ` UPDATE portal - SET receiver=COALESCE(( - SELECT login_id - FROM user_portal - WHERE bridge_id=portal.bridge_id AND portal_id=portal.id AND portal_receiver='' - LIMIT 1 - ), ( - SELECT id FROM user_login WHERE bridge_id=portal.bridge_id LIMIT 1 - ), '') - WHERE receiver='' AND bridge_id=$1 + SET receiver=new_receiver + FROM ( + SELECT bridge_id, id, COALESCE(( + SELECT login_id + FROM user_portal + WHERE bridge_id=portal.bridge_id AND portal_id=portal.id AND portal_receiver='' + LIMIT 1 + ), ( + SELECT login_id + FROM user_portal + WHERE portal.parent_id<>'' AND bridge_id=portal.bridge_id AND portal_id=portal.parent_id + LIMIT 1 + ), ( + SELECT id FROM user_login WHERE bridge_id=portal.bridge_id LIMIT 1 + ), '') AS new_receiver + FROM portal + WHERE receiver='' AND bridge_id=$1 + ) updates + WHERE portal.bridge_id=updates.bridge_id AND portal.id=updates.id AND portal.receiver='' AND NOT EXISTS ( + SELECT 1 FROM portal p2 WHERE p2.bridge_id=updates.bridge_id AND p2.id=updates.id AND p2.receiver=updates.new_receiver + ) + ` + fixParentsAfterSplitPortalMigrationQuery = ` + UPDATE portal + SET parent_receiver=receiver + WHERE bridge_id=$1 AND parent_receiver='' AND receiver<>'' AND parent_id<>'' + AND EXISTS(SELECT 1 FROM portal pp WHERE pp.bridge_id=$1 AND pp.id=portal.parent_id AND pp.receiver=portal.receiver); ` ) @@ -141,6 +180,10 @@ func (pq *PortalQuery) GetAllWithMXID(ctx context.Context) ([]*Portal, error) { return pq.QueryMany(ctx, getAllPortalsWithMXIDQuery, pq.BridgeID) } +func (pq *PortalQuery) GetAllWithoutReceiver(ctx context.Context) ([]*Portal, error) { + return pq.QueryMany(ctx, getAllPortalsWithoutReceiver, pq.BridgeID) +} + func (pq *PortalQuery) GetAll(ctx context.Context) ([]*Portal, error) { return pq.QueryMany(ctx, getAllPortalsQuery, pq.BridgeID) } @@ -149,6 +192,10 @@ func (pq *PortalQuery) GetAllDMsWith(ctx context.Context, otherUserID networkid. return pq.QueryMany(ctx, getAllDMPortalsQuery, pq.BridgeID, otherUserID) } +func (pq *PortalQuery) GetDM(ctx context.Context, receiver networkid.UserLoginID, otherUserID networkid.UserID) (*Portal, error) { + return pq.QueryOne(ctx, getDMPortalQuery, pq.BridgeID, receiver, otherUserID) +} + func (pq *PortalQuery) GetChildren(ctx context.Context, parentKey networkid.PortalKey) ([]*Portal, error) { return pq.QueryMany(ctx, getChildPortalsQuery, pq.BridgeID, parentKey.ID, parentKey.Receiver) } @@ -179,6 +226,14 @@ func (pq *PortalQuery) MigrateToSplitPortals(ctx context.Context) (int64, error) return res.RowsAffected() } +func (pq *PortalQuery) FixParentsAfterSplitPortalMigration(ctx context.Context) (int64, error) { + res, err := pq.GetDB().Exec(ctx, fixParentsAfterSplitPortalMigrationQuery, pq.BridgeID) + if err != nil { + return 0, err + } + return res.RowsAffected() +} + func (p *Portal) Scan(row dbutil.Scannable) (*Portal, error) { var mxid, parentID, parentReceiver, relayLoginID, otherUserID, disappearType sql.NullString var disappearTimer sql.NullInt64 @@ -187,9 +242,9 @@ func (p *Portal) Scan(row dbutil.Scannable) (*Portal, error) { &p.BridgeID, &p.ID, &p.Receiver, &mxid, &parentID, &parentReceiver, &relayLoginID, &otherUserID, &p.Name, &p.Topic, &p.AvatarID, &avatarHash, &p.AvatarMXC, - &p.NameSet, &p.TopicSet, &p.AvatarSet, &p.NameIsCustom, &p.InSpace, + &p.NameSet, &p.TopicSet, &p.AvatarSet, &p.NameIsCustom, &p.InSpace, &p.MessageRequest, &p.RoomType, &disappearType, &disappearTimer, - dbutil.JSON{Data: p.Metadata}, + dbutil.JSON{Data: &p.CapState}, dbutil.JSON{Data: p.Metadata}, ) if err != nil { return nil, err @@ -202,7 +257,7 @@ func (p *Portal) Scan(row dbutil.Scannable) (*Portal, error) { } if disappearType.Valid { p.Disappear = DisappearingSetting{ - Type: DisappearingType(disappearType.String), + Type: event.DisappearingType(disappearType.String), Timer: time.Duration(disappearTimer.Int64), } } @@ -234,8 +289,8 @@ func (p *Portal) sqlVariables() []any { p.BridgeID, p.ID, p.Receiver, dbutil.StrPtr(p.MXID), dbutil.StrPtr(p.ParentKey.ID), p.ParentKey.Receiver, dbutil.StrPtr(p.RelayLoginID), dbutil.StrPtr(p.OtherUserID), p.Name, p.Topic, p.AvatarID, avatarHash, p.AvatarMXC, - p.NameSet, p.TopicSet, p.AvatarSet, p.NameIsCustom, p.InSpace, + p.NameSet, p.TopicSet, p.AvatarSet, p.NameIsCustom, p.InSpace, p.MessageRequest, p.RoomType, dbutil.StrPtr(p.Disappear.Type), dbutil.NumPtr(p.Disappear.Timer), - dbutil.JSON{Data: p.Metadata}, + dbutil.JSON{Data: p.CapState}, dbutil.JSON{Data: p.Metadata}, } } diff --git a/bridgev2/database/publicmedia.go b/bridgev2/database/publicmedia.go new file mode 100644 index 00000000..b667399c --- /dev/null +++ b/bridgev2/database/publicmedia.go @@ -0,0 +1,72 @@ +// Copyright (c) 2025 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package database + +import ( + "context" + "database/sql" + "time" + + "go.mau.fi/util/dbutil" + + "maunium.net/go/mautrix/bridgev2/networkid" + "maunium.net/go/mautrix/crypto/attachment" + "maunium.net/go/mautrix/id" +) + +type PublicMediaQuery struct { + BridgeID networkid.BridgeID + *dbutil.QueryHelper[*PublicMedia] +} + +type PublicMedia struct { + BridgeID networkid.BridgeID + PublicID string + MXC id.ContentURI + Keys *attachment.EncryptedFile + MimeType string + Expiry time.Time +} + +const ( + upsertPublicMediaQuery = ` + INSERT INTO public_media (bridge_id, public_id, mxc, keys, mimetype, expiry) + VALUES ($1, $2, $3, $4, $5, $6) + ON CONFLICT (bridge_id, public_id) DO UPDATE SET expiry=EXCLUDED.expiry + ` + getPublicMediaQuery = ` + SELECT bridge_id, public_id, mxc, keys, mimetype, expiry + FROM public_media WHERE bridge_id=$1 AND public_id=$2 + ` +) + +func (pmq *PublicMediaQuery) Put(ctx context.Context, pm *PublicMedia) error { + ensureBridgeIDMatches(&pm.BridgeID, pmq.BridgeID) + return pmq.Exec(ctx, upsertPublicMediaQuery, pm.sqlVariables()...) +} + +func (pmq *PublicMediaQuery) Get(ctx context.Context, publicID string) (*PublicMedia, error) { + return pmq.QueryOne(ctx, getPublicMediaQuery, pmq.BridgeID, publicID) +} + +func (pm *PublicMedia) Scan(row dbutil.Scannable) (*PublicMedia, error) { + var expiry sql.NullInt64 + var mimetype sql.NullString + err := row.Scan(&pm.BridgeID, &pm.PublicID, &pm.MXC, dbutil.JSON{Data: &pm.Keys}, &mimetype, &expiry) + if err != nil { + return nil, err + } + if expiry.Valid { + pm.Expiry = time.Unix(0, expiry.Int64) + } + pm.MimeType = mimetype.String + return pm, nil +} + +func (pm *PublicMedia) sqlVariables() []any { + return []any{pm.BridgeID, pm.PublicID, &pm.MXC, dbutil.JSONPtr(pm.Keys), dbutil.StrPtr(pm.MimeType), dbutil.ConvertedPtr(pm.Expiry, time.Time.UnixNano)} +} diff --git a/bridgev2/database/upgrades/00-latest.sql b/bridgev2/database/upgrades/00-latest.sql index 6d6dcf2c..6092dc24 100644 --- a/bridgev2/database/upgrades/00-latest.sql +++ b/bridgev2/database/upgrades/00-latest.sql @@ -1,4 +1,4 @@ --- v0 -> v18 (compatible with v9+): Latest revision +-- v0 -> v27 (compatible with v9+): Latest revision CREATE TABLE "user" ( bridge_id TEXT NOT NULL, mxid TEXT NOT NULL, @@ -48,9 +48,11 @@ CREATE TABLE portal ( topic_set BOOLEAN NOT NULL, name_is_custom BOOLEAN NOT NULL DEFAULT false, in_space BOOLEAN NOT NULL, + message_request BOOLEAN NOT NULL DEFAULT false, room_type TEXT NOT NULL, disappear_type TEXT, disappear_timer BIGINT, + cap_state jsonb, metadata jsonb NOT NULL, PRIMARY KEY (bridge_id, id, receiver), @@ -62,6 +64,8 @@ CREATE TABLE portal ( REFERENCES user_login (bridge_id, id) ON DELETE SET NULL ON UPDATE CASCADE ); +CREATE UNIQUE INDEX portal_bridge_mxid_idx ON portal (bridge_id, mxid); +CREATE INDEX portal_parent_idx ON portal (bridge_id, parent_id, parent_receiver); CREATE TABLE ghost ( bridge_id TEXT NOT NULL, @@ -76,6 +80,7 @@ CREATE TABLE ghost ( contact_info_set BOOLEAN NOT NULL, is_bot BOOLEAN NOT NULL, identifiers jsonb NOT NULL, + extra_profile jsonb, metadata jsonb NOT NULL, PRIMARY KEY (bridge_id, id) @@ -87,7 +92,7 @@ CREATE TABLE message ( -- would try to set bridge_id to null as well. -- only: sqlite (line commented) --- rowid INTEGER PRIMARY KEY, +-- rowid INTEGER PRIMARY KEY, -- only: postgres rowid BIGINT PRIMARY KEY GENERATED BY DEFAULT AS IDENTITY, @@ -102,9 +107,11 @@ CREATE TABLE message ( sender_mxid TEXT NOT NULL, timestamp BIGINT NOT NULL, edit_count INTEGER NOT NULL, + double_puppeted BOOLEAN, thread_root_id TEXT, reply_to_id TEXT, reply_to_part_id TEXT, + send_txn_id TEXT, metadata jsonb NOT NULL, CONSTRAINT message_room_fkey FOREIGN KEY (bridge_id, room_id, room_receiver) @@ -114,7 +121,8 @@ CREATE TABLE message ( REFERENCES ghost (bridge_id, id) ON DELETE CASCADE ON UPDATE CASCADE, CONSTRAINT message_real_pkey UNIQUE (bridge_id, room_receiver, id, part_id), - CONSTRAINT message_mxid_unique UNIQUE (bridge_id, mxid) + CONSTRAINT message_mxid_unique UNIQUE (bridge_id, mxid), + CONSTRAINT message_txn_id_unique UNIQUE (bridge_id, room_receiver, send_txn_id) ); CREATE INDEX message_room_idx ON message (bridge_id, room_id, room_receiver); @@ -122,12 +130,18 @@ CREATE TABLE disappearing_message ( bridge_id TEXT NOT NULL, mx_room TEXT NOT NULL, mxid TEXT NOT NULL, + timestamp BIGINT NOT NULL DEFAULT 0, type TEXT NOT NULL, timer BIGINT NOT NULL, disappear_at BIGINT, - PRIMARY KEY (bridge_id, mxid) + PRIMARY KEY (bridge_id, mxid), + CONSTRAINT disappearing_message_portal_fkey + FOREIGN KEY (bridge_id, mx_room) + REFERENCES portal (bridge_id, mxid) + ON DELETE CASCADE ); +CREATE INDEX disappearing_message_portal_idx ON disappearing_message (bridge_id, mx_room); CREATE TABLE reaction ( bridge_id TEXT NOT NULL, @@ -206,3 +220,14 @@ CREATE TABLE kv_store ( PRIMARY KEY (bridge_id, key) ); + +CREATE TABLE public_media ( + bridge_id TEXT NOT NULL, + public_id TEXT NOT NULL, + mxc TEXT NOT NULL, + keys jsonb, + mimetype TEXT, + expiry BIGINT, + + PRIMARY KEY (bridge_id, public_id) +); diff --git a/bridgev2/database/upgrades/19-add-double-puppeted-to-message.sql b/bridgev2/database/upgrades/19-add-double-puppeted-to-message.sql new file mode 100644 index 00000000..ec6fe836 --- /dev/null +++ b/bridgev2/database/upgrades/19-add-double-puppeted-to-message.sql @@ -0,0 +1,2 @@ +-- v19 (compatible with v9+): Add double puppeted state to messages +ALTER TABLE message ADD COLUMN double_puppeted BOOLEAN; diff --git a/bridgev2/database/upgrades/20-portal-capabilities.sql b/bridgev2/database/upgrades/20-portal-capabilities.sql new file mode 100644 index 00000000..00bd96ca --- /dev/null +++ b/bridgev2/database/upgrades/20-portal-capabilities.sql @@ -0,0 +1,2 @@ +-- v20 (compatible with v9+): Add portal capability state +ALTER TABLE portal ADD COLUMN cap_state jsonb; diff --git a/bridgev2/database/upgrades/21-disappearing-message-fkey.postgres.sql b/bridgev2/database/upgrades/21-disappearing-message-fkey.postgres.sql new file mode 100644 index 00000000..d1c1ad9a --- /dev/null +++ b/bridgev2/database/upgrades/21-disappearing-message-fkey.postgres.sql @@ -0,0 +1,8 @@ +-- v21 (compatible with v9+): Add foreign key constraint from disappearing_message.mx_room to portals.mxid +CREATE UNIQUE INDEX portal_bridge_mxid_idx ON portal (bridge_id, mxid); +DELETE FROM disappearing_message WHERE mx_room NOT IN (SELECT mxid FROM portal WHERE mxid IS NOT NULL); +ALTER TABLE disappearing_message + ADD CONSTRAINT disappearing_message_portal_fkey + FOREIGN KEY (bridge_id, mx_room) + REFERENCES portal (bridge_id, mxid) + ON DELETE CASCADE; diff --git a/bridgev2/database/upgrades/21-disappearing-message-fkey.sqlite.sql b/bridgev2/database/upgrades/21-disappearing-message-fkey.sqlite.sql new file mode 100644 index 00000000..f5468c6b --- /dev/null +++ b/bridgev2/database/upgrades/21-disappearing-message-fkey.sqlite.sql @@ -0,0 +1,24 @@ +-- v21 (compatible with v9+): Add foreign key constraint from disappearing_message.mx_room to portals.mxid +CREATE UNIQUE INDEX portal_bridge_mxid_idx ON portal (bridge_id, mxid); +CREATE TABLE disappearing_message_new ( + bridge_id TEXT NOT NULL, + mx_room TEXT NOT NULL, + mxid TEXT NOT NULL, + type TEXT NOT NULL, + timer BIGINT NOT NULL, + disappear_at BIGINT, + + PRIMARY KEY (bridge_id, mxid), + CONSTRAINT disappearing_message_portal_fkey + FOREIGN KEY (bridge_id, mx_room) + REFERENCES portal (bridge_id, mxid) + ON DELETE CASCADE +); + +WITH portal_mxids AS (SELECT mxid FROM portal WHERE mxid IS NOT NULL) +INSERT INTO disappearing_message_new (bridge_id, mx_room, mxid, type, timer, disappear_at) +SELECT bridge_id, mx_room, mxid, type, timer, disappear_at +FROM disappearing_message WHERE mx_room IN portal_mxids; + +DROP TABLE disappearing_message; +ALTER TABLE disappearing_message_new RENAME TO disappearing_message; diff --git a/bridgev2/database/upgrades/22-message-send-txn-id.sql b/bridgev2/database/upgrades/22-message-send-txn-id.sql new file mode 100644 index 00000000..8933984e --- /dev/null +++ b/bridgev2/database/upgrades/22-message-send-txn-id.sql @@ -0,0 +1,6 @@ +-- v22 (compatible with v9+): Add message send transaction ID column +ALTER TABLE message ADD COLUMN send_txn_id TEXT; +-- only: postgres +ALTER TABLE message ADD CONSTRAINT message_txn_id_unique UNIQUE (bridge_id, room_receiver, send_txn_id); +-- only: sqlite +CREATE UNIQUE INDEX message_txn_id_unique ON message (bridge_id, room_receiver, send_txn_id); diff --git a/bridgev2/database/upgrades/23-disappearing-timer-ts.sql b/bridgev2/database/upgrades/23-disappearing-timer-ts.sql new file mode 100644 index 00000000..ecd00b8d --- /dev/null +++ b/bridgev2/database/upgrades/23-disappearing-timer-ts.sql @@ -0,0 +1,2 @@ +-- v23 (compatible with v9+): Add event timestamp for disappearing messages +ALTER TABLE disappearing_message ADD COLUMN timestamp BIGINT NOT NULL DEFAULT 0; diff --git a/bridgev2/database/upgrades/24-public-media.sql b/bridgev2/database/upgrades/24-public-media.sql new file mode 100644 index 00000000..c4290090 --- /dev/null +++ b/bridgev2/database/upgrades/24-public-media.sql @@ -0,0 +1,11 @@ +-- v24 (compatible with v9+): Custom URLs for public media +CREATE TABLE public_media ( + bridge_id TEXT NOT NULL, + public_id TEXT NOT NULL, + mxc TEXT NOT NULL, + keys jsonb, + mimetype TEXT, + expiry BIGINT, + + PRIMARY KEY (bridge_id, public_id) +); diff --git a/bridgev2/database/upgrades/25-message-requests.sql b/bridgev2/database/upgrades/25-message-requests.sql new file mode 100644 index 00000000..b9d82a7a --- /dev/null +++ b/bridgev2/database/upgrades/25-message-requests.sql @@ -0,0 +1,2 @@ +-- v25 (compatible with v9+): Flag for message request portals +ALTER TABLE portal ADD COLUMN message_request BOOLEAN NOT NULL DEFAULT false; diff --git a/bridgev2/database/upgrades/26-disappearing-message-portal-index.sql b/bridgev2/database/upgrades/26-disappearing-message-portal-index.sql new file mode 100644 index 00000000..ae5d8cad --- /dev/null +++ b/bridgev2/database/upgrades/26-disappearing-message-portal-index.sql @@ -0,0 +1,3 @@ +-- v26 (compatible with v9+): Add room index for disappearing message table and portal parents +CREATE INDEX disappearing_message_portal_idx ON disappearing_message (bridge_id, mx_room); +CREATE INDEX portal_parent_idx ON portal (bridge_id, parent_id, parent_receiver); diff --git a/bridgev2/database/upgrades/27-ghost-extra-profile.sql b/bridgev2/database/upgrades/27-ghost-extra-profile.sql new file mode 100644 index 00000000..e8e0549a --- /dev/null +++ b/bridgev2/database/upgrades/27-ghost-extra-profile.sql @@ -0,0 +1,2 @@ +-- v27 (compatible with v9+): Add column for extra ghost profile metadata +ALTER TABLE ghost ADD COLUMN extra_profile jsonb; diff --git a/bridgev2/database/userlogin.go b/bridgev2/database/userlogin.go index 610e7d60..00ff01c9 100644 --- a/bridgev2/database/userlogin.go +++ b/bridgev2/database/userlogin.go @@ -12,8 +12,8 @@ import ( "go.mau.fi/util/dbutil" - "maunium.net/go/mautrix/bridge/status" "maunium.net/go/mautrix/bridgev2/networkid" + "maunium.net/go/mautrix/bridgev2/status" "maunium.net/go/mautrix/id" ) @@ -116,7 +116,7 @@ func (u *UserLogin) ensureHasMetadata(metaType MetaTypeCreator) *UserLogin { func (u *UserLogin) sqlVariables() []any { var remoteProfile dbutil.JSON - if !u.RemoteProfile.IsEmpty() { + if !u.RemoteProfile.IsZero() { remoteProfile.Data = &u.RemoteProfile } return []any{u.BridgeID, u.UserMXID, u.ID, u.RemoteName, remoteProfile, dbutil.StrPtr(u.SpaceRoom), dbutil.JSON{Data: u.Metadata}} diff --git a/bridgev2/database/userportal.go b/bridgev2/database/userportal.go index 278b236b..e928a4c7 100644 --- a/bridgev2/database/userportal.go +++ b/bridgev2/database/userportal.go @@ -67,6 +67,9 @@ const ( markLoginAsPreferredQuery = ` UPDATE user_portal SET preferred=(login_id=$3) WHERE bridge_id=$1 AND user_mxid=$2 AND portal_id=$4 AND portal_receiver=$5 ` + markAllNotInSpaceQuery = ` + UPDATE user_portal SET in_space=false WHERE bridge_id=$1 AND portal_id=$2 AND portal_receiver=$3 + ` deleteUserPortalQuery = ` DELETE FROM user_portal WHERE bridge_id=$1 AND user_mxid=$2 AND login_id=$3 AND portal_id=$4 AND portal_receiver=$5 ` @@ -110,6 +113,10 @@ func (upq *UserPortalQuery) MarkAsPreferred(ctx context.Context, login *UserLogi return upq.Exec(ctx, markLoginAsPreferredQuery, upq.BridgeID, login.UserMXID, login.ID, portal.ID, portal.Receiver) } +func (upq *UserPortalQuery) MarkAllNotInSpace(ctx context.Context, portal networkid.PortalKey) error { + return upq.Exec(ctx, markAllNotInSpaceQuery, upq.BridgeID, portal.ID, portal.Receiver) +} + func (upq *UserPortalQuery) Delete(ctx context.Context, up *UserPortal) error { return upq.Exec(ctx, deleteUserPortalQuery, up.BridgeID, up.UserMXID, up.LoginID, up.Portal.ID, up.Portal.Receiver) } diff --git a/bridgev2/disappear.go b/bridgev2/disappear.go index 5f9900a5..b5c37e8f 100644 --- a/bridgev2/disappear.go +++ b/bridgev2/disappear.go @@ -8,6 +8,7 @@ package bridgev2 import ( "context" + "sync/atomic" "time" "github.com/rs/zerolog" @@ -20,27 +21,44 @@ import ( type DisappearLoop struct { br *Bridge - NextCheck time.Time - stop context.CancelFunc + nextCheck atomic.Pointer[time.Time] + stop atomic.Pointer[context.CancelFunc] } const DisappearCheckInterval = 1 * time.Hour func (dl *DisappearLoop) Start() { log := dl.br.Log.With().Str("component", "disappear loop").Logger() - ctx := log.WithContext(context.Background()) - ctx, dl.stop = context.WithCancel(ctx) + ctx, stop := context.WithCancel(log.WithContext(context.Background())) + if oldStop := dl.stop.Swap(&stop); oldStop != nil { + (*oldStop)() + } log.Debug().Msg("Disappearing message loop starting") for { - dl.NextCheck = time.Now().Add(DisappearCheckInterval) - messages, err := dl.br.DB.DisappearingMessage.GetUpcoming(ctx, DisappearCheckInterval) + nextCheck := time.Now().Add(DisappearCheckInterval) + dl.nextCheck.Store(&nextCheck) + const MessageLimit = 200 + messages, err := dl.br.DB.DisappearingMessage.GetUpcoming(ctx, DisappearCheckInterval, MessageLimit) if err != nil { log.Err(err).Msg("Failed to get upcoming disappearing messages") } else if len(messages) > 0 { + if len(messages) >= MessageLimit { + lastDisappearTime := messages[len(messages)-1].DisappearAt + log.Debug(). + Int("message_count", len(messages)). + Time("last_due", lastDisappearTime). + Msg("Deleting disappearing messages synchronously and checking again immediately") + // Store the expected next check time to avoid Add spawning unnecessary goroutines. + // This can be in the past, in which case Add will put everything in the db, which is also fine. + dl.nextCheck.Store(&lastDisappearTime) + // If there are many messages, process them synchronously and then check again. + dl.sleepAndDisappear(ctx, messages...) + continue + } go dl.sleepAndDisappear(ctx, messages...) } select { - case <-time.After(time.Until(dl.NextCheck)): + case <-time.After(time.Until(dl.GetNextCheck())): case <-ctx.Done(): log.Debug().Msg("Disappearing message loop stopping") return @@ -48,20 +66,34 @@ func (dl *DisappearLoop) Start() { } } +func (dl *DisappearLoop) GetNextCheck() time.Time { + if dl == nil { + return time.Time{} + } + nextCheck := dl.nextCheck.Load() + if nextCheck == nil { + return time.Time{} + } + return *nextCheck +} + func (dl *DisappearLoop) Stop() { - if dl.stop != nil { - dl.stop() + if dl == nil { + return + } + if stop := dl.stop.Load(); stop != nil { + (*stop)() } } -func (dl *DisappearLoop) StartAll(ctx context.Context, roomID id.RoomID) { - startedMessages, err := dl.br.DB.DisappearingMessage.StartAll(ctx, roomID) +func (dl *DisappearLoop) StartAllBefore(ctx context.Context, roomID id.RoomID, beforeTS time.Time) { + startedMessages, err := dl.br.DB.DisappearingMessage.StartAllBefore(ctx, roomID, beforeTS) if err != nil { zerolog.Ctx(ctx).Err(err).Msg("Failed to start disappearing messages") return } startedMessages = slices.DeleteFunc(startedMessages, func(dm *database.DisappearingMessage) bool { - return dm.DisappearAt.After(dl.NextCheck) + return dm.DisappearAt.After(dl.GetNextCheck()) }) slices.SortFunc(startedMessages, func(a, b *database.DisappearingMessage) int { return a.DisappearAt.Compare(b.DisappearAt) @@ -78,14 +110,25 @@ func (dl *DisappearLoop) Add(ctx context.Context, dm *database.DisappearingMessa Stringer("event_id", dm.EventID). Msg("Failed to save disappearing message") } - if !dm.DisappearAt.IsZero() && dm.DisappearAt.Before(dl.NextCheck) { - go dl.sleepAndDisappear(context.WithoutCancel(ctx), dm) + if !dm.DisappearAt.IsZero() && dm.DisappearAt.Before(dl.GetNextCheck()) { + go dl.sleepAndDisappear(zerolog.Ctx(ctx).WithContext(dl.br.BackgroundCtx), dm) } } func (dl *DisappearLoop) sleepAndDisappear(ctx context.Context, dms ...*database.DisappearingMessage) { for _, msg := range dms { - time.Sleep(time.Until(msg.DisappearAt)) + timeUntilDisappear := time.Until(msg.DisappearAt) + if timeUntilDisappear <= 0 { + if ctx.Err() != nil { + return + } + } else { + select { + case <-time.After(timeUntilDisappear): + case <-ctx.Done(): + return + } + } resp, err := dl.br.Bot.SendMessage(ctx, msg.RoomID, event.EventRedaction, &event.Content{ Parsed: &event.RedactionEventContent{ Redacts: msg.EventID, diff --git a/bridgev2/errors.go b/bridgev2/errors.go index 789d0026..f6677d2e 100644 --- a/bridgev2/errors.go +++ b/bridgev2/errors.go @@ -12,6 +12,7 @@ import ( "net/http" "maunium.net/go/mautrix" + "maunium.net/go/mautrix/event" ) // ErrIgnoringRemoteEvent can be returned by [RemoteMessage.ConvertMessage] or [RemoteEdit.ConvertEdit] @@ -37,32 +38,53 @@ var ErrNotLoggedIn = errors.New("not logged in") // but direct media is not enabled. var ErrDirectMediaNotEnabled = errors.New("direct media is not enabled") +var ErrPortalIsDeleted = errors.New("portal is deleted") +var ErrPortalNotFoundInEventHandler = errors.New("portal not found to handle remote event") + // Common message status errors var ( - ErrPanicInEventHandler error = WrapErrorInStatus(errors.New("panic in event handler")).WithSendNotice(true).WithErrorAsMessage() - ErrNoPortal error = WrapErrorInStatus(errors.New("room is not a portal")).WithIsCertain(true).WithSendNotice(false) - ErrIgnoringReactionFromRelayedUser error = WrapErrorInStatus(errors.New("ignoring reaction event from relayed user")).WithIsCertain(true).WithSendNotice(false) - ErrIgnoringPollFromRelayedUser error = WrapErrorInStatus(errors.New("ignoring poll event from relayed user")).WithIsCertain(true).WithSendNotice(false) - ErrEditsNotSupported error = WrapErrorInStatus(errors.New("this bridge does not support edits")).WithIsCertain(true).WithErrorAsMessage() - ErrEditsNotSupportedInPortal error = WrapErrorInStatus(errors.New("edits are not allowed in this chat")).WithIsCertain(true).WithErrorAsMessage() - ErrCaptionsNotAllowed error = WrapErrorInStatus(errors.New("captions are not supported here")).WithIsCertain(true).WithErrorAsMessage() - ErrLocationMessagesNotAllowed error = WrapErrorInStatus(errors.New("location messages are not supported here")).WithIsCertain(true).WithErrorAsMessage() - ErrEditTargetTooOld error = WrapErrorInStatus(errors.New("the message is too old to be edited")).WithIsCertain(true).WithErrorAsMessage() - ErrEditTargetTooManyEdits error = WrapErrorInStatus(errors.New("the message has been edited too many times")).WithIsCertain(true).WithErrorAsMessage() - ErrReactionsNotSupported error = WrapErrorInStatus(errors.New("this bridge does not support reactions")).WithIsCertain(true).WithErrorAsMessage() - ErrPollsNotSupported error = WrapErrorInStatus(errors.New("this bridge does not support polls")).WithIsCertain(true).WithErrorAsMessage() - ErrRoomMetadataNotSupported error = WrapErrorInStatus(errors.New("this bridge does not support changing room metadata")).WithIsCertain(true).WithErrorAsMessage().WithSendNotice(false) - ErrRedactionsNotSupported error = WrapErrorInStatus(errors.New("this bridge does not support deleting messages")).WithIsCertain(true).WithErrorAsMessage() - ErrUnexpectedParsedContentType error = WrapErrorInStatus(errors.New("unexpected parsed content type")).WithErrorAsMessage().WithIsCertain(true).WithSendNotice(true) - ErrDatabaseError error = WrapErrorInStatus(errors.New("database error")).WithMessage("internal database error").WithIsCertain(true).WithSendNotice(true) - ErrTargetMessageNotFound error = WrapErrorInStatus(errors.New("target message not found")).WithErrorAsMessage().WithIsCertain(true).WithSendNotice(false) - ErrUnsupportedMessageType error = WrapErrorInStatus(errors.New("unsupported message type")).WithErrorAsMessage().WithIsCertain(true).WithSendNotice(true) - ErrUnsupportedMediaType error = WrapErrorInStatus(errors.New("unsupported media type")).WithErrorAsMessage().WithIsCertain(true).WithSendNotice(true) - ErrMediaDownloadFailed error = WrapErrorInStatus(errors.New("failed to download media")).WithMessage("failed to download media").WithIsCertain(true).WithSendNotice(true) - ErrMediaReuploadFailed error = WrapErrorInStatus(errors.New("failed to reupload media")).WithMessage("failed to reupload media").WithIsCertain(true).WithSendNotice(true) - ErrMediaConvertFailed error = WrapErrorInStatus(errors.New("failed to convert media")).WithMessage("failed to convert media").WithIsCertain(true).WithSendNotice(true) - ErrMembershipNotSupported error = WrapErrorInStatus(errors.New("this bridge does not support changing group membership")).WithIsCertain(true).WithErrorAsMessage().WithSendNotice(false) - ErrPowerLevelsNotSupported error = WrapErrorInStatus(errors.New("this bridge does not support changing group power levels")).WithIsCertain(true).WithErrorAsMessage().WithSendNotice(false) + ErrPanicInEventHandler error = WrapErrorInStatus(errors.New("panic in event handler")).WithSendNotice(true).WithErrorAsMessage() + ErrNoPortal error = WrapErrorInStatus(errors.New("room is not a portal")).WithIsCertain(true).WithSendNotice(false) + ErrIgnoringReactionFromRelayedUser error = WrapErrorInStatus(errors.New("ignoring reaction event from relayed user")).WithIsCertain(true).WithSendNotice(false) + ErrIgnoringPollFromRelayedUser error = WrapErrorInStatus(errors.New("ignoring poll event from relayed user")).WithIsCertain(true).WithSendNotice(false) + ErrIgnoringDeleteChatRelayedUser error = WrapErrorInStatus(errors.New("ignoring delete chat event from relayed user")).WithIsCertain(true).WithSendNotice(false) + ErrIgnoringAcceptRequestRelayedUser error = WrapErrorInStatus(errors.New("ignoring accept message request event from relayed user")).WithIsCertain(true).WithSendNotice(false) + ErrEditsNotSupported error = WrapErrorInStatus(errors.New("this bridge does not support edits")).WithIsCertain(true).WithErrorAsMessage().WithErrorReason(event.MessageStatusUnsupported) + ErrEditsNotSupportedInPortal error = WrapErrorInStatus(errors.New("edits are not allowed in this chat")).WithIsCertain(true).WithErrorAsMessage().WithErrorReason(event.MessageStatusUnsupported) + ErrCaptionsNotAllowed error = WrapErrorInStatus(errors.New("captions are not supported here")).WithIsCertain(true).WithErrorAsMessage().WithErrorReason(event.MessageStatusUnsupported) + ErrLocationMessagesNotAllowed error = WrapErrorInStatus(errors.New("location messages are not supported here")).WithIsCertain(true).WithErrorAsMessage().WithErrorReason(event.MessageStatusUnsupported) + ErrEditTargetTooOld error = WrapErrorInStatus(errors.New("the message is too old to be edited")).WithIsCertain(true).WithErrorAsMessage().WithErrorReason(event.MessageStatusUnsupported) + ErrEditTargetTooManyEdits error = WrapErrorInStatus(errors.New("the message has been edited too many times")).WithIsCertain(true).WithErrorAsMessage().WithErrorReason(event.MessageStatusUnsupported) + ErrReactionsNotSupported error = WrapErrorInStatus(errors.New("this bridge does not support reactions")).WithIsCertain(true).WithErrorAsMessage().WithErrorReason(event.MessageStatusUnsupported) + ErrPollsNotSupported error = WrapErrorInStatus(errors.New("this bridge does not support polls")).WithIsCertain(true).WithErrorAsMessage().WithErrorReason(event.MessageStatusUnsupported) + ErrRoomMetadataNotSupported error = WrapErrorInStatus(errors.New("this bridge does not support changing room metadata")).WithIsCertain(true).WithErrorAsMessage().WithSendNotice(false).WithErrorReason(event.MessageStatusUnsupported) + ErrRoomMetadataNotAllowed error = WrapErrorInStatus(errors.New("changes are not allowed here")).WithIsCertain(true).WithErrorAsMessage().WithSendNotice(false).WithErrorReason(event.MessageStatusUnsupported) + ErrRedactionsNotSupported error = WrapErrorInStatus(errors.New("this bridge does not support deleting messages")).WithIsCertain(true).WithErrorAsMessage().WithErrorReason(event.MessageStatusUnsupported) + ErrUnexpectedParsedContentType error = WrapErrorInStatus(errors.New("unexpected parsed content type")).WithErrorAsMessage().WithIsCertain(true).WithSendNotice(true) + ErrInvalidStateKey error = WrapErrorInStatus(errors.New("room metadata state key is unset or non-empty")).WithErrorAsMessage().WithIsCertain(true).WithSendNotice(false) + ErrDatabaseError error = WrapErrorInStatus(errors.New("database error")).WithMessage("internal database error").WithIsCertain(true).WithSendNotice(true) + ErrTargetMessageNotFound error = WrapErrorInStatus(errors.New("target message not found")).WithErrorAsMessage().WithIsCertain(true).WithSendNotice(false) + ErrUnsupportedMessageType error = WrapErrorInStatus(errors.New("unsupported message type")).WithErrorAsMessage().WithIsCertain(true).WithSendNotice(true).WithErrorReason(event.MessageStatusUnsupported) + ErrUnsupportedMediaType error = WrapErrorInStatus(errors.New("unsupported media type")).WithErrorAsMessage().WithIsCertain(true).WithSendNotice(true).WithErrorReason(event.MessageStatusUnsupported) + ErrMediaDurationTooLong error = WrapErrorInStatus(errors.New("media duration too long")).WithErrorAsMessage().WithSendNotice(true).WithErrorReason(event.MessageStatusUnsupported) + ErrVoiceMessageDurationTooLong error = WrapErrorInStatus(errors.New("voice message too long")).WithErrorAsMessage().WithSendNotice(true).WithErrorReason(event.MessageStatusUnsupported) + ErrMediaTooLarge error = WrapErrorInStatus(errors.New("media too large")).WithErrorAsMessage().WithIsCertain(true).WithSendNotice(true).WithErrorReason(event.MessageStatusUnsupported) + ErrIgnoringMNotice error = WrapErrorInStatus(errors.New("ignoring m.notice message")).WithIsCertain(true).WithErrorAsMessage().WithSendNotice(false) + ErrMediaDownloadFailed error = WrapErrorInStatus(errors.New("failed to download media")).WithMessage("failed to download media").WithIsCertain(true).WithSendNotice(true) + ErrMediaReuploadFailed error = WrapErrorInStatus(errors.New("failed to reupload media")).WithMessage("failed to reupload media").WithIsCertain(true).WithSendNotice(true) + ErrMediaConvertFailed error = WrapErrorInStatus(errors.New("failed to convert media")).WithMessage("failed to convert media").WithIsCertain(true).WithSendNotice(true) + ErrMembershipNotSupported error = WrapErrorInStatus(errors.New("this bridge does not support changing group membership")).WithIsCertain(true).WithErrorAsMessage().WithSendNotice(false).WithErrorReason(event.MessageStatusUnsupported) + ErrDeleteChatNotSupported error = WrapErrorInStatus(errors.New("this bridge does not support deleting chats")).WithIsCertain(true).WithErrorAsMessage().WithSendNotice(false).WithErrorReason(event.MessageStatusUnsupported) + ErrBeeperAIStreamNotSupported error = WrapErrorInStatus(errors.New("this bridge does not support Beeper AI stream events")).WithIsCertain(true).WithErrorAsMessage().WithSendNotice(false).WithErrorReason(event.MessageStatusUnsupported) + ErrPowerLevelsNotSupported error = WrapErrorInStatus(errors.New("this bridge does not support changing group power levels")).WithIsCertain(true).WithErrorAsMessage().WithSendNotice(false).WithErrorReason(event.MessageStatusUnsupported) + ErrRemoteEchoTimeout = WrapErrorInStatus(errors.New("remote echo timed out")).WithIsCertain(false).WithSendNotice(true).WithErrorReason(event.MessageStatusTooOld) + ErrRemoteAckTimeout = WrapErrorInStatus(errors.New("remote ack timed out")).WithIsCertain(false).WithSendNotice(true).WithErrorReason(event.MessageStatusTooOld) + + ErrPublicMediaDisabled = WrapErrorInStatus(errors.New("public media is not enabled in the bridge config")).WithIsCertain(true).WithErrorAsMessage().WithErrorReason(event.MessageStatusUnsupported).WithSendNotice(true) + ErrPublicMediaDatabaseDisabled = WrapErrorInStatus(errors.New("public media database storage is disabled")).WithIsCertain(true).WithErrorAsMessage().WithErrorReason(event.MessageStatusUnsupported).WithSendNotice(true) + ErrPublicMediaGenerateFailed = WrapErrorInStatus(errors.New("failed to generate public media URL")).WithIsCertain(true).WithMessage("failed to generate public media URL").WithErrorReason(event.MessageStatusUnsupported).WithSendNotice(true) + + ErrDisappearingTimerUnsupported error = WrapErrorInStatus(errors.New("invalid disappearing timer")).WithIsCertain(true) ) // Common login interface errors diff --git a/bridgev2/ghost.go b/bridgev2/ghost.go index e4e007cd..590dd1dc 100644 --- a/bridgev2/ghost.go +++ b/bridgev2/ghost.go @@ -9,12 +9,15 @@ package bridgev2 import ( "context" "crypto/sha256" + "encoding/json" "fmt" + "maps" "net/http" + "slices" "github.com/rs/zerolog" + "go.mau.fi/util/exerrors" "go.mau.fi/util/exmime" - "golang.org/x/exp/slices" "maunium.net/go/mautrix/bridgev2/database" "maunium.net/go/mautrix/bridgev2/networkid" @@ -85,7 +88,13 @@ func (br *Bridge) GetGhostByMXID(ctx context.Context, mxid id.UserID) (*Ghost, e func (br *Bridge) GetGhostByID(ctx context.Context, id networkid.UserID) (*Ghost, error) { br.cacheLock.Lock() defer br.cacheLock.Unlock() - return br.unlockedGetGhostByID(ctx, id, false) + ghost, err := br.unlockedGetGhostByID(ctx, id, false) + if err != nil { + return nil, err + } else if ghost == nil { + panic(fmt.Errorf("unlockedGetGhostByID(ctx, %q, false) returned nil", id)) + } + return ghost, nil } func (br *Bridge) GetExistingGhostByID(ctx context.Context, id networkid.UserID) (*Ghost, error) { @@ -128,10 +137,11 @@ func (a *Avatar) Reupload(ctx context.Context, intent MatrixAPI, currentHash [32 } type UserInfo struct { - Identifiers []string - Name *string - Avatar *Avatar - IsBot *bool + Identifiers []string + Name *string + Avatar *Avatar + IsBot *bool + ExtraProfile database.ExtraProfile ExtraUpdates ExtraUpdater[*Ghost] } @@ -152,7 +162,7 @@ func (ghost *Ghost) UpdateName(ctx context.Context, name string) bool { } func (ghost *Ghost) UpdateAvatar(ctx context.Context, avatar *Avatar) bool { - if ghost.AvatarID == avatar.ID && ghost.AvatarSet { + if ghost.AvatarID == avatar.ID && (avatar.Remove || ghost.AvatarMXC != "") && ghost.AvatarSet { return false } ghost.AvatarID = avatar.ID @@ -162,7 +172,7 @@ func (ghost *Ghost) UpdateAvatar(ctx context.Context, avatar *Avatar) bool { ghost.AvatarSet = false zerolog.Ctx(ctx).Err(err).Msg("Failed to reupload avatar") return true - } else if newHash == ghost.AvatarHash && ghost.AvatarSet { + } else if newHash == ghost.AvatarHash && ghost.AvatarMXC != "" && ghost.AvatarSet { return true } ghost.AvatarHash = newHash @@ -179,23 +189,9 @@ func (ghost *Ghost) UpdateAvatar(ctx context.Context, avatar *Avatar) bool { return true } -func (ghost *Ghost) UpdateContactInfo(ctx context.Context, identifiers []string, isBot *bool) bool { - if identifiers != nil { - slices.Sort(identifiers) - } - if ghost.ContactInfoSet && - (identifiers == nil || slices.Equal(identifiers, ghost.Identifiers)) && - (isBot == nil || *isBot == ghost.IsBot) { - return false - } - if identifiers != nil { - ghost.Identifiers = identifiers - } - if isBot != nil { - ghost.IsBot = *isBot - } +func (ghost *Ghost) getExtraProfileMeta() any { bridgeName := ghost.Bridge.Network.GetName() - meta := &event.BeeperProfileExtra{ + baseExtra := &event.BeeperProfileExtra{ RemoteID: string(ghost.ID), Identifiers: ghost.Identifiers, Service: bridgeName.BeeperBridgeType, @@ -203,7 +199,36 @@ func (ghost *Ghost) UpdateContactInfo(ctx context.Context, identifiers []string, IsBridgeBot: false, IsNetworkBot: ghost.IsBot, } - err := ghost.Intent.SetExtraProfileMeta(ctx, meta) + if len(ghost.ExtraProfile) == 0 { + return baseExtra + } + mergedExtra := maps.Clone(ghost.ExtraProfile) + baseExtraMarshaled := exerrors.Must(json.Marshal(baseExtra)) + exerrors.PanicIfNotNil(json.Unmarshal(baseExtraMarshaled, &mergedExtra)) + return mergedExtra +} + +func (ghost *Ghost) UpdateContactInfo(ctx context.Context, identifiers []string, isBot *bool, extraProfile database.ExtraProfile) bool { + if !ghost.Bridge.Matrix.GetCapabilities().ExtraProfileMeta { + ghost.ContactInfoSet = false + return false + } + if identifiers != nil { + slices.Sort(identifiers) + } + changed := extraProfile.CopyTo(&ghost.ExtraProfile) + if identifiers != nil { + changed = changed || !slices.Equal(identifiers, ghost.Identifiers) + ghost.Identifiers = identifiers + } + if isBot != nil { + changed = changed || *isBot != ghost.IsBot + ghost.IsBot = *isBot + } + if ghost.ContactInfoSet && !changed { + return false + } + err := ghost.Intent.SetExtraProfileMeta(ctx, ghost.getExtraProfileMeta()) if err != nil { zerolog.Ctx(ctx).Err(err).Msg("Failed to set extra profile metadata") } else { @@ -225,7 +250,7 @@ func (br *Bridge) allowAggressiveUpdateForType(evtType RemoteEventType) bool { } func (ghost *Ghost) UpdateInfoIfNecessary(ctx context.Context, source *UserLogin, evtType RemoteEventType) { - if ghost.Name != "" && ghost.NameSet && !ghost.Bridge.allowAggressiveUpdateForType(evtType) { + if ghost.Name != "" && ghost.NameSet && ghost.AvatarSet && !ghost.Bridge.allowAggressiveUpdateForType(evtType) { return } info, err := source.Client.GetUserInfo(ctx, ghost) @@ -235,12 +260,16 @@ func (ghost *Ghost) UpdateInfoIfNecessary(ctx context.Context, source *UserLogin zerolog.Ctx(ctx).Debug(). Bool("has_name", ghost.Name != ""). Bool("name_set", ghost.NameSet). + Bool("has_avatar", ghost.AvatarMXC != ""). + Bool("avatar_set", ghost.AvatarSet). Msg("Updating ghost info in IfNecessary call") ghost.UpdateInfo(ctx, info) } else { zerolog.Ctx(ctx).Trace(). Bool("has_name", ghost.Name != ""). Bool("name_set", ghost.NameSet). + Bool("has_avatar", ghost.AvatarMXC != ""). + Bool("avatar_set", ghost.AvatarSet). Msg("No ghost info received in IfNecessary call") } } @@ -268,9 +297,14 @@ func (ghost *Ghost) UpdateInfo(ctx context.Context, info *UserInfo) { } if info.Avatar != nil { update = ghost.UpdateAvatar(ctx, info.Avatar) || update + } else if oldAvatar == "" && !ghost.AvatarSet { + // Special case: nil avatar means we're not expecting one ever, if we don't currently have + // one we flag it as set to avoid constantly refetching in UpdateInfoIfNecessary. + ghost.AvatarSet = true + update = true } - if info.Identifiers != nil || info.IsBot != nil { - update = ghost.UpdateContactInfo(ctx, info.Identifiers, info.IsBot) || update + if info.Identifiers != nil || info.IsBot != nil || info.ExtraProfile != nil { + update = ghost.UpdateContactInfo(ctx, info.Identifiers, info.IsBot, info.ExtraProfile) || update } if info.ExtraUpdates != nil { update = info.ExtraUpdates(ctx, ghost) || update diff --git a/bridgev2/login.go b/bridgev2/login.go index b28ccfdb..b8321719 100644 --- a/bridgev2/login.go +++ b/bridgev2/login.go @@ -13,6 +13,7 @@ import ( "strings" "maunium.net/go/mautrix/bridgev2/networkid" + "maunium.net/go/mautrix/event" ) // LoginProcess represents a single occurrence of a user logging into the remote network. @@ -159,6 +160,12 @@ type LoginCookiesParams struct { // The snippet will evaluate to a promise that resolves when the relevant fields are found. // Fields that are not present in the promise result must be extracted another way. ExtractJS string `json:"extract_js,omitempty"` + // A regex pattern that the URL should match before the client closes the webview. + // + // The client may submit the login if the user closes the webview after all cookies are collected + // even if this URL is not reached, but it should only automatically close the webview after + // both cookies and the URL match. + WaitForURLPattern string `json:"wait_for_url_pattern,omitempty"` } type LoginInputFieldType string @@ -172,6 +179,8 @@ const ( LoginInputFieldTypeToken LoginInputFieldType = "token" LoginInputFieldTypeURL LoginInputFieldType = "url" LoginInputFieldTypeDomain LoginInputFieldType = "domain" + LoginInputFieldTypeSelect LoginInputFieldType = "select" + LoginInputFieldTypeCaptchaCode LoginInputFieldType = "captcha_code" ) type LoginInputDataField struct { @@ -183,8 +192,13 @@ type LoginInputDataField struct { Name string `json:"name"` // The description of the field shown to the user. Description string `json:"description"` + // A default value that the client can pre-fill the field with. + DefaultValue string `json:"default_value,omitempty"` // A regex pattern that the client can use to validate input client-side. Pattern string `json:"pattern,omitempty"` + // For fields of type select, the valid options. + // Pattern may also be filled with a regex that matches the same options. + Options []string `json:"options,omitempty"` // A function that validates the input and optionally cleans it up before it's submitted to the connector. Validate func(string) (string, error) `json:"-"` } @@ -259,6 +273,23 @@ func (f *LoginInputDataField) FillDefaultValidate() { type LoginUserInputParams struct { // The fields that the user needs to fill in. Fields []LoginInputDataField `json:"fields"` + + // Attachments to display alongside the input fields. + Attachments []*LoginUserInputAttachment `json:"attachments"` +} + +type LoginUserInputAttachment struct { + Type event.MessageType `json:"type,omitempty"` + FileName string `json:"filename,omitempty"` + Content []byte `json:"content,omitempty"` + Info LoginUserInputAttachmentInfo `json:"info,omitempty"` +} + +type LoginUserInputAttachmentInfo struct { + MimeType string `json:"mimetype,omitempty"` + Width int `json:"w,omitempty"` + Height int `json:"h,omitempty"` + Size int `json:"size,omitempty"` } type LoginCompleteParams struct { diff --git a/bridgev2/matrix/connector.go b/bridgev2/matrix/connector.go index 94fdd97c..5a2df953 100644 --- a/bridgev2/matrix/connector.go +++ b/bridgev2/matrix/connector.go @@ -10,34 +10,34 @@ import ( "context" "crypto/sha256" "encoding/base64" - "encoding/json" "errors" "fmt" + "net/http" "net/url" "os" "regexp" "strings" "sync" "time" - "unsafe" - "github.com/gorilla/mux" _ "github.com/lib/pq" "github.com/rs/zerolog" "go.mau.fi/util/dbutil" _ "go.mau.fi/util/dbutil/litestream" + "go.mau.fi/util/exbytes" "go.mau.fi/util/exsync" + "go.mau.fi/util/ptr" "go.mau.fi/util/random" "golang.org/x/sync/semaphore" "maunium.net/go/mautrix" "maunium.net/go/mautrix/appservice" - "maunium.net/go/mautrix/bridge/status" "maunium.net/go/mautrix/bridgev2" "maunium.net/go/mautrix/bridgev2/bridgeconfig" "maunium.net/go/mautrix/bridgev2/commands" "maunium.net/go/mautrix/bridgev2/database" "maunium.net/go/mautrix/bridgev2/networkid" + "maunium.net/go/mautrix/bridgev2/status" "maunium.net/go/mautrix/event" "maunium.net/go/mautrix/id" "maunium.net/go/mautrix/mediaproxy" @@ -81,6 +81,8 @@ type Connector struct { MediaConfig mautrix.RespMediaConfig SpecVersions *mautrix.RespVersions + SpecCaps *mautrix.RespCapabilities + specCapsLock sync.Mutex Capabilities *bridgev2.MatrixCapabilities IgnoreUnsupportedServer bool @@ -102,6 +104,7 @@ type Connector struct { var ( _ bridgev2.MatrixConnector = (*Connector)(nil) _ bridgev2.MatrixConnectorWithServer = (*Connector)(nil) + _ bridgev2.MatrixConnectorWithArbitraryRoomState = (*Connector)(nil) _ bridgev2.MatrixConnectorWithPostRoomBridgeHandling = (*Connector)(nil) _ bridgev2.MatrixConnectorWithPublicMedia = (*Connector)(nil) _ bridgev2.MatrixConnectorWithNameDisambiguation = (*Connector)(nil) @@ -141,13 +144,20 @@ func (br *Connector) Init(bridge *bridgev2.Bridge) { br.EventProcessor.On(event.EventReaction, br.handleRoomEvent) br.EventProcessor.On(event.EventRedaction, br.handleRoomEvent) br.EventProcessor.On(event.EventEncrypted, br.handleEncryptedEvent) + br.EventProcessor.On(event.EphemeralEventEncrypted, br.handleEncryptedEvent) br.EventProcessor.On(event.StateMember, br.handleRoomEvent) br.EventProcessor.On(event.StatePowerLevels, br.handleRoomEvent) br.EventProcessor.On(event.StateRoomName, br.handleRoomEvent) + br.EventProcessor.On(event.BeeperSendState, br.handleRoomEvent) br.EventProcessor.On(event.StateRoomAvatar, br.handleRoomEvent) br.EventProcessor.On(event.StateTopic, br.handleRoomEvent) + br.EventProcessor.On(event.StateTombstone, br.handleRoomEvent) + br.EventProcessor.On(event.StateBeeperDisappearingTimer, br.handleRoomEvent) + br.EventProcessor.On(event.BeeperDeleteChat, br.handleRoomEvent) + br.EventProcessor.On(event.BeeperAcceptMessageRequest, br.handleRoomEvent) br.EventProcessor.On(event.EphemeralEventReceipt, br.handleEphemeralEvent) br.EventProcessor.On(event.EphemeralEventTyping, br.handleEphemeralEvent) + br.EventProcessor.On(event.BeeperEphemeralEventAIStream, br.handleEphemeralEvent) br.Bot = br.AS.BotIntent() br.Crypto = NewCryptoHelper(br) br.Bridge.Commands.(*commands.Processor).AddHandlers( @@ -169,6 +179,17 @@ func (br *Connector) Start(ctx context.Context) error { if err != nil { return err } + needsStateResync := br.Config.Encryption.Default && + br.Bridge.DB.KV.Get(ctx, database.KeyEncryptionStateResynced) != "true" + if needsStateResync { + dbExists, err := br.StateStore.TableExists(ctx, "mx_version") + if err != nil { + return fmt.Errorf("failed to check if mx_version table exists: %w", err) + } else if !dbExists { + needsStateResync = false + br.Bridge.DB.KV.Set(ctx, database.KeyEncryptionStateResynced, "true") + } + } err = br.StateStore.Upgrade(ctx) if err != nil { return bridgev2.DBUpgradeError{Section: "matrix_state", Err: err} @@ -212,17 +233,59 @@ func (br *Connector) Start(ctx context.Context) error { br.wsStopPinger = make(chan struct{}, 1) go br.websocketServerPinger() } + if needsStateResync { + br.ResyncEncryptionState(ctx) + } return nil } +func (br *Connector) ResyncEncryptionState(ctx context.Context) { + log := zerolog.Ctx(ctx) + roomIDScanner := dbutil.ConvertRowFn[id.RoomID](dbutil.ScanSingleColumn[id.RoomID]) + rooms, err := roomIDScanner.NewRowIter(br.Bridge.DB.Query(ctx, ` + SELECT rooms.room_id + FROM (SELECT DISTINCT(room_id) FROM mx_user_profile WHERE room_id<>'') rooms + LEFT JOIN mx_room_state ON rooms.room_id = mx_room_state.room_id + WHERE mx_room_state.encryption IS NULL + `)).AsList() + if err != nil { + log.Err(err).Msg("Failed to get room list to resync state") + return + } + var failedCount, successCount, forbiddenCount int + for _, roomID := range rooms { + if roomID == "" { + continue + } + var outContent *event.EncryptionEventContent + err = br.Bot.Client.StateEvent(ctx, roomID, event.StateEncryption, "", &outContent) + if errors.Is(err, mautrix.MForbidden) { + // Most likely non-existent room + log.Debug().Err(err).Stringer("room_id", roomID).Msg("Failed to get state for room") + forbiddenCount++ + } else if err != nil { + log.Err(err).Stringer("room_id", roomID).Msg("Failed to get state for room") + failedCount++ + } else { + successCount++ + } + } + br.Bridge.DB.KV.Set(ctx, database.KeyEncryptionStateResynced, "true") + log.Info(). + Int("success_count", successCount). + Int("forbidden_count", forbiddenCount). + Int("failed_count", failedCount). + Msg("Resynced rooms") +} + func (br *Connector) GetPublicAddress() string { if br.Config.AppService.PublicAddress == "https://bridge.example.com" { return "" } - return br.Config.AppService.PublicAddress + return strings.TrimRight(br.Config.AppService.PublicAddress, "/") } -func (br *Connector) GetRouter() *mux.Router { +func (br *Connector) GetRouter() *http.ServeMux { if br.GetPublicAddress() != "" { return br.AS.Router } @@ -233,13 +296,37 @@ func (br *Connector) GetCapabilities() *bridgev2.MatrixCapabilities { return br.Capabilities } -func (br *Connector) Stop() { +func sendStopSignal(ch chan struct{}) { + if ch != nil { + select { + case ch <- struct{}{}: + default: + } + } +} + +func (br *Connector) PreStop() { br.stopping = true br.AS.Stop() + if stopWebsocket := br.AS.StopWebsocket; stopWebsocket != nil { + stopWebsocket(appservice.ErrWebsocketManualStop) + } + sendStopSignal(br.wsStopPinger) + sendStopSignal(br.wsShortCircuitReconnectBackoff) +} + +func (br *Connector) Stop() { br.EventProcessor.Stop() if br.Crypto != nil { br.Crypto.Stop() } + if wsStopChan := br.wsStopped; wsStopChan != nil { + select { + case <-wsStopChan: + case <-time.After(4 * time.Second): + br.Log.Warn().Msg("Timed out waiting for websocket to close") + } + } } var MinSpecVersion = mautrix.SpecV14 @@ -257,16 +344,18 @@ func (br *Connector) logInitialRequestError(err error, defaultMessage string) { } func (br *Connector) ensureConnection(ctx context.Context) { + triedToRegister := false for { versions, err := br.Bot.Versions(ctx) if err != nil { - if errors.Is(err, mautrix.MForbidden) { + if errors.Is(err, mautrix.MForbidden) && !triedToRegister { br.Log.Debug().Msg("M_FORBIDDEN in /versions, trying to register before retrying") err = br.Bot.EnsureRegistered(ctx) if err != nil { br.logInitialRequestError(err, "Failed to register after /versions failed with M_FORBIDDEN") os.Exit(16) } + triedToRegister = true } else if errors.Is(err, mautrix.MUnknownToken) || errors.Is(err, mautrix.MExclusive) { br.logInitialRequestError(err, "/versions request failed with auth error") os.Exit(16) @@ -279,6 +368,9 @@ func (br *Connector) ensureConnection(ctx context.Context) { *br.AS.SpecVersions = *versions br.Capabilities.AutoJoinInvites = br.SpecVersions.Supports(mautrix.BeeperFeatureAutojoinInvites) br.Capabilities.BatchSending = br.SpecVersions.Supports(mautrix.BeeperFeatureBatchSending) + br.Capabilities.ArbitraryMemberChange = br.SpecVersions.Supports(mautrix.BeeperFeatureArbitraryMemberChange) + br.Capabilities.ExtraProfileMeta = br.SpecVersions.Supports(mautrix.BeeperFeatureArbitraryProfileMeta) || + (br.SpecVersions.Supports(mautrix.FeatureArbitraryProfileFields) && br.Config.Matrix.GhostExtraProfileInfo) break } } @@ -319,50 +411,23 @@ func (br *Connector) ensureConnection(ctx context.Context) { br.Log.Debug().Msg("Homeserver does not support checking status of homeserver -> bridge connection") return } - var pingResp *mautrix.RespAppservicePing - var txnID string - var retryCount int - const maxRetries = 6 - for { - txnID = br.Bot.TxnID() - pingResp, err = br.Bot.AppservicePing(ctx, br.Config.AppService.ID, txnID) - if err == nil { - break - } - var httpErr mautrix.HTTPError - var pingErrBody string - if errors.As(err, &httpErr) && httpErr.RespError != nil { - if val, ok := httpErr.RespError.ExtraData["body"].(string); ok { - pingErrBody = strings.TrimSpace(val) - } - } - outOfRetries := retryCount >= maxRetries - level := zerolog.ErrorLevel - if outOfRetries { - level = zerolog.FatalLevel - } - evt := br.Log.WithLevel(level).Err(err).Str("txn_id", txnID) - if pingErrBody != "" { - bodyBytes := []byte(pingErrBody) - if json.Valid(bodyBytes) { - evt.RawJSON("body", bodyBytes) - } else { - evt.Str("body", pingErrBody) - } - } - if outOfRetries { - evt.Msg("Homeserver -> bridge connection is not working") - br.Log.Info().Msg("See https://docs.mau.fi/faq/as-ping for more info") - os.Exit(13) - } - evt.Msg("Homeserver -> bridge connection is not working, retrying in 5 seconds...") - time.Sleep(5 * time.Second) - retryCount++ + + br.Bot.EnsureAppserviceConnection(ctx) +} + +func (br *Connector) fetchCapabilities(ctx context.Context) *mautrix.RespCapabilities { + br.specCapsLock.Lock() + defer br.specCapsLock.Unlock() + if br.SpecCaps != nil { + return br.SpecCaps } - br.Log.Debug(). - Str("txn_id", txnID). - Int64("duration_ms", pingResp.DurationMS). - Msg("Homeserver -> bridge connection works") + caps, err := br.Bot.Capabilities(ctx) + if err != nil { + br.Log.Err(err).Msg("Failed to fetch capabilities from homeserver") + return nil + } + br.SpecCaps = caps + return caps } func (br *Connector) fetchMediaConfig(ctx context.Context) { @@ -430,11 +495,15 @@ func (br *Connector) GhostIntent(userID networkid.UserID) bridgev2.MatrixAPI { func (br *Connector) SendBridgeStatus(ctx context.Context, state *status.BridgeState) error { if br.Websocket { br.hasSentAnyStates = true - return br.AS.SendWebsocket(&appservice.WebsocketRequest{ + return br.AS.SendWebsocket(ctx, &appservice.WebsocketRequest{ Command: "bridge_status", Data: state, }) } else if br.Config.Homeserver.StatusEndpoint != "" { + // Connecting states aren't really relevant unless the bridge runs somewhere with an unreliable network + if state.StateEvent == status.StateConnecting { + return nil + } return state.SendHTTP(ctx, br.Config.Homeserver.StatusEndpoint, br.Config.AppService.ASToken) } else { return nil @@ -450,13 +519,17 @@ func (br *Connector) internalSendMessageStatus(ctx context.Context, ms *bridgev2 return "" } log := zerolog.Ctx(ctx) - err := br.SendMessageCheckpoints([]*status.MessageCheckpoint{ms.ToCheckpoint(evt)}) - if err != nil { - log.Err(err).Msg("Failed to send message checkpoint") + + if !evt.IsSourceEventDoublePuppeted { + err := br.SendMessageCheckpoints(ctx, []*status.MessageCheckpoint{ms.ToCheckpoint(evt)}) + if err != nil { + log.Err(err).Msg("Failed to send message checkpoint") + } } + if !ms.DisableMSS && br.Config.Matrix.MessageStatusEvents { mssEvt := ms.ToMSSEvent(evt) - _, err = br.Bot.SendMessageEvent(ctx, evt.RoomID, event.BeeperMessageStatus, mssEvt) + _, err := br.Bot.SendMessageEvent(ctx, evt.RoomID, event.BeeperMessageStatus, mssEvt) if err != nil { log.Err(err). Stringer("room_id", evt.RoomID). @@ -465,7 +538,8 @@ func (br *Connector) internalSendMessageStatus(ctx context.Context, ms *bridgev2 Msg("Failed to send MSS event") } } - if ms.SendNotice && br.Config.Matrix.MessageErrorNotices && (ms.Status == event.MessageStatusFail || ms.Status == event.MessageStatusRetriable || ms.Step == status.MsgStepDecrypted) { + if ms.SendNotice && br.Config.Matrix.MessageErrorNotices && evt.MessageType != event.MsgNotice && + (ms.Status == event.MessageStatusFail || ms.Status == event.MessageStatusRetriable || ms.Step == status.MsgStepDecrypted) { content := ms.ToNoticeEvent(evt) if editEvent != "" { content.SetEdit(editEvent) @@ -482,7 +556,7 @@ func (br *Connector) internalSendMessageStatus(ctx context.Context, ms *bridgev2 } } if ms.Status == event.MessageStatusSuccess && br.Config.Matrix.DeliveryReceipts { - err = br.Bot.SendReceipt(ctx, evt.RoomID, evt.SourceEventID, event.ReceiptTypeRead, nil) + err := br.Bot.SendReceipt(ctx, evt.RoomID, evt.SourceEventID, event.ReceiptTypeRead, nil) if err != nil { log.Err(err). Stringer("room_id", evt.RoomID). @@ -493,11 +567,11 @@ func (br *Connector) internalSendMessageStatus(ctx context.Context, ms *bridgev2 return "" } -func (br *Connector) SendMessageCheckpoints(checkpoints []*status.MessageCheckpoint) error { +func (br *Connector) SendMessageCheckpoints(ctx context.Context, checkpoints []*status.MessageCheckpoint) error { checkpointsJSON := status.CheckpointsJSON{Checkpoints: checkpoints} if br.Websocket { - return br.AS.SendWebsocket(&appservice.WebsocketRequest{ + return br.AS.SendWebsocket(ctx, &appservice.WebsocketRequest{ Command: "message_checkpoint", Data: checkpointsJSON, }) @@ -508,7 +582,7 @@ func (br *Connector) SendMessageCheckpoints(checkpoints []*status.MessageCheckpo return nil } - return checkpointsJSON.SendHTTP(endpoint, br.AS.Registration.AppToken) + return checkpointsJSON.SendHTTP(ctx, br.AS.HTTPClient, endpoint, br.AS.Registration.AppToken) } func (br *Connector) ParseGhostMXID(userID id.UserID) (networkid.UserID, bool) { @@ -548,6 +622,31 @@ func (br *Connector) GetPowerLevels(ctx context.Context, roomID id.RoomID) (*eve return br.Bot.PowerLevels(ctx, roomID) } +func (br *Connector) GetStateEvent(ctx context.Context, roomID id.RoomID, eventType event.Type, stateKey string) (*event.Event, error) { + if stateKey == "" { + switch eventType { + case event.StateCreate: + createEvt, err := br.Bot.StateStore.GetCreate(ctx, roomID) + if err != nil || createEvt != nil { + return createEvt, err + } + case event.StateJoinRules: + joinRulesContent, err := br.Bot.StateStore.GetJoinRules(ctx, roomID) + if err != nil { + return nil, err + } else if joinRulesContent != nil { + return &event.Event{ + Type: event.StateJoinRules, + RoomID: roomID, + StateKey: ptr.Ptr(""), + Content: event.Content{Parsed: joinRulesContent}, + }, nil + } + } + } + return br.Bot.FullStateEvent(ctx, roomID, eventType, stateKey) +} + func (br *Connector) GetMembers(ctx context.Context, roomID id.RoomID) (map[id.UserID]*event.MemberEventContent, error) { fetched, err := br.Bot.StateStore.HasFetchedMembers(ctx, roomID) if err != nil { @@ -575,6 +674,10 @@ func (br *Connector) IsConfusableName(ctx context.Context, roomID id.RoomID, use return br.AS.StateStore.IsConfusableName(ctx, roomID, userID, name) } +func (br *Connector) GetUniqueBridgeID() string { + return fmt.Sprintf("%s/%s", br.Config.Homeserver.Domain, br.Config.AppService.ID) +} + func (br *Connector) BatchSend(ctx context.Context, roomID id.RoomID, req *mautrix.ReqBeeperBatchSend, extras []*bridgev2.MatrixSendExtra) (*mautrix.RespBeeperBatchSend, error) { if encrypted, err := br.StateStore.IsEncrypted(ctx, roomID); err != nil { return nil, fmt.Errorf("failed to check if room is encrypted: %w", err) @@ -584,7 +687,7 @@ func (br *Connector) BatchSend(ctx context.Context, roomID id.RoomID, req *mautr if intent != nil { intent.AddDoublePuppetValueWithTS(&evt.Content, evt.Timestamp) } - if evt.Type != event.EventEncrypted { + if evt.Type != event.EventEncrypted && evt.Type != event.EventReaction { err = br.Crypto.Encrypt(ctx, roomID, evt.Type, &evt.Content) if err != nil { return nil, err @@ -616,7 +719,11 @@ func (br *Connector) GenerateDeterministicEventID(roomID id.RoomID, _ networkid. eventID[1+hashB64Len] = ':' copy(eventID[1+hashB64Len+1:], br.deterministicEventIDServer) - return id.EventID(unsafe.String(unsafe.SliceData(eventID), len(eventID))) + return id.EventID(exbytes.UnsafeString(eventID)) +} + +func (br *Connector) GenerateDeterministicRoomID(key networkid.PortalKey) id.RoomID { + return id.RoomID(fmt.Sprintf("!%s.%s:%s", key.ID, key.Receiver, br.ServerName())) } func (br *Connector) GenerateReactionEventID(roomID id.RoomID, targetMessage *database.Message, sender networkid.UserID, emojiID networkid.EmojiID) id.EventID { diff --git a/bridgev2/matrix/crypto.go b/bridgev2/matrix/crypto.go index 3cb16e52..7f18f1f5 100644 --- a/bridgev2/matrix/crypto.go +++ b/bridgev2/matrix/crypto.go @@ -14,6 +14,7 @@ import ( "fmt" "os" "runtime/debug" + "strings" "sync" "time" @@ -23,6 +24,7 @@ import ( "maunium.net/go/mautrix" "maunium.net/go/mautrix/bridgev2" + "maunium.net/go/mautrix/bridgev2/database" "maunium.net/go/mautrix/crypto" "maunium.net/go/mautrix/crypto/olm" "maunium.net/go/mautrix/event" @@ -36,9 +38,9 @@ func init() { var _ crypto.StateStore = (*sqlstatestore.SQLStateStore)(nil) -var NoSessionFound = crypto.NoSessionFound -var DuplicateMessageIndex = crypto.DuplicateMessageIndex -var UnknownMessageIndex = olm.UnknownMessageIndex +var NoSessionFound = crypto.ErrNoSessionFound +var DuplicateMessageIndex = crypto.ErrDuplicateMessageIndex +var UnknownMessageIndex = olm.ErrUnknownMessageIndex type CryptoHelper struct { bridge *Connector @@ -77,7 +79,7 @@ func (helper *CryptoHelper) Init(ctx context.Context) error { dbutil.ZeroLogger(helper.bridge.Log.With().Str("db_section", "crypto").Logger()), string(helper.bridge.Bridge.ID), helper.bridge.AS.BotMXID(), - fmt.Sprintf("@%s:%s", helper.bridge.Config.AppService.FormatUsername("%"), helper.bridge.AS.HomeserverDomain), + fmt.Sprintf("@%s:%s", strings.ReplaceAll(helper.bridge.Config.AppService.FormatUsername("%"), "_", `\_`), helper.bridge.AS.HomeserverDomain), helper.bridge.Config.Encryption.PickleKey, ) @@ -134,7 +136,19 @@ func (helper *CryptoHelper) Init(ctx context.Context) error { return err } if isExistingDevice { - helper.verifyKeysAreOnServer(ctx) + if !helper.verifyKeysAreOnServer(ctx) { + return nil + } + } else { + err = helper.ShareKeys(ctx) + if err != nil { + return fmt.Errorf("failed to share device keys: %w", err) + } + } + if helper.bridge.Config.Encryption.SelfSign { + if !helper.doSelfSign(ctx) { + os.Exit(34) + } } go helper.resyncEncryptionInfo(context.TODO()) @@ -142,30 +156,66 @@ func (helper *CryptoHelper) Init(ctx context.Context) error { return nil } +func (helper *CryptoHelper) doSelfSign(ctx context.Context) bool { + log := zerolog.Ctx(ctx) + hasKeys, isVerified, err := helper.mach.GetOwnVerificationStatus(ctx) + if err != nil { + log.WithLevel(zerolog.FatalLevel).Err(err).Msg("Failed to check verification status") + return false + } + log.Debug().Bool("has_keys", hasKeys).Bool("is_verified", isVerified).Msg("Checked verification status") + keyInDB := helper.bridge.Bridge.DB.KV.Get(ctx, database.KeyRecoveryKey) + if !hasKeys || keyInDB == "overwrite" { + if keyInDB != "" && keyInDB != "overwrite" { + log.WithLevel(zerolog.FatalLevel). + Msg("No keys on server, but database already has recovery key. Delete `recovery_key` from `kv_store` manually to continue.") + return false + } + recoveryKey, err := helper.mach.GenerateAndVerifyWithRecoveryKey(ctx) + if recoveryKey != "" { + helper.bridge.Bridge.DB.KV.Set(ctx, database.KeyRecoveryKey, recoveryKey) + } + if err != nil { + log.WithLevel(zerolog.FatalLevel).Err(err).Msg("Failed to generate recovery key and self-sign") + return false + } + log.Info().Msg("Generated new recovery key and self-signed bot device") + } else if !isVerified { + if keyInDB == "" { + log.WithLevel(zerolog.FatalLevel). + Msg("Server already has cross-signing keys, but no key in database. Add `recovery_key` to `kv_store`, or set it to `overwrite` to generate new keys.") + return false + } + err = helper.mach.VerifyWithRecoveryKey(ctx, keyInDB) + if err != nil { + log.WithLevel(zerolog.FatalLevel).Err(err).Msg("Failed to verify with recovery key") + return false + } + log.Info().Msg("Verified bot device with existing recovery key") + } + return true +} + func (helper *CryptoHelper) resyncEncryptionInfo(ctx context.Context) { log := helper.log.With().Str("action", "resync encryption event").Logger() rows, err := helper.store.DB.Query(ctx, `SELECT room_id FROM mx_room_state WHERE encryption='{"resync":true}'`) + roomIDs, err := dbutil.NewRowIterWithError(rows, dbutil.ScanSingleColumn[id.RoomID], err).AsList() if err != nil { log.Err(err).Msg("Failed to query rooms for resync") return } - roomIDs, err := dbutil.NewRowIter(rows, dbutil.ScanSingleColumn[id.RoomID]).AsList() - if err != nil { - log.Err(err).Msg("Failed to scan rooms for resync") - return - } if len(roomIDs) > 0 { log.Debug().Interface("room_ids", roomIDs).Msg("Resyncing rooms") for _, roomID := range roomIDs { var evt event.EncryptionEventContent err = helper.client.StateEvent(ctx, roomID, event.StateEncryption, "", &evt) if err != nil { - log.Err(err).Str("room_id", roomID.String()).Msg("Failed to get encryption event") + log.Err(err).Stringer("room_id", roomID).Msg("Failed to get encryption event") _, err = helper.store.DB.Exec(ctx, ` UPDATE mx_room_state SET encryption=NULL WHERE room_id=$1 AND encryption='{"resync":true}' `, roomID) if err != nil { - log.Err(err).Str("room_id", roomID.String()).Msg("Failed to unmark room for resync after failed sync") + log.Err(err).Stringer("room_id", roomID).Msg("Failed to unmark room for resync after failed sync") } } else { maxAge := evt.RotationPeriodMillis @@ -188,9 +238,9 @@ func (helper *CryptoHelper) resyncEncryptionInfo(ctx context.Context) { WHERE room_id=$3 AND max_age IS NULL AND max_messages IS NULL `, maxAge, maxMessages, roomID) if err != nil { - log.Err(err).Str("room_id", roomID.String()).Msg("Failed to update megolm session table") + log.Err(err).Stringer("room_id", roomID).Msg("Failed to update megolm session table") } else { - log.Debug().Str("room_id", roomID.String()).Msg("Updated megolm session table") + log.Debug().Stringer("room_id", roomID).Msg("Updated megolm session table") } } } @@ -203,7 +253,7 @@ func (helper *CryptoHelper) allowKeyShare(ctx context.Context, device *id.Device return &crypto.KeyShareRejectNoResponse } else if device.Trust == id.TrustStateBlacklisted { return &crypto.KeyShareRejectBlacklisted - } else if trustState := helper.mach.ResolveTrust(device); trustState >= cfg.VerificationLevels.Share { + } else if trustState, _ := helper.mach.ResolveTrustContext(ctx, device); trustState >= cfg.VerificationLevels.Share { portal, err := helper.bridge.Bridge.GetPortalByMXID(ctx, info.RoomID) if err != nil { zerolog.Ctx(ctx).Err(err).Msg("Failed to get portal to handle key request") @@ -236,14 +286,14 @@ func (helper *CryptoHelper) loginBot(ctx context.Context) (*mautrix.Client, bool if err != nil { return nil, false, fmt.Errorf("failed to find existing device ID: %w", err) } else if len(deviceID) > 0 { - helper.log.Debug().Str("device_id", deviceID.String()).Msg("Found existing device ID for bot in database") + helper.log.Debug().Stringer("device_id", deviceID).Msg("Found existing device ID for bot in database") } // Create a new client instance with the default AS settings (including as_token), // the Login call will then override the access token in the client. client := helper.bridge.AS.NewMautrixClient(helper.bridge.AS.BotMXID()) initialDeviceDisplayName := fmt.Sprintf("%s bridge", helper.bridge.Bridge.Network.GetName().DisplayName) - if helper.bridge.Config.AppService.MSC4190 { + if helper.bridge.Config.Encryption.MSC4190 { helper.log.Debug().Msg("Creating bot device with MSC4190") err = client.CreateDeviceMSC4190(ctx, deviceID, initialDeviceDisplayName) if err != nil { @@ -277,7 +327,7 @@ func (helper *CryptoHelper) loginBot(ctx context.Context) (*mautrix.Client, bool return client, deviceID != "", nil } -func (helper *CryptoHelper) verifyKeysAreOnServer(ctx context.Context) { +func (helper *CryptoHelper) verifyKeysAreOnServer(ctx context.Context) bool { helper.log.Debug().Msg("Making sure keys are still on server") resp, err := helper.client.QueryKeys(ctx, &mautrix.ReqQueryKeys{ DeviceKeys: map[id.UserID]mautrix.DeviceIDList{ @@ -290,10 +340,11 @@ func (helper *CryptoHelper) verifyKeysAreOnServer(ctx context.Context) { } device, ok := resp.DeviceKeys[helper.client.UserID][helper.client.DeviceID] if ok && len(device.Keys) > 0 { - return + return true } helper.log.Warn().Msg("Existing device doesn't have keys on server, resetting crypto") helper.Reset(ctx, false) + return false } func (helper *CryptoHelper) Start() { @@ -388,7 +439,7 @@ func (helper *CryptoHelper) Encrypt(ctx context.Context, roomID id.RoomID, evtTy var encrypted *event.EncryptedEventContent encrypted, err = helper.mach.EncryptMegolmEvent(ctx, roomID, evtType, content) if err != nil { - if !errors.Is(err, crypto.SessionExpired) && !errors.Is(err, crypto.SessionNotShared) && !errors.Is(err, crypto.NoGroupSession) { + if !errors.Is(err, crypto.ErrSessionExpired) && !errors.Is(err, crypto.ErrSessionNotShared) && !errors.Is(err, crypto.ErrNoGroupSession) { return } helper.log.Debug().Err(err). @@ -503,14 +554,14 @@ func (syncer *cryptoSyncer) OnFailedSync(_ *mautrix.RespSync, err error) (time.D func (syncer *cryptoSyncer) GetFilterJSON(_ id.UserID) *mautrix.Filter { everything := []event.Type{{Type: "*"}} return &mautrix.Filter{ - Presence: mautrix.FilterPart{NotTypes: everything}, - AccountData: mautrix.FilterPart{NotTypes: everything}, - Room: mautrix.RoomFilter{ + Presence: &mautrix.FilterPart{NotTypes: everything}, + AccountData: &mautrix.FilterPart{NotTypes: everything}, + Room: &mautrix.RoomFilter{ IncludeLeave: false, - Ephemeral: mautrix.FilterPart{NotTypes: everything}, - AccountData: mautrix.FilterPart{NotTypes: everything}, - State: mautrix.FilterPart{NotTypes: everything}, - Timeline: mautrix.FilterPart{NotTypes: everything}, + Ephemeral: &mautrix.FilterPart{NotTypes: everything}, + AccountData: &mautrix.FilterPart{NotTypes: everything}, + State: &mautrix.FilterPart{NotTypes: everything}, + Timeline: &mautrix.FilterPart{NotTypes: everything}, }, } } diff --git a/bridgev2/matrix/cryptoerror.go b/bridgev2/matrix/cryptoerror.go index 55110429..ea29703a 100644 --- a/bridgev2/matrix/cryptoerror.go +++ b/bridgev2/matrix/cryptoerror.go @@ -11,8 +11,8 @@ import ( "errors" "fmt" - "maunium.net/go/mautrix/bridge/status" "maunium.net/go/mautrix/bridgev2" + "maunium.net/go/mautrix/bridgev2/status" "maunium.net/go/mautrix/event" "maunium.net/go/mautrix/id" ) diff --git a/bridgev2/matrix/cryptostore.go b/bridgev2/matrix/cryptostore.go index 234797a6..4c3b5d30 100644 --- a/bridgev2/matrix/cryptostore.go +++ b/bridgev2/matrix/cryptostore.go @@ -45,7 +45,7 @@ func (store *SQLCryptoStore) GetRoomJoinedOrInvitedMembers(ctx context.Context, WHERE room_id=$1 AND (membership='join' OR membership='invite') AND user_id<>$2 - AND user_id NOT LIKE $3 + AND user_id NOT LIKE $3 ESCAPE '\' `, roomID, store.UserID, store.GhostIDFormat) if err != nil { return diff --git a/bridgev2/matrix/directmedia.go b/bridgev2/matrix/directmedia.go index 71c01078..0667981a 100644 --- a/bridgev2/matrix/directmedia.go +++ b/bridgev2/matrix/directmedia.go @@ -39,7 +39,7 @@ func (br *Connector) initDirectMedia() error { if err != nil { return fmt.Errorf("failed to initialize media proxy: %w", err) } - br.MediaProxy.RegisterRoutes(br.AS.Router) + br.MediaProxy.RegisterRoutes(br.AS.Router, br.Log.With().Str("component", "media proxy").Logger()) br.dmaSigKey = sha256.Sum256(br.MediaProxy.GetServerKey().Priv.Seed()) dmn.SetUseDirectMedia() br.Log.Debug().Str("server_name", br.MediaProxy.GetServerName()).Msg("Enabled direct media access") diff --git a/bridgev2/matrix/intent.go b/bridgev2/matrix/intent.go index 9f6c520e..f7254bd4 100644 --- a/bridgev2/matrix/intent.go +++ b/bridgev2/matrix/intent.go @@ -9,6 +9,7 @@ package matrix import ( "bytes" "context" + "encoding/json" "errors" "fmt" "io" @@ -27,6 +28,7 @@ import ( "maunium.net/go/mautrix/bridgev2" "maunium.net/go/mautrix/bridgev2/bridgeconfig" "maunium.net/go/mautrix/crypto/attachment" + "maunium.net/go/mautrix/crypto/canonicaljson" "maunium.net/go/mautrix/event" "maunium.net/go/mautrix/id" "maunium.net/go/mautrix/pushrules" @@ -43,13 +45,13 @@ type ASIntent struct { var _ bridgev2.MatrixAPI = (*ASIntent)(nil) var _ bridgev2.MarkAsDMMatrixAPI = (*ASIntent)(nil) +var _ bridgev2.EphemeralSendingMatrixAPI = (*ASIntent)(nil) func (as *ASIntent) SendMessage(ctx context.Context, roomID id.RoomID, eventType event.Type, content *event.Content, extra *bridgev2.MatrixSendExtra) (*mautrix.RespSendEvent, error) { if extra == nil { extra = &bridgev2.MatrixSendExtra{} } - // TODO remove this once hungryserv and synapse support sending m.room.redactions directly in all room versions - if eventType == event.EventRedaction { + if eventType == event.EventRedaction && !as.Connector.SpecVersions.Supports(mautrix.FeatureRedactSendAsEvent) { parsedContent := content.Parsed.(*event.RedactionEventContent) as.Matrix.AddDoublePuppetValue(content) return as.Matrix.RedactEvent(ctx, roomID, parsedContent.Redacts, mautrix.ReqRedact{ @@ -57,7 +59,11 @@ func (as *ASIntent) SendMessage(ctx context.Context, roomID id.RoomID, eventType Extra: content.Raw, }) } - if eventType != event.EventReaction && eventType != event.EventRedaction { + if (eventType != event.EventReaction || as.Connector.Config.Encryption.MSC4392) && eventType != event.EventRedaction { + msgContent, ok := content.Parsed.(*event.MessageEventContent) + if ok { + msgContent.AddPerMessageProfileFallback() + } if encrypted, err := as.Matrix.StateStore.IsEncrypted(ctx, roomID); err != nil { return nil, fmt.Errorf("failed to check if room is encrypted: %w", err) } else if encrypted { @@ -78,16 +84,27 @@ func (as *ASIntent) SendMessage(ctx context.Context, roomID id.RoomID, eventType eventType = event.EventEncrypted } } - if extra.Timestamp.IsZero() { - return as.Matrix.SendMessageEvent(ctx, roomID, eventType, content) - } else { - return as.Matrix.SendMassagedMessageEvent(ctx, roomID, eventType, content, extra.Timestamp.UnixMilli()) + return as.Matrix.SendMessageEvent(ctx, roomID, eventType, content, mautrix.ReqSendEvent{Timestamp: extra.Timestamp.UnixMilli()}) +} + +func (as *ASIntent) BeeperSendEphemeralEvent(ctx context.Context, roomID id.RoomID, eventType event.Type, content *event.Content, txnID string) (*mautrix.RespSendEvent, error) { + if !as.Connector.SpecVersions.Supports(mautrix.BeeperFeatureEphemeralEvents) { + return nil, mautrix.MUnrecognized.WithMessage("Homeserver does not advertise com.beeper.ephemeral support") } + if encrypted, err := as.Matrix.StateStore.IsEncrypted(ctx, roomID); err != nil { + return nil, fmt.Errorf("failed to check if room is encrypted: %w", err) + } else if encrypted && as.Connector.Crypto != nil { + if err = as.Connector.Crypto.Encrypt(ctx, roomID, eventType, content); err != nil { + return nil, err + } + eventType = event.EventEncrypted + } + return as.Matrix.BeeperSendEphemeralEvent(ctx, roomID, eventType, content, mautrix.ReqSendEvent{TransactionID: txnID}) } func (as *ASIntent) fillMemberEvent(ctx context.Context, roomID id.RoomID, userID id.UserID, content *event.Content) { - targetContent := content.Parsed.(*event.MemberEventContent) - if targetContent.Displayname != "" || targetContent.AvatarURL != "" { + targetContent, ok := content.Parsed.(*event.MemberEventContent) + if !ok || targetContent.Displayname != "" || targetContent.AvatarURL != "" { return } memberContent, err := as.Matrix.StateStore.TryGetMember(ctx, roomID, userID) @@ -122,11 +139,7 @@ func (as *ASIntent) SendState(ctx context.Context, roomID id.RoomID, eventType e if eventType == event.StateMember { as.fillMemberEvent(ctx, roomID, id.UserID(stateKey), content) } - if ts.IsZero() { - resp, err = as.Matrix.SendStateEvent(ctx, roomID, eventType, stateKey, content) - } else { - resp, err = as.Matrix.SendMassagedStateEvent(ctx, roomID, eventType, stateKey, content, ts.UnixMilli()) - } + resp, err = as.Matrix.SendStateEvent(ctx, roomID, eventType, stateKey, content, mautrix.ReqSendEvent{Timestamp: ts.UnixMilli()}) if err != nil && eventType == event.StateMember { var httpErr mautrix.HTTPError if errors.As(err, &httpErr) && httpErr.RespError != nil && @@ -393,9 +406,13 @@ func (as *ASIntent) UploadMediaStream( err = fmt.Errorf("failed to get temp file info: %w", err) return } + size = info.Size() + if size > as.Connector.MediaConfig.UploadSize { + return "", nil, fmt.Errorf("file too large (%.2f MB > %.2f MB)", float64(size)/1000/1000, float64(as.Connector.MediaConfig.UploadSize)/1000/1000) + } req := mautrix.ReqUploadMedia{ Content: replFile, - ContentLength: info.Size(), + ContentLength: size, ContentType: res.MimeType, FileName: res.FileName, } @@ -404,6 +421,7 @@ func (as *ASIntent) UploadMediaStream( removeAndClose(replFile) removeAndClose(tempFile) } + req.AsyncContext = zerolog.Ctx(ctx).WithContext(as.Connector.Bridge.BackgroundCtx) startedAsyncUpload = true var resp *mautrix.RespCreateMXC resp, err = as.Matrix.UploadAsync(ctx, req) @@ -436,6 +454,7 @@ func (as *ASIntent) doUploadReq(ctx context.Context, file *event.EncryptedFileIn as.Connector.uploadSema.Release(int64(len(req.ContentBytes))) } } + req.AsyncContext = zerolog.Ctx(ctx).WithContext(as.Connector.Bridge.BackgroundCtx) var resp *mautrix.RespCreateMXC resp, err = as.Matrix.UploadAsync(ctx, req) if resp != nil { @@ -467,19 +486,78 @@ func (as *ASIntent) SetAvatarURL(ctx context.Context, avatarURL id.ContentURIStr return as.Matrix.SetAvatarURL(ctx, parsedAvatarURL) } -func (as *ASIntent) SetExtraProfileMeta(ctx context.Context, data any) error { - if !as.Connector.SpecVersions.Supports(mautrix.BeeperFeatureArbitraryProfileMeta) { - return nil +func dataToFields(data any) (map[string]json.RawMessage, error) { + fields, ok := data.(map[string]json.RawMessage) + if ok { + return fields, nil } - return as.Matrix.BeeperUpdateProfile(ctx, data) + d, err := json.Marshal(data) + if err != nil { + return nil, err + } + d = canonicaljson.CanonicalJSONAssumeValid(d) + err = json.Unmarshal(d, &fields) + return fields, err +} + +func marshalField(val any) json.RawMessage { + data, _ := json.Marshal(val) + if len(data) > 0 && (data[0] == '{' || data[0] == '[') { + return canonicaljson.CanonicalJSONAssumeValid(data) + } + return data +} + +var nullJSON = json.RawMessage("null") + +func (as *ASIntent) SetExtraProfileMeta(ctx context.Context, data any) error { + if as.Connector.SpecVersions.Supports(mautrix.BeeperFeatureArbitraryProfileMeta) { + return as.Matrix.BeeperUpdateProfile(ctx, data) + } else if as.Connector.SpecVersions.Supports(mautrix.FeatureArbitraryProfileFields) && as.Connector.Config.Matrix.GhostExtraProfileInfo { + fields, err := dataToFields(data) + if err != nil { + return fmt.Errorf("failed to marshal fields: %w", err) + } + currentProfile, err := as.Matrix.GetProfile(ctx, as.Matrix.UserID) + if err != nil { + return fmt.Errorf("failed to get current profile: %w", err) + } + for key, val := range fields { + existing, ok := currentProfile.Extra[key] + if !ok { + if bytes.Equal(val, nullJSON) { + continue + } + err = as.Matrix.SetProfileField(ctx, key, val) + } else if !bytes.Equal(marshalField(existing), val) { + if bytes.Equal(val, nullJSON) { + err = as.Matrix.DeleteProfileField(ctx, key) + } else { + err = as.Matrix.SetProfileField(ctx, key, val) + } + } + if err != nil { + return fmt.Errorf("failed to set profile field %q: %w", key, err) + } + } + } + return nil } func (as *ASIntent) GetMXID() id.UserID { return as.Matrix.UserID } -func (as *ASIntent) EnsureJoined(ctx context.Context, roomID id.RoomID) error { - err := as.Matrix.EnsureJoined(ctx, roomID) +func (as *ASIntent) IsDoublePuppet() bool { + return as.Matrix.IsDoublePuppet() +} + +func (as *ASIntent) EnsureJoined(ctx context.Context, roomID id.RoomID, extra ...bridgev2.EnsureJoinedParams) error { + var params bridgev2.EnsureJoinedParams + if len(extra) > 0 { + params = extra[0] + } + err := as.Matrix.EnsureJoined(ctx, roomID, appservice.EnsureJoinedParams{Via: params.Via}) if err != nil { return err } @@ -505,6 +583,39 @@ func (br *Connector) getDefaultEncryptionEvent() *event.EncryptionEventContent { return content } +func (as *ASIntent) filterCreateRequestForV12(ctx context.Context, req *mautrix.ReqCreateRoom) { + if as.Connector.Config.Homeserver.Software == bridgeconfig.SoftwareHungry { + // Hungryserv doesn't override the capabilities endpoint nor do room versions + return + } + caps := as.Connector.fetchCapabilities(ctx) + roomVer := req.RoomVersion + if roomVer == "" && caps != nil && caps.RoomVersions != nil { + roomVer = id.RoomVersion(caps.RoomVersions.Default) + } + if roomVer != "" && !roomVer.PrivilegedRoomCreators() { + return + } + creators, _ := req.CreationContent["additional_creators"].([]id.UserID) + creators = append(slices.Clone(creators), as.GetMXID()) + if req.PowerLevelOverride != nil { + for _, creator := range creators { + delete(req.PowerLevelOverride.Users, creator) + } + } + for _, evt := range req.InitialState { + if evt.Type != event.StatePowerLevels { + continue + } + content, ok := evt.Content.Parsed.(*event.PowerLevelsEventContent) + if ok { + for _, creator := range creators { + delete(content.Users, creator) + } + } + } +} + func (as *ASIntent) CreateRoom(ctx context.Context, req *mautrix.ReqCreateRoom) (id.RoomID, error) { if as.Connector.Config.Encryption.Default { req.InitialState = append(req.InitialState, &event.Event{ @@ -520,6 +631,7 @@ func (as *ASIntent) CreateRoom(ctx context.Context, req *mautrix.ReqCreateRoom) } req.CreationContent["m.federate"] = false } + as.filterCreateRequestForV12(ctx, req) resp, err := as.Matrix.CreateRoom(ctx, req) if err != nil { return "", err @@ -561,8 +673,19 @@ func (as *ASIntent) MarkAsDM(ctx context.Context, roomID id.RoomID, withUser id. } func (as *ASIntent) DeleteRoom(ctx context.Context, roomID id.RoomID, puppetsOnly bool) error { + if roomID == "" { + return nil + } if as.Connector.SpecVersions.Supports(mautrix.BeeperFeatureRoomYeeting) { - return as.Matrix.BeeperDeleteRoom(ctx, roomID) + err := as.Matrix.BeeperDeleteRoom(ctx, roomID) + if err != nil { + return err + } + err = as.Matrix.StateStore.ClearCachedMembers(ctx, roomID) + if err != nil { + zerolog.Ctx(ctx).Err(err).Msg("Failed to clear cached members while cleaning up portal") + } + return nil } members, err := as.Matrix.JoinedMembers(ctx, roomID) if err != nil { @@ -650,3 +773,23 @@ func (as *ASIntent) MuteRoom(ctx context.Context, roomID id.RoomID, until time.T }) } } + +func (as *ASIntent) GetEvent(ctx context.Context, roomID id.RoomID, eventID id.EventID) (*event.Event, error) { + evt, err := as.Matrix.Client.GetEvent(ctx, roomID, eventID) + if err != nil { + return nil, err + } + err = evt.Content.ParseRaw(evt.Type) + if err != nil { + zerolog.Ctx(ctx).Err(err).Stringer("room_id", roomID).Stringer("event_id", eventID).Msg("failed to parse event content") + } + + if evt.Type == event.EventEncrypted { + if as.Connector.Crypto == nil || as.Connector.Config.Encryption.DeleteKeys.RatchetOnDecrypt { + return nil, errors.New("can't decrypt the event") + } + return as.Connector.Crypto.Decrypt(ctx, evt) + } + + return evt, nil +} diff --git a/bridgev2/matrix/matrix.go b/bridgev2/matrix/matrix.go index 1117fca2..954d0ad9 100644 --- a/bridgev2/matrix/matrix.go +++ b/bridgev2/matrix/matrix.go @@ -17,8 +17,8 @@ import ( "go.mau.fi/util/jsontime" "maunium.net/go/mautrix/appservice" - "maunium.net/go/mautrix/bridge/status" "maunium.net/go/mautrix/bridgev2" + "maunium.net/go/mautrix/bridgev2/status" "maunium.net/go/mautrix/event" "maunium.net/go/mautrix/id" ) @@ -27,6 +27,11 @@ func (br *Connector) handleRoomEvent(ctx context.Context, evt *event.Event) { if br.shouldIgnoreEvent(evt) { return } + if !br.Config.Bridge.Permissions.Get(evt.Sender).SendEvents && evt.Type != event.StateMember { + zerolog.Ctx(ctx).Debug().Msg("Dropping event from user with no permission to send events") + br.SendMessageStatus(ctx, &bridgev2.ErrNoPermissionToInteract, bridgev2.StatusEventInfoFromEvent(evt)) + return + } if (evt.Type == event.EventMessage || evt.Type == event.EventSticker) && !evt.Mautrix.WasEncrypted && br.Config.Encryption.Require { zerolog.Ctx(ctx).Warn().Msg("Dropping unencrypted event as encryption is configured to be required") br.sendCryptoStatusError(ctx, evt, errMessageNotEncrypted, nil, 0, true) @@ -63,6 +68,10 @@ func (br *Connector) handleEphemeralEvent(ctx context.Context, evt *event.Event) case event.EphemeralEventTyping: typingContent := evt.Content.AsTyping() typingContent.UserIDs = slices.DeleteFunc(typingContent.UserIDs, br.shouldIgnoreEventFromUser) + case event.BeeperEphemeralEventAIStream: + if br.shouldIgnoreEvent(evt) { + return + } } br.Bridge.QueueMatrixEvent(ctx, evt) } @@ -76,6 +85,11 @@ func (br *Connector) handleEncryptedEvent(ctx context.Context, evt *event.Event) Str("event_id", evt.ID.String()). Str("session_id", content.SessionID.String()). Logger() + if !br.Config.Bridge.Permissions.Get(evt.Sender).SendEvents { + log.Debug().Msg("Dropping event from user with no permission to send events") + br.SendMessageStatus(ctx, &bridgev2.ErrNoPermissionToInteract, bridgev2.StatusEventInfoFromEvent(evt)) + return + } ctx = log.WithContext(ctx) if br.Crypto == nil { br.sendCryptoStatusError(ctx, evt, errNoCrypto, nil, 0, true) @@ -87,17 +101,18 @@ func (br *Connector) handleEncryptedEvent(ctx context.Context, evt *event.Event) decryptionStart := time.Now() decrypted, err := br.Crypto.Decrypt(ctx, evt) decryptionRetryCount := 0 + var errorEventID id.EventID if errors.Is(err, NoSessionFound) { decryptionRetryCount = 1 log.Debug(). Int("wait_seconds", int(initialSessionWaitTimeout.Seconds())). Msg("Couldn't find session, waiting for keys to arrive...") - go br.sendCryptoStatusError(ctx, evt, err, nil, 0, false) + go br.sendCryptoStatusError(ctx, evt, err, &errorEventID, 0, false) if br.Crypto.WaitForSession(ctx, evt.RoomID, content.SenderKey, content.SessionID, initialSessionWaitTimeout) { log.Debug().Msg("Got keys after waiting, trying to decrypt event again") decrypted, err = br.Crypto.Decrypt(ctx, evt) } else { - go br.waitLongerForSession(ctx, evt, decryptionStart) + go br.waitLongerForSession(ctx, evt, decryptionStart, &errorEventID) return } } @@ -106,18 +121,18 @@ func (br *Connector) handleEncryptedEvent(ctx context.Context, evt *event.Event) go br.sendCryptoStatusError(ctx, evt, err, nil, decryptionRetryCount, true) return } - br.postDecrypt(ctx, evt, decrypted, decryptionRetryCount, nil, time.Since(decryptionStart)) + br.postDecrypt(ctx, evt, decrypted, decryptionRetryCount, &errorEventID, time.Since(decryptionStart)) } -func (br *Connector) waitLongerForSession(ctx context.Context, evt *event.Event, decryptionStart time.Time) { +func (br *Connector) waitLongerForSession(ctx context.Context, evt *event.Event, decryptionStart time.Time, errorEventID *id.EventID) { log := zerolog.Ctx(ctx) content := evt.Content.AsEncrypted() log.Debug(). Int("wait_seconds", int(extendedSessionWaitTimeout.Seconds())). Msg("Couldn't find session, requesting keys and waiting longer...") + //lint:ignore SA1019 RequestSession will gracefully request from all devices if DeviceID is blank go br.Crypto.RequestSession(ctx, evt.RoomID, content.SenderKey, content.SessionID, evt.Sender, content.DeviceID) - var errorEventID *id.EventID go br.sendCryptoStatusError(ctx, evt, fmt.Errorf("%w. The bridge will retry for %d seconds", errNoDecryptionKeys, int(extendedSessionWaitTimeout.Seconds())), errorEventID, 1, false) if !br.Crypto.WaitForSession(ctx, evt.RoomID, content.SenderKey, content.SessionID, extendedSessionWaitTimeout) { @@ -142,7 +157,7 @@ type CommandProcessor interface { } func (br *Connector) sendSuccessCheckpoint(ctx context.Context, evt *event.Event, step status.MessageCheckpointStep, retryNum int) { - err := br.SendMessageCheckpoints([]*status.MessageCheckpoint{{ + err := br.SendMessageCheckpoints(ctx, []*status.MessageCheckpoint{{ RoomID: evt.RoomID, EventID: evt.ID, EventType: evt.Type, @@ -169,7 +184,7 @@ func (br *Connector) shouldIgnoreEventFromUser(userID id.UserID) bool { } func (br *Connector) shouldIgnoreEvent(evt *event.Event) bool { - if br.shouldIgnoreEventFromUser(evt.Sender) { + if br.shouldIgnoreEventFromUser(evt.Sender) && evt.Type != event.StateTombstone { return true } dpVal, ok := evt.Content.Raw[appservice.DoublePuppetKey] @@ -220,7 +235,6 @@ func (br *Connector) postDecrypt(ctx context.Context, original, decrypted *event go br.sendSuccessCheckpoint(ctx, decrypted, status.MsgStepDecrypted, retryCount) decrypted.Mautrix.CheckpointSent = true decrypted.Mautrix.DecryptionDuration = duration - decrypted.Mautrix.EventSource |= event.SourceDecrypted br.EventProcessor.Dispatch(ctx, decrypted) if errorEventID != nil && *errorEventID != "" { _, _ = br.Bot.RedactEvent(ctx, decrypted.RoomID, *errorEventID) diff --git a/bridgev2/matrix/mxmain/dberror.go b/bridgev2/matrix/mxmain/dberror.go index 0f6aa68c..f5e438de 100644 --- a/bridgev2/matrix/mxmain/dberror.go +++ b/bridgev2/matrix/mxmain/dberror.go @@ -66,7 +66,12 @@ func (br *BridgeMain) LogDBUpgradeErrorAndExit(name string, err error, message s } else if errors.Is(err, dbutil.ErrForeignTables) { br.Log.Info().Msg("See https://docs.mau.fi/faq/foreign-tables for more info") } else if errors.Is(err, dbutil.ErrNotOwned) { - br.Log.Info().Msg("Sharing the same database with different programs is not supported") + var noe dbutil.NotOwnedError + if errors.As(err, &noe) && noe.Owner == br.Name { + br.Log.Info().Msg("The database appears to be on a very old pre-megabridge schema. Perhaps you need to run an older version of the bridge with migration support first?") + } else { + br.Log.Info().Msg("Sharing the same database with different programs is not supported") + } } else if errors.Is(err, dbutil.ErrUnsupportedDatabaseVersion) { br.Log.Info().Msg("Downgrading the bridge is not supported") } diff --git a/bridgev2/matrix/mxmain/envconfig.go b/bridgev2/matrix/mxmain/envconfig.go new file mode 100644 index 00000000..1b4f1467 --- /dev/null +++ b/bridgev2/matrix/mxmain/envconfig.go @@ -0,0 +1,161 @@ +// Copyright (c) 2025 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package mxmain + +import ( + "fmt" + "iter" + "os" + "reflect" + "strconv" + "strings" + + "go.mau.fi/util/random" +) + +var randomParseFilePrefix = random.String(16) + "READFILE:" + +func parseEnv(prefix string) iter.Seq2[[]string, string] { + return func(yield func([]string, string) bool) { + for _, s := range os.Environ() { + if !strings.HasPrefix(s, prefix) { + continue + } + kv := strings.SplitN(s, "=", 2) + key := strings.TrimPrefix(kv[0], prefix) + value := kv[1] + if strings.HasSuffix(key, "_FILE") { + key = strings.TrimSuffix(key, "_FILE") + value = randomParseFilePrefix + value + } + key = strings.ToLower(key) + if !strings.ContainsRune(key, '.') { + key = strings.ReplaceAll(key, "__", ".") + } + if !yield(strings.Split(key, "."), value) { + return + } + } + } +} + +func reflectYAMLFieldName(f *reflect.StructField) string { + parts := strings.SplitN(f.Tag.Get("yaml"), ",", 2) + fieldName := parts[0] + if fieldName == "-" && len(parts) == 1 { + return "" + } + if fieldName == "" { + return strings.ToLower(f.Name) + } + return fieldName +} + +type reflectGetResult struct { + val reflect.Value + valKind reflect.Kind + remainingPath []string +} + +func reflectGetYAML(rv reflect.Value, path []string) (*reflectGetResult, bool) { + if len(path) == 0 { + return &reflectGetResult{val: rv, valKind: rv.Kind()}, true + } + if rv.Kind() == reflect.Ptr { + rv = rv.Elem() + } + switch rv.Kind() { + case reflect.Map: + return &reflectGetResult{val: rv, remainingPath: path, valKind: rv.Type().Elem().Kind()}, true + case reflect.Struct: + fields := reflect.VisibleFields(rv.Type()) + for _, field := range fields { + fieldName := reflectYAMLFieldName(&field) + if fieldName != "" && fieldName == path[0] { + return reflectGetYAML(rv.FieldByIndex(field.Index), path[1:]) + } + } + default: + } + return nil, false +} + +func reflectGetFromMainOrNetwork(main, network reflect.Value, path []string) (*reflectGetResult, bool) { + if len(path) > 0 && path[0] == "network" { + return reflectGetYAML(network, path[1:]) + } + return reflectGetYAML(main, path) +} + +func formatKeyString(key []string) string { + return strings.Join(key, "->") +} + +func UpdateConfigFromEnv(cfg, networkData any, prefix string) error { + cfgVal := reflect.ValueOf(cfg) + networkVal := reflect.ValueOf(networkData) + for key, value := range parseEnv(prefix) { + field, ok := reflectGetFromMainOrNetwork(cfgVal, networkVal, key) + if !ok { + return fmt.Errorf("%s not found", formatKeyString(key)) + } + if strings.HasPrefix(value, randomParseFilePrefix) { + filepath := strings.TrimPrefix(value, randomParseFilePrefix) + fileData, err := os.ReadFile(filepath) + if err != nil { + return fmt.Errorf("failed to read file %s for %s: %w", filepath, formatKeyString(key), err) + } + value = strings.TrimSpace(string(fileData)) + } + var parsedVal any + var err error + switch field.valKind { + case reflect.String: + parsedVal = value + case reflect.Bool: + parsedVal, err = strconv.ParseBool(value) + if err != nil { + return fmt.Errorf("invalid value for %s: %w", formatKeyString(key), err) + } + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + parsedVal, err = strconv.ParseInt(value, 10, 64) + if err != nil { + return fmt.Errorf("invalid value for %s: %w", formatKeyString(key), err) + } + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + parsedVal, err = strconv.ParseUint(value, 10, 64) + if err != nil { + return fmt.Errorf("invalid value for %s: %w", formatKeyString(key), err) + } + case reflect.Float32, reflect.Float64: + parsedVal, err = strconv.ParseFloat(value, 64) + if err != nil { + return fmt.Errorf("invalid value for %s: %w", formatKeyString(key), err) + } + default: + return fmt.Errorf("unsupported type %s in %s", field.valKind, formatKeyString(key)) + } + if field.val.Kind() == reflect.Ptr { + if field.val.IsNil() { + field.val.Set(reflect.New(field.val.Type().Elem())) + } + field.val = field.val.Elem() + } + if field.val.Kind() == reflect.Map { + key = key[:len(key)-len(field.remainingPath)] + mapKeyStr := strings.Join(field.remainingPath, ".") + key = append(key, mapKeyStr) + if field.val.Type().Key().Kind() != reflect.String { + return fmt.Errorf("unsupported map key type %s in %s", field.val.Type().Key().Kind(), formatKeyString(key)) + } + field.val.SetMapIndex(reflect.ValueOf(mapKeyStr), reflect.ValueOf(parsedVal)) + } else { + field.val.Set(reflect.ValueOf(parsedVal)) + } + } + return nil +} diff --git a/bridgev2/matrix/mxmain/example-config.yaml b/bridgev2/matrix/mxmain/example-config.yaml index 8f7655fc..ccc81c4b 100644 --- a/bridgev2/matrix/mxmain/example-config.yaml +++ b/bridgev2/matrix/mxmain/example-config.yaml @@ -15,12 +15,28 @@ bridge: # By default, users who are in the same group on the remote network will be # in the same Matrix room bridged to that group. If this is set to true, # every user will get their own Matrix room instead. + # SETTING THIS IS IRREVERSIBLE AND POTENTIALLY DESTRUCTIVE IF PORTALS ALREADY EXIST. split_portals: false # Should the bridge resend `m.bridge` events to all portals on startup? resend_bridge_info: false + # Should `m.bridge` events be sent without a state key? + # By default, the bridge uses a unique key that won't conflict with other bridges. + no_bridge_info_state_key: false + # Should bridge connection status be sent to the management room as `m.notice` events? + # These contain the same data that can be posted to an external HTTP server using homeserver -> status_endpoint. + # Allowed values: none, errors, all + bridge_status_notices: errors + # How long after an unknown error should the bridge attempt a full reconnect? + # Must be at least 1 minute. The bridge will add an extra ±20% jitter to this value. + unknown_error_auto_reconnect: null + # Maximum number of times to do the auto-reconnect above. + # The counter is per login, but is never reset except on logout and restart. + unknown_error_max_auto_reconnects: 10 # Should leaving Matrix rooms be bridged as leaving groups on the remote network? bridge_matrix_leave: false + # Should `m.notice` messages be bridged? + bridge_notices: false # Should room tags only be synced when creating the portal? Tags mean things like favorite/pin and archive/low priority. # Tags currently can't be synced back to the remote network, so a continuous sync means tagging from Matrix will be undone. tag_only_on_create: true @@ -29,6 +45,16 @@ bridge: # Should room mute status only be synced when creating the portal? # Like tags, mutes can't currently be synced back to the remote network. mute_only_on_create: true + # Should the bridge check the db to ensure that incoming events haven't been handled before + deduplicate_matrix_messages: false + # Should cross-room reply metadata be bridged? + # Most Matrix clients don't support this and servers may reject such messages too. + cross_room_replies: false + # If a state event fails to bridge, should the bridge revert any state changes made by that event? + revert_failed_state_changes: false + # In portals with no relay set, should Matrix users be kicked if they're + # not logged into an account that's in the remote chat? + kick_matrix_users: true # What should be done to portal rooms when a user logs out or is logged out? # Permitted values: @@ -193,10 +219,6 @@ appservice: # However, messages will not be guaranteed to be bridged in the same order they were sent in. # This value doesn't affect the registration file. async_transactions: false - # Whether to use MSC4190 instead of appservice login to create the bridge bot device. - # Requires the homeserver to support MSC4190 and the device masquerading parts of MSC3202. - # Only relevant when using end-to-bridge encryption, required when using encryption with next-gen auth (MSC3861). - msc4190: false # Authentication tokens for AS <-> HS communication. Autogenerated; do not modify. as_token: "This value is generated when generating the registration" @@ -222,6 +244,9 @@ matrix: # The threshold as bytes after which the bridge should roundtrip uploads via the disk # rather than keeping the whole file in memory. upload_file_threshold: 5242880 + # Should the bridge set additional custom profile info for ghosts? + # This can make a lot of requests, as there's no batch profile update endpoint. + ghost_extra_profile_info: false # Segment-compatible analytics endpoint for tracking some events, like provisioning API login and encryption errors. analytics: @@ -234,10 +259,8 @@ analytics: # Settings for provisioning API provisioning: - # Prefix for the provisioning API paths. - prefix: /_matrix/provision # Shared secret for authentication. If set to "generate" or null, a random secret will be generated, - # or if set to "disable", the provisioning API will be disabled. + # or if set to "disable", the provisioning API will be disabled. Must be at least 16 characters. shared_secret: generate # Whether to allow provisioning API requests to be authed using Matrix access tokens. # This follows the same rules as double puppeting to determine which server to contact to check the token, @@ -245,6 +268,9 @@ provisioning: allow_matrix_auth: true # Enable debug API at /debug with provisioning authentication. debug_endpoints: false + # Enable session transfers between bridges. Note that this only validates Matrix or shared secret + # auth before passing live network client credentials down in the response. + enable_session_transfers: false # Some networks require publicly accessible media download links (e.g. for user avatars when using Discord webhooks). # These settings control whether the bridge will provide such public media access. @@ -260,6 +286,14 @@ public_media: expiry: 0 # Length of hash to use for public media URLs. Must be between 0 and 32. hash_length: 32 + # The path prefix for generated URLs. Note that this will NOT change the path where media is actually served. + # If you change this, you must configure your reverse proxy to rewrite the path accordingly. + path_prefix: /_mautrix/publicmedia + # Should the bridge store media metadata in the database in order to support encrypted media and generate shorter URLs? + # If false, the generated URLs will just have the MXC URI and a HMAC signature. + # The hash_length field will be used to decide the length of the generated URL. + # This also allows invalidating URLs by deleting the database entry. + use_database: false # Settings for converting remote media to custom mxc:// URIs instead of reuploading. # More details can be found at https://docs.mau.fi/bridges/go/discord/direct-media.html @@ -341,9 +375,21 @@ encryption: default: false # Whether to require all messages to be encrypted and drop any unencrypted messages. require: false - # Whether to use MSC2409/MSC3202 instead of /sync long polling for receiving encryption-related data. + # Whether to use MSC3202/MSC4203 instead of /sync long polling for receiving encryption-related data. # This option is not yet compatible with standard Matrix servers like Synapse and should not be used. + # Changing this option requires updating the appservice registration file. appservice: false + # Whether to use MSC4190 instead of appservice login to create the bridge bot device. + # Requires the homeserver to support MSC4190 and the device masquerading parts of MSC3202. + # Only relevant when using end-to-bridge encryption, required when using encryption with next-gen auth (MSC3861). + # Changing this option requires updating the appservice registration file. + msc4190: false + # Whether to encrypt reactions and reply metadata as per MSC4392. + msc4392: false + # Should the bridge bot generate a recovery key and cross-signing keys and verify itself? + # Note that without the latest version of MSC4190, this will fail if you reset the bridge database. + # The generated recovery key will be saved in the kv_store table under `recovery_key`. + self_sign: false # Enable key sharing? If enabled, key requests for rooms where users are in will be fulfilled. # You must use a client that supports requesting keys from other users to use this feature. allow_key_sharing: true @@ -406,6 +452,16 @@ encryption: # You should not enable this option unless you understand all the implications. disable_device_change_key_rotation: false +# Prefix for environment variables. All variables with this prefix must map to valid config fields. +# Nesting in variable names is represented with a dot (.). +# If there are no dots in the name, two underscores (__) are replaced with a dot. +# +# e.g. if the prefix is set to `BRIDGE_`, then `BRIDGE_APPSERVICE__AS_TOKEN` will set appservice.as_token. +# `BRIDGE_appservice.as_token` would work as well, but can't be set in a shell as easily. +# +# If this is null, reading config fields from environment will be disabled. +env_config_prefix: null + # Logging config. See https://github.com/tulir/zeroconfig for details. logging: min_level: debug diff --git a/bridgev2/matrix/mxmain/legacymigrate.go b/bridgev2/matrix/mxmain/legacymigrate.go index d33dd8cd..97cdeddf 100644 --- a/bridgev2/matrix/mxmain/legacymigrate.go +++ b/bridgev2/matrix/mxmain/legacymigrate.go @@ -135,7 +135,10 @@ func (br *BridgeMain) CheckLegacyDB( } var dbVersion int err = br.DB.QueryRow(ctx, "SELECT version FROM version").Scan(&dbVersion) - if dbVersion < expectedVersion { + if err != nil { + log.Fatal().Err(err).Msg("Failed to get database version") + return + } else if dbVersion < expectedVersion { log.Fatal(). Int("expected_version", expectedVersion). Int("version", dbVersion). @@ -208,36 +211,46 @@ func (br *BridgeMain) postMigrateDMPortal(ctx context.Context, portal *bridgev2. } func (br *BridgeMain) PostMigrate(ctx context.Context) error { + log := br.Log.With().Str("action", "post-migrate").Logger() wasMigrated, err := br.DB.TableExists(ctx, "database_was_migrated") if err != nil { return fmt.Errorf("failed to check if database_was_migrated table exists: %w", err) } else if !wasMigrated { return nil } - zerolog.Ctx(ctx).Info().Msg("Doing post-migration updates to Matrix rooms") + log.Info().Msg("Doing post-migration updates to Matrix rooms") portals, err := br.Bridge.GetAllPortalsWithMXID(ctx) if err != nil { return fmt.Errorf("failed to get all portals: %w", err) } for _, portal := range portals { - zerolog.Ctx(ctx).Debug(). + log := log.With(). Stringer("room_id", portal.MXID). Object("portal_key", portal.PortalKey). Str("room_type", string(portal.RoomType)). - Msg("Migrating portal") - switch portal.RoomType { - case database.RoomTypeDM: - err = br.postMigrateDMPortal(ctx, portal) + Logger() + log.Debug().Msg("Migrating portal") + if br.PostMigratePortal != nil { + err = br.PostMigratePortal(ctx, portal) if err != nil { - return fmt.Errorf("failed to update DM portal %s: %w", portal.MXID, err) + log.Err(err).Msg("Failed to run post-migrate portal hook") + continue + } + } else { + switch portal.RoomType { + case database.RoomTypeDM: + err = br.postMigrateDMPortal(ctx, portal) + if err != nil { + return fmt.Errorf("failed to update DM portal %s: %w", portal.MXID, err) + } } } _, err = br.Matrix.Bot.SendStateEvent(ctx, portal.MXID, event.StateElementFunctionalMembers, "", &event.ElementFunctionalMembersContent{ ServiceMembers: []id.UserID{br.Matrix.Bot.UserID}, }) if err != nil { - zerolog.Ctx(ctx).Warn().Err(err).Stringer("room_id", portal.MXID).Msg("Failed to set service members") + log.Warn().Err(err).Stringer("room_id", portal.MXID).Msg("Failed to set service members") } } @@ -245,6 +258,6 @@ func (br *BridgeMain) PostMigrate(ctx context.Context) error { if err != nil { return fmt.Errorf("failed to drop database_was_migrated table: %w", err) } - zerolog.Ctx(ctx).Info().Msg("Post-migration updates complete") + log.Info().Msg("Post-migration updates complete") return nil } diff --git a/bridgev2/matrix/mxmain/main.go b/bridgev2/matrix/mxmain/main.go index 2c0c07b9..1e8b51d1 100644 --- a/bridgev2/matrix/mxmain/main.go +++ b/bridgev2/matrix/mxmain/main.go @@ -26,6 +26,7 @@ import ( "go.mau.fi/util/dbutil" "go.mau.fi/util/exerrors" "go.mau.fi/util/exzerolog" + "go.mau.fi/util/progver" "gopkg.in/yaml.v3" flag "maunium.net/go/mauflag" @@ -62,11 +63,18 @@ type BridgeMain struct { // git tag to see if the built version is the release or a dev build. // You can either bump this right after a release or right before, as long as it matches on the release commit. Version string + // SemCalVer defines whether this bridge uses a mix of semantic and calendar versioning, + // such that the Version field is YY.0M.patch, while git tags are major.YY0M.patch. + SemCalVer bool // PostInit is a function that will be called after the bridge has been initialized but before it is started. PostInit func() PostStart func() + // PostMigratePortal is a function that will be called during a legacy + // migration for each portal. + PostMigratePortal func(context.Context, *bridgev2.Portal) error + // Connector is the network connector for the bridge. Connector bridgev2.NetworkConnector @@ -82,11 +90,7 @@ type BridgeMain struct { RegistrationPath string SaveConfig bool - baseVersion string - commit string - LinkifiedVersion string - VersionDesc string - BuildTime time.Time + ver progver.ProgramVersion AdditionalShortFlags string AdditionalLongFlags string @@ -95,14 +99,7 @@ type BridgeMain struct { } type VersionJSONOutput struct { - Name string - URL string - - Version string - IsRelease bool - Commit string - FormattedVersion string - BuildTime time.Time + progver.ProgramVersion OS string Arch string @@ -143,18 +140,11 @@ func (br *BridgeMain) PreInit() { flag.PrintHelp() os.Exit(0) } else if *version { - fmt.Println(br.VersionDesc) + fmt.Println(br.ver.VersionDescription) os.Exit(0) } else if *versionJSON { output := VersionJSONOutput{ - URL: br.URL, - Name: br.Name, - - Version: br.baseVersion, - IsRelease: br.Version == br.baseVersion, - Commit: br.commit, - FormattedVersion: br.Version, - BuildTime: br.BuildTime, + ProgramVersion: br.ver, OS: runtime.GOOS, Arch: runtime.GOARCH, @@ -236,8 +226,8 @@ func (br *BridgeMain) Init() { br.Log.Info(). Str("name", br.Name). - Str("version", br.Version). - Time("built_at", br.BuildTime). + Str("version", br.ver.FormattedVersion). + Time("built_at", br.ver.BuildTime). Str("go_version", runtime.Version()). Msg("Initializing bridge") @@ -251,7 +241,7 @@ func (br *BridgeMain) Init() { br.Matrix.AS.DoublePuppetValue = br.Name br.Bridge.Commands.(*commands.Processor).AddHandler(&commands.FullHandler{ Func: func(ce *commands.Event) { - ce.Reply("[%s](%s) %s (%s)", br.Name, br.URL, br.LinkifiedVersion, br.BuildTime.Format(time.RFC1123)) + ce.Reply(br.ver.MarkdownDescription()) }, Name: "version", Help: commands.HelpMeta{ @@ -310,7 +300,7 @@ func (br *BridgeMain) validateConfig() error { case br.Config.AppService.HSToken == "This value is generated when generating the registration": return errors.New("appservice.hs_token not configured. Did you forget to generate the registration? ") case br.Config.Database.URI == "postgres://user:password@host/database?sslmode=disable": - return errors.New("appservice.database not configured") + return errors.New("database.uri not configured") case !br.Config.Bridge.Permissions.IsConfigured(): return errors.New("bridge.permissions not configured") case !strings.Contains(br.Config.AppService.FormatUsername("1234567890"), "1234567890"): @@ -364,13 +354,21 @@ func (br *BridgeMain) LoadConfig() { } } cfg.Bridge.Backfill = cfg.Backfill + if cfg.EnvConfigPrefix != "" { + err = UpdateConfigFromEnv(&cfg, networkData, cfg.EnvConfigPrefix) + if err != nil { + _, _ = fmt.Fprintln(os.Stderr, "Failed to parse environment variables:", err) + os.Exit(10) + } + } br.Config = &cfg } // Start starts the bridge after everything has been initialized. // This is called by [Run] and does not need to be called manually. func (br *BridgeMain) Start() { - err := br.Bridge.StartConnectors() + ctx := br.Log.WithContext(context.Background()) + err := br.Bridge.StartConnectors(ctx) if err != nil { var dbUpgradeErr bridgev2.DBUpgradeError if errors.As(err, &dbUpgradeErr) { @@ -379,14 +377,15 @@ func (br *BridgeMain) Start() { br.Log.Fatal().Err(err).Msg("Failed to start bridge") } } - err = br.PostMigrate(br.Log.WithContext(context.Background())) + err = br.PostMigrate(ctx) if err != nil { br.Log.Fatal().Err(err).Msg("Failed to run post-migration updates") } - err = br.Bridge.StartLogins() + err = br.Bridge.StartLogins(ctx) if err != nil { br.Log.Fatal().Err(err).Msg("Failed to start existing user logins") } + br.Bridge.PostStart(ctx) if br.PostStart != nil { br.PostStart() } @@ -415,7 +414,7 @@ func (br *BridgeMain) TriggerStop(exitCode int) { // Stop cleanly stops the bridge. This is called by [Run] and does not need to be called manually. func (br *BridgeMain) Stop() { - br.Bridge.Stop() + br.Bridge.StopWithTimeout(5 * time.Second) } // InitVersion formats the bridge version and build time nicely for things like @@ -440,42 +439,12 @@ func (br *BridgeMain) Stop() { // // (to use both at the same time, simply merge the ldflags into one, `-ldflags "-X '...' -X ..."`) func (br *BridgeMain) InitVersion(tag, commit, rawBuildTime string) { - br.baseVersion = br.Version - if len(tag) > 0 && tag[0] == 'v' { - tag = tag[1:] - } - if tag != br.Version { - suffix := "" - if !strings.HasSuffix(br.Version, "+dev") { - suffix = "+dev" - } - if len(commit) > 8 { - br.Version = fmt.Sprintf("%s%s.%s", br.Version, suffix, commit[:8]) - } else { - br.Version = fmt.Sprintf("%s%s.unknown", br.Version, suffix) - } - } - - br.LinkifiedVersion = fmt.Sprintf("v%s", br.Version) - if tag == br.Version { - br.LinkifiedVersion = fmt.Sprintf("[v%s](%s/releases/v%s)", br.Version, br.URL, tag) - } else if len(commit) > 8 { - br.LinkifiedVersion = strings.Replace(br.LinkifiedVersion, commit[:8], fmt.Sprintf("[%s](%s/commit/%s)", commit[:8], br.URL, commit), 1) - } - var buildTime time.Time - if rawBuildTime != "unknown" { - buildTime, _ = time.Parse(time.RFC3339, rawBuildTime) - } - var builtWith string - if buildTime.IsZero() { - rawBuildTime = "unknown" - builtWith = runtime.Version() - } else { - rawBuildTime = buildTime.Format(time.RFC1123) - builtWith = fmt.Sprintf("built at %s with %s", rawBuildTime, runtime.Version()) - } - mautrix.DefaultUserAgent = fmt.Sprintf("%s/%s %s", br.Name, br.Version, mautrix.DefaultUserAgent) - br.VersionDesc = fmt.Sprintf("%s %s (%s)", br.Name, br.Version, builtWith) - br.commit = commit - br.BuildTime = buildTime + br.ver = progver.ProgramVersion{ + Name: br.Name, + URL: br.URL, + BaseVersion: br.Version, + SemCalVer: br.SemCalVer, + }.Init(tag, commit, rawBuildTime) + mautrix.DefaultUserAgent = fmt.Sprintf("%s/%s %s", br.Name, br.ver.FormattedVersion, mautrix.DefaultUserAgent) + br.Version = br.ver.FormattedVersion } diff --git a/bridgev2/matrix/provisioning.go b/bridgev2/matrix/provisioning.go index 87f6576d..243b91da 100644 --- a/bridgev2/matrix/provisioning.go +++ b/bridgev2/matrix/provisioning.go @@ -1,4 +1,4 @@ -// Copyright (c) 2024 Tulir Asokan +// Copyright (c) 2025 Tulir Asokan // // This Source Code Form is subject to the terms of the Mozilla Public // License, v. 2.0. If a copy of the MPL was not distributed with this @@ -17,17 +17,21 @@ import ( "sync" "time" - "github.com/gorilla/mux" "github.com/rs/xid" "github.com/rs/zerolog" "github.com/rs/zerolog/hlog" + "go.mau.fi/util/exerrors" + "go.mau.fi/util/exhttp" + "go.mau.fi/util/exstrings" "go.mau.fi/util/jsontime" + "go.mau.fi/util/ptr" "go.mau.fi/util/requestlog" "maunium.net/go/mautrix" - "maunium.net/go/mautrix/bridge/status" "maunium.net/go/mautrix/bridgev2" "maunium.net/go/mautrix/bridgev2/networkid" + "maunium.net/go/mautrix/bridgev2/provisionutil" + "maunium.net/go/mautrix/bridgev2/status" "maunium.net/go/mautrix/federation" "maunium.net/go/mautrix/id" ) @@ -38,7 +42,7 @@ type matrixAuthCacheEntry struct { } type ProvisioningAPI struct { - Router *mux.Router + Router *http.ServeMux br *Connector log zerolog.Logger @@ -52,6 +56,11 @@ type ProvisioningAPI struct { matrixAuthCache map[string]matrixAuthCacheEntry matrixAuthCacheLock sync.Mutex + // Set for a given login once credentials have been exported, once in this state the finish + // API is available which will call logout on the client in question. + sessionTransfers map[networkid.UserLoginID]struct{} + sessionTransfersLock sync.Mutex + // GetAuthFromRequest is a custom function for getting the auth token from // the request if the Authorization header is not present. GetAuthFromRequest func(r *http.Request) string @@ -76,86 +85,84 @@ const ( provisioningUserKey provisioningContextKey = iota provisioningUserLoginKey provisioningLoginProcessKey + ProvisioningKeyRequest ) -const ProvisioningKeyRequest = "fi.mau.provision.request" - func (prov *ProvisioningAPI) GetUser(r *http.Request) *bridgev2.User { return r.Context().Value(provisioningUserKey).(*bridgev2.User) } -func (prov *ProvisioningAPI) GetRouter() *mux.Router { +func (prov *ProvisioningAPI) GetRouter() *http.ServeMux { return prov.Router } -type IProvisioningAPI interface { - GetRouter() *mux.Router - GetUser(r *http.Request) *bridgev2.User -} - -func (br *Connector) GetProvisioning() IProvisioningAPI { +func (br *Connector) GetProvisioning() bridgev2.IProvisioningAPI { return br.Provisioning } func (prov *ProvisioningAPI) Init() { prov.matrixAuthCache = make(map[string]matrixAuthCacheEntry) prov.logins = make(map[string]*ProvLogin) + prov.sessionTransfers = make(map[networkid.UserLoginID]struct{}) prov.net = prov.br.Bridge.Network prov.log = prov.br.Log.With().Str("component", "provisioning").Logger() - prov.fedClient = federation.NewClient("", nil) + prov.fedClient = federation.NewClient("", nil, nil) prov.fedClient.HTTP.Timeout = 20 * time.Second tp := prov.fedClient.HTTP.Transport.(*federation.ServerResolvingTransport) tp.Dialer.Timeout = 10 * time.Second tp.Transport.ResponseHeaderTimeout = 10 * time.Second tp.Transport.TLSHandshakeTimeout = 10 * time.Second - prov.Router = prov.br.AS.Router.PathPrefix(prov.br.Config.Provisioning.Prefix).Subrouter() - prov.Router.Use(hlog.NewHandler(prov.log)) - prov.Router.Use(hlog.RequestIDHandler("request_id", "Request-Id")) - prov.Router.Use(corsMiddleware) - prov.Router.Use(requestlog.AccessLogger(false)) - prov.Router.Use(prov.AuthMiddleware) - prov.Router.Path("/v3/whoami").Methods(http.MethodGet, http.MethodOptions).HandlerFunc(prov.GetWhoami) - prov.Router.Path("/v3/login/flows").Methods(http.MethodGet, http.MethodOptions).HandlerFunc(prov.GetLoginFlows) - prov.Router.Path("/v3/login/start/{flowID}").Methods(http.MethodPost, http.MethodOptions).HandlerFunc(prov.PostLoginStart) - prov.Router.Path("/v3/login/step/{loginProcessID}/{stepID}/{stepType:user_input|cookies}").Methods(http.MethodPost, http.MethodOptions).HandlerFunc(prov.PostLoginSubmitInput) - prov.Router.Path("/v3/login/step/{loginProcessID}/{stepID}/{stepType:display_and_wait}").Methods(http.MethodPost, http.MethodOptions).HandlerFunc(prov.PostLoginWait) - prov.Router.Path("/v3/logout/{loginID}").Methods(http.MethodPost, http.MethodOptions).HandlerFunc(prov.PostLogout) - prov.Router.Path("/v3/logins").Methods(http.MethodGet, http.MethodOptions).HandlerFunc(prov.GetLogins) - prov.Router.Path("/v3/contacts").Methods(http.MethodGet, http.MethodOptions).HandlerFunc(prov.GetContactList) - prov.Router.Path("/v3/search_users").Methods(http.MethodPost, http.MethodOptions).HandlerFunc(prov.PostSearchUsers) - prov.Router.Path("/v3/resolve_identifier/{identifier}").Methods(http.MethodGet, http.MethodOptions).HandlerFunc(prov.GetResolveIdentifier) - prov.Router.Path("/v3/create_dm/{identifier}").Methods(http.MethodPost, http.MethodOptions).HandlerFunc(prov.PostCreateDM) - prov.Router.Path("/v3/create_group").Methods(http.MethodPost, http.MethodOptions).HandlerFunc(prov.PostCreateGroup) + prov.Router = http.NewServeMux() + prov.Router.HandleFunc("GET /v3/whoami", prov.GetWhoami) + prov.Router.HandleFunc("GET /v3/capabilities", prov.GetCapabilities) + prov.Router.HandleFunc("GET /v3/login/flows", prov.GetLoginFlows) + prov.Router.HandleFunc("POST /v3/login/start/{flowID}", prov.PostLoginStart) + prov.Router.HandleFunc("POST /v3/login/step/{loginProcessID}/{stepID}/{stepType}", prov.PostLoginStep) + prov.Router.HandleFunc("POST /v3/logout/{loginID}", prov.PostLogout) + prov.Router.HandleFunc("GET /v3/logins", prov.GetLogins) + prov.Router.HandleFunc("GET /v3/contacts", prov.GetContactList) + prov.Router.HandleFunc("POST /v3/search_users", prov.PostSearchUsers) + prov.Router.HandleFunc("GET /v3/resolve_identifier/{identifier}", prov.GetResolveIdentifier) + prov.Router.HandleFunc("POST /v3/create_dm/{identifier}", prov.PostCreateDM) + prov.Router.HandleFunc("POST /v3/create_group/{type}", prov.PostCreateGroup) + + if prov.br.Config.Provisioning.EnableSessionTransfers { + prov.log.Debug().Msg("Enabling session transfer API") + prov.Router.HandleFunc("POST /v3/session_transfer/init", prov.PostInitSessionTransfer) + prov.Router.HandleFunc("POST /v3/session_transfer/finish", prov.PostFinishSessionTransfer) + } if prov.br.Config.Provisioning.DebugEndpoints { prov.log.Debug().Msg("Enabling debug API at /debug") - r := prov.br.AS.Router.PathPrefix("/debug").Subrouter() - r.Use(prov.DebugAuthMiddleware) - r.HandleFunc("/pprof/cmdline", pprof.Cmdline).Methods(http.MethodGet) - r.HandleFunc("/pprof/profile", pprof.Profile).Methods(http.MethodGet) - r.HandleFunc("/pprof/symbol", pprof.Symbol).Methods(http.MethodGet) - r.HandleFunc("/pprof/trace", pprof.Trace).Methods(http.MethodGet) - r.PathPrefix("/pprof/").HandlerFunc(pprof.Index) + debugRouter := http.NewServeMux() + debugRouter.HandleFunc("GET /pprof/cmdline", pprof.Cmdline) + debugRouter.HandleFunc("GET /pprof/profile", pprof.Profile) + debugRouter.HandleFunc("GET /pprof/symbol", pprof.Symbol) + debugRouter.HandleFunc("GET /pprof/trace", pprof.Trace) + debugRouter.HandleFunc("/pprof/", pprof.Index) + prov.br.AS.Router.Handle("/debug/", exhttp.ApplyMiddleware( + debugRouter, + exhttp.StripPrefix("/debug"), + hlog.NewHandler(prov.br.Log.With().Str("component", "debug api").Logger()), + requestlog.AccessLogger(requestlog.Options{TrustXForwardedFor: true}), + prov.DebugAuthMiddleware, + )) } -} -func corsMiddleware(handler http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Access-Control-Allow-Origin", "*") - w.Header().Set("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS") - w.Header().Set("Access-Control-Allow-Headers", "X-Requested-With, Content-Type, Authorization") - if r.Method == http.MethodOptions { - w.WriteHeader(http.StatusOK) - return - } - handler.ServeHTTP(w, r) - }) -} - -func jsonResponse(w http.ResponseWriter, status int, response any) { - w.Header().Add("Content-Type", "application/json") - w.WriteHeader(status) - _ = json.NewEncoder(w).Encode(response) + errorBodies := exhttp.ErrorBodies{ + NotFound: exerrors.Must(ptr.Ptr(mautrix.MUnrecognized.WithMessage("Unrecognized endpoint")).MarshalJSON()), + MethodNotAllowed: exerrors.Must(ptr.Ptr(mautrix.MUnrecognized.WithMessage("Invalid method for endpoint")).MarshalJSON()), + } + prov.br.AS.Router.Handle("/_matrix/provision/", exhttp.ApplyMiddleware( + prov.Router, + exhttp.StripPrefix("/_matrix/provision"), + hlog.NewHandler(prov.log), + hlog.RequestIDHandler("request_id", "Request-Id"), + exhttp.CORSMiddleware, + requestlog.AccessLogger(requestlog.Options{TrustXForwardedFor: true}), + exhttp.HandleErrors(errorBodies), + prov.AuthMiddleware, + )) } func (prov *ProvisioningAPI) checkMatrixAuth(ctx context.Context, userID id.UserID, token string) error { @@ -199,19 +206,21 @@ func (prov *ProvisioningAPI) checkFederatedMatrixAuth(ctx context.Context, userI } } +func disabledAuth(w http.ResponseWriter, r *http.Request) { + mautrix.MForbidden.WithMessage("Provisioning API is disabled").Write(w) +} + func (prov *ProvisioningAPI) DebugAuthMiddleware(h http.Handler) http.Handler { + secret := prov.br.Config.Provisioning.SharedSecret + if len(secret) < 16 { + return http.HandlerFunc(disabledAuth) + } return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { auth := strings.TrimPrefix(r.Header.Get("Authorization"), "Bearer ") if auth == "" { - jsonResponse(w, http.StatusUnauthorized, &mautrix.RespError{ - Err: "Missing auth token", - ErrCode: mautrix.MMissingToken.ErrCode, - }) - } else if auth != prov.br.Config.Provisioning.SharedSecret { - jsonResponse(w, http.StatusUnauthorized, &mautrix.RespError{ - Err: "Invalid auth token", - ErrCode: mautrix.MUnknownToken.ErrCode, - }) + mautrix.MMissingToken.WithMessage("Missing auth token").Write(w) + } else if !exstrings.ConstantTimeEqual(auth, secret) { + mautrix.MUnknownToken.WithMessage("Invalid auth token").Write(w) } else { h.ServeHTTP(w, r) } @@ -219,23 +228,24 @@ func (prov *ProvisioningAPI) DebugAuthMiddleware(h http.Handler) http.Handler { } func (prov *ProvisioningAPI) AuthMiddleware(h http.Handler) http.Handler { + secret := prov.br.Config.Provisioning.SharedSecret + if len(secret) < 16 { + return http.HandlerFunc(disabledAuth) + } return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { auth := strings.TrimPrefix(r.Header.Get("Authorization"), "Bearer ") if auth == "" && prov.GetAuthFromRequest != nil { auth = prov.GetAuthFromRequest(r) } if auth == "" { - jsonResponse(w, http.StatusUnauthorized, &mautrix.RespError{ - Err: "Missing auth token", - ErrCode: mautrix.MMissingToken.ErrCode, - }) + mautrix.MMissingToken.WithMessage("Missing auth token").Write(w) return } userID := id.UserID(r.URL.Query().Get("user_id")) if userID == "" && prov.GetUserIDFromRequest != nil { userID = prov.GetUserIDFromRequest(r) } - if auth != prov.br.Config.Provisioning.SharedSecret { + if !exstrings.ConstantTimeEqual(auth, secret) { var err error if strings.HasPrefix(auth, "openid:") { err = prov.checkFederatedMatrixAuth(r.Context(), userID, strings.TrimPrefix(auth, "openid:")) @@ -245,75 +255,25 @@ func (prov *ProvisioningAPI) AuthMiddleware(h http.Handler) http.Handler { if err != nil { zerolog.Ctx(r.Context()).Warn().Err(err). Msg("Provisioning API request contained invalid auth") - jsonResponse(w, http.StatusUnauthorized, &mautrix.RespError{ - Err: "Invalid auth token", - ErrCode: mautrix.MUnknownToken.ErrCode, - }) + mautrix.MUnknownToken.WithMessage("Invalid auth token").Write(w) return } } user, err := prov.br.Bridge.GetUserByMXID(r.Context(), userID) if err != nil { zerolog.Ctx(r.Context()).Err(err).Msg("Failed to get user") - jsonResponse(w, http.StatusInternalServerError, &mautrix.RespError{ - Err: "Failed to get user", - ErrCode: "M_UNKNOWN", - }) + mautrix.MUnknown.WithMessage("Failed to get user").Write(w) return } // TODO handle user being nil? // TODO per-endpoint permissions? if !user.Permissions.Login { - jsonResponse(w, http.StatusForbidden, &mautrix.RespError{ - Err: "User does not have login permissions", - ErrCode: mautrix.MForbidden.ErrCode, - }) + mautrix.MForbidden.WithMessage("User does not have login permissions").Write(w) return } ctx := context.WithValue(r.Context(), ProvisioningKeyRequest, r) ctx = context.WithValue(ctx, provisioningUserKey, user) - if loginID, ok := mux.Vars(r)["loginProcessID"]; ok { - prov.loginsLock.RLock() - login, ok := prov.logins[loginID] - prov.loginsLock.RUnlock() - if !ok { - zerolog.Ctx(r.Context()).Warn().Str("login_id", loginID).Msg("Login not found") - jsonResponse(w, http.StatusNotFound, &mautrix.RespError{ - Err: "Login not found", - ErrCode: mautrix.MNotFound.ErrCode, - }) - return - } - login.Lock.Lock() - // This will only unlock after the handler runs - defer login.Lock.Unlock() - stepID := mux.Vars(r)["stepID"] - if login.NextStep.StepID != stepID { - zerolog.Ctx(r.Context()).Warn(). - Str("request_step_id", stepID). - Str("expected_step_id", login.NextStep.StepID). - Msg("Step ID does not match") - jsonResponse(w, http.StatusBadRequest, &mautrix.RespError{ - Err: "Step ID does not match", - ErrCode: mautrix.MBadState.ErrCode, - }) - return - } - stepType := mux.Vars(r)["stepType"] - if login.NextStep.Type != bridgev2.LoginStepType(stepType) { - zerolog.Ctx(r.Context()).Warn(). - Str("request_step_type", stepType). - Str("expected_step_type", string(login.NextStep.Type)). - Msg("Step type does not match") - jsonResponse(w, http.StatusBadRequest, &mautrix.RespError{ - Err: "Step type does not match", - ErrCode: mautrix.MBadState.ErrCode, - }) - return - } - ctx = context.WithValue(ctx, provisioningLoginProcessKey, login) - } h.ServeHTTP(w, r.WithContext(ctx)) }) } @@ -364,7 +324,7 @@ func (prov *ProvisioningAPI) GetWhoami(w http.ResponseWriter, r *http.Request) { prevState.UserID = "" prevState.RemoteID = "" prevState.RemoteName = "" - prevState.RemoteProfile = nil + prevState.RemoteProfile = status.RemoteProfile{} resp.Logins[i] = RespWhoamiLogin{ StateEvent: prevState.StateEvent, StateTS: prevState.Timestamp, @@ -378,7 +338,7 @@ func (prov *ProvisioningAPI) GetWhoami(w http.ResponseWriter, r *http.Request) { SpaceRoom: login.SpaceRoom, } } - jsonResponse(w, http.StatusOK, resp) + exhttp.WriteJSONResponse(w, http.StatusOK, resp) } type RespLoginFlows struct { @@ -391,21 +351,29 @@ type RespSubmitLogin struct { } func (prov *ProvisioningAPI) GetLoginFlows(w http.ResponseWriter, r *http.Request) { - jsonResponse(w, http.StatusOK, &RespLoginFlows{ + exhttp.WriteJSONResponse(w, http.StatusOK, &RespLoginFlows{ Flows: prov.net.GetLoginFlows(), }) } +func (prov *ProvisioningAPI) GetCapabilities(w http.ResponseWriter, r *http.Request) { + exhttp.WriteJSONResponse(w, http.StatusOK, &prov.net.GetCapabilities().Provisioning) +} + +var ErrNilStep = errors.New("bridge returned nil step with no error") +var ErrTooManyLogins = bridgev2.RespError{ErrCode: "FI.MAU.BRIDGE.TOO_MANY_LOGINS", Err: "Maximum number of logins exceeded"} + func (prov *ProvisioningAPI) PostLoginStart(w http.ResponseWriter, r *http.Request) { overrideLogin, failed := prov.GetExplicitLoginForRequest(w, r) if failed { return } - login, err := prov.net.CreateLogin( - r.Context(), - prov.GetUser(r), - mux.Vars(r)["flowID"], - ) + user := prov.GetUser(r) + if overrideLogin == nil && user.HasTooManyLogins() { + ErrTooManyLogins.AppendMessage(" (%d)", user.Permissions.MaxLogins).Write(w) + return + } + login, err := prov.net.CreateLogin(r.Context(), user, r.PathValue("flowID")) if err != nil { zerolog.Ctx(r.Context()).Err(err).Msg("Failed to create login process") RespondWithError(w, err, "Internal error creating login process") @@ -418,6 +386,9 @@ func (prov *ProvisioningAPI) PostLoginStart(w http.ResponseWriter, r *http.Reque } else { firstStep, err = login.Start(r.Context()) } + if err == nil && firstStep == nil { + err = ErrNilStep + } if err != nil { zerolog.Ctx(r.Context()).Err(err).Msg("Failed to start login") RespondWithError(w, err, "Internal error starting login") @@ -432,10 +403,18 @@ func (prov *ProvisioningAPI) PostLoginStart(w http.ResponseWriter, r *http.Reque Override: overrideLogin, } prov.loginsLock.Unlock() - jsonResponse(w, http.StatusOK, &RespSubmitLogin{LoginID: loginID, LoginStep: firstStep}) + zerolog.Ctx(r.Context()).Info(). + Any("first_step", firstStep). + Msg("Created login process") + exhttp.WriteJSONResponse(w, http.StatusOK, &RespSubmitLogin{LoginID: loginID, LoginStep: firstStep}) } func (prov *ProvisioningAPI) handleCompleteStep(ctx context.Context, login *ProvLogin, step *bridgev2.LoginStep) { + zerolog.Ctx(ctx).Info(). + Str("step_id", step.StepID). + Str("user_login_id", string(step.CompleteParams.UserLoginID)). + Msg("Login completed successfully") + prov.deleteLogin(login, false) if login.Override == nil || login.Override.ID == step.CompleteParams.UserLoginID { return } @@ -449,15 +428,67 @@ func (prov *ProvisioningAPI) handleCompleteStep(ctx context.Context, login *Prov }, bridgev2.DeleteOpts{LogoutRemote: true}) } +func (prov *ProvisioningAPI) deleteLogin(login *ProvLogin, cancel bool) { + if cancel { + login.Process.Cancel() + } + prov.loginsLock.Lock() + delete(prov.logins, login.ID) + prov.loginsLock.Unlock() +} + +func (prov *ProvisioningAPI) PostLoginStep(w http.ResponseWriter, r *http.Request) { + loginID := r.PathValue("loginProcessID") + prov.loginsLock.RLock() + login, ok := prov.logins[loginID] + prov.loginsLock.RUnlock() + if !ok { + zerolog.Ctx(r.Context()).Warn().Str("login_id", loginID).Msg("Login not found") + mautrix.MNotFound.WithMessage("Login not found").Write(w) + return + } + login.Lock.Lock() + // This will only unlock after the handler runs + defer login.Lock.Unlock() + stepID := r.PathValue("stepID") + if login.NextStep.StepID != stepID { + zerolog.Ctx(r.Context()).Warn(). + Str("request_step_id", stepID). + Str("expected_step_id", login.NextStep.StepID). + Msg("Step ID does not match") + mautrix.MBadState.WithMessage("Step ID does not match").Write(w) + return + } + stepType := r.PathValue("stepType") + if login.NextStep.Type != bridgev2.LoginStepType(stepType) { + zerolog.Ctx(r.Context()).Warn(). + Str("request_step_type", stepType). + Str("expected_step_type", string(login.NextStep.Type)). + Msg("Step type does not match") + mautrix.MBadState.WithMessage("Step type does not match").Write(w) + return + } + ctx := context.WithValue(r.Context(), provisioningLoginProcessKey, login) + r = r.WithContext(ctx) + switch bridgev2.LoginStepType(r.PathValue("stepType")) { + case bridgev2.LoginStepTypeUserInput, bridgev2.LoginStepTypeCookies: + prov.PostLoginSubmitInput(w, r) + case bridgev2.LoginStepTypeDisplayAndWait: + prov.PostLoginWait(w, r) + case bridgev2.LoginStepTypeComplete: + fallthrough + default: + // This is probably impossible because of the above check that the next step type matches the request. + mautrix.MUnrecognized.WithMessage("Invalid step type %q", r.PathValue("stepType")).Write(w) + } +} + func (prov *ProvisioningAPI) PostLoginSubmitInput(w http.ResponseWriter, r *http.Request) { var params map[string]string err := json.NewDecoder(r.Body).Decode(¶ms) if err != nil { zerolog.Ctx(r.Context()).Err(err).Msg("Failed to decode request body") - jsonResponse(w, http.StatusBadRequest, &mautrix.RespError{ - Err: "Failed to decode request body", - ErrCode: mautrix.MNotJSON.ErrCode, - }) + mautrix.MNotJSON.WithMessage("Failed to decode request body").Write(w) return } login := r.Context().Value(provisioningLoginProcessKey).(*ProvLogin) @@ -470,39 +501,48 @@ func (prov *ProvisioningAPI) PostLoginSubmitInput(w http.ResponseWriter, r *http default: panic("Impossible state") } + if err == nil && nextStep == nil { + err = ErrNilStep + } if err != nil { zerolog.Ctx(r.Context()).Err(err).Msg("Failed to submit input") RespondWithError(w, err, "Internal error submitting input") + prov.deleteLogin(login, true) return } login.NextStep = nextStep if nextStep.Type == bridgev2.LoginStepTypeComplete { prov.handleCompleteStep(r.Context(), login, nextStep) + } else { + zerolog.Ctx(r.Context()).Debug().Any("next_step", nextStep).Msg("Returning next login step") } - jsonResponse(w, http.StatusOK, &RespSubmitLogin{LoginID: login.ID, LoginStep: nextStep}) + exhttp.WriteJSONResponse(w, http.StatusOK, &RespSubmitLogin{LoginID: login.ID, LoginStep: nextStep}) } func (prov *ProvisioningAPI) PostLoginWait(w http.ResponseWriter, r *http.Request) { login := r.Context().Value(provisioningLoginProcessKey).(*ProvLogin) nextStep, err := login.Process.(bridgev2.LoginProcessDisplayAndWait).Wait(r.Context()) + if err == nil && nextStep == nil { + err = ErrNilStep + } if err != nil { zerolog.Ctx(r.Context()).Err(err).Msg("Failed to wait") - jsonResponse(w, http.StatusInternalServerError, &mautrix.RespError{ - Err: "Failed to wait", - ErrCode: "M_UNKNOWN", - }) + RespondWithError(w, err, "Internal error waiting for login") + prov.deleteLogin(login, true) return } login.NextStep = nextStep if nextStep.Type == bridgev2.LoginStepTypeComplete { prov.handleCompleteStep(r.Context(), login, nextStep) + } else { + zerolog.Ctx(r.Context()).Debug().Any("next_step", nextStep).Msg("Returning next login step") } - jsonResponse(w, http.StatusOK, &RespSubmitLogin{LoginID: login.ID, LoginStep: nextStep}) + exhttp.WriteJSONResponse(w, http.StatusOK, &RespSubmitLogin{LoginID: login.ID, LoginStep: nextStep}) } func (prov *ProvisioningAPI) PostLogout(w http.ResponseWriter, r *http.Request) { user := prov.GetUser(r) - userLoginID := networkid.UserLoginID(mux.Vars(r)["loginID"]) + userLoginID := networkid.UserLoginID(r.PathValue("loginID")) if userLoginID == "all" { for { login := user.GetDefaultLogin() @@ -514,15 +554,12 @@ func (prov *ProvisioningAPI) PostLogout(w http.ResponseWriter, r *http.Request) } else { userLogin := prov.br.Bridge.GetCachedUserLoginByID(userLoginID) if userLogin == nil || userLogin.UserMXID != user.MXID { - jsonResponse(w, http.StatusNotFound, &mautrix.RespError{ - Err: "Login not found", - ErrCode: mautrix.MNotFound.ErrCode, - }) + mautrix.MNotFound.WithMessage("Login not found").Write(w) return } userLogin.Logout(r.Context()) } - jsonResponse(w, http.StatusOK, json.RawMessage("{}")) + exhttp.WriteEmptyJSONResponse(w, http.StatusOK) } type RespGetLogins struct { @@ -531,7 +568,7 @@ type RespGetLogins struct { func (prov *ProvisioningAPI) GetLogins(w http.ResponseWriter, r *http.Request) { user := prov.GetUser(r) - jsonResponse(w, http.StatusOK, &RespGetLogins{LoginIDs: user.GetUserLoginIDs()}) + exhttp.WriteJSONResponse(w, http.StatusOK, &RespGetLogins{LoginIDs: user.GetUserLoginIDs()}) } func (prov *ProvisioningAPI) GetExplicitLoginForRequest(w http.ResponseWriter, r *http.Request) (*bridgev2.UserLogin, bool) { @@ -541,15 +578,21 @@ func (prov *ProvisioningAPI) GetExplicitLoginForRequest(w http.ResponseWriter, r } userLogin := prov.br.Bridge.GetCachedUserLoginByID(userLoginID) if userLogin == nil || userLogin.UserMXID != prov.GetUser(r).MXID { - jsonResponse(w, http.StatusNotFound, &mautrix.RespError{ - Err: "Login not found", - ErrCode: mautrix.MNotFound.ErrCode, - }) + hlog.FromRequest(r).Warn(). + Str("login_id", string(userLoginID)). + Msg("Tried to use non-existent login, returning 404") + mautrix.MNotFound.WithMessage("Login not found").Write(w) return nil, true } return userLogin, false } +var ErrNotLoggedIn = mautrix.RespError{ + Err: "Not logged in", + ErrCode: "FI.MAU.NOT_LOGGED_IN", + StatusCode: http.StatusBadRequest, +} + func (prov *ProvisioningAPI) GetLoginForRequest(w http.ResponseWriter, r *http.Request) *bridgev2.UserLogin { userLogin, failed := prov.GetExplicitLoginForRequest(w, r) if userLogin != nil || failed { @@ -557,10 +600,7 @@ func (prov *ProvisioningAPI) GetLoginForRequest(w http.ResponseWriter, r *http.R } userLogin = prov.GetUser(r).GetDefaultLogin() if userLogin == nil { - jsonResponse(w, http.StatusBadRequest, &mautrix.RespError{ - Err: "Not logged in", - ErrCode: "FI.MAU.NOT_LOGGED_IN", - }) + ErrNotLoggedIn.Write(w) return nil } return userLogin @@ -575,135 +615,27 @@ func RespondWithError(w http.ResponseWriter, err error, message string) { if errors.As(err, &we) { we.Write(w) } else { - mautrix.RespError{ - Err: message, - ErrCode: "M_UNKNOWN", - StatusCode: http.StatusInternalServerError, - }.Write(w) + mautrix.MUnknown.WithMessage(message).Write(w) } } -type RespResolveIdentifier struct { - ID networkid.UserID `json:"id"` - Name string `json:"name,omitempty"` - AvatarURL id.ContentURIString `json:"avatar_url,omitempty"` - Identifiers []string `json:"identifiers,omitempty"` - MXID id.UserID `json:"mxid,omitempty"` - DMRoomID id.RoomID `json:"dm_room_mxid,omitempty"` -} - func (prov *ProvisioningAPI) doResolveIdentifier(w http.ResponseWriter, r *http.Request, createChat bool) { login := prov.GetLoginForRequest(w, r) if login == nil { return } - api, ok := login.Client.(bridgev2.IdentifierResolvingNetworkAPI) - if !ok { - jsonResponse(w, http.StatusNotImplemented, &mautrix.RespError{ - Err: "This bridge does not support resolving identifiers", - ErrCode: mautrix.MUnrecognized.ErrCode, - }) - return - } - resp, err := api.ResolveIdentifier(r.Context(), mux.Vars(r)["identifier"], createChat) + resp, err := provisionutil.ResolveIdentifier(r.Context(), login, r.PathValue("identifier"), createChat) if err != nil { - zerolog.Ctx(r.Context()).Err(err).Msg("Failed to resolve identifier") RespondWithError(w, err, "Internal error resolving identifier") - return } else if resp == nil { - jsonResponse(w, http.StatusNotFound, &mautrix.RespError{ - ErrCode: mautrix.MNotFound.ErrCode, - Err: "Identifier not found", - }) - return - } - apiResp := &RespResolveIdentifier{ - ID: resp.UserID, - } - status := http.StatusOK - if resp.Ghost != nil { - if resp.UserInfo != nil { - resp.Ghost.UpdateInfo(r.Context(), resp.UserInfo) - } - apiResp.Name = resp.Ghost.Name - apiResp.AvatarURL = resp.Ghost.AvatarMXC - apiResp.Identifiers = resp.Ghost.Identifiers - apiResp.MXID = resp.Ghost.Intent.GetMXID() - } else if resp.UserInfo != nil && resp.UserInfo.Name != nil { - apiResp.Name = *resp.UserInfo.Name - } - if resp.Chat != nil { - if resp.Chat.Portal == nil { - resp.Chat.Portal, err = prov.br.Bridge.GetPortalByKey(r.Context(), resp.Chat.PortalKey) - if err != nil { - zerolog.Ctx(r.Context()).Err(err).Msg("Failed to get portal") - jsonResponse(w, http.StatusInternalServerError, &mautrix.RespError{ - Err: "Failed to get portal", - ErrCode: "M_UNKNOWN", - }) - return - } - } - if createChat && resp.Chat.Portal.MXID == "" { + mautrix.MNotFound.WithMessage("Identifier not found").Write(w) + } else { + status := http.StatusOK + if resp.JustCreated { status = http.StatusCreated - err = resp.Chat.Portal.CreateMatrixRoom(r.Context(), login, resp.Chat.PortalInfo) - if err != nil { - zerolog.Ctx(r.Context()).Err(err).Msg("Failed to create portal room") - jsonResponse(w, http.StatusInternalServerError, &mautrix.RespError{ - Err: "Failed to create portal room", - ErrCode: "M_UNKNOWN", - }) - return - } } - apiResp.DMRoomID = resp.Chat.Portal.MXID + exhttp.WriteJSONResponse(w, status, resp) } - jsonResponse(w, status, apiResp) -} - -type RespGetContactList struct { - Contacts []*RespResolveIdentifier `json:"contacts"` -} - -func (prov *ProvisioningAPI) processResolveIdentifiers(ctx context.Context, resp []*bridgev2.ResolveIdentifierResponse) (apiResp []*RespResolveIdentifier) { - apiResp = make([]*RespResolveIdentifier, len(resp)) - for i, contact := range resp { - apiContact := &RespResolveIdentifier{ - ID: contact.UserID, - } - apiResp[i] = apiContact - if contact.UserInfo != nil { - if contact.UserInfo.Name != nil { - apiContact.Name = *contact.UserInfo.Name - } - if contact.UserInfo.Identifiers != nil { - apiContact.Identifiers = contact.UserInfo.Identifiers - } - } - if contact.Ghost != nil { - if contact.Ghost.Name != "" { - apiContact.Name = contact.Ghost.Name - } - if len(contact.Ghost.Identifiers) >= len(apiContact.Identifiers) { - apiContact.Identifiers = contact.Ghost.Identifiers - } - apiContact.AvatarURL = contact.Ghost.AvatarMXC - apiContact.MXID = contact.Ghost.Intent.GetMXID() - } - if contact.Chat != nil { - if contact.Chat.Portal == nil { - var err error - contact.Chat.Portal, err = prov.br.Bridge.GetPortalByKey(ctx, contact.Chat.PortalKey) - if err != nil { - zerolog.Ctx(ctx).Err(err).Msg("Failed to get portal") - } - } - if contact.Chat.Portal != nil { - apiContact.DMRoomID = contact.Chat.Portal.MXID - } - } - } - return } func (prov *ProvisioningAPI) GetContactList(w http.ResponseWriter, r *http.Request) { @@ -711,65 +643,36 @@ func (prov *ProvisioningAPI) GetContactList(w http.ResponseWriter, r *http.Reque if login == nil { return } - api, ok := login.Client.(bridgev2.ContactListingNetworkAPI) - if !ok { - jsonResponse(w, http.StatusNotImplemented, &mautrix.RespError{ - Err: "This bridge does not support listing contacts", - ErrCode: mautrix.MUnrecognized.ErrCode, - }) - return - } - resp, err := api.GetContactList(r.Context()) + resp, err := provisionutil.GetContactList(r.Context(), login) if err != nil { - zerolog.Ctx(r.Context()).Err(err).Msg("Failed to get contact list") - RespondWithError(w, err, "Internal error fetching contact list") + RespondWithError(w, err, "Internal error getting contact list") return } - jsonResponse(w, http.StatusOK, &RespGetContactList{ - Contacts: prov.processResolveIdentifiers(r.Context(), resp), - }) + exhttp.WriteJSONResponse(w, http.StatusOK, resp) } type ReqSearchUsers struct { Query string `json:"query"` } -type RespSearchUsers struct { - Results []*RespResolveIdentifier `json:"results"` -} - func (prov *ProvisioningAPI) PostSearchUsers(w http.ResponseWriter, r *http.Request) { var req ReqSearchUsers err := json.NewDecoder(r.Body).Decode(&req) if err != nil { zerolog.Ctx(r.Context()).Err(err).Msg("Failed to decode request body") - jsonResponse(w, http.StatusBadRequest, &mautrix.RespError{ - Err: "Failed to decode request body", - ErrCode: mautrix.MNotJSON.ErrCode, - }) + mautrix.MNotJSON.WithMessage("Failed to decode request body").Write(w) return } login := prov.GetLoginForRequest(w, r) if login == nil { return } - api, ok := login.Client.(bridgev2.UserSearchingNetworkAPI) - if !ok { - jsonResponse(w, http.StatusNotImplemented, &mautrix.RespError{ - Err: "This bridge does not support searching for users", - ErrCode: mautrix.MUnrecognized.ErrCode, - }) - return - } - resp, err := api.SearchUsers(r.Context(), req.Query) + resp, err := provisionutil.SearchUsers(r.Context(), login, req.Query) if err != nil { - zerolog.Ctx(r.Context()).Err(err).Msg("Failed to get contact list") - RespondWithError(w, err, "Internal error fetching contact list") + RespondWithError(w, err, "Internal error searching users") return } - jsonResponse(w, http.StatusOK, &RespSearchUsers{ - Results: prov.processResolveIdentifiers(r.Context(), resp), - }) + exhttp.WriteJSONResponse(w, http.StatusOK, resp) } func (prov *ProvisioningAPI) GetResolveIdentifier(w http.ResponseWriter, r *http.Request) { @@ -781,12 +684,114 @@ func (prov *ProvisioningAPI) PostCreateDM(w http.ResponseWriter, r *http.Request } func (prov *ProvisioningAPI) PostCreateGroup(w http.ResponseWriter, r *http.Request) { + var req bridgev2.GroupCreateParams + err := json.NewDecoder(r.Body).Decode(&req) + if err != nil { + zerolog.Ctx(r.Context()).Err(err).Msg("Failed to decode request body") + mautrix.MNotJSON.WithMessage("Failed to decode request body").Write(w) + return + } + req.Type = r.PathValue("type") login := prov.GetLoginForRequest(w, r) if login == nil { return } - jsonResponse(w, http.StatusNotImplemented, &mautrix.RespError{ - Err: "Creating groups is not yet implemented", - ErrCode: mautrix.MUnrecognized.ErrCode, + resp, err := provisionutil.CreateGroup(r.Context(), login, &req) + if err != nil { + RespondWithError(w, err, "Internal error creating group") + return + } + exhttp.WriteJSONResponse(w, http.StatusOK, resp) +} + +type ReqExportCredentials struct { + RemoteID networkid.UserLoginID `json:"remote_id"` +} + +type RespExportCredentials struct { + Credentials any `json:"credentials"` +} + +func (prov *ProvisioningAPI) PostInitSessionTransfer(w http.ResponseWriter, r *http.Request) { + prov.sessionTransfersLock.Lock() + defer prov.sessionTransfersLock.Unlock() + + var req ReqExportCredentials + err := json.NewDecoder(r.Body).Decode(&req) + if err != nil { + zerolog.Ctx(r.Context()).Err(err).Msg("Failed to decode request body") + mautrix.MNotJSON.WithMessage("Failed to decode request body").Write(w) + return + } + + user := prov.GetUser(r) + logins := user.GetUserLogins() + var loginToExport *bridgev2.UserLogin + for _, login := range logins { + if login.ID == req.RemoteID { + loginToExport = login + break + } + } + if loginToExport == nil { + mautrix.MNotFound.WithMessage("No matching user login found").Write(w) + return + } + + client, ok := loginToExport.Client.(bridgev2.CredentialExportingNetworkAPI) + if !ok { + mautrix.MUnrecognized.WithMessage("This bridge does not support exporting credentials").Write(w) + return + } + + if _, ok := prov.sessionTransfers[loginToExport.ID]; ok { + // Warn, but allow, double exports. This might happen if a client crashes handling creds, + // and should be safe to call multiple times. + zerolog.Ctx(r.Context()).Warn().Msg("Exporting already exported credentials") + } + + // Disconnect now so we don't use the same network session in two places at once + client.Disconnect() + exhttp.WriteJSONResponse(w, http.StatusOK, &RespExportCredentials{ + Credentials: client.ExportCredentials(r.Context()), }) } + +func (prov *ProvisioningAPI) PostFinishSessionTransfer(w http.ResponseWriter, r *http.Request) { + prov.sessionTransfersLock.Lock() + defer prov.sessionTransfersLock.Unlock() + + var req ReqExportCredentials + err := json.NewDecoder(r.Body).Decode(&req) + if err != nil { + zerolog.Ctx(r.Context()).Err(err).Msg("Failed to decode request body") + mautrix.MNotJSON.WithMessage("Failed to decode request body").Write(w) + return + } + + user := prov.GetUser(r) + logins := user.GetUserLogins() + var loginToExport *bridgev2.UserLogin + for _, login := range logins { + if login.ID == req.RemoteID { + loginToExport = login + break + } + } + if loginToExport == nil { + mautrix.MNotFound.WithMessage("No matching user login found").Write(w) + return + } else if _, ok := prov.sessionTransfers[loginToExport.ID]; !ok { + mautrix.MBadState.WithMessage("No matching credential export found").Write(w) + return + } + + zerolog.Ctx(r.Context()).Info(). + Str("remote_name", string(req.RemoteID)). + Msg("Logging out remote after finishing credential export") + + loginToExport.Client.LogoutRemote(r.Context()) + delete(prov.sessionTransfers, req.RemoteID) + + exhttp.WriteEmptyJSONResponse(w, http.StatusOK) +} diff --git a/bridgev2/matrix/provisioning.yaml b/bridgev2/matrix/provisioning.yaml index bf6c6f3d..26068db4 100644 --- a/bridgev2/matrix/provisioning.yaml +++ b/bridgev2/matrix/provisioning.yaml @@ -361,14 +361,25 @@ paths: $ref: '#/components/responses/InternalError' 501: $ref: '#/components/responses/NotSupported' - /v3/create_group: + /v3/create_group/{type}: post: tags: [ snc ] summary: Create a group chat on the remote network. operationId: createGroup parameters: - $ref: "#/components/parameters/loginID" + requestBody: + content: + application/json: + schema: + $ref: '#/components/schemas/GroupCreateParams' responses: + 200: + description: Identifier resolved successfully + content: + application/json: + schema: + $ref: '#/components/schemas/CreatedGroup' 401: $ref: '#/components/responses/Unauthorized' 404: @@ -389,7 +400,7 @@ components: - username - meow@example.com loginID: - name: loginID + name: login_id in: query description: An optional explicit login ID to do the action through. required: false @@ -572,6 +583,74 @@ components: description: The Matrix room ID of the direct chat with the user. examples: - '!OKhS0I5q2fCzdnl2qgeozDQw:t2bot.io' + GroupCreateParams: + type: object + description: | + Parameters for creating a group chat. + The /capabilities endpoint response must be checked to see which fields are actually allowed. + properties: + type: + type: string + description: The type of group to create. + examples: + - channel + username: + type: string + description: The public username for the created group. + participants: + type: array + description: The users to add to the group initially. + items: + type: string + parent: + type: object + name: + type: object + description: The `m.room.name` event content for the room. + properties: + name: + type: string + avatar: + type: object + description: The `m.room.avatar` event content for the room. + properties: + url: + type: string + format: mxc + topic: + type: object + description: The `m.room.topic` event content for the room. + properties: + topic: + type: string + disappear: + type: object + description: The `com.beeper.disappearing_timer` event content for the room. + properties: + type: + type: string + timer: + type: number + room_id: + type: string + format: matrix_room_id + description: | + An existing Matrix room ID to bridge to. + The other parameters must be already in sync with the room state when using this parameter. + CreatedGroup: + type: object + description: A successfully created group chat. + required: [id, mxid] + properties: + id: + type: string + description: The internal chat ID of the created group. + mxid: + type: string + format: matrix_room_id + description: The Matrix room ID of the portal. + examples: + - '!OKhS0I5q2fCzdnl2qgeozDQw:t2bot.io' LoginStep: type: object description: A step in a login process. @@ -635,7 +714,7 @@ components: type: type: string description: The type of field. - enum: [ username, phone_number, email, password, 2fa_code, token, url, domain ] + enum: [ username, phone_number, email, password, 2fa_code, token, url, domain, select ] id: type: string description: The internal ID of the field. This must be used as the key in the object when submitting the data back to the bridge. @@ -649,10 +728,53 @@ components: description: A more detailed description of the field shown to the user. examples: - Include the country code with a + + default_value: + type: string + description: A default value that the client can pre-fill the field with. pattern: type: string format: regex description: A regular expression that the field value must match. + options: + type: array + description: For fields of type select, the valid options. + items: + type: string + attachments: + type: array + description: A list of media attachments to show the user alongside the form fields. + items: + type: object + description: A media attachment to show the user. + required: [ type, filename, content ] + properties: + type: + type: string + description: The type of media attachment, using the same media type identifiers as Matrix attachments. Only some are supported. + enum: [ m.image, m.audio ] + filename: + type: string + description: The filename for the media attachment. + content: + type: string + description: The raw file content for the attachment encoded in base64. + info: + type: object + description: Optional but recommended metadata for the attachment. Can generally be derived from the raw content if omitted. + properties: + mimetype: + type: string + description: The MIME type for the media content. + examples: [ image/png, audio/mpeg ] + w: + type: number + description: The width of the media in pixels. Only applicable for images and videos. + h: + type: number + description: The height of the media in pixels. Only applicable for images and videos. + size: + type: number + description: The size of the media content in number of bytes. Strongly recommended to include. - description: Cookie login step required: [ type, cookies ] properties: @@ -671,6 +793,20 @@ components: user_agent: type: string description: An optional user agent that the webview should use. + wait_for_url_pattern: + type: string + description: | + A regex pattern that the URL should match before the client closes the webview. + + The client may submit the login if the user closes the webview after all cookies are collected + even if this URL is not reached, but it should only automatically close the webview after + both cookies and the URL match. + extract_js: + type: string + description: | + A JavaScript snippet that can extract some or all of the fields. + The snippet will evaluate to a promise that resolves when the relevant fields are found. + Fields that are not present in the promise result must be extracted another way. fields: type: array description: The list of cookies or other stored data that must be extracted. diff --git a/bridgev2/matrix/publicmedia.go b/bridgev2/matrix/publicmedia.go index 9db5f442..82ea8c2b 100644 --- a/bridgev2/matrix/publicmedia.go +++ b/bridgev2/matrix/publicmedia.go @@ -1,4 +1,4 @@ -// Copyright (c) 2024 Tulir Asokan +// Copyright (c) 2025 Tulir Asokan // // This Source Code Form is subject to the terms of the Mozilla Public // License, v. 2.0. If a copy of the MPL was not distributed with this @@ -7,18 +7,26 @@ package matrix import ( + "context" "crypto/hmac" "crypto/sha256" "encoding/base64" "encoding/binary" "fmt" "io" + "mime" "net/http" + "net/url" + "slices" + "strings" "time" - "github.com/gorilla/mux" + "github.com/rs/zerolog" "maunium.net/go/mautrix/bridgev2" + "maunium.net/go/mautrix/bridgev2/database" + "maunium.net/go/mautrix/crypto/attachment" + "maunium.net/go/mautrix/event" "maunium.net/go/mautrix/id" ) @@ -35,7 +43,10 @@ func (br *Connector) initPublicMedia() error { return fmt.Errorf("public media hash length is negative") } br.pubMediaSigKey = []byte(br.Config.PublicMedia.SigningKey) - br.AS.Router.HandleFunc("/_mautrix/publicmedia/{server}/{mediaID}/{checksum}", br.servePublicMedia).Methods(http.MethodGet) + br.AS.Router.HandleFunc("GET /_mautrix/publicmedia/{customID}", br.serveDatabasePublicMedia) + br.AS.Router.HandleFunc("GET /_mautrix/publicmedia/{customID}/{filename}", br.serveDatabasePublicMedia) + br.AS.Router.HandleFunc("GET /_mautrix/publicmedia/{server}/{mediaID}/{checksum}", br.servePublicMedia) + br.AS.Router.HandleFunc("GET /_mautrix/publicmedia/{server}/{mediaID}/{checksum}/{filename}", br.servePublicMedia) return nil } @@ -46,6 +57,20 @@ func (br *Connector) hashContentURI(uri id.ContentURI, expiry []byte) []byte { return hasher.Sum(expiry)[:br.Config.PublicMedia.HashLength+len(expiry)] } +func (br *Connector) hashDBPublicMedia(pm *database.PublicMedia) []byte { + hasher := hmac.New(sha256.New, br.pubMediaSigKey) + hasher.Write([]byte(pm.MXC.String())) + hasher.Write([]byte(pm.MimeType)) + if pm.Keys != nil { + hasher.Write([]byte(pm.Keys.Version)) + hasher.Write([]byte(pm.Keys.Key.Algorithm)) + hasher.Write([]byte(pm.Keys.Key.Key)) + hasher.Write([]byte(pm.Keys.InitVector)) + hasher.Write([]byte(pm.Keys.Hashes.SHA256)) + } + return hasher.Sum(nil)[:br.Config.PublicMedia.HashLength] +} + func (br *Connector) makePublicMediaChecksum(uri id.ContentURI) []byte { var expiresAt []byte if br.Config.PublicMedia.Expiry > 0 { @@ -76,16 +101,15 @@ var proxyHeadersToCopy = []string{ } func (br *Connector) servePublicMedia(w http.ResponseWriter, r *http.Request) { - vars := mux.Vars(r) contentURI := id.ContentURI{ - Homeserver: vars["server"], - FileID: vars["mediaID"], + Homeserver: r.PathValue("server"), + FileID: r.PathValue("mediaID"), } if !contentURI.IsValid() { http.Error(w, "invalid content URI", http.StatusBadRequest) return } - checksum, err := base64.RawURLEncoding.DecodeString(vars["checksum"]) + checksum, err := base64.RawURLEncoding.DecodeString(r.PathValue("checksum")) if err != nil || !hmac.Equal(checksum, br.makePublicMediaChecksum(contentURI)) { http.Error(w, "invalid base64 in checksum", http.StatusBadRequest) return @@ -96,9 +120,47 @@ func (br *Connector) servePublicMedia(w http.ResponseWriter, r *http.Request) { http.Error(w, "checksum expired", http.StatusGone) return } + br.doProxyMedia(w, r, contentURI, nil, "") +} + +func (br *Connector) serveDatabasePublicMedia(w http.ResponseWriter, r *http.Request) { + if !br.Config.PublicMedia.UseDatabase { + http.Error(w, "public media short links are disabled", http.StatusNotFound) + return + } + log := zerolog.Ctx(r.Context()) + media, err := br.Bridge.DB.PublicMedia.Get(r.Context(), r.PathValue("customID")) + if err != nil { + log.Err(err).Msg("Failed to get public media from database") + http.Error(w, "failed to get media metadata", http.StatusInternalServerError) + return + } else if media == nil { + http.Error(w, "media ID not found", http.StatusNotFound) + return + } else if !media.Expiry.IsZero() && media.Expiry.Before(time.Now()) { + // This is not gone as it can still be refreshed in the DB + http.Error(w, "media expired", http.StatusNotFound) + return + } else if media.Keys != nil && media.Keys.PrepareForDecryption() != nil { + http.Error(w, "media keys are malformed", http.StatusInternalServerError) + return + } + br.doProxyMedia(w, r, media.MXC, media.Keys, media.MimeType) +} + +var safeMimes = []string{ + "text/css", "text/plain", "text/csv", + "application/json", "application/ld+json", + "image/jpeg", "image/gif", "image/png", "image/apng", "image/webp", "image/avif", + "video/mp4", "video/webm", "video/ogg", "video/quicktime", + "audio/mp4", "audio/webm", "audio/aac", "audio/mpeg", "audio/ogg", "audio/wave", + "audio/wav", "audio/x-wav", "audio/x-pn-wav", "audio/flac", "audio/x-flac", +} + +func (br *Connector) doProxyMedia(w http.ResponseWriter, r *http.Request, contentURI id.ContentURI, encInfo *attachment.EncryptedFile, mimeType string) { resp, err := br.Bot.Download(r.Context(), contentURI) if err != nil { - br.Log.Warn().Stringer("uri", contentURI).Err(err).Msg("Failed to download media to proxy") + zerolog.Ctx(r.Context()).Warn().Stringer("uri", contentURI).Err(err).Msg("Failed to download media to proxy") http.Error(w, "failed to download media", http.StatusInternalServerError) return } @@ -106,11 +168,41 @@ func (br *Connector) servePublicMedia(w http.ResponseWriter, r *http.Request) { for _, hdr := range proxyHeadersToCopy { w.Header()[hdr] = resp.Header[hdr] } + stream := resp.Body + if encInfo != nil { + if mimeType == "" { + mimeType = "application/octet-stream" + } + contentDisposition := "attachment" + if slices.Contains(safeMimes, mimeType) { + contentDisposition = "inline" + } + dispositionArgs := map[string]string{} + if filename := r.PathValue("filename"); filename != "" { + dispositionArgs["filename"] = filename + } + w.Header().Set("Content-Type", mimeType) + w.Header().Set("Content-Disposition", mime.FormatMediaType(contentDisposition, dispositionArgs)) + // Note: this won't check the Close result like it should, but it's probably not a big deal here + stream = encInfo.DecryptStream(stream) + } else if filename := r.PathValue("filename"); filename != "" { + contentDisposition, _, _ := mime.ParseMediaType(resp.Header.Get("Content-Disposition")) + if contentDisposition == "" { + contentDisposition = "attachment" + } + w.Header().Set("Content-Disposition", mime.FormatMediaType(contentDisposition, map[string]string{ + "filename": filename, + })) + } w.WriteHeader(http.StatusOK) - _, _ = io.Copy(w, resp.Body) + _, _ = io.Copy(w, stream) } func (br *Connector) GetPublicMediaAddress(contentURI id.ContentURIString) string { + return br.getPublicMediaAddressWithFileName(contentURI, "") +} + +func (br *Connector) getPublicMediaAddressWithFileName(contentURI id.ContentURIString, fileName string) string { if br.pubMediaSigKey == nil { return "" } @@ -118,11 +210,69 @@ func (br *Connector) GetPublicMediaAddress(contentURI id.ContentURIString) strin if err != nil || !parsed.IsValid() { return "" } - return fmt.Sprintf( - "%s/_mautrix/publicmedia/%s/%s/%s", + fileName = url.PathEscape(strings.ReplaceAll(fileName, "/", "_")) + if fileName == ".." { + fileName = "" + } + parts := []string{ br.GetPublicAddress(), + strings.Trim(br.Config.PublicMedia.PathPrefix, "/"), parsed.Homeserver, parsed.FileID, base64.RawURLEncoding.EncodeToString(br.makePublicMediaChecksum(parsed)), - ) + fileName, + } + if fileName == "" { + parts = parts[:len(parts)-1] + } + return strings.Join(parts, "/") +} + +func (br *Connector) GetPublicMediaAddressForEvent(ctx context.Context, evt *event.MessageEventContent) (string, error) { + if br.pubMediaSigKey == nil { + return "", bridgev2.ErrPublicMediaDisabled + } + if !br.Config.PublicMedia.UseDatabase { + if evt.File != nil { + return "", fmt.Errorf("can't generate address for encrypted file: %w", bridgev2.ErrPublicMediaDatabaseDisabled) + } + return br.getPublicMediaAddressWithFileName(evt.URL, evt.GetFileName()), nil + } + mxc := evt.URL + var keys *attachment.EncryptedFile + if evt.File != nil { + mxc = evt.File.URL + keys = &evt.File.EncryptedFile + } + parsedMXC, err := mxc.Parse() + if err != nil { + return "", fmt.Errorf("%w: failed to parse MXC: %w", bridgev2.ErrPublicMediaGenerateFailed, err) + } + pm := &database.PublicMedia{ + MXC: parsedMXC, + Keys: keys, + MimeType: evt.GetInfo().MimeType, + } + if br.Config.PublicMedia.Expiry > 0 { + pm.Expiry = time.Now().Add(time.Duration(br.Config.PublicMedia.Expiry) * time.Second) + } + pm.PublicID = base64.RawURLEncoding.EncodeToString(br.hashDBPublicMedia(pm)) + err = br.Bridge.DB.PublicMedia.Put(ctx, pm) + if err != nil { + return "", fmt.Errorf("%w: failed to store public media in database: %w", bridgev2.ErrPublicMediaGenerateFailed, err) + } + fileName := url.PathEscape(strings.ReplaceAll(evt.GetFileName(), "/", "_")) + if fileName == ".." { + fileName = "" + } + parts := []string{ + br.GetPublicAddress(), + strings.Trim(br.Config.PublicMedia.PathPrefix, "/"), + pm.PublicID, + fileName, + } + if fileName == "" { + parts = parts[:len(parts)-1] + } + return strings.Join(parts, "/"), nil } diff --git a/bridgev2/matrix/websocket.go b/bridgev2/matrix/websocket.go index c679f960..b498cacd 100644 --- a/bridgev2/matrix/websocket.go +++ b/bridgev2/matrix/websocket.go @@ -57,7 +57,7 @@ func (br *Connector) startWebsocket(wg *sync.WaitGroup) { addr = br.Config.Homeserver.Address } for { - err := br.AS.StartWebsocket(addr, onConnect) + err := br.AS.StartWebsocket(br.Bridge.BackgroundCtx, addr, onConnect) if errors.Is(err, appservice.ErrWebsocketManualStop) { return } else if closeCommand := (&appservice.CloseCommand{}); errors.As(err, &closeCommand) && closeCommand.Status == appservice.MeowConnectionReplaced { diff --git a/bridgev2/matrixinterface.go b/bridgev2/matrixinterface.go index 699ce07b..be26db49 100644 --- a/bridgev2/matrixinterface.go +++ b/bridgev2/matrixinterface.go @@ -1,4 +1,4 @@ -// Copyright (c) 2024 Tulir Asokan +// Copyright (c) 2025 Tulir Asokan // // This Source Code Form is subject to the terms of the Mozilla Public // License, v. 2.0. If a copy of the MPL was not distributed with this @@ -10,27 +10,31 @@ import ( "context" "fmt" "io" + "net/http" "os" "time" - "github.com/gorilla/mux" + "go.mau.fi/util/exhttp" "maunium.net/go/mautrix" - "maunium.net/go/mautrix/bridge/status" "maunium.net/go/mautrix/bridgev2/database" "maunium.net/go/mautrix/bridgev2/networkid" + "maunium.net/go/mautrix/bridgev2/status" "maunium.net/go/mautrix/event" "maunium.net/go/mautrix/id" ) type MatrixCapabilities struct { - AutoJoinInvites bool - BatchSending bool + AutoJoinInvites bool + BatchSending bool + ArbitraryMemberChange bool + ExtraProfileMeta bool } type MatrixConnector interface { Init(*Bridge) Start(ctx context.Context) error + PreStop() Stop() GetCapabilities() *MatrixCapabilities @@ -50,37 +54,85 @@ type MatrixConnector interface { GetMemberInfo(ctx context.Context, roomID id.RoomID, userID id.UserID) (*event.MemberEventContent, error) BatchSend(ctx context.Context, roomID id.RoomID, req *mautrix.ReqBeeperBatchSend, extras []*MatrixSendExtra) (*mautrix.RespBeeperBatchSend, error) + GenerateDeterministicRoomID(portalKey networkid.PortalKey) id.RoomID GenerateDeterministicEventID(roomID id.RoomID, portalKey networkid.PortalKey, messageID networkid.MessageID, partID networkid.PartID) id.EventID GenerateReactionEventID(roomID id.RoomID, targetMessage *database.Message, sender networkid.UserID, emojiID networkid.EmojiID) id.EventID ServerName() string } +type MatrixConnectorWithArbitraryRoomState interface { + MatrixConnector + GetStateEvent(ctx context.Context, roomID id.RoomID, eventType event.Type, stateKey string) (*event.Event, error) +} + type MatrixConnectorWithServer interface { + MatrixConnector GetPublicAddress() string - GetRouter() *mux.Router + GetRouter() *http.ServeMux +} + +type IProvisioningAPI interface { + GetRouter() *http.ServeMux + GetUser(r *http.Request) *User +} + +type MatrixConnectorWithProvisioning interface { + MatrixConnector + GetProvisioning() IProvisioningAPI } type MatrixConnectorWithPublicMedia interface { + MatrixConnector GetPublicMediaAddress(contentURI id.ContentURIString) string + GetPublicMediaAddressForEvent(ctx context.Context, evt *event.MessageEventContent) (string, error) } type MatrixConnectorWithNameDisambiguation interface { + MatrixConnector IsConfusableName(ctx context.Context, roomID id.RoomID, userID id.UserID, name string) ([]id.UserID, error) } +type MatrixConnectorWithBridgeIdentifier interface { + MatrixConnector + GetUniqueBridgeID() string +} + type MatrixConnectorWithURLPreviews interface { + MatrixConnector GetURLPreview(ctx context.Context, url string) (*event.LinkPreview, error) } type MatrixConnectorWithPostRoomBridgeHandling interface { + MatrixConnector HandleNewlyBridgedRoom(ctx context.Context, roomID id.RoomID) error } type MatrixConnectorWithAnalytics interface { + MatrixConnector TrackAnalytics(userID id.UserID, event string, properties map[string]any) } +type DirectNotificationData struct { + Portal *Portal + Sender *Ghost + MessageID networkid.MessageID + Message string + + FormattedNotification string + FormattedTitle string +} + +type MatrixConnectorWithNotifications interface { + MatrixConnector + DisplayNotification(ctx context.Context, data *DirectNotificationData) +} + +type MatrixConnectorWithHTTPSettings interface { + MatrixConnector + GetHTTPClientSettings() exhttp.ClientSettings +} + type MatrixSendExtra struct { Timestamp time.Time MessageMeta *database.Message @@ -124,8 +176,13 @@ func (ce CallbackError) Unwrap() error { return ce.Wrapped } +type EnsureJoinedParams struct { + Via []string +} + type MatrixAPI interface { GetMXID() id.UserID + IsDoublePuppet() bool SendMessage(ctx context.Context, roomID id.RoomID, eventType event.Type, content *event.Content, extra *MatrixSendExtra) (*mautrix.RespSendEvent, error) SendState(ctx context.Context, roomID id.RoomID, eventType event.Type, stateKey string, content *event.Content, ts time.Time) (*mautrix.RespSendEvent, error) @@ -143,13 +200,26 @@ type MatrixAPI interface { CreateRoom(ctx context.Context, req *mautrix.ReqCreateRoom) (id.RoomID, error) DeleteRoom(ctx context.Context, roomID id.RoomID, puppetsOnly bool) error - EnsureJoined(ctx context.Context, roomID id.RoomID) error + EnsureJoined(ctx context.Context, roomID id.RoomID, params ...EnsureJoinedParams) error EnsureInvited(ctx context.Context, roomID id.RoomID, userID id.UserID) error TagRoom(ctx context.Context, roomID id.RoomID, tag event.RoomTag, isTagged bool) error MuteRoom(ctx context.Context, roomID id.RoomID, until time.Time) error + + GetEvent(ctx context.Context, roomID id.RoomID, eventID id.EventID) (*event.Event, error) +} + +type StreamOrderReadingMatrixAPI interface { + MatrixAPI + MarkStreamOrderRead(ctx context.Context, roomID id.RoomID, streamOrder int64, ts time.Time) error } type MarkAsDMMatrixAPI interface { + MatrixAPI MarkAsDM(ctx context.Context, roomID id.RoomID, otherUser id.UserID) error } + +type EphemeralSendingMatrixAPI interface { + MatrixAPI + BeeperSendEphemeralEvent(ctx context.Context, roomID id.RoomID, eventType event.Type, content *event.Content, txnID string) (*mautrix.RespSendEvent, error) +} diff --git a/bridgev2/matrixinvite.go b/bridgev2/matrixinvite.go index f8217700..75c00cb0 100644 --- a/bridgev2/matrixinvite.go +++ b/bridgev2/matrixinvite.go @@ -19,17 +19,17 @@ import ( "maunium.net/go/mautrix/id" ) -func (br *Bridge) handleBotInvite(ctx context.Context, evt *event.Event, sender *User) { +func (br *Bridge) handleBotInvite(ctx context.Context, evt *event.Event, sender *User) EventHandlingResult { log := zerolog.Ctx(ctx) // These invites should already be rejected in QueueMatrixEvent if !sender.Permissions.Commands { log.Warn().Msg("Received bot invite from user without permission to send commands") - return + return EventHandlingResultIgnored } err := br.Bot.EnsureJoined(ctx, evt.RoomID) if err != nil { log.Err(err).Msg("Failed to accept invite to room") - return + return EventHandlingResultFailed } log.Debug().Msg("Accepted invite to room as bot") members, err := br.Matrix.GetMembers(ctx, evt.RoomID) @@ -55,6 +55,7 @@ func (br *Bridge) handleBotInvite(ctx context.Context, evt *event.Event, sender log.Err(err).Msg("Failed to send welcome message to room") } } + return EventHandlingResultSuccess } func sendNotice(ctx context.Context, evt *event.Event, intent MatrixAPI, message string, args ...any) { @@ -87,12 +88,42 @@ func sendErrorAndLeave(ctx context.Context, evt *event.Event, intent MatrixAPI, rejectInvite(ctx, evt, intent, "") } -func (br *Bridge) handleGhostDMInvite(ctx context.Context, evt *event.Event, sender *User) { +func (portal *Portal) CleanupOrphanedDM(ctx context.Context, userMXID id.UserID) { + if portal.MXID == "" { + return + } + log := zerolog.Ctx(ctx) + existingPortalMembers, err := portal.Bridge.Matrix.GetMembers(ctx, portal.MXID) + if err != nil { + log.Err(err). + Stringer("old_portal_mxid", portal.MXID). + Msg("Failed to check existing portal members, deleting room") + } else if targetUserMember, ok := existingPortalMembers[userMXID]; !ok { + log.Debug(). + Stringer("old_portal_mxid", portal.MXID). + Msg("Inviter has no member event in old portal, deleting room") + } else if targetUserMember.Membership.IsInviteOrJoin() { + return + } else { + log.Debug(). + Stringer("old_portal_mxid", portal.MXID). + Str("membership", string(targetUserMember.Membership)). + Msg("Inviter is not in old portal, deleting room") + } + + if err = portal.RemoveMXID(ctx); err != nil { + log.Err(err).Msg("Failed to delete old portal mxid") + } else if err = portal.Bridge.Bot.DeleteRoom(ctx, portal.MXID, true); err != nil { + log.Err(err).Msg("Failed to clean up old portal room") + } +} + +func (br *Bridge) handleGhostDMInvite(ctx context.Context, evt *event.Event, sender *User) EventHandlingResult { ghostID, _ := br.Matrix.ParseGhostMXID(id.UserID(evt.GetStateKey())) validator, ok := br.Network.(IdentifierValidatingNetwork) if ghostID == "" || (ok && !validator.ValidateUserID(ghostID)) { rejectInvite(ctx, evt, br.Matrix.GhostIntent(ghostID), "Malformed user ID") - return + return EventHandlingResultIgnored } log := zerolog.Ctx(ctx).With(). Str("invitee_network_id", string(ghostID)). @@ -102,22 +133,22 @@ func (br *Bridge) handleGhostDMInvite(ctx context.Context, evt *event.Event, sen logins := sender.GetUserLogins() if len(logins) == 0 { rejectInvite(ctx, evt, br.Matrix.GhostIntent(ghostID), "You're not logged in") - return + return EventHandlingResultIgnored } _, ok = logins[0].Client.(IdentifierResolvingNetworkAPI) if !ok { rejectInvite(ctx, evt, br.Matrix.GhostIntent(ghostID), "This bridge does not support starting chats") - return + return EventHandlingResultIgnored } invitedGhost, err := br.GetGhostByID(ctx, ghostID) if err != nil { log.Err(err).Msg("Failed to get invited ghost") - return + return EventHandlingResultFailed } err = invitedGhost.Intent.EnsureJoined(ctx, evt.RoomID) if err != nil { log.Err(err).Msg("Failed to accept invite to room") - return + return EventHandlingResultFailed } var resp *CreateChatResponse var sourceLogin *UserLogin @@ -144,7 +175,7 @@ func (br *Bridge) handleGhostDMInvite(ctx context.Context, evt *event.Event, sen } else if err != nil { log.Err(err).Msg("Failed to resolve identifier") sendErrorAndLeave(ctx, evt, invitedGhost.Intent, "Failed to create chat") - return + return EventHandlingResultFailed } else { sourceLogin = login break @@ -153,7 +184,7 @@ func (br *Bridge) handleGhostDMInvite(ctx context.Context, evt *event.Event, sen if resp == nil { log.Warn().Msg("No login could resolve the identifier") sendErrorAndLeave(ctx, evt, br.Matrix.GhostIntent(ghostID), "Failed to create chat via any login") - return + return EventHandlingResultFailed } portal := resp.Portal if portal == nil { @@ -161,61 +192,85 @@ func (br *Bridge) handleGhostDMInvite(ctx context.Context, evt *event.Event, sen if err != nil { log.Err(err).Msg("Failed to get portal by key") sendErrorAndLeave(ctx, evt, br.Matrix.GhostIntent(ghostID), "Failed to create portal entry") - return + return EventHandlingResultFailed } } + portal.CleanupOrphanedDM(ctx, sender.MXID) err = invitedGhost.Intent.EnsureInvited(ctx, evt.RoomID, br.Bot.GetMXID()) if err != nil { log.Err(err).Msg("Failed to ensure bot is invited to room") sendErrorAndLeave(ctx, evt, invitedGhost.Intent, "Failed to invite bridge bot") - return + return EventHandlingResultFailed } err = br.Bot.EnsureJoined(ctx, evt.RoomID) if err != nil { log.Err(err).Msg("Failed to ensure bot is joined to room") sendErrorAndLeave(ctx, evt, invitedGhost.Intent, "Failed to join with bridge bot") - return + return EventHandlingResultFailed } - didSetPortal := portal.setMXIDToExistingRoom(evt.RoomID) - if resp.PortalInfo != nil { - portal.UpdateInfo(ctx, resp.PortalInfo, sourceLogin, nil, time.Time{}) + portal.roomCreateLock.Lock() + defer portal.roomCreateLock.Unlock() + portalMXID := portal.MXID + if portalMXID != "" { + sendErrorAndLeave(ctx, evt, invitedGhost.Intent, "You already have a direct chat with me at [%s](%s)", portalMXID, portalMXID.URI(br.Matrix.ServerName()).MatrixToURL()) + rejectInvite(ctx, evt, br.Bot, "") + return EventHandlingResultSuccess } - if didSetPortal { - // TODO this might become unnecessary if UpdateInfo starts taking care of it - _, err = br.Bot.SendState(ctx, portal.MXID, event.StateElementFunctionalMembers, "", &event.Content{ - Parsed: &event.ElementFunctionalMembersContent{ - ServiceMembers: []id.UserID{br.Bot.GetMXID()}, + err = br.givePowerToBot(ctx, evt.RoomID, invitedGhost.Intent) + if err != nil { + log.Err(err).Msg("Failed to give permissions to bridge bot") + sendErrorAndLeave(ctx, evt, invitedGhost.Intent, "Failed to give permissions to bridge bot") + rejectInvite(ctx, evt, br.Bot, "") + return EventHandlingResultSuccess + } + overrideIntent := invitedGhost.Intent + if resp.DMRedirectedTo != "" && resp.DMRedirectedTo != invitedGhost.ID { + log.Debug(). + Str("dm_redirected_to_id", string(resp.DMRedirectedTo)). + Msg("Created DM was redirected to another user ID") + _, err = invitedGhost.Intent.SendState(ctx, evt.RoomID, event.StateMember, invitedGhost.Intent.GetMXID().String(), &event.Content{ + Parsed: &event.MemberEventContent{ + Membership: event.MembershipLeave, + Reason: "Direct chat redirected to another internal user ID", }, }, time.Time{}) if err != nil { - log.Warn().Err(err).Msg("Failed to set service members in room") + log.Err(err).Msg("Failed to make incorrect ghost leave new DM room") } - message := "Private chat portal created" - err = br.givePowerToBot(ctx, evt.RoomID, invitedGhost.Intent) - hasWarning := false - if err != nil { - log.Warn().Err(err).Msg("Failed to give power to bot in new DM") - message += "\n\nWarning: failed to promote bot" - hasWarning = true + if resp.DMRedirectedTo == SpecialValueDMRedirectedToBot { + overrideIntent = br.Bot + } else if otherUserGhost, err := br.GetGhostByID(ctx, resp.DMRedirectedTo); err != nil { + log.Err(err).Msg("Failed to get ghost of real portal other user ID") + } else { + invitedGhost = otherUserGhost + overrideIntent = otherUserGhost.Intent } - mx, ok := br.Matrix.(MatrixConnectorWithPostRoomBridgeHandling) - if ok { - err = mx.HandleNewlyBridgedRoom(ctx, evt.RoomID) - if err != nil { - if hasWarning { - message += fmt.Sprintf(", %s", err.Error()) - } else { - message += fmt.Sprintf("\n\nWarning: %s", err.Error()) - } - } - } - sendNotice(ctx, evt, invitedGhost.Intent, message) - } else { - // TODO ensure user is invited even if PortalInfo wasn't provided? - sendErrorAndLeave(ctx, evt, invitedGhost.Intent, "You already have a direct chat with me at [%s](%s)", portal.MXID, portal.MXID.URI(br.Matrix.ServerName()).MatrixToURL()) - rejectInvite(ctx, evt, br.Bot, "") } + err = portal.UpdateMatrixRoomID(ctx, evt.RoomID, UpdateMatrixRoomIDParams{ + // We locked it before checking the mxid + RoomCreateAlreadyLocked: true, + + FailIfMXIDSet: true, + ChatInfo: resp.PortalInfo, + ChatInfoSource: sourceLogin, + }) + if err != nil { + log.Err(err).Msg("Failed to update Matrix room ID for new DM portal") + sendNotice(ctx, evt, overrideIntent, "Failed to finish configuring portal. The chat may or may not work") + return EventHandlingResultSuccess + } + message := "Private chat portal created" + mx, ok := br.Matrix.(MatrixConnectorWithPostRoomBridgeHandling) + if ok { + err = mx.HandleNewlyBridgedRoom(ctx, evt.RoomID) + if err != nil { + log.Err(err).Msg("Error in connector newly bridged room handler") + message += fmt.Sprintf("\n\nWarning: %s", err.Error()) + } + } + sendNotice(ctx, evt, overrideIntent, message) + return EventHandlingResultSuccess } func (br *Bridge) givePowerToBot(ctx context.Context, roomID id.RoomID, userWithPower MatrixAPI) error { @@ -225,6 +280,9 @@ func (br *Bridge) givePowerToBot(ctx context.Context, roomID id.RoomID, userWith } userLevel := powers.GetUserLevel(userWithPower.GetMXID()) if powers.EnsureUserLevelAs(userWithPower.GetMXID(), br.Bot.GetMXID(), userLevel) { + if userLevel > powers.UsersDefault { + powers.SetUserLevel(userWithPower.GetMXID(), userLevel-1) + } _, err = userWithPower.SendState(ctx, roomID, event.StatePowerLevels, "", &event.Content{ Parsed: powers, }, time.Time{}) @@ -234,17 +292,3 @@ func (br *Bridge) givePowerToBot(ctx context.Context, roomID id.RoomID, userWith } return nil } - -func (portal *Portal) setMXIDToExistingRoom(roomID id.RoomID) bool { - portal.roomCreateLock.Lock() - defer portal.roomCreateLock.Unlock() - if portal.MXID != "" { - return false - } - portal.MXID = roomID - portal.updateLogger() - portal.Bridge.cacheLock.Lock() - portal.Bridge.portalsByMXID[portal.MXID] = portal - portal.Bridge.cacheLock.Unlock() - return true -} diff --git a/bridgev2/messagestatus.go b/bridgev2/messagestatus.go index 1983b4de..df0c9e4d 100644 --- a/bridgev2/messagestatus.go +++ b/bridgev2/messagestatus.go @@ -12,13 +12,15 @@ import ( "go.mau.fi/util/jsontime" - "maunium.net/go/mautrix/bridge/status" + "maunium.net/go/mautrix/appservice" + "maunium.net/go/mautrix/bridgev2/status" "maunium.net/go/mautrix/event" "maunium.net/go/mautrix/id" ) type MessageStatusEventInfo struct { RoomID id.RoomID + TransactionID string SourceEventID id.EventID NewEventID id.EventID EventType event.Type @@ -26,6 +28,8 @@ type MessageStatusEventInfo struct { Sender id.UserID ThreadRoot id.EventID StreamOrder int64 + + IsSourceEventDoublePuppeted bool } func StatusEventInfoFromEvent(evt *event.Event) *MessageStatusEventInfo { @@ -33,13 +37,19 @@ func StatusEventInfoFromEvent(evt *event.Event) *MessageStatusEventInfo { if relatable, ok := evt.Content.Parsed.(event.Relatable); ok { threadRoot = relatable.OptionalGetRelatesTo().GetThreadParent() } + + _, isDoublePuppeted := evt.Content.Raw[appservice.DoublePuppetKey] + return &MessageStatusEventInfo{ RoomID: evt.RoomID, + TransactionID: evt.Unsigned.TransactionID, SourceEventID: evt.ID, EventType: evt.Type, MessageType: evt.Content.AsMessage().MsgType, Sender: evt.Sender, ThreadRoot: threadRoot, + + IsSourceEventDoublePuppeted: isDoublePuppeted, } } @@ -174,9 +184,10 @@ func (ms *MessageStatus) ToMSSEvent(evt *MessageStatusEventInfo) *event.BeeperMe Type: event.RelReference, EventID: evt.SourceEventID, }, - Status: ms.Status, - Reason: ms.ErrorReason, - Message: ms.Message, + TargetTxnID: evt.TransactionID, + Status: ms.Status, + Reason: ms.ErrorReason, + Message: ms.Message, } if ms.InternalError != nil { content.InternalError = ms.InternalError.Error() @@ -211,7 +222,7 @@ func (ms *MessageStatus) ToNoticeEvent(evt *MessageStatusEventInfo) *event.Messa messagePrefix = "Handling your command panicked" } content := &event.MessageEventContent{ - MsgType: event.MsgText, + MsgType: event.MsgNotice, Body: fmt.Sprintf("\u26a0\ufe0f %s: %s", messagePrefix, msg), RelatesTo: &event.RelatesTo{}, Mentions: &event.Mentions{}, diff --git a/bridgev2/networkid/bridgeid.go b/bridgev2/networkid/bridgeid.go index d78813eb..e3a6df70 100644 --- a/bridgev2/networkid/bridgeid.go +++ b/bridgev2/networkid/bridgeid.go @@ -47,8 +47,8 @@ type PortalID string // As a special case, Receiver MUST be set if the Bridge.Config.SplitPortals flag is set to true. // The flag is intended for puppeting-only bridges which want multiple logins to create separate portals for each user. type PortalKey struct { - ID PortalID - Receiver UserLoginID + ID PortalID `json:"portal_id"` + Receiver UserLoginID `json:"portal_receiver,omitempty"` } func (pk PortalKey) IsEmpty() bool { @@ -94,6 +94,11 @@ type MessageID string // Transaction IDs must be unique across users in a room, but don't need to be unique across different rooms. type TransactionID string +// RawTransactionID is a client-generated identifier for a message send operation on the remote network. +// +// Unlike TransactionID, RawTransactionID's are only used for sending and don't have any uniqueness requirements. +type RawTransactionID string + // PartID is the ID of a message part on the remote network (e.g. index of image in album). // // Part IDs are only unique within a message, not globally. diff --git a/bridgev2/networkinterface.go b/bridgev2/networkinterface.go index 8ddf1269..efc5f100 100644 --- a/bridgev2/networkinterface.go +++ b/bridgev2/networkinterface.go @@ -8,6 +8,7 @@ package bridgev2 import ( "context" + "encoding/json" "fmt" "strings" "time" @@ -15,7 +16,9 @@ import ( "github.com/rs/zerolog" "go.mau.fi/util/configupgrade" "go.mau.fi/util/ptr" + "go.mau.fi/util/random" + "maunium.net/go/mautrix" "maunium.net/go/mautrix/bridgev2/database" "maunium.net/go/mautrix/bridgev2/networkid" "maunium.net/go/mautrix/event" @@ -76,8 +79,28 @@ type EventSender struct { ForceDMUser bool } +func (es EventSender) MarshalZerologObject(evt *zerolog.Event) { + evt.Str("user_id", string(es.Sender)) + if string(es.SenderLogin) != string(es.Sender) { + evt.Str("sender_login", string(es.SenderLogin)) + } + if es.IsFromMe { + evt.Bool("is_from_me", true) + } + if es.ForceDMUser { + evt.Bool("force_dm_user", true) + } +} + type ConvertedMessage struct { - ReplyTo *networkid.MessageOptionalPartID + ReplyTo *networkid.MessageOptionalPartID + // Optional additional info about the reply. This is only used when backfilling messages + // on Beeper, where replies may target messages that haven't been bridged yet. + // Standard Matrix servers can't backwards backfill, so these are never used. + ReplyToRoom networkid.PortalKey + ReplyToUser networkid.UserID + ReplyToLogin networkid.UserLoginID + ThreadRoot *networkid.MessageID Parts []*ConvertedMessagePart Disappear database.DisappearingSetting @@ -96,11 +119,15 @@ func MergeCaption(textPart, mediaPart *ConvertedMessagePart) *ConvertedMessagePa mediaPart.Content.EnsureHasHTML() mediaPart.Content.Body += "\n\n" + textPart.Content.Body mediaPart.Content.FormattedBody += "

" + textPart.Content.FormattedBody + mediaPart.Content.Mentions = mediaPart.Content.Mentions.Merge(textPart.Content.Mentions) + mediaPart.Content.BeeperLinkPreviews = append(mediaPart.Content.BeeperLinkPreviews, textPart.Content.BeeperLinkPreviews...) } else { mediaPart.Content.FileName = mediaPart.Content.Body mediaPart.Content.Body = textPart.Content.Body mediaPart.Content.Format = textPart.Content.Format mediaPart.Content.FormattedBody = textPart.Content.FormattedBody + mediaPart.Content.Mentions = textPart.Content.Mentions + mediaPart.Content.BeeperLinkPreviews = textPart.Content.BeeperLinkPreviews } if metaMerger, ok := mediaPart.DBMetadata.(database.MetaMerger); ok { metaMerger.CopyFrom(textPart.DBMetadata) @@ -227,9 +254,14 @@ type NetworkConnector interface { // This should generally not do any work, it should just return a LoginProcess that remembers // the user and will execute the requested flow. The actual work should start when [LoginProcess.Start] is called. CreateLogin(ctx context.Context, user *User, flowID string) (LoginProcess, error) + + // GetBridgeInfoVersion returns version numbers for bridge info and room capabilities respectively. + // When the versions change, the bridge will automatically resend bridge info to all rooms. + GetBridgeInfoVersion() (info, capabilities int) } type StoppableNetwork interface { + NetworkConnector // Stop is called when the bridge is stopping, after all network clients have been disconnected. Stop() } @@ -254,6 +286,11 @@ type IdentifierValidatingNetwork interface { ValidateUserID(id networkid.UserID) bool } +type TransactionIDGeneratingNetwork interface { + NetworkConnector + GenerateTransactionID(userID id.UserID, roomID id.RoomID, eventType event.Type) networkid.RawTransactionID +} + type PortalBridgeInfoFillingNetwork interface { NetworkConnector FillPortalBridgeInfo(portal *Portal, content *event.BridgeEventContent) @@ -281,6 +318,16 @@ type MaxFileSizeingNetwork interface { SetMaxFileSize(maxSize int64) } +type NetworkResettingNetwork interface { + NetworkConnector + // ResetHTTPTransport should recreate the HTTP client used by the bridge. + // It should refetch settings from the Matrix connector using GetHTTPClientSettings if applicable. + ResetHTTPTransport() + // ResetNetworkConnections should forcefully disconnect and restart any persistent network connections. + // ResetHTTPTransport will usually be called before this, so resetting the transport is not necessary here. + ResetNetworkConnections() +} + type RemoteEchoHandler func(RemoteMessage, *database.Message) (bool, error) type MatrixMessageResponse struct { @@ -297,9 +344,12 @@ type MatrixMessageResponse struct { PostSave func(context.Context, *database.Message) } -type FileRestriction struct { - MaxSize int64 - MimeTypes []string +type OutgoingTimeoutConfig struct { + CheckInterval time.Duration + NoEchoTimeout time.Duration + NoEchoMessage string + NoAckTimeout time.Duration + NoAckMessage string } type NetworkGeneralCapabilities struct { @@ -309,35 +359,16 @@ type NetworkGeneralCapabilities struct { // Should the bridge re-request user info on incoming messages even if the ghost already has info? // By default, info is only requested for ghosts with no name, and other updating is left to events. AggressiveUpdateInfo bool -} - -type NetworkRoomCapabilities struct { - FormattedText bool - UserMentions bool - RoomMentions bool - - LocationMessages bool - Captions bool - MaxTextLength int - MaxCaptionLength int - Polls bool - - Threads bool - Replies bool - Edits bool - EditMaxCount int - EditMaxAge time.Duration - Deletes bool - DeleteMaxAge time.Duration - - DefaultFileRestriction *FileRestriction - Files map[event.MessageType]FileRestriction - - ReadReceipts bool - - Reactions bool - ReactionCount int - AllowedReactions []string + // Should the bridge call HandleMatrixReadReceipt with fake data when receiving a new message? + // This should be enabled if the network requires each message to be marked as read independently, + // and doesn't automatically do it when sending a message. + ImplicitReadReceipts bool + // If the bridge uses the pending message mechanism ([MatrixMessage.AddPendingToSave]) + // to handle asynchronous message responses, this field can be set to enable + // automatic timeout errors in case the asynchronous response never arrives. + OutgoingMessageTimeouts *OutgoingTimeoutConfig + // Capabilities related to the provisioning API. + Provisioning ProvisioningCapabilities } // NetworkAPI is an interface representing a remote network client for a single user login. @@ -372,7 +403,7 @@ type NetworkAPI interface { // GetCapabilities returns the bridging capabilities in a given room. // This can simply return a static list if the remote network has no per-chat capability differences, // but all calls will include the portal, because some networks do have per-chat differences. - GetCapabilities(ctx context.Context, portal *Portal) *NetworkRoomCapabilities + GetCapabilities(ctx context.Context, portal *Portal) *event.RoomFeatures // HandleMatrixMessage is called when a message is sent from Matrix in an existing portal room. // This function should convert the message as appropriate, send it over to the remote network, @@ -382,6 +413,30 @@ type NetworkAPI interface { HandleMatrixMessage(ctx context.Context, msg *MatrixMessage) (message *MatrixMessageResponse, err error) } +type ConnectBackgroundParams struct { + // RawData is the raw data in the push that triggered the background connection. + RawData json.RawMessage + // ExtraData is the data returned by [PushParsingNetwork.ParsePushNotification]. + // It's only present for native pushes. Relayed pushes will only have the raw data. + ExtraData any +} + +// BackgroundSyncingNetworkAPI is an optional interface that network connectors can implement to support background resyncs. +type BackgroundSyncingNetworkAPI interface { + NetworkAPI + // ConnectBackground is called in place of Connect for background resyncs. + // The client should connect to the remote network, handle pending messages, and then disconnect. + // This call should block until the entire sync is complete and the client is disconnected. + ConnectBackground(ctx context.Context, params *ConnectBackgroundParams) error +} + +// CredentialExportingNetworkAPI is an optional interface that networks connectors can implement to support export of +// the credentials associated with that login. Credential type is bridge specific. +type CredentialExportingNetworkAPI interface { + NetworkAPI + ExportCredentials(ctx context.Context) any +} + // FetchMessagesParams contains the parameters for a message history pagination request. type FetchMessagesParams struct { // The portal to fetch messages in. Always present. @@ -518,11 +573,18 @@ type FetchMessagesResponse struct { // BackfillingNetworkAPI is an optional interface that network connectors can implement to support backfilling message history. type BackfillingNetworkAPI interface { NetworkAPI + // FetchMessages returns a batch of messages to backfill in a portal room. + // For details on the input and output, see the documentation of [FetchMessagesParams] and [FetchMessagesResponse]. FetchMessages(ctx context.Context, fetchParams FetchMessagesParams) (*FetchMessagesResponse, error) } +// BackfillingNetworkAPIWithLimits is an optional interface that network connectors can implement to customize +// the limit for backwards backfilling tasks. It is recommended to implement this by reading the MaxBatchesOverride +// config field with network-specific keys for different room types. type BackfillingNetworkAPIWithLimits interface { BackfillingNetworkAPI + // GetBackfillMaxBatchCount is called before a backfill task is executed to determine the maximum number of batches + // that should be backfilled. Return values less than 0 are treated as unlimited. GetBackfillMaxBatchCount(ctx context.Context, portal *Portal, task *database.BackfillTask) int } @@ -576,6 +638,16 @@ type ReadReceiptHandlingNetworkAPI interface { HandleMatrixReadReceipt(ctx context.Context, msg *MatrixReadReceipt) error } +// ChatViewingNetworkAPI is an optional interface that network connectors can implement to handle viewing chat status. +type ChatViewingNetworkAPI interface { + NetworkAPI + // HandleMatrixViewingChat is called when the user opens a portal room. + // This will never be called by the standard appservice connector, + // as Matrix doesn't have any standard way of signaling chat open status. + // Clients are expected to call this every 5 seconds. There is no signal for closing a chat. + HandleMatrixViewingChat(ctx context.Context, msg *MatrixViewingChat) error +} + // TypingHandlingNetworkAPI is an optional interface that network connectors can implement to handle typing events. type TypingHandlingNetworkAPI interface { NetworkAPI @@ -630,6 +702,35 @@ type RoomTopicHandlingNetworkAPI interface { HandleMatrixRoomTopic(ctx context.Context, msg *MatrixRoomTopic) (bool, error) } +type DisappearTimerChangingNetworkAPI interface { + NetworkAPI + // HandleMatrixDisappearingTimer is called when the disappearing timer of a portal room is changed. + // This method should update the Disappear field of the Portal with the new timer and return true + // if the change was successful. If the change is not successful, then the field should not be updated. + HandleMatrixDisappearingTimer(ctx context.Context, msg *MatrixDisappearingTimer) (bool, error) +} + +// DeleteChatHandlingNetworkAPI is an optional interface that network connectors +// can implement to delete a chat from the remote network. +type DeleteChatHandlingNetworkAPI interface { + NetworkAPI + // HandleMatrixDeleteChat is called when the user explicitly deletes a chat. + HandleMatrixDeleteChat(ctx context.Context, msg *MatrixDeleteChat) error +} + +// MessageRequestAcceptingNetworkAPI is an optional interface that network connectors +// can implement to accept message requests from the remote network. +type MessageRequestAcceptingNetworkAPI interface { + NetworkAPI + // HandleMatrixAcceptMessageRequest is called when the user accepts a message request. + HandleMatrixAcceptMessageRequest(ctx context.Context, msg *MatrixAcceptMessageRequest) error +} + +type BeeperAIStreamHandlingNetworkAPI interface { + NetworkAPI + HandleMatrixBeeperAIStream(ctx context.Context, msg *MatrixBeeperAIStream) error +} + type ResolveIdentifierResponse struct { // Ghost is the ghost of the user that the identifier resolves to. // This field should be set whenever possible. However, it is not required, @@ -649,11 +750,27 @@ type ResolveIdentifierResponse struct { Chat *CreateChatResponse } +var SpecialValueDMRedirectedToBot = networkid.UserID("__fi.mau.bridgev2.dm_redirected_to_bot::" + random.String(10)) + type CreateChatResponse struct { PortalKey networkid.PortalKey // Portal and PortalInfo are not required, the caller will fetch them automatically based on PortalKey if necessary. Portal *Portal PortalInfo *ChatInfo + // If a start DM request (CreateChatWithGhost or ResolveIdentifier) returns the DM to a different user, + // this field should have the user ID of said different user. + DMRedirectedTo networkid.UserID + + FailedParticipants map[networkid.UserID]*CreateChatFailedParticipant +} + +type CreateChatFailedParticipant struct { + Reason string `json:"reason"` + InviteEventType string `json:"invite_event_type,omitempty"` + InviteContent *event.Content `json:"invite_content,omitempty"` + + UserMXID id.UserID `json:"user_mxid,omitempty"` + DMRoomMXID id.RoomID `json:"dm_room_mxid,omitempty"` } // IdentifierResolvingNetworkAPI is an optional interface that network connectors can implement to support starting new direct chats. @@ -688,7 +805,83 @@ type UserSearchingNetworkAPI interface { type GroupCreatingNetworkAPI interface { IdentifierResolvingNetworkAPI - CreateGroup(ctx context.Context, name string, users ...networkid.UserID) (*CreateChatResponse, error) + CreateGroup(ctx context.Context, params *GroupCreateParams) (*CreateChatResponse, error) +} + +type PersonalFilteringCustomizingNetworkAPI interface { + NetworkAPI + CustomizePersonalFilteringSpace(req *mautrix.ReqCreateRoom) +} + +type ProvisioningCapabilities struct { + ResolveIdentifier ResolveIdentifierCapabilities `json:"resolve_identifier"` + GroupCreation map[string]GroupTypeCapabilities `json:"group_creation"` +} + +type ResolveIdentifierCapabilities struct { + // Can DMs be created after resolving an identifier? + CreateDM bool `json:"create_dm"` + // Can users be looked up by phone number? + LookupPhone bool `json:"lookup_phone"` + // Can users be looked up by email address? + LookupEmail bool `json:"lookup_email"` + // Can users be looked up by network-specific username? + LookupUsername bool `json:"lookup_username"` + // Can any phone number be contacted without having to validate it via lookup first? + AnyPhone bool `json:"any_phone"` + // Can a contact list be retrieved from the bridge? + ContactList bool `json:"contact_list"` + // Can users be searched by name on the remote network? + Search bool `json:"search"` +} + +type GroupTypeCapabilities struct { + TypeDescription string `json:"type_description"` + + Name GroupFieldCapability `json:"name"` + Username GroupFieldCapability `json:"username"` + Avatar GroupFieldCapability `json:"avatar"` + Topic GroupFieldCapability `json:"topic"` + Disappear GroupFieldCapability `json:"disappear"` + Participants GroupFieldCapability `json:"participants"` + Parent GroupFieldCapability `json:"parent"` +} + +type GroupFieldCapability struct { + // Is setting this field allowed at all in the create request? + // Even if false, the network connector should attempt to set the metadata after group creation, + // as the allowed flag can't be enforced properly when creating a group for an existing Matrix room. + Allowed bool `json:"allowed"` + // Is setting this field mandatory for the creation to succeed? + Required bool `json:"required,omitempty"` + // The minimum/maximum length of the field, if applicable. + // For members, length means the number of members excluding the creator. + MinLength int `json:"min_length,omitempty"` + MaxLength int `json:"max_length,omitempty"` + + // Only for the disappear field: allowed disappearing settings + DisappearSettings *event.DisappearingTimerCapability `json:"settings,omitempty"` + + // This can be used to tell provisionutil not to call ValidateUserID on each participant. + // It only meant to allow hacks where ResolveIdentifier returns a fake ID that isn't actually valid for MXIDs. + SkipIdentifierValidation bool `json:"-"` +} + +type GroupCreateParams struct { + Type string `json:"type,omitempty"` + + Username string `json:"username,omitempty"` + // Clients may also provide MXIDs here, but provisionutil will normalize them, so bridges only need to handle network IDs + Participants []networkid.UserID `json:"participants,omitempty"` + Parent *networkid.PortalKey `json:"parent,omitempty"` + + Name *event.RoomNameEventContent `json:"name,omitempty"` + Avatar *event.RoomAvatarEventContent `json:"avatar,omitempty"` + Topic *event.TopicEventContent `json:"topic,omitempty"` + Disappear *event.BeeperDisappearingTimer `json:"disappear,omitempty"` + + // An existing room ID to bridge to. If unset, a new room will be created. + RoomID id.RoomID `json:"room_id,omitempty"` } type MembershipChangeType struct { @@ -728,16 +921,15 @@ type MatrixMembershipChange struct { MatrixRoomMeta[*event.MemberEventContent] Target GhostOrUserLogin Type MembershipChangeType +} - // Deprecated: Use Target instead - TargetGhost *Ghost - // Deprecated: Use Target instead - TargetUserLogin *UserLogin +type MatrixMembershipResult struct { + RedirectTo networkid.UserID } type MembershipHandlingNetworkAPI interface { NetworkAPI - HandleMatrixMembership(ctx context.Context, msg *MatrixMembershipChange) (bool, error) + HandleMatrixMembership(ctx context.Context, msg *MatrixMembershipChange) (*MatrixMembershipResult, error) } type SinglePowerLevelChange struct { @@ -823,17 +1015,35 @@ type APNsPushConfig struct { } type PushConfig struct { - Web *WebPushConfig `json:"web,omitempty"` - FCM *FCMPushConfig `json:"fcm,omitempty"` - APNs *APNsPushConfig `json:"apns,omitempty"` - Native bool `json:"native,omitempty"` + Web *WebPushConfig `json:"web,omitempty"` + FCM *FCMPushConfig `json:"fcm,omitempty"` + APNs *APNsPushConfig `json:"apns,omitempty"` + // If Native is true, it means the network supports registering for pushes + // that are delivered directly to the app without the use of a push relay. + Native bool `json:"native,omitempty"` } +// PushableNetworkAPI is an optional interface that network connectors can implement +// to support waking up the wrapper app using push notifications. type PushableNetworkAPI interface { + NetworkAPI + + // RegisterPushNotifications is called when the wrapper app wants to register a push token with the remote network. RegisterPushNotifications(ctx context.Context, pushType PushType, token string) error + // GetPushConfigs is used to find which types of push notifications the remote network can provide. GetPushConfigs() *PushConfig } +// PushParsingNetwork is an optional interface that network connectors can implement +// to support parsing native push notifications from networks. +type PushParsingNetwork interface { + NetworkConnector + + // ParsePushNotification is called when a native push is received. + // It must return the corresponding user login ID to wake up, plus optionally data to pass to the wakeup call. + ParsePushNotification(ctx context.Context, data json.RawMessage) (networkid.UserLoginID, any, error) +} + type RemoteEventType int func (ret RemoteEventType) String() string { @@ -958,6 +1168,11 @@ type RemoteChatDelete interface { RemoteDeleteOnlyForMe } +type RemoteChatDeleteWithChildren interface { + RemoteChatDelete + DeleteChildren() bool +} + type RemoteEventThatMayCreatePortal interface { RemoteEvent ShouldCreatePortal() bool @@ -1077,6 +1292,11 @@ type RemoteReadReceipt interface { GetReadUpTo() time.Time } +type RemoteReadReceiptWithStreamOrder interface { + RemoteReadReceipt + GetReadUpToStreamOrder() int64 +} + type RemoteDeliveryReceipt interface { RemoteEvent GetReceiptTargets() []networkid.MessageID @@ -1112,6 +1332,7 @@ type OrigSender struct { RequiresDisambiguation bool DisambiguatedName string FormattedName string + PerMessageProfile event.BeeperPerMessageProfile event.MemberEventContent } @@ -1126,12 +1347,16 @@ type MatrixEventBase[ContentType any] struct { // The original sender user ID. Only present in case the event is being relayed (and Sender is not the same user). OrigSender *OrigSender + + InputTransactionID networkid.RawTransactionID } type MatrixMessage struct { MatrixEventBase[*event.MessageEventContent] ThreadRoot *database.Message ReplyTo *database.Message + + pendingSaves []*outgoingMessage } type MatrixEdit struct { @@ -1180,12 +1405,14 @@ type MatrixMessageRemove struct { type MatrixRoomMeta[ContentType any] struct { MatrixEventBase[ContentType] - PrevContent ContentType + PrevContent ContentType + IsStateRequest bool } type MatrixRoomName = MatrixRoomMeta[*event.RoomNameEventContent] type MatrixRoomAvatar = MatrixRoomMeta[*event.RoomAvatarEventContent] type MatrixRoomTopic = MatrixRoomMeta[*event.TopicEventContent] +type MatrixDisappearingTimer = MatrixRoomMeta[*event.BeeperDisappearingTimer] type MatrixReadReceipt struct { Portal *Portal @@ -1200,6 +1427,8 @@ type MatrixReadReceipt struct { LastRead time.Time // The receipt metadata. Receipt event.ReadReceipt + // Whether the receipt is implicit, i.e. triggered by an incoming timeline event rather than an explicit receipt. + Implicit bool } type MatrixTyping struct { @@ -1208,6 +1437,14 @@ type MatrixTyping struct { Type TypingType } +type MatrixViewingChat struct { + // The portal that the user is viewing. This will be nil when the user switches to a chat from a different bridge. + Portal *Portal +} + +type MatrixDeleteChat = MatrixEventBase[*event.BeeperChatDeleteEventContent] +type MatrixAcceptMessageRequest = MatrixEventBase[*event.BeeperAcceptMessageRequestEventContent] +type MatrixBeeperAIStream = MatrixEventBase[*event.BeeperAIStreamEventContent] type MatrixMarkedUnread = MatrixRoomMeta[*event.MarkedUnreadEventContent] type MatrixMute = MatrixRoomMeta[*event.BeeperMuteEventContent] type MatrixRoomTag = MatrixRoomMeta[*event.TagEventContent] diff --git a/bridgev2/portal.go b/bridgev2/portal.go index b1aae9e7..16aa703b 100644 --- a/bridgev2/portal.go +++ b/bridgev2/portal.go @@ -19,7 +19,9 @@ import ( "github.com/rs/zerolog" "go.mau.fi/util/exfmt" + "go.mau.fi/util/exmaps" "go.mau.fi/util/exslices" + "go.mau.fi/util/exsync" "go.mau.fi/util/ptr" "go.mau.fi/util/variationselector" "golang.org/x/exp/maps" @@ -59,10 +61,12 @@ type portalEvent interface { } type outgoingMessage struct { - db *database.Message - evt *event.Event - ignore bool - handle func(RemoteMessage, *database.Message) (bool, error) + db *database.Message + evt *event.Event + ignore bool + handle func(RemoteMessage, *database.Message) (bool, error) + ackedAt time.Time + timeouted bool } type Portal struct { @@ -75,16 +79,28 @@ type Portal struct { currentlyTyping []id.UserID currentlyTypingLogins map[id.UserID]*UserLogin currentlyTypingLock sync.Mutex + currentlyTypingGhosts *exsync.Set[id.UserID] - outgoingMessages map[networkid.TransactionID]outgoingMessage + outgoingMessages map[networkid.TransactionID]*outgoingMessage outgoingMessagesLock sync.Mutex - roomCreateLock sync.Mutex + lastCapUpdate time.Time - events chan portalEvent + roomCreateLock sync.Mutex + cancelRoomCreate atomic.Pointer[context.CancelFunc] + RoomCreated *exsync.Event + + functionalMembersLock sync.Mutex + functionalMembersCache *event.ElementFunctionalMembersContent + + events chan portalEvent + deleted *exsync.Event + + eventsLock sync.Mutex + eventIdx int } -const PortalEventBuffer = 64 +var PortalEventBuffer = 64 func (br *Bridge) loadPortal(ctx context.Context, dbPortal *database.Portal, queryErr error, key *networkid.PortalKey) (*Portal, error) { if queryErr != nil { @@ -107,10 +123,18 @@ func (br *Bridge) loadPortal(ctx context.Context, dbPortal *database.Portal, que Portal: dbPortal, Bridge: br, - events: make(chan portalEvent, PortalEventBuffer), currentlyTypingLogins: make(map[id.UserID]*UserLogin), - outgoingMessages: make(map[networkid.TransactionID]outgoingMessage), + currentlyTypingGhosts: exsync.NewSet[id.UserID](), + outgoingMessages: make(map[networkid.TransactionID]*outgoingMessage), + + RoomCreated: exsync.NewEvent(), + deleted: exsync.NewEvent(), } + if portal.MXID != "" { + portal.RoomCreated.Set() + } + // Putting the portal in the cache before it's fully initialized is mildly dangerous, + // but loading the relay user login may depend on it. br.portalsByKey[portal.PortalKey] = portal if portal.MXID != "" { br.portalsByMXID[portal.MXID] = portal @@ -119,22 +143,35 @@ func (br *Bridge) loadPortal(ctx context.Context, dbPortal *database.Portal, que if portal.ParentKey.ID != "" { portal.Parent, err = br.UnlockedGetPortalByKey(ctx, portal.ParentKey, false) if err != nil { + delete(br.portalsByKey, portal.PortalKey) + if portal.MXID != "" { + delete(br.portalsByMXID, portal.MXID) + } return nil, fmt.Errorf("failed to load parent portal (%s): %w", portal.ParentKey, err) } } if portal.RelayLoginID != "" { portal.Relay, err = br.unlockedGetExistingUserLoginByID(ctx, portal.RelayLoginID) if err != nil { + delete(br.portalsByKey, portal.PortalKey) + if portal.MXID != "" { + delete(br.portalsByMXID, portal.MXID) + } return nil, fmt.Errorf("failed to load relay login (%s): %w", portal.RelayLoginID, err) } } portal.updateLogger() - go portal.eventLoop() + if PortalEventBuffer != 0 { + portal.events = make(chan portalEvent, PortalEventBuffer) + go portal.eventLoop() + } return portal, nil } func (portal *Portal) updateLogger() { - logWith := portal.Bridge.Log.With().Str("portal_id", string(portal.ID)) + logWith := portal.Bridge.Log.With(). + Str("portal_id", string(portal.ID)). + Str("portal_receiver", string(portal.Receiver)) if portal.MXID != "" { logWith = logWith.Stringer("portal_mxid", portal.MXID) } @@ -158,6 +195,16 @@ func (br *Bridge) loadManyPortals(ctx context.Context, portals []*database.Porta return output, nil } +func (br *Bridge) loadPortalWithCacheCheck(ctx context.Context, dbPortal *database.Portal) (*Portal, error) { + if dbPortal == nil { + return nil, nil + } else if cached, ok := br.portalsByKey[dbPortal.PortalKey]; ok { + return cached, nil + } else { + return br.loadPortal(ctx, dbPortal, nil, nil) + } +} + func (br *Bridge) UnlockedGetPortalByKey(ctx context.Context, key networkid.PortalKey, onlyIfExists bool) (*Portal, error) { if br.Config.SplitPortals && key.Receiver == "" { return nil, fmt.Errorf("receiver must always be set when split portals is enabled") @@ -247,6 +294,26 @@ func (br *Bridge) GetDMPortalsWith(ctx context.Context, otherUserID networkid.Us return br.loadManyPortals(ctx, rows) } +func (br *Bridge) GetChildPortals(ctx context.Context, parent networkid.PortalKey) ([]*Portal, error) { + br.cacheLock.Lock() + defer br.cacheLock.Unlock() + rows, err := br.DB.Portal.GetChildren(ctx, parent) + if err != nil { + return nil, err + } + return br.loadManyPortals(ctx, rows) +} + +func (br *Bridge) GetDMPortal(ctx context.Context, receiver networkid.UserLoginID, otherUserID networkid.UserID) (*Portal, error) { + br.cacheLock.Lock() + defer br.cacheLock.Unlock() + dbPortal, err := br.DB.Portal.GetDM(ctx, receiver, otherUserID) + if err != nil { + return nil, err + } + return br.loadPortalWithCacheCheck(ctx, dbPortal) +} + func (br *Bridge) GetPortalByKey(ctx context.Context, key networkid.PortalKey) (*Portal, error) { br.cacheLock.Lock() defer br.cacheLock.Unlock() @@ -271,61 +338,114 @@ func (br *Bridge) GetExistingPortalByKey(ctx context.Context, key networkid.Port return br.loadPortal(ctx, db, err, nil) } -func (portal *Portal) queueEvent(ctx context.Context, evt portalEvent) { - select { - case portal.events <- evt: - default: - zerolog.Ctx(ctx).Error(). - Str("portal_id", string(portal.ID)). - Msg("Portal event channel is full") +func (portal *Portal) queueEvent(ctx context.Context, evt portalEvent) EventHandlingResult { + if portal.deleted.IsSet() { + return EventHandlingResultIgnored + } + if PortalEventBuffer == 0 { + portal.eventsLock.Lock() + defer portal.eventsLock.Unlock() + portal.eventIdx++ + return portal.handleSingleEventWithDelayLogging(portal.eventIdx, evt) + } else { + if portal.events == nil { + panic(fmt.Errorf("queueEvent into uninitialized portal %s", portal.PortalKey)) + } + select { + case portal.events <- evt: + return EventHandlingResultQueued + case <-portal.deleted.GetChan(): + return EventHandlingResultIgnored + default: + zerolog.Ctx(ctx).Error(). + Str("portal_id", string(portal.ID)). + Msg("Portal event channel is full, queue will block") + for { + select { + case portal.events <- evt: + return EventHandlingResultQueued + case <-time.After(5 * time.Second): + zerolog.Ctx(ctx).Error(). + Str("portal_id", string(portal.ID)). + Msg("Portal event channel is still full") + } + } + } } } func (portal *Portal) eventLoop() { - i := 0 - for rawEvt := range portal.events { - i++ - portal.handleSingleEventAsync(i, rawEvt) + if cfg := portal.Bridge.Network.GetCapabilities().OutgoingMessageTimeouts; cfg != nil { + ctx, cancel := context.WithCancel(portal.Log.WithContext(portal.Bridge.BackgroundCtx)) + go portal.pendingMessageTimeoutLoop(ctx, cfg) + defer cancel() + } + deleteCh := portal.deleted.GetChan() + for i := 0; ; i++ { + select { + case rawEvt := <-portal.events: + if rawEvt == nil { + return + } + if portal.Bridge.Config.AsyncEvents { + go portal.handleSingleEventWithDelayLogging(i, rawEvt) + } else { + portal.handleSingleEventWithDelayLogging(i, rawEvt) + } + case <-deleteCh: + return + } } } -func (portal *Portal) handleSingleEventAsync(idx int, rawEvt any) { +func (portal *Portal) handleSingleEventWithDelayLogging(idx int, rawEvt any) (outerRes EventHandlingResult) { ctx := portal.getEventCtxWithLog(rawEvt, idx) - if _, isCreate := rawEvt.(*portalCreateEvent); isCreate { - portal.handleSingleEvent(ctx, rawEvt, func() {}) - } else if portal.Bridge.Config.AsyncEvents { - go portal.handleSingleEvent(ctx, rawEvt, func() {}) - } else { - log := zerolog.Ctx(ctx) - doneCh := make(chan struct{}) - var backgrounded atomic.Bool - start := time.Now() - var handleDuration time.Duration - go portal.handleSingleEvent(ctx, rawEvt, func() { - handleDuration = time.Since(start) - close(doneCh) - if backgrounded.Load() { - log.Debug().Stringer("duration", handleDuration). - Msg("Event that took too long finally finished handling") + log := zerolog.Ctx(ctx) + doneCh := make(chan struct{}) + var backgrounded atomic.Bool + start := time.Now() + var handleDuration time.Duration + // Note: this will not set the success flag if the handler times out + outerRes = EventHandlingResult{Queued: true} + go portal.handleSingleEvent(ctx, rawEvt, func(res EventHandlingResult) { + outerRes = res + handleDuration = time.Since(start) + close(doneCh) + if backgrounded.Load() { + log.Debug(). + Time("started_at", start). + Stringer("duration", handleDuration). + Msg("Event that took too long finally finished handling") + } + }) + tick := time.NewTicker(30 * time.Second) + _, isCreate := rawEvt.(*portalCreateEvent) + defer tick.Stop() + for i := 0; i < 10; i++ { + select { + case <-doneCh: + if i > 0 { + log.Debug(). + Time("started_at", start). + Stringer("duration", handleDuration). + Msg("Event that took long finished handling") } - }) - tick := time.NewTicker(30 * time.Second) - defer tick.Stop() - for i := 0; i < 10; i++ { - select { - case <-doneCh: - if i > 0 { - log.Debug().Stringer("duration", handleDuration). - Msg("Event that took long finished handling") - } - return - case <-tick.C: - log.Warn().Msg("Event handling is taking long") + return + case <-tick.C: + log.Warn(). + Time("started_at", start). + Msg("Event handling is taking long") + if isCreate { + // Never background portal creation events + i = 1 } } - log.Warn().Msg("Event handling is taking too long, continuing in background") - backgrounded.Store(true) } + log.Warn(). + Time("started_at", start). + Msg("Event handling is taking too long, continuing in background") + backgrounded.Store(true) + return } func (portal *Portal) getEventCtxWithLog(rawEvt any, idx int) context.Context { @@ -348,22 +468,49 @@ func (portal *Portal) getEventCtxWithLog(rawEvt any, idx int) context.Context { Str("source_id", string(evt.source.ID)). Stringer("bridge_evt_type", evt.evtType) logWith = evt.evt.AddLogContext(logWith) + if remoteSender := evt.evt.GetSender(); remoteSender.Sender != "" || remoteSender.IsFromMe { + logWith = logWith.Object("remote_sender", remoteSender) + } + if remoteMsg, ok := evt.evt.(RemoteMessage); ok { + if remoteMsgID := remoteMsg.GetID(); remoteMsgID != "" { + logWith = logWith.Str("remote_message_id", string(remoteMsgID)) + } + } + if remoteMsg, ok := evt.evt.(RemoteEventWithTargetMessage); ok { + if targetMsgID := remoteMsg.GetTargetMessage(); targetMsgID != "" { + logWith = logWith.Str("remote_target_message_id", string(targetMsgID)) + } + } + if remoteMsg, ok := evt.evt.(RemoteEventWithStreamOrder); ok { + if remoteStreamOrder := remoteMsg.GetStreamOrder(); remoteStreamOrder != 0 { + logWith = logWith.Int64("remote_stream_order", remoteStreamOrder) + } + } + if remoteMsg, ok := evt.evt.(RemoteEventWithTimestamp); ok { + if remoteTimestamp := remoteMsg.GetTimestamp(); !remoteTimestamp.IsZero() { + logWith = logWith.Time("remote_timestamp", remoteTimestamp) + } + } case *portalCreateEvent: return evt.ctx } - return logWith.Logger().WithContext(context.Background()) + return logWith.Logger().WithContext(portal.Bridge.BackgroundCtx) } -func (portal *Portal) handleSingleEvent(ctx context.Context, rawEvt any, doneCallback func()) { +func (portal *Portal) handleSingleEvent(ctx context.Context, rawEvt any, doneCallback func(res EventHandlingResult)) { log := zerolog.Ctx(ctx) + var res EventHandlingResult defer func() { - doneCallback() + doneCallback(res) if err := recover(); err != nil { logEvt := log.Error() + var errorString string if realErr, ok := err.(error); ok { logEvt = logEvt.Err(realErr) + errorString = realErr.Error() } else { logEvt = logEvt.Any(zerolog.ErrorFieldName, err) + errorString = fmt.Sprintf("%v", err) } logEvt. Bytes("stack", debug.Stack()). @@ -376,32 +523,92 @@ func (portal *Portal) handleSingleEvent(ctx context.Context, rawEvt any, doneCal case *portalCreateEvent: evt.cb(fmt.Errorf("portal creation panicked")) } + portal.Bridge.TrackAnalytics("", "Bridge Event Handler Panic", map[string]any{ + "error": errorString, + }) } }() switch evt := rawEvt.(type) { case *portalMatrixEvent: - portal.handleMatrixEvent(ctx, evt.sender, evt.evt) + isStateRequest := evt.evt.Type == event.BeeperSendState + if isStateRequest { + if err := portal.unwrapBeeperSendState(ctx, evt.evt); err != nil { + portal.sendErrorStatus(ctx, evt.evt, err) + return + } + } + res = portal.handleMatrixEvent(ctx, evt.sender, evt.evt, isStateRequest) + if res.SendMSS { + if res.Error != nil { + portal.sendErrorStatus(ctx, evt.evt, res.Error) + } else { + portal.sendSuccessStatus(ctx, evt.evt, 0, "") + } + } + if !isStateRequest && res.Error != nil && evt.evt.StateKey != nil { + portal.revertRoomMeta(ctx, evt.evt) + } + if isStateRequest && res.Success && !res.SkipStateEcho { + portal.sendRoomMeta( + ctx, + evt.sender.DoublePuppet(ctx), + time.UnixMilli(evt.evt.Timestamp), + evt.evt.Type, + evt.evt.GetStateKey(), + evt.evt.Content.Parsed, + false, + evt.evt.Content.Raw, + ) + } case *portalRemoteEvent: - portal.handleRemoteEvent(ctx, evt.source, evt.evtType, evt.evt) + res = portal.handleRemoteEvent(ctx, evt.source, evt.evtType, evt.evt) case *portalCreateEvent: - evt.cb(portal.createMatrixRoomInLoop(evt.ctx, evt.source, evt.info, nil)) + err := portal.createMatrixRoomInLoop(evt.ctx, evt.source, evt.info, nil) + res.Success = err == nil + evt.cb(err) default: panic(fmt.Errorf("illegal type %T in eventLoop", evt)) } } +func (portal *Portal) unwrapBeeperSendState(ctx context.Context, evt *event.Event) error { + content, ok := evt.Content.Parsed.(*event.BeeperSendStateEventContent) + if !ok { + return fmt.Errorf("%w: %T", ErrUnexpectedParsedContentType, evt.Content.Parsed) + } + evt.Content = content.Content + evt.StateKey = &content.StateKey + evt.Type = event.Type{Type: content.Type, Class: event.StateEventType} + _ = evt.Content.ParseRaw(evt.Type) + mx, ok := portal.Bridge.Matrix.(MatrixConnectorWithArbitraryRoomState) + if !ok { + return fmt.Errorf("matrix connector doesn't support fetching state") + } + prevEvt, err := mx.GetStateEvent(ctx, portal.MXID, evt.Type, evt.GetStateKey()) + if err != nil && !errors.Is(err, mautrix.MNotFound) { + return fmt.Errorf("failed to get prev event: %w", err) + } else if prevEvt != nil { + evt.Unsigned.PrevContent = &prevEvt.Content + evt.Unsigned.PrevSender = prevEvt.Sender + } + return nil +} + func (portal *Portal) FindPreferredLogin(ctx context.Context, user *User, allowRelay bool) (*UserLogin, *database.UserPortal, error) { if portal.Receiver != "" { login, err := portal.Bridge.GetExistingUserLoginByID(ctx, portal.Receiver) if err != nil { return nil, nil, err } - if login == nil || login.UserMXID != user.MXID || !login.Client.IsLoggedIn() { + if login == nil { + return nil, nil, fmt.Errorf("%w (receiver login is nil)", ErrNotLoggedIn) + } else if !login.Client.IsLoggedIn() { + return nil, nil, fmt.Errorf("%w (receiver login is not logged in)", ErrNotLoggedIn) + } else if login.UserMXID != user.MXID { if allowRelay && portal.Relay != nil { return nil, nil, nil } - // TODO different error for this case? - return nil, nil, ErrNotLoggedIn + return nil, nil, fmt.Errorf("%w (relay not set and receiver login is owned by %s, not %s)", ErrNotLoggedIn, login.UserMXID, user.MXID) } up, err := portal.Bridge.DB.UserPortal.Get(ctx, login.UserLogin, portal.PortalKey) return login, up, err @@ -482,29 +689,45 @@ func (portal *Portal) checkConfusableName(ctx context.Context, userID id.UserID, return false } -func (portal *Portal) handleMatrixEvent(ctx context.Context, sender *User, evt *event.Event) { +var fakePerMessageProfileEventType = event.Type{Class: event.StateEventType, Type: "m.per_message_profile"} + +func (portal *Portal) handleMatrixEvent(ctx context.Context, sender *User, evt *event.Event, isStateRequest bool) EventHandlingResult { log := zerolog.Ctx(ctx) if evt.Mautrix.EventSource&event.SourceEphemeral != 0 { switch evt.Type { case event.EphemeralEventReceipt: - portal.handleMatrixReceipts(ctx, evt) + return portal.handleMatrixReceipts(ctx, evt) case event.EphemeralEventTyping: - portal.handleMatrixTyping(ctx, evt) + return portal.handleMatrixTyping(ctx, evt) + case event.BeeperEphemeralEventAIStream: + return portal.handleMatrixAIStream(ctx, sender, evt) + default: + return EventHandlingResultIgnored } - return } - login, _, err := portal.FindPreferredLogin(ctx, sender, true) + if evt.Type == event.StateTombstone { + // Tombstones aren't bridged so they don't need a login + return portal.handleMatrixTombstone(ctx, evt) + } + login, userPortal, err := portal.FindPreferredLogin(ctx, sender, true) if err != nil { log.Err(err).Msg("Failed to get user login to handle Matrix event") if errors.Is(err, ErrNotLoggedIn) { - portal.sendErrorStatus(ctx, evt, WrapErrorInStatus(err).WithMessage("You're not logged in").WithIsCertain(true).WithSendNotice(true)) + shouldSendNotice := evt.Content.AsMessage().MsgType != event.MsgNotice + return EventHandlingResultFailed.WithMSSError( + WrapErrorInStatus(err).WithMessage("You're not logged in").WithIsCertain(true).WithSendNotice(shouldSendNotice), + ) } else { - portal.sendErrorStatus(ctx, evt, WrapErrorInStatus(err).WithMessage("Failed to get login to handle event").WithIsCertain(true).WithSendNotice(true)) + return EventHandlingResultFailed.WithMSSError( + WrapErrorInStatus(err).WithMessage("Failed to get login to handle event").WithIsCertain(true).WithSendNotice(true), + ) } - return } var origSender *OrigSender if login == nil { + if isStateRequest { + return EventHandlingResultFailed.WithMSSError(ErrCantRelayStateRequest) + } login = portal.Relay origSender = &OrigSender{ User: sender, @@ -526,47 +749,88 @@ func (portal *Portal) handleMatrixEvent(ctx context.Context, sender *User, evt * } else { origSender.DisambiguatedName = sender.MXID.String() } + msg := evt.Content.AsMessage() + if msg != nil && msg.BeeperPerMessageProfile != nil && msg.BeeperPerMessageProfile.Displayname != "" { + pmp := msg.BeeperPerMessageProfile + origSender.PerMessageProfile = *pmp + roomPLs, err := portal.Bridge.Matrix.GetPowerLevels(ctx, portal.MXID) + if err != nil { + log.Warn().Err(err).Msg("Failed to get power levels to check relay profile") + } + if roomPLs != nil && + roomPLs.GetUserLevel(sender.MXID) >= roomPLs.GetEventLevel(fakePerMessageProfileEventType) && + !portal.checkConfusableName(ctx, sender.MXID, pmp.Displayname) { + origSender.DisambiguatedName = pmp.Displayname + origSender.RequiresDisambiguation = false + } else { + origSender.DisambiguatedName = fmt.Sprintf("%s via %s", pmp.Displayname, origSender.DisambiguatedName) + } + } + origSender.FormattedName = portal.Bridge.Config.Relay.FormatName(origSender) } // Copy logger because many of the handlers will use UpdateContext ctx = log.With().Str("login_id", string(login.ID)).Logger().WithContext(ctx) + + if origSender == nil && portal.Bridge.Network.GetCapabilities().ImplicitReadReceipts && !evt.Type.IsAccountData() { + rrLog := log.With().Str("subaction", "implicit read receipt").Logger() + rrCtx := rrLog.WithContext(ctx) + rrLog.Debug().Msg("Sending implicit read receipt for event") + evtTS := time.UnixMilli(evt.Timestamp) + portal.callReadReceiptHandler(rrCtx, login, nil, &MatrixReadReceipt{ + Portal: portal, + EventID: evt.ID, + Implicit: true, + ReadUpTo: evtTS, + Receipt: event.ReadReceipt{Timestamp: evtTS}, + }, userPortal) + } + switch evt.Type { case event.EventMessage, event.EventSticker, event.EventUnstablePollStart, event.EventUnstablePollResponse: - portal.handleMatrixMessage(ctx, login, origSender, evt) + return portal.handleMatrixMessage(ctx, login, origSender, evt) case event.EventReaction: if origSender != nil { log.Debug().Msg("Ignoring reaction event from relayed user") - portal.sendErrorStatus(ctx, evt, ErrIgnoringReactionFromRelayedUser) - return + return EventHandlingResultIgnored.WithMSSError(ErrIgnoringReactionFromRelayedUser) } - portal.handleMatrixReaction(ctx, login, evt) + return portal.handleMatrixReaction(ctx, login, evt) case event.EventRedaction: - portal.handleMatrixRedaction(ctx, login, origSender, evt) + return portal.handleMatrixRedaction(ctx, login, origSender, evt) case event.StateRoomName: - handleMatrixRoomMeta(portal, ctx, login, origSender, evt, RoomNameHandlingNetworkAPI.HandleMatrixRoomName) + return handleMatrixRoomMeta(portal, ctx, login, origSender, evt, isStateRequest, RoomNameHandlingNetworkAPI.HandleMatrixRoomName) case event.StateTopic: - handleMatrixRoomMeta(portal, ctx, login, origSender, evt, RoomTopicHandlingNetworkAPI.HandleMatrixRoomTopic) + return handleMatrixRoomMeta(portal, ctx, login, origSender, evt, isStateRequest, RoomTopicHandlingNetworkAPI.HandleMatrixRoomTopic) case event.StateRoomAvatar: - handleMatrixRoomMeta(portal, ctx, login, origSender, evt, RoomAvatarHandlingNetworkAPI.HandleMatrixRoomAvatar) + return handleMatrixRoomMeta(portal, ctx, login, origSender, evt, isStateRequest, RoomAvatarHandlingNetworkAPI.HandleMatrixRoomAvatar) + case event.StateBeeperDisappearingTimer: + return handleMatrixRoomMeta(portal, ctx, login, origSender, evt, isStateRequest, DisappearTimerChangingNetworkAPI.HandleMatrixDisappearingTimer) case event.StateEncryption: // TODO? + return EventHandlingResultIgnored case event.AccountDataMarkedUnread: - handleMatrixAccountData(portal, ctx, login, evt, MarkedUnreadHandlingNetworkAPI.HandleMarkedUnread) + return handleMatrixAccountData(portal, ctx, login, evt, MarkedUnreadHandlingNetworkAPI.HandleMarkedUnread) case event.AccountDataRoomTags: - handleMatrixAccountData(portal, ctx, login, evt, TagHandlingNetworkAPI.HandleRoomTag) + return handleMatrixAccountData(portal, ctx, login, evt, TagHandlingNetworkAPI.HandleRoomTag) case event.AccountDataBeeperMute: - handleMatrixAccountData(portal, ctx, login, evt, MuteHandlingNetworkAPI.HandleMute) + return handleMatrixAccountData(portal, ctx, login, evt, MuteHandlingNetworkAPI.HandleMute) case event.StateMember: - portal.handleMatrixMembership(ctx, login, origSender, evt) + return portal.handleMatrixMembership(ctx, login, origSender, evt, isStateRequest) case event.StatePowerLevels: - portal.handleMatrixPowerLevels(ctx, login, origSender, evt) + return portal.handleMatrixPowerLevels(ctx, login, origSender, evt, isStateRequest) + case event.BeeperDeleteChat: + return portal.handleMatrixDeleteChat(ctx, login, origSender, evt) + case event.BeeperAcceptMessageRequest: + return portal.handleMatrixAcceptMessageRequest(ctx, login, origSender, evt) + default: + return EventHandlingResultIgnored } } -func (portal *Portal) handleMatrixReceipts(ctx context.Context, evt *event.Event) { +func (portal *Portal) handleMatrixReceipts(ctx context.Context, evt *event.Event) EventHandlingResult { content, ok := evt.Content.Parsed.(*event.ReceiptEventContent) if !ok { - return + return EventHandlingResultFailed } for evtID, receipts := range *content { readReceipts, ok := receipts[event.ReceiptTypeRead] @@ -576,12 +840,14 @@ func (portal *Portal) handleMatrixReceipts(ctx context.Context, evt *event.Event for userID, receipt := range readReceipts { sender, err := portal.Bridge.GetUserByMXID(ctx, userID) if err != nil { - // TODO log - return + zerolog.Ctx(ctx).Err(err).Msg("Failed to get user to handle read receipt") + return EventHandlingResultFailed.WithError(err) } portal.handleMatrixReadReceipt(ctx, sender, evtID, receipt) } } + // TODO actual status + return EventHandlingResultSuccess } func (portal *Portal) handleMatrixReadReceipt(ctx context.Context, user *User, eventID id.EventID, receipt event.ReadReceipt) { @@ -613,15 +879,10 @@ func (portal *Portal) handleMatrixReadReceipt(ctx context.Context, user *User, e EventID: eventID, Receipt: receipt, } - if userPortal == nil { - userPortal = database.UserPortalFor(login.UserLogin, portal.PortalKey) - } else { - evt.LastRead = userPortal.LastRead - userPortal = userPortal.CopyWithoutValues() - } evt.ExactMessage, err = portal.Bridge.DB.Message.GetPartByMXID(ctx, eventID) if err != nil { log.Err(err).Msg("Failed to get exact message from database") + evt.ReadUpTo = receipt.Timestamp } else if evt.ExactMessage != nil { log.UpdateContext(func(c zerolog.Context) zerolog.Context { return c.Str("exact_message_id", string(evt.ExactMessage.ID)).Time("exact_message_ts", evt.ExactMessage.Timestamp) @@ -630,27 +891,46 @@ func (portal *Portal) handleMatrixReadReceipt(ctx context.Context, user *User, e } else { evt.ReadUpTo = receipt.Timestamp } - err = rrClient.HandleMatrixReadReceipt(ctx, evt) - if err != nil { - log.Err(err).Msg("Failed to handle read receipt") - return - } - if evt.ExactMessage != nil { - userPortal.LastRead = evt.ExactMessage.Timestamp - } else { - userPortal.LastRead = receipt.Timestamp - } - err = portal.Bridge.DB.UserPortal.Put(ctx, userPortal) - if err != nil { - log.Err(err).Msg("Failed to save user portal metadata") - } - portal.Bridge.DisappearLoop.StartAll(ctx, portal.MXID) + portal.callReadReceiptHandler(ctx, login, rrClient, evt, userPortal) } -func (portal *Portal) handleMatrixTyping(ctx context.Context, evt *event.Event) { +func (portal *Portal) callReadReceiptHandler( + ctx context.Context, + login *UserLogin, + rrClient ReadReceiptHandlingNetworkAPI, + evt *MatrixReadReceipt, + userPortal *database.UserPortal, +) { + if rrClient == nil { + var ok bool + rrClient, ok = login.Client.(ReadReceiptHandlingNetworkAPI) + if !ok { + return + } + } + if userPortal == nil { + userPortal = database.UserPortalFor(login.UserLogin, portal.PortalKey) + } else { + evt.LastRead = userPortal.LastRead + userPortal = userPortal.CopyWithoutValues() + } + err := rrClient.HandleMatrixReadReceipt(ctx, evt) + if err != nil { + zerolog.Ctx(ctx).Err(err).Msg("Failed to handle read receipt") + return + } + userPortal.LastRead = evt.ReadUpTo + err = portal.Bridge.DB.UserPortal.Put(ctx, userPortal) + if err != nil { + zerolog.Ctx(ctx).Err(err).Msg("Failed to save user portal metadata") + } + portal.Bridge.DisappearLoop.StartAllBefore(ctx, portal.MXID, evt.ReadUpTo) +} + +func (portal *Portal) handleMatrixTyping(ctx context.Context, evt *event.Event) EventHandlingResult { content, ok := evt.Content.Parsed.(*event.TypingEventContent) if !ok { - return + return EventHandlingResultFailed } portal.currentlyTypingLock.Lock() defer portal.currentlyTypingLock.Unlock() @@ -661,6 +941,52 @@ func (portal *Portal) handleMatrixTyping(ctx context.Context, evt *event.Event) portal.sendTypings(ctx, stoppedTyping, false) portal.sendTypings(ctx, startedTyping, true) portal.currentlyTyping = content.UserIDs + // TODO actual status + return EventHandlingResultSuccess +} + +func (portal *Portal) handleMatrixAIStream(ctx context.Context, sender *User, evt *event.Event) EventHandlingResult { + log := zerolog.Ctx(ctx) + if sender == nil { + log.Error().Msg("Missing sender for Matrix AI stream event") + return EventHandlingResultIgnored + } + login, _, err := portal.FindPreferredLogin(ctx, sender, true) + if err != nil { + log.Err(err).Msg("Failed to get user login to handle Matrix AI stream event") + return EventHandlingResultFailed.WithMSSError(err) + } + var origSender *OrigSender + if login == nil { + if portal.Relay == nil { + return EventHandlingResultIgnored + } + login = portal.Relay + origSender = &OrigSender{ + User: sender, + UserID: sender.MXID, + } + } + content, ok := evt.Content.Parsed.(*event.BeeperAIStreamEventContent) + if !ok { + log.Error().Type("content_type", evt.Content.Parsed).Msg("Unexpected parsed content type") + return EventHandlingResultFailed.WithMSSError(fmt.Errorf("%w: %T", ErrUnexpectedParsedContentType, evt.Content.Parsed)) + } + api, ok := login.Client.(BeeperAIStreamHandlingNetworkAPI) + if !ok { + return EventHandlingResultIgnored.WithMSSError(ErrBeeperAIStreamNotSupported) + } + err = api.HandleMatrixBeeperAIStream(ctx, &MatrixBeeperAIStream{ + Event: evt, + Content: content, + Portal: portal, + OrigSender: origSender, + }) + if err != nil { + log.Err(err).Msg("Failed to handle Matrix AI stream event") + return EventHandlingResultFailed.WithMSSError(err) + } + return EventHandlingResultSuccess.WithMSS() } func (portal *Portal) sendTypings(ctx context.Context, userIDs []id.UserID, typing bool) { @@ -751,28 +1077,55 @@ func (portal *Portal) periodicTypingUpdater() { } } -func (portal *Portal) checkMessageContentCaps(ctx context.Context, caps *NetworkRoomCapabilities, content *event.MessageEventContent, evt *event.Event) bool { +func (portal *Portal) checkMessageContentCaps(caps *event.RoomFeatures, content *event.MessageEventContent) error { switch content.MsgType { case event.MsgText, event.MsgNotice, event.MsgEmote: // No checks for now, message length is safer to check after conversion inside connector case event.MsgLocation: - if !caps.LocationMessages { - portal.sendErrorStatus(ctx, evt, ErrLocationMessagesNotAllowed) - return false + if caps.LocationMessage.Reject() { + return ErrLocationMessagesNotAllowed } - case event.MsgImage, event.MsgAudio, event.MsgVideo, event.MsgFile: - if content.FileName != "" && content.Body != content.FileName { - if !caps.Captions { - portal.sendErrorStatus(ctx, evt, ErrCaptionsNotAllowed) - return false + case event.MsgImage, event.MsgAudio, event.MsgVideo, event.MsgFile, event.CapMsgSticker: + capMsgType := content.GetCapMsgType() + feat, ok := caps.File[capMsgType] + if !ok { + return ErrUnsupportedMessageType + } + if content.MsgType != event.CapMsgSticker && + content.FileName != "" && + content.Body != content.FileName && + feat.Caption.Reject() { + return ErrCaptionsNotAllowed + } + if content.Info != nil { + dur := time.Duration(content.Info.Duration) * time.Millisecond + if feat.MaxDuration != nil && dur > feat.MaxDuration.Duration { + if capMsgType == event.CapMsgVoice { + return fmt.Errorf("%w: %s supports voice messages up to %s long", ErrVoiceMessageDurationTooLong, portal.Bridge.Network.GetName().DisplayName, exfmt.Duration(feat.MaxDuration.Duration)) + } + return fmt.Errorf("%w: %s is longer than the maximum of %s", ErrMediaDurationTooLong, exfmt.Duration(dur), exfmt.Duration(feat.MaxDuration.Duration)) + } + if feat.MaxSize != 0 && int64(content.Info.Size) > feat.MaxSize { + return fmt.Errorf("%w: %.1f MiB is larger than the maximum of %.1f MiB", ErrMediaTooLarge, float64(content.Info.Size)/1024/1024, float64(feat.MaxSize)/1024/1024) + } + if content.Info.MimeType != "" && feat.GetMimeSupport(content.Info.MimeType).Reject() { + return fmt.Errorf("%w (%s in %s)", ErrUnsupportedMediaType, content.Info.MimeType, capMsgType) } } + fallthrough default: } - return true + return nil } -func (portal *Portal) handleMatrixMessage(ctx context.Context, sender *UserLogin, origSender *OrigSender, evt *event.Event) { +func (portal *Portal) parseInputTransactionID(origSender *OrigSender, evt *event.Event) networkid.RawTransactionID { + if origSender != nil || !strings.HasPrefix(evt.ID.String(), database.NetworkTxnMXIDPrefix) { + return "" + } + return networkid.RawTransactionID(strings.TrimPrefix(evt.ID.String(), database.NetworkTxnMXIDPrefix)) +} + +func (portal *Portal) handleMatrixMessage(ctx context.Context, sender *UserLogin, origSender *OrigSender, evt *event.Event) EventHandlingResult { log := zerolog.Ctx(ctx) var relatesTo *event.RelatesTo var msgContent *event.MessageEventContent @@ -788,46 +1141,49 @@ func (portal *Portal) handleMatrixMessage(ctx context.Context, sender *UserLogin } else { msgContent, ok = evt.Content.Parsed.(*event.MessageEventContent) relatesTo = msgContent.RelatesTo + if evt.Type == event.EventSticker { + msgContent.MsgType = event.CapMsgSticker + } + if msgContent.MsgType == event.MsgNotice && !portal.Bridge.Config.BridgeNotices { + return EventHandlingResultIgnored.WithMSSError(ErrIgnoringMNotice) + } } if !ok { log.Error().Type("content_type", evt.Content.Parsed).Msg("Unexpected parsed content type") - portal.sendErrorStatus(ctx, evt, fmt.Errorf("%w: %T", ErrUnexpectedParsedContentType, evt.Content.Parsed)) - return + return EventHandlingResultFailed. + WithMSSError(fmt.Errorf("%w: %T", ErrUnexpectedParsedContentType, evt.Content.Parsed)) } caps := sender.Client.GetCapabilities(ctx, portal) if relatesTo.GetReplaceID() != "" { if msgContent == nil { log.Warn().Msg("Ignoring edit of poll") - portal.sendErrorStatus(ctx, evt, fmt.Errorf("%w of polls", ErrEditsNotSupported)) - return + return EventHandlingResultFailed.WithMSSError(fmt.Errorf("%w of polls", ErrEditsNotSupported)) } - portal.handleMatrixEdit(ctx, sender, origSender, evt, msgContent, caps) - return + return portal.handleMatrixEdit(ctx, sender, origSender, evt, msgContent, caps) } var err error if origSender != nil { if msgContent == nil { log.Debug().Msg("Ignoring poll event from relayed user") - portal.sendErrorStatus(ctx, evt, ErrIgnoringPollFromRelayedUser) - return + return EventHandlingResultIgnored.WithMSSError(ErrIgnoringPollFromRelayedUser) } - msgContent, err = portal.Bridge.Config.Relay.FormatMessage(msgContent, origSender) - if err != nil { - log.Err(err).Msg("Failed to format message for relaying") - portal.sendErrorStatus(ctx, evt, err) - return + if !caps.PerMessageProfileRelay { + msgContent, err = portal.Bridge.Config.Relay.FormatMessage(msgContent, origSender) + if err != nil { + log.Err(err).Msg("Failed to format message for relaying") + return EventHandlingResultFailed.WithMSSError(err) + } } } if msgContent != nil { - if !portal.checkMessageContentCaps(ctx, caps, msgContent, evt) { - return + if err = portal.checkMessageContentCaps(caps, msgContent); err != nil { + return EventHandlingResultFailed.WithMSSError(err) } } else if pollResponseContent != nil || pollContent != nil { if _, ok = sender.Client.(PollHandlingNetworkAPI); !ok { log.Debug().Msg("Ignoring poll event as network connector doesn't implement PollHandlingNetworkAPI") - portal.sendErrorStatus(ctx, evt, ErrPollsNotSupported) - return + return EventHandlingResultIgnored.WithMSSError(ErrPollsNotSupported) } } @@ -837,29 +1193,29 @@ func (portal *Portal) handleMatrixMessage(ctx context.Context, sender *UserLogin if err != nil { log.Err(err).Msg("Failed to get poll target message from database") // TODO send status - return + return EventHandlingResultFailed } else if voteTo == nil { log.Warn().Stringer("vote_to_id", relatesTo.GetReferenceID()).Msg("Poll target message not found") // TODO send status - return + return EventHandlingResultFailed } } var replyToID id.EventID - if caps.Threads { + threadRootID := relatesTo.GetThreadParent() + if caps.Thread.Partial() { replyToID = relatesTo.GetNonFallbackReplyTo() + if threadRootID != "" { + threadRoot, err = portal.Bridge.DB.Message.GetPartByMXID(ctx, threadRootID) + if err != nil { + log.Err(err).Msg("Failed to get thread root message from database") + } else if threadRoot == nil { + log.Warn().Stringer("thread_root_id", threadRootID).Msg("Thread root message not found") + } + } } else { replyToID = relatesTo.GetReplyTo() } - threadRootID := relatesTo.GetThreadParent() - if caps.Threads && threadRootID != "" { - threadRoot, err = portal.Bridge.DB.Message.GetPartByMXID(ctx, threadRootID) - if err != nil { - log.Err(err).Msg("Failed to get thread root message from database") - } else if threadRoot == nil { - log.Warn().Stringer("thread_root_id", threadRootID).Msg("Thread root message not found") - } - } - if replyToID != "" && (caps.Replies || caps.Threads) { + if replyToID != "" && (caps.Reply.Partial() || caps.Thread.Partial()) { replyTo, err = portal.Bridge.DB.Message.GetPartByMXID(ctx, replyToID) if err != nil { log.Err(err).Msg("Failed to get reply target message from database") @@ -870,7 +1226,7 @@ func (portal *Portal) handleMatrixMessage(ctx context.Context, sender *UserLogin // The fallback happens if the message is not a Matrix thread and either // * the replied-to message is in a thread, or // * the network only supports threads (assume the user wants to start a new thread) - if caps.Threads && threadRoot == nil && (replyTo.ThreadRoot != "" || !caps.Replies) { + if caps.Thread.Partial() && threadRoot == nil && (replyTo.ThreadRoot != "" || !caps.Reply.Partial()) { threadRootRemoteID := replyTo.ThreadRoot if threadRootRemoteID == "" { threadRootRemoteID = replyTo.ID @@ -880,11 +1236,21 @@ func (portal *Portal) handleMatrixMessage(ctx context.Context, sender *UserLogin log.Err(err).Msg("Failed to get thread root message from database (via reply fallback)") } } - if !caps.Replies { + if !caps.Reply.Partial() { replyTo = nil } } } + var messageTimer *event.BeeperDisappearingTimer + if msgContent != nil { + messageTimer = msgContent.BeeperDisappearingTimer + } + if messageTimer != nil && *portal.Disappear.ToEventContent() != *messageTimer { + log.Warn(). + Any("event_timer", messageTimer). + Any("portal_timer", portal.Disappear.ToEventContent()). + Msg("Mismatching disappearing timer in event") + } wrappedMsgEvt := &MatrixMessage{ MatrixEventBase: MatrixEventBase[*event.MessageEventContent]{ @@ -892,10 +1258,30 @@ func (portal *Portal) handleMatrixMessage(ctx context.Context, sender *UserLogin Content: msgContent, OrigSender: origSender, Portal: portal, + + InputTransactionID: portal.parseInputTransactionID(origSender, evt), }, ThreadRoot: threadRoot, ReplyTo: replyTo, } + if portal.Bridge.Config.DeduplicateMatrixMessages { + if part, err := portal.Bridge.DB.Message.GetPartByTxnID(ctx, portal.Receiver, evt.ID, wrappedMsgEvt.InputTransactionID); err != nil { + log.Err(err).Msg("Failed to check db if message is already sent") + } else if part != nil { + log.Debug(). + Stringer("message_mxid", part.MXID). + Stringer("input_event_id", evt.ID). + Msg("Message already sent, ignoring") + return EventHandlingResultIgnored + } + } + + err = portal.autoAcceptMessageRequest(ctx, evt, sender, origSender, caps) + if err != nil { + log.Warn().Err(err).Msg("Failed to auto-accept message request on message") + // TODO stop processing? + } + var resp *MatrixMessageResponse if msgContent != nil { resp, err = sender.Client.HandleMatrixMessage(ctx, wrappedMsgEvt) @@ -912,16 +1298,18 @@ func (portal *Portal) handleMatrixMessage(ctx context.Context, sender *UserLogin }) } else { log.Error().Msg("Failed to handle Matrix message: all contents are nil?") - portal.sendErrorStatus(ctx, evt, fmt.Errorf("all contents are nil")) - return + return EventHandlingResultFailed.WithMSSError(fmt.Errorf("all contents are nil")) } if err != nil { log.Err(err).Msg("Failed to handle Matrix message") - portal.sendErrorStatus(ctx, evt, err) - return + return EventHandlingResultFailed.WithMSSError(err) } message := wrappedMsgEvt.fillDBMessage(resp.DB) - if !resp.Pending { + if resp.Pending { + for _, save := range wrappedMsgEvt.pendingSaves { + save.ackedAt = time.Now() + } + } else { if resp.DB == nil { log.Error().Msg("Network connector didn't return a message to save") } else { @@ -945,17 +1333,23 @@ func (portal *Portal) handleMatrixMessage(ctx context.Context, sender *UserLogin } portal.sendSuccessStatus(ctx, evt, resp.StreamOrder, message.MXID) } - if portal.Disappear.Type != database.DisappearingTypeNone { + ds := portal.Disappear + if messageTimer != nil { + ds = database.DisappearingSettingFromEvent(messageTimer) + } + if ds.Type != event.DisappearingTypeNone { go portal.Bridge.DisappearLoop.Add(ctx, &database.DisappearingMessage{ - RoomID: portal.MXID, - EventID: message.MXID, - DisappearingSetting: database.DisappearingSetting{ - Type: portal.Disappear.Type, - Timer: portal.Disappear.Timer, - DisappearAt: message.Timestamp.Add(portal.Disappear.Timer), - }, + RoomID: portal.MXID, + EventID: message.MXID, + Timestamp: message.Timestamp, + DisappearingSetting: ds.StartingAt(message.Timestamp), }) } + if resp.Pending { + // Not exactly queued, but not finished either + return EventHandlingResultQueued + } + return EventHandlingResultSuccess.WithEventID(message.MXID).WithStreamOrder(resp.StreamOrder) } // AddPendingToIgnore adds a transaction ID that should be ignored if encountered as a new message. @@ -967,7 +1361,7 @@ func (portal *Portal) handleMatrixMessage(ctx context.Context, sender *UserLogin // See also: [MatrixMessage.AddPendingToSave] func (evt *MatrixMessage) AddPendingToIgnore(txnID networkid.TransactionID) { evt.Portal.outgoingMessagesLock.Lock() - evt.Portal.outgoingMessages[txnID] = outgoingMessage{ + evt.Portal.outgoingMessages[txnID] = &outgoingMessage{ ignore: true, } evt.Portal.outgoingMessagesLock.Unlock() @@ -981,12 +1375,14 @@ func (evt *MatrixMessage) AddPendingToIgnore(txnID networkid.TransactionID) { // // The provided function will be called when the message is encountered. func (evt *MatrixMessage) AddPendingToSave(message *database.Message, txnID networkid.TransactionID, handleEcho RemoteEchoHandler) { - evt.Portal.outgoingMessagesLock.Lock() - evt.Portal.outgoingMessages[txnID] = outgoingMessage{ + pending := &outgoingMessage{ db: evt.fillDBMessage(message), evt: evt.Event, handle: handleEcho, } + evt.Portal.outgoingMessagesLock.Lock() + evt.Portal.outgoingMessages[txnID] = pending + evt.pendingSaves = append(evt.pendingSaves, pending) evt.Portal.outgoingMessagesLock.Unlock() } @@ -994,6 +1390,12 @@ func (evt *MatrixMessage) AddPendingToSave(message *database.Message, txnID netw // This should only be called if sending the message fails. func (evt *MatrixMessage) RemovePending(txnID networkid.TransactionID) { evt.Portal.outgoingMessagesLock.Lock() + pendingSave := evt.Portal.outgoingMessages[txnID] + if pendingSave != nil { + evt.pendingSaves = slices.DeleteFunc(evt.pendingSaves, func(save *outgoingMessage) bool { + return save == pendingSave + }) + } delete(evt.Portal.outgoingMessages, txnID) evt.Portal.outgoingMessagesLock.Unlock() } @@ -1024,10 +1426,49 @@ func (evt *MatrixMessage) fillDBMessage(message *database.Message) *database.Mes if message.SenderMXID == "" { message.SenderMXID = evt.Event.Sender } + if message.SendTxnID != "" { + message.SendTxnID = evt.InputTransactionID + } return message } -func (portal *Portal) handleMatrixEdit(ctx context.Context, sender *UserLogin, origSender *OrigSender, evt *event.Event, content *event.MessageEventContent, caps *NetworkRoomCapabilities) { +func (portal *Portal) pendingMessageTimeoutLoop(ctx context.Context, cfg *OutgoingTimeoutConfig) { + ticker := time.NewTicker(cfg.CheckInterval) + defer ticker.Stop() + for { + select { + case <-ticker.C: + portal.checkPendingMessages(ctx, cfg) + case <-ctx.Done(): + return + } + } +} + +func (portal *Portal) checkPendingMessages(ctx context.Context, cfg *OutgoingTimeoutConfig) { + portal.outgoingMessagesLock.Lock() + defer portal.outgoingMessagesLock.Unlock() + for _, msg := range portal.outgoingMessages { + if msg.evt != nil && !msg.timeouted { + if cfg.NoEchoTimeout > 0 && !msg.ackedAt.IsZero() && time.Since(msg.ackedAt) > cfg.NoEchoTimeout { + msg.timeouted = true + portal.sendErrorStatus(ctx, msg.evt, ErrRemoteEchoTimeout.WithMessage(cfg.NoEchoMessage)) + } else if cfg.NoAckTimeout > 0 && time.Since(msg.db.Timestamp) > cfg.NoAckTimeout { + msg.timeouted = true + portal.sendErrorStatus(ctx, msg.evt, ErrRemoteAckTimeout.WithMessage(cfg.NoAckMessage)) + } + } + } +} + +func (portal *Portal) handleMatrixEdit( + ctx context.Context, + sender *UserLogin, + origSender *OrigSender, + evt *event.Event, + content *event.MessageEventContent, + caps *event.RoomFeatures, +) EventHandlingResult { log := zerolog.Ctx(ctx) editTargetID := content.RelatesTo.GetReplaceID() log.UpdateContext(func(c zerolog.Context) zerolog.Context { @@ -1035,44 +1476,40 @@ func (portal *Portal) handleMatrixEdit(ctx context.Context, sender *UserLogin, o }) if content.NewContent != nil { content = content.NewContent + if evt.Type == event.EventSticker { + content.MsgType = event.CapMsgSticker + } } if origSender != nil { var err error content, err = portal.Bridge.Config.Relay.FormatMessage(content, origSender) if err != nil { log.Err(err).Msg("Failed to format message for relaying") - portal.sendErrorStatus(ctx, evt, err) - return + return EventHandlingResultFailed.WithMSSError(err) } } editingAPI, ok := sender.Client.(EditHandlingNetworkAPI) if !ok { log.Debug().Msg("Ignoring edit as network connector doesn't implement EditHandlingNetworkAPI") - portal.sendErrorStatus(ctx, evt, ErrEditsNotSupported) - return - } else if !caps.Edits { + return EventHandlingResultIgnored.WithMSSError(ErrEditsNotSupported) + } else if !caps.Edit.Partial() { log.Debug().Msg("Ignoring edit as room doesn't support edits") - portal.sendErrorStatus(ctx, evt, ErrEditsNotSupportedInPortal) - return - } else if !portal.checkMessageContentCaps(ctx, caps, content, evt) { - return + return EventHandlingResultIgnored.WithMSSError(ErrEditsNotSupportedInPortal) + } else if err := portal.checkMessageContentCaps(caps, content); err != nil { + return EventHandlingResultFailed.WithMSSError(err) } editTarget, err := portal.Bridge.DB.Message.GetPartByMXID(ctx, editTargetID) if err != nil { log.Err(err).Msg("Failed to get edit target message from database") - portal.sendErrorStatus(ctx, evt, fmt.Errorf("%w: failed to get edit target: %w", ErrDatabaseError, err)) - return + return EventHandlingResultFailed.WithMSSError(fmt.Errorf("%w: failed to get edit target: %w", ErrDatabaseError, err)) } else if editTarget == nil { log.Warn().Msg("Edit target message not found in database") - portal.sendErrorStatus(ctx, evt, fmt.Errorf("edit %w", ErrTargetMessageNotFound)) - return - } else if caps.EditMaxAge > 0 && time.Since(editTarget.Timestamp) > caps.EditMaxAge { - portal.sendErrorStatus(ctx, evt, ErrEditTargetTooOld) - return + return EventHandlingResultFailed.WithMSSError(fmt.Errorf("edit %w", ErrTargetMessageNotFound)) + } else if caps.EditMaxAge != nil && caps.EditMaxAge.Duration > 0 && time.Since(editTarget.Timestamp) > caps.EditMaxAge.Duration { + return EventHandlingResultFailed.WithMSSError(ErrEditTargetTooOld) } else if caps.EditMaxCount > 0 && editTarget.EditCount >= caps.EditMaxCount { - portal.sendErrorStatus(ctx, evt, ErrEditTargetTooManyEdits) - return + return EventHandlingResultFailed.WithMSSError(ErrEditTargetTooManyEdits) } log.UpdateContext(func(c zerolog.Context) zerolog.Context { return c.Str("edit_target_remote_id", string(editTarget.ID)) @@ -1083,13 +1520,14 @@ func (portal *Portal) handleMatrixEdit(ctx context.Context, sender *UserLogin, o Content: content, OrigSender: origSender, Portal: portal, + + InputTransactionID: portal.parseInputTransactionID(origSender, evt), }, EditTarget: editTarget, }) if err != nil { log.Err(err).Msg("Failed to handle Matrix edit") - portal.sendErrorStatus(ctx, evt, err) - return + return EventHandlingResultFailed.WithMSSError(err) } err = portal.Bridge.DB.Message.Update(ctx, editTarget) if err != nil { @@ -1097,21 +1535,20 @@ func (portal *Portal) handleMatrixEdit(ctx context.Context, sender *UserLogin, o } // TODO allow returning stream order from HandleMatrixEdit portal.sendSuccessStatus(ctx, evt, 0, "") + return EventHandlingResultSuccess } -func (portal *Portal) handleMatrixReaction(ctx context.Context, sender *UserLogin, evt *event.Event) { +func (portal *Portal) handleMatrixReaction(ctx context.Context, sender *UserLogin, evt *event.Event) (handleRes EventHandlingResult) { log := zerolog.Ctx(ctx) reactingAPI, ok := sender.Client.(ReactionHandlingNetworkAPI) if !ok { log.Debug().Msg("Ignoring reaction as network connector doesn't implement ReactionHandlingNetworkAPI") - portal.sendErrorStatus(ctx, evt, ErrReactionsNotSupported) - return + return EventHandlingResultIgnored.WithMSSError(ErrReactionsNotSupported) } content, ok := evt.Content.Parsed.(*event.ReactionEventContent) if !ok { log.Error().Type("content_type", evt.Content.Parsed).Msg("Unexpected parsed content type") - portal.sendErrorStatus(ctx, evt, fmt.Errorf("%w: %T", ErrUnexpectedParsedContentType, evt.Content.Parsed)) - return + return EventHandlingResultFailed.WithMSSError(fmt.Errorf("%w: %T", ErrUnexpectedParsedContentType, evt.Content.Parsed)) } log.UpdateContext(func(c zerolog.Context) zerolog.Context { return c.Stringer("reaction_target_mxid", content.RelatesTo.EventID) @@ -1119,12 +1556,16 @@ func (portal *Portal) handleMatrixReaction(ctx context.Context, sender *UserLogi reactionTarget, err := portal.Bridge.DB.Message.GetPartByMXID(ctx, content.RelatesTo.EventID) if err != nil { log.Err(err).Msg("Failed to get reaction target message from database") - portal.sendErrorStatus(ctx, evt, fmt.Errorf("%w: failed to get reaction target: %w", ErrDatabaseError, err)) - return + return EventHandlingResultFailed.WithMSSError(fmt.Errorf("%w: failed to get reaction target: %w", ErrDatabaseError, err)) } else if reactionTarget == nil { log.Warn().Msg("Reaction target message not found in database") - portal.sendErrorStatus(ctx, evt, fmt.Errorf("reaction %w", ErrTargetMessageNotFound)) - return + return EventHandlingResultFailed.WithMSSError(fmt.Errorf("reaction %w", ErrTargetMessageNotFound)) + } + caps := sender.Client.GetCapabilities(ctx, portal) + err = portal.autoAcceptMessageRequest(ctx, evt, sender, nil, caps) + if err != nil { + log.Warn().Err(err).Msg("Failed to auto-accept message request on reaction") + // TODO stop processing? } log.UpdateContext(func(c zerolog.Context) zerolog.Context { return c.Str("reaction_target_remote_id", string(reactionTarget.ID)) @@ -1134,46 +1575,64 @@ func (portal *Portal) handleMatrixReaction(ctx context.Context, sender *UserLogi Event: evt, Content: content, Portal: portal, + + InputTransactionID: portal.parseInputTransactionID(nil, evt), }, TargetMessage: reactionTarget, } preResp, err := reactingAPI.PreHandleMatrixReaction(ctx, react) if err != nil { log.Err(err).Msg("Failed to pre-handle Matrix reaction") - portal.sendErrorStatus(ctx, evt, err) - return + return EventHandlingResultFailed.WithMSSError(err) } var deterministicID id.EventID if portal.Bridge.Config.OutgoingMessageReID { deterministicID = portal.Bridge.Matrix.GenerateReactionEventID(portal.MXID, reactionTarget, preResp.SenderID, preResp.EmojiID) } - existing, err := portal.Bridge.DB.Reaction.GetByID(ctx, portal.Receiver, reactionTarget.ID, reactionTarget.PartID, preResp.SenderID, preResp.EmojiID) - if err != nil { - log.Err(err).Msg("Failed to check if reaction is a duplicate") - return - } else if existing != nil { - if existing.EmojiID != "" || existing.Emoji == preResp.Emoji { - log.Debug().Msg("Ignoring duplicate reaction") + defer func() { + // Do this in a defer so that it happens after any potential defer calls to removeOutdatedReaction + if handleRes.Success { portal.sendSuccessStatus(ctx, evt, 0, deterministicID) + } + }() + removeOutdatedReaction := func(oldReact *database.Reaction, deleteDB bool) { + if !handleRes.Success { return } - react.ReactionToOverride = existing - _, err = portal.Bridge.Bot.SendMessage(ctx, portal.MXID, event.EventRedaction, &event.Content{ + _, err := portal.Bridge.Bot.SendMessage(ctx, portal.MXID, event.EventRedaction, &event.Content{ Parsed: &event.RedactionEventContent{ - Redacts: existing.MXID, + Redacts: oldReact.MXID, }, }, nil) if err != nil { log.Err(err).Msg("Failed to remove old reaction") } + if deleteDB { + err = portal.Bridge.DB.Reaction.Delete(ctx, oldReact) + if err != nil { + log.Err(err).Msg("Failed to delete old reaction from database") + } + } + } + existing, err := portal.Bridge.DB.Reaction.GetByID(ctx, portal.Receiver, reactionTarget.ID, reactionTarget.PartID, preResp.SenderID, preResp.EmojiID) + if err != nil { + log.Err(err).Msg("Failed to check if reaction is a duplicate") + return EventHandlingResultFailed.WithMSSError(fmt.Errorf("%w: failed to check for existing reaction: %w", ErrDatabaseError, err)) + } else if existing != nil { + if existing.EmojiID != "" || existing.Emoji == preResp.Emoji { + log.Debug().Msg("Ignoring duplicate reaction") + portal.sendSuccessStatus(ctx, evt, 0, deterministicID) + return EventHandlingResultIgnored.WithEventID(deterministicID) + } + react.ReactionToOverride = existing + defer removeOutdatedReaction(existing, false) } react.PreHandleResp = &preResp if preResp.MaxReactions > 0 { allReactions, err := portal.Bridge.DB.Reaction.GetAllToMessageBySender(ctx, portal.Receiver, reactionTarget.ID, preResp.SenderID) if err != nil { log.Err(err).Msg("Failed to get all reactions to message by sender") - portal.sendErrorStatus(ctx, evt, fmt.Errorf("%w: failed to get previous reactions: %w", ErrDatabaseError, err)) - return + return EventHandlingResultFailed.WithMSSError(fmt.Errorf("%w: failed to get previous reactions: %w", ErrDatabaseError, err)) } if len(allReactions) < preResp.MaxReactions { react.ExistingReactionsToKeep = allReactions @@ -1181,26 +1640,21 @@ func (portal *Portal) handleMatrixReaction(ctx context.Context, sender *UserLogi // Keep n-1 previous reactions and remove the rest react.ExistingReactionsToKeep = allReactions[:preResp.MaxReactions-1] for _, oldReaction := range allReactions[preResp.MaxReactions-1:] { - _, err = portal.Bridge.Bot.SendMessage(ctx, portal.MXID, event.EventRedaction, &event.Content{ - Parsed: &event.RedactionEventContent{ - Redacts: oldReaction.MXID, - }, - }, nil) - if err != nil { - log.Err(err).Msg("Failed to remove previous reaction after limit was exceeded") - } - err = portal.Bridge.DB.Reaction.Delete(ctx, oldReaction) - if err != nil { - log.Err(err).Msg("Failed to delete previous reaction from database after limit was exceeded") + if existing != nil && oldReaction.EmojiID == existing.EmojiID { + // Don't double-delete on networks that only allow one emoji + continue } + // Intentionally defer in a loop, there won't be that many items, + // and we want all of them to be done after this function completes successfully + //goland:noinspection GoDeferInLoop + defer removeOutdatedReaction(oldReaction, true) } } } dbReaction, err := reactingAPI.HandleMatrixReaction(ctx, react) if err != nil { log.Err(err).Msg("Failed to handle Matrix reaction") - portal.sendErrorStatus(ctx, evt, err) - return + return EventHandlingResultFailed.WithMSSError(err) } if dbReaction == nil { dbReaction = &database.Reaction{} @@ -1238,7 +1692,7 @@ func (portal *Portal) handleMatrixReaction(ctx context.Context, sender *UserLogi if err != nil { log.Err(err).Msg("Failed to save reaction to database") } - portal.sendSuccessStatus(ctx, evt, 0, deterministicID) + return EventHandlingResultSuccess.WithEventID(deterministicID) } func handleMatrixRoomMeta[APIType any, ContentType any]( @@ -1247,35 +1701,53 @@ func handleMatrixRoomMeta[APIType any, ContentType any]( sender *UserLogin, origSender *OrigSender, evt *event.Event, + isStateRequest bool, fn func(APIType, context.Context, *MatrixRoomMeta[ContentType]) (bool, error), -) { +) EventHandlingResult { + if evt.StateKey == nil || *evt.StateKey != "" { + return EventHandlingResultFailed.WithMSSError(ErrInvalidStateKey) + } + //caps := sender.Client.GetCapabilities(ctx, portal) + //if stateCap, ok := caps.State[evt.Type.Type]; !ok || stateCap.Level <= event.CapLevelUnsupported { + // return EventHandlingResultIgnored.WithMSSError(fmt.Errorf("%s %w", evt.Type.Type, ErrRoomMetadataNotAllowed)) + //} api, ok := sender.Client.(APIType) if !ok { - portal.sendErrorStatus(ctx, evt, ErrRoomMetadataNotSupported) - return + return EventHandlingResultIgnored.WithMSSError(fmt.Errorf("%w of type %s", ErrRoomMetadataNotSupported, evt.Type)) } log := zerolog.Ctx(ctx) content, ok := evt.Content.Parsed.(ContentType) if !ok { log.Error().Type("content_type", evt.Content.Parsed).Msg("Unexpected parsed content type") - portal.sendErrorStatus(ctx, evt, fmt.Errorf("%w: %T", ErrUnexpectedParsedContentType, evt.Content.Parsed)) - return + return EventHandlingResultFailed.WithMSSError(fmt.Errorf("%w: %T", ErrUnexpectedParsedContentType, evt.Content.Parsed)) } switch typedContent := evt.Content.Parsed.(type) { case *event.RoomNameEventContent: if typedContent.Name == portal.Name { portal.sendSuccessStatus(ctx, evt, 0, "") - return + return EventHandlingResultIgnored } case *event.TopicEventContent: if typedContent.Topic == portal.Topic { portal.sendSuccessStatus(ctx, evt, 0, "") - return + return EventHandlingResultIgnored } case *event.RoomAvatarEventContent: if typedContent.URL == portal.AvatarMXC { portal.sendSuccessStatus(ctx, evt, 0, "") - return + return EventHandlingResultIgnored + } + case *event.BeeperDisappearingTimer: + if typedContent.Type == event.DisappearingTypeNone || typedContent.Timer.Duration <= 0 { + typedContent.Type = event.DisappearingTypeNone + typedContent.Timer.Duration = 0 + } + if typedContent.Type == portal.Disappear.Type && typedContent.Timer.Duration == portal.Disappear.Timer { + portal.sendSuccessStatus(ctx, evt, 0, "") + return EventHandlingResultIgnored + } + if !sender.Client.GetCapabilities(ctx, portal).DisappearingTimer.Supports(typedContent) { + return EventHandlingResultFailed.WithMSSError(ErrDisappearingTimerUnsupported) } } var prevContent ContentType @@ -1290,37 +1762,41 @@ func handleMatrixRoomMeta[APIType any, ContentType any]( Content: content, Portal: portal, OrigSender: origSender, + + InputTransactionID: portal.parseInputTransactionID(origSender, evt), }, - PrevContent: prevContent, + IsStateRequest: isStateRequest, + PrevContent: prevContent, }) if err != nil { log.Err(err).Msg("Failed to handle Matrix room metadata") - portal.sendErrorStatus(ctx, evt, err) - return + return EventHandlingResultFailed.WithMSSError(err) } if changed { - portal.UpdateBridgeInfo(ctx) + if evt.Type != event.StateBeeperDisappearingTimer { + portal.UpdateBridgeInfo(ctx) + } err = portal.Save(ctx) if err != nil { log.Err(err).Msg("Failed to save portal after updating room metadata") } } - portal.sendSuccessStatus(ctx, evt, 0, "") + return EventHandlingResultSuccess.WithMSS() } func handleMatrixAccountData[APIType any, ContentType any]( portal *Portal, ctx context.Context, sender *UserLogin, evt *event.Event, fn func(APIType, context.Context, *MatrixRoomMeta[ContentType]) error, -) { +) EventHandlingResult { api, ok := sender.Client.(APIType) if !ok { - return + return EventHandlingResultIgnored } log := zerolog.Ctx(ctx) content, ok := evt.Content.Parsed.(ContentType) if !ok { log.Error().Type("content_type", evt.Content.Parsed).Msg("Unexpected parsed content type") - return + return EventHandlingResultFailed.WithError(fmt.Errorf("%w: %T", ErrUnexpectedParsedContentType, evt.Content.Parsed)) } var prevContent ContentType if evt.Unsigned.PrevContent != nil { @@ -1338,7 +1814,9 @@ func handleMatrixAccountData[APIType any, ContentType any]( }) if err != nil { log.Err(err).Msg("Failed to handle Matrix room account data") + return EventHandlingResultFailed.WithError(err) } + return EventHandlingResultSuccess } func (portal *Portal) getTargetUser(ctx context.Context, userID id.UserID) (GhostOrUserLogin, error) { @@ -1358,18 +1836,144 @@ func (portal *Portal) getTargetUser(ctx context.Context, userID id.UserID) (Ghos } } +func (portal *Portal) handleMatrixAcceptMessageRequest( + ctx context.Context, + sender *UserLogin, + origSender *OrigSender, + evt *event.Event, +) EventHandlingResult { + if origSender != nil { + return EventHandlingResultFailed.WithMSSError(ErrIgnoringAcceptRequestRelayedUser) + } + log := zerolog.Ctx(ctx) + content, ok := evt.Content.Parsed.(*event.BeeperAcceptMessageRequestEventContent) + if !ok { + log.Error().Type("content_type", evt.Content.Parsed).Msg("Unexpected parsed content type") + return EventHandlingResultFailed.WithMSSError(fmt.Errorf("%w: %T", ErrUnexpectedParsedContentType, evt.Content.Parsed)) + } + api, ok := sender.Client.(MessageRequestAcceptingNetworkAPI) + if !ok { + return EventHandlingResultIgnored.WithMSSError(ErrDeleteChatNotSupported) + } + err := api.HandleMatrixAcceptMessageRequest(ctx, &MatrixAcceptMessageRequest{ + Event: evt, + Content: content, + Portal: portal, + }) + if err != nil { + log.Err(err).Msg("Failed to handle Matrix accept message request") + return EventHandlingResultFailed.WithMSSError(err) + } + if portal.MessageRequest { + portal.MessageRequest = false + portal.UpdateBridgeInfo(ctx) + err = portal.Save(ctx) + if err != nil { + log.Err(err).Msg("Failed to save portal after accepting message request") + } + } + return EventHandlingResultSuccess.WithMSS() +} + +func (portal *Portal) autoAcceptMessageRequest( + ctx context.Context, evt *event.Event, sender *UserLogin, origSender *OrigSender, caps *event.RoomFeatures, +) error { + if !portal.MessageRequest || caps.MessageRequest == nil || caps.MessageRequest.AcceptWithMessage == event.CapLevelFullySupported { + return nil + } + mran, ok := sender.Client.(MessageRequestAcceptingNetworkAPI) + if !ok { + return nil + } + err := mran.HandleMatrixAcceptMessageRequest(ctx, &MatrixAcceptMessageRequest{ + Event: evt, + Content: &event.BeeperAcceptMessageRequestEventContent{ + IsImplicit: true, + }, + Portal: portal, + OrigSender: origSender, + }) + if err != nil { + return err + } + if portal.MessageRequest { + portal.MessageRequest = false + portal.UpdateBridgeInfo(ctx) + err = portal.Save(ctx) + if err != nil { + zerolog.Ctx(ctx).Err(err).Msg("Failed to save portal after accepting message request") + } + } + return nil +} + +func (portal *Portal) handleMatrixDeleteChat( + ctx context.Context, + sender *UserLogin, + origSender *OrigSender, + evt *event.Event, +) EventHandlingResult { + if origSender != nil { + return EventHandlingResultFailed.WithMSSError(ErrIgnoringDeleteChatRelayedUser) + } + log := zerolog.Ctx(ctx) + content, ok := evt.Content.Parsed.(*event.BeeperChatDeleteEventContent) + if !ok { + log.Error().Type("content_type", evt.Content.Parsed).Msg("Unexpected parsed content type") + return EventHandlingResultFailed.WithMSSError(fmt.Errorf("%w: %T", ErrUnexpectedParsedContentType, evt.Content.Parsed)) + } + api, ok := sender.Client.(DeleteChatHandlingNetworkAPI) + if !ok { + return EventHandlingResultIgnored.WithMSSError(ErrDeleteChatNotSupported) + } + err := api.HandleMatrixDeleteChat(ctx, &MatrixDeleteChat{ + Event: evt, + Content: content, + Portal: portal, + }) + if err != nil { + log.Err(err).Msg("Failed to handle Matrix chat delete") + return EventHandlingResultFailed.WithMSSError(err) + } + if portal.Receiver == "" { + _, others, err := portal.findOtherLogins(ctx, sender) + if err != nil { + log.Err(err).Msg("Failed to check if portal has other logins") + return EventHandlingResultFailed.WithError(err) + } else if len(others) > 0 { + log.Debug().Msg("Not deleting portal after chat delete as other logins are present") + return EventHandlingResultSuccess + } + } + err = portal.Delete(ctx) + if err != nil { + log.Err(err).Msg("Failed to delete portal from database") + return EventHandlingResultFailed.WithMSSError(err) + } + err = portal.Bridge.Bot.DeleteRoom(ctx, portal.MXID, false) + if err != nil { + log.Err(err).Msg("Failed to delete Matrix room") + return EventHandlingResultFailed.WithMSSError(err) + } + // No MSS here as the portal was deleted + return EventHandlingResultSuccess +} + func (portal *Portal) handleMatrixMembership( ctx context.Context, sender *UserLogin, origSender *OrigSender, evt *event.Event, -) { + isStateRequest bool, +) EventHandlingResult { + if evt.StateKey == nil { + return EventHandlingResultFailed.WithMSSError(ErrInvalidStateKey) + } log := zerolog.Ctx(ctx) content, ok := evt.Content.Parsed.(*event.MemberEventContent) if !ok { log.Error().Type("content_type", evt.Content.Parsed).Msg("Unexpected parsed content type") - portal.sendErrorStatus(ctx, evt, fmt.Errorf("%w: %T", ErrUnexpectedParsedContentType, evt.Content.Parsed)) - return + return EventHandlingResultFailed.WithMSSError(fmt.Errorf("%w: %T", ErrUnexpectedParsedContentType, evt.Content.Parsed)) } prevContent := &event.MemberEventContent{Membership: event.MembershipLeave} if evt.Unsigned.PrevContent != nil { @@ -1384,26 +1988,22 @@ func (portal *Portal) handleMatrixMembership( }) api, ok := sender.Client.(MembershipHandlingNetworkAPI) if !ok { - portal.sendErrorStatus(ctx, evt, ErrMembershipNotSupported) - return + return EventHandlingResultIgnored.WithMSSError(ErrMembershipNotSupported) } targetMXID := id.UserID(*evt.StateKey) isSelf := sender.User.MXID == targetMXID target, err := portal.getTargetUser(ctx, targetMXID) if err != nil { log.Err(err).Msg("Failed to get member event target") - portal.sendErrorStatus(ctx, evt, err) - return + return EventHandlingResultFailed.WithMSSError(err) } membershipChangeType := MembershipChangeType{From: prevContent.Membership, To: content.Membership, IsSelf: isSelf} if !portal.Bridge.Config.BridgeMatrixLeave && membershipChangeType == Leave { log.Debug().Msg("Dropping leave event") - //portal.sendErrorStatus(ctx, evt, ErrIgnoringLeaveEvent) - return + return EventHandlingResultIgnored //.WithMSSError(ErrIgnoringLeaveEvent) } targetGhost, _ := target.(*Ghost) - targetUserLogin, _ := target.(*UserLogin) membershipChange := &MatrixMembershipChange{ MatrixRoomMeta: MatrixRoomMeta[*event.MemberEventContent]{ MatrixEventBase: MatrixEventBase[*event.MemberEventContent]{ @@ -1411,20 +2011,63 @@ func (portal *Portal) handleMatrixMembership( Content: content, Portal: portal, OrigSender: origSender, + + InputTransactionID: portal.parseInputTransactionID(origSender, evt), }, - PrevContent: prevContent, + IsStateRequest: isStateRequest, + PrevContent: prevContent, }, - Target: target, - TargetGhost: targetGhost, - TargetUserLogin: targetUserLogin, - Type: membershipChangeType, + Target: target, + Type: membershipChangeType, } - _, err = api.HandleMatrixMembership(ctx, membershipChange) + res, err := api.HandleMatrixMembership(ctx, membershipChange) if err != nil { log.Err(err).Msg("Failed to handle Matrix membership change") - portal.sendErrorStatus(ctx, evt, err) - return + return EventHandlingResultFailed.WithMSSError(err) } + didRedirectInvite := membershipChangeType == Invite && + targetGhost != nil && + res != nil && + res.RedirectTo != "" && + res.RedirectTo != targetGhost.ID + if didRedirectInvite { + log.Debug(). + Str("orig_id", string(targetGhost.ID)). + Str("redirect_id", string(res.RedirectTo)). + Msg("Invite was redirected to different ghost") + var redirectGhost *Ghost + redirectGhost, err = portal.Bridge.GetGhostByID(ctx, res.RedirectTo) + if err != nil { + log.Err(err).Msg("Failed to get redirect target ghost") + return EventHandlingResultFailed.WithError(err) + } + if !isStateRequest { + portal.sendRoomMeta( + ctx, + sender.User.DoublePuppet(ctx), + time.UnixMilli(evt.Timestamp), + event.StateMember, + evt.GetStateKey(), + &event.MemberEventContent{ + Membership: event.MembershipLeave, + Reason: fmt.Sprintf("Invite redirected to %s", res.RedirectTo), + }, + true, + nil, + ) + } + portal.sendRoomMeta( + ctx, + sender.User.DoublePuppet(ctx), + time.UnixMilli(evt.Timestamp), + event.StateMember, + redirectGhost.Intent.GetMXID().String(), + content, + false, + nil, + ) + } + return EventHandlingResultSuccess.WithMSS().WithSkipStateEcho(didRedirectInvite) } func makePLChange(old, new int, newIsSet bool) *SinglePowerLevelChange { @@ -1449,23 +2092,36 @@ func (portal *Portal) handleMatrixPowerLevels( sender *UserLogin, origSender *OrigSender, evt *event.Event, -) { + isStateRequest bool, +) EventHandlingResult { + if evt.StateKey == nil || *evt.StateKey != "" { + return EventHandlingResultFailed.WithMSSError(ErrInvalidStateKey) + } log := zerolog.Ctx(ctx) content, ok := evt.Content.Parsed.(*event.PowerLevelsEventContent) if !ok { log.Error().Type("content_type", evt.Content.Parsed).Msg("Unexpected parsed content type") - portal.sendErrorStatus(ctx, evt, fmt.Errorf("%w: %T", ErrUnexpectedParsedContentType, evt.Content.Parsed)) - return + return EventHandlingResultFailed.WithMSSError(fmt.Errorf("%w: %T", ErrUnexpectedParsedContentType, evt.Content.Parsed)) + } + if content.CreateEvent == nil { + ars, ok := portal.Bridge.Matrix.(MatrixConnectorWithArbitraryRoomState) + if ok { + var err error + content.CreateEvent, err = ars.GetStateEvent(ctx, portal.MXID, event.StateCreate, "") + if err != nil { + return EventHandlingResultFailed.WithMSSError(fmt.Errorf("failed to get create event for power levels: %w", err)) + } + } } api, ok := sender.Client.(PowerLevelHandlingNetworkAPI) if !ok { - portal.sendErrorStatus(ctx, evt, ErrPowerLevelsNotSupported) - return + return EventHandlingResultIgnored.WithMSSError(ErrPowerLevelsNotSupported) } prevContent := &event.PowerLevelsEventContent{} if evt.Unsigned.PrevContent != nil { _ = evt.Unsigned.PrevContent.ParseRaw(evt.Type) prevContent, _ = evt.Unsigned.PrevContent.Parsed.(*event.PowerLevelsEventContent) + prevContent.CreateEvent = content.CreateEvent } plChange := &MatrixPowerLevelChange{ @@ -1475,8 +2131,11 @@ func (portal *Portal) handleMatrixPowerLevels( Content: content, Portal: portal, OrigSender: origSender, + + InputTransactionID: portal.parseInputTransactionID(origSender, evt), }, - PrevContent: prevContent, + IsStateRequest: isStateRequest, + PrevContent: prevContent, }, Users: make(map[id.UserID]*UserPowerLevelChange), Events: make(map[string]*SinglePowerLevelChange), @@ -1517,18 +2176,269 @@ func (portal *Portal) handleMatrixPowerLevels( _, err := api.HandleMatrixPowerLevels(ctx, plChange) if err != nil { log.Err(err).Msg("Failed to handle Matrix power level change") - portal.sendErrorStatus(ctx, evt, err) - return + return EventHandlingResultFailed.WithMSSError(err) } + return EventHandlingResultSuccess.WithMSS() } -func (portal *Portal) handleMatrixRedaction(ctx context.Context, sender *UserLogin, origSender *OrigSender, evt *event.Event) { +func (portal *Portal) handleMatrixTombstone(ctx context.Context, evt *event.Event) EventHandlingResult { + if evt.StateKey == nil || *evt.StateKey != "" || portal.MXID != evt.RoomID { + return EventHandlingResultIgnored + } + log := *zerolog.Ctx(ctx) + sentByBridge := evt.Sender == portal.Bridge.Bot.GetMXID() || portal.Bridge.IsGhostMXID(evt.Sender) + var senderUser *User + var err error + if !sentByBridge { + senderUser, err = portal.Bridge.GetUserByMXID(ctx, evt.Sender) + if err != nil { + log.Err(err).Msg("Failed to get tombstone sender user") + return EventHandlingResultFailed.WithError(err) + } + } + content, ok := evt.Content.Parsed.(*event.TombstoneEventContent) + if !ok { + log.Error().Type("content_type", evt.Content.Parsed).Msg("Unexpected parsed content type") + return EventHandlingResultFailed.WithMSSError(fmt.Errorf("%w: %T", ErrUnexpectedParsedContentType, evt.Content.Parsed)) + } + log = log.With(). + Stringer("replacement_room", content.ReplacementRoom). + Logger() + if content.ReplacementRoom == "" { + log.Info().Msg("Received tombstone with no replacement room, cleaning up portal") + err := portal.RemoveMXID(ctx) + if err != nil { + log.Err(err).Msg("Failed to remove portal MXID") + return EventHandlingResultFailed.WithMSSError(err) + } + err = portal.Bridge.Bot.DeleteRoom(ctx, portal.MXID, true) + if err != nil { + log.Err(err).Msg("Failed to clean up Matrix room") + return EventHandlingResultFailed.WithError(err) + } + return EventHandlingResultSuccess + } + existingMemberEvt, err := portal.Bridge.Matrix.GetMemberInfo(ctx, content.ReplacementRoom, portal.Bridge.Bot.GetMXID()) + if err != nil { + log.Err(err).Msg("Failed to get member info of bot in replacement room") + return EventHandlingResultFailed.WithError(err) + } + leaveOnError := func() { + if existingMemberEvt != nil && existingMemberEvt.Membership == event.MembershipJoin { + return + } + log.Debug().Msg("Leaving replacement room with bot after tombstone validation failed") + _, err = portal.Bridge.Bot.SendState( + ctx, + content.ReplacementRoom, + event.StateMember, + portal.Bridge.Bot.GetMXID().String(), + &event.Content{ + Parsed: &event.MemberEventContent{ + Membership: event.MembershipLeave, + Reason: fmt.Sprintf("Failed to validate tombstone sent by %s from %s", evt.Sender, evt.RoomID), + }, + }, + time.Time{}, + ) + if err != nil { + log.Err(err).Msg("Failed to leave replacement room after tombstone validation failed") + } + } + var via []string + if senderHS := evt.Sender.Homeserver(); senderHS != "" { + via = []string{senderHS} + } + err = portal.Bridge.Bot.EnsureJoined(ctx, content.ReplacementRoom, EnsureJoinedParams{Via: via}) + if err != nil { + log.Err(err).Msg("Failed to join replacement room from tombstone") + return EventHandlingResultFailed.WithError(err) + } + if !sentByBridge && !senderUser.Permissions.Admin { + powers, err := portal.Bridge.Matrix.GetPowerLevels(ctx, content.ReplacementRoom) + if err != nil { + log.Err(err).Msg("Failed to get power levels in replacement room") + leaveOnError() + return EventHandlingResultFailed.WithError(err) + } + if powers.GetUserLevel(evt.Sender) < powers.Invite() { + log.Warn().Msg("Tombstone sender doesn't have enough power to invite the bot to the replacement room") + leaveOnError() + return EventHandlingResultIgnored + } + } + err = portal.UpdateMatrixRoomID(ctx, content.ReplacementRoom, UpdateMatrixRoomIDParams{ + DeleteOldRoom: true, + FetchInfoVia: senderUser, + }) + if errors.Is(err, ErrTargetRoomIsPortal) { + return EventHandlingResultIgnored + } else if err != nil { + return EventHandlingResultFailed.WithError(err) + } + return EventHandlingResultSuccess +} + +var ErrTargetRoomIsPortal = errors.New("target room is already a portal") +var ErrRoomAlreadyExists = errors.New("this portal already has a room") + +type UpdateMatrixRoomIDParams struct { + SyncDBMetadata func() + FailIfMXIDSet bool + OverwriteOldPortal bool + TombstoneOldRoom bool + DeleteOldRoom bool + + RoomCreateAlreadyLocked bool + + FetchInfoVia *User + ChatInfo *ChatInfo + ChatInfoSource *UserLogin +} + +func (portal *Portal) UpdateMatrixRoomID( + ctx context.Context, + newRoomID id.RoomID, + params UpdateMatrixRoomIDParams, +) error { + if !params.RoomCreateAlreadyLocked { + portal.roomCreateLock.Lock() + defer portal.roomCreateLock.Unlock() + } + oldRoom := portal.MXID + if oldRoom == newRoomID { + return nil + } else if oldRoom != "" && params.FailIfMXIDSet { + return ErrRoomAlreadyExists + } + log := zerolog.Ctx(ctx) + portal.Bridge.cacheLock.Lock() + // Wrap unlock in a sync.OnceFunc because we want to both defer it to catch early returns + // and unlock it before return if nothing goes wrong. + unlockCacheLock := sync.OnceFunc(portal.Bridge.cacheLock.Unlock) + defer unlockCacheLock() + if existingPortal, alreadyExists := portal.Bridge.portalsByMXID[newRoomID]; alreadyExists && !params.OverwriteOldPortal { + log.Warn().Msg("Replacement room is already a portal, ignoring") + return ErrTargetRoomIsPortal + } else if alreadyExists { + log.Debug().Msg("Replacement room is already a portal, overwriting") + existingPortal.MXID = "" + existingPortal.RoomCreated.Clear() + err := existingPortal.Save(ctx) + if err != nil { + return fmt.Errorf("failed to clear mxid of existing portal: %w", err) + } + delete(portal.Bridge.portalsByMXID, portal.MXID) + } + portal.MXID = newRoomID + portal.RoomCreated.Set() + portal.Bridge.portalsByMXID[portal.MXID] = portal + portal.NameSet = false + portal.AvatarSet = false + portal.TopicSet = false + portal.InSpace = false + portal.CapState = database.CapabilityState{} + portal.lastCapUpdate = time.Time{} + if params.SyncDBMetadata != nil { + params.SyncDBMetadata() + } + unlockCacheLock() + portal.updateLogger() + + err := portal.Save(ctx) + if err != nil { + log.Err(err).Msg("Failed to save portal in UpdateMatrixRoomID") + return err + } + log.Info().Msg("Successfully followed tombstone and updated portal MXID") + err = portal.Bridge.DB.UserPortal.MarkAllNotInSpace(ctx, portal.PortalKey) + if err != nil { + log.Err(err).Msg("Failed to update in_space flag for user portals after updating portal MXID") + } + go portal.addToUserSpaces(ctx) + if params.FetchInfoVia != nil { + go portal.updateInfoAfterTombstone(ctx, params.FetchInfoVia) + } else if params.ChatInfo != nil { + go portal.UpdateInfo(ctx, params.ChatInfo, params.ChatInfoSource, nil, time.Time{}) + } else if params.ChatInfoSource != nil { + portal.UpdateCapabilities(ctx, params.ChatInfoSource, true) + portal.UpdateBridgeInfo(ctx) + } + go func() { + // TODO this might become unnecessary if UpdateInfo starts taking care of it + _, err = portal.Bridge.Bot.SendState(ctx, portal.MXID, event.StateElementFunctionalMembers, "", &event.Content{ + Parsed: &event.ElementFunctionalMembersContent{ + ServiceMembers: []id.UserID{portal.Bridge.Bot.GetMXID()}, + }, + }, time.Time{}) + if err != nil { + if err != nil { + log.Warn().Err(err).Msg("Failed to set service members in new room") + } + } + }() + if params.TombstoneOldRoom && oldRoom != "" { + _, err = portal.Bridge.Bot.SendState(ctx, portal.MXID, event.StateTombstone, "", &event.Content{ + Parsed: &event.TombstoneEventContent{ + Body: "Room has been replaced.", + ReplacementRoom: newRoomID, + }, + }, time.Now()) + if err != nil { + log.Err(err).Msg("Failed to send tombstone event to old room") + } + } + if params.DeleteOldRoom && oldRoom != "" { + go func() { + err = portal.Bridge.Bot.DeleteRoom(ctx, oldRoom, true) + if err != nil { + log.Err(err).Msg("Failed to clean up old Matrix room after updating portal MXID") + } + }() + } + return nil +} + +func (portal *Portal) updateInfoAfterTombstone(ctx context.Context, senderUser *User) { + log := zerolog.Ctx(ctx) + logins, err := portal.Bridge.GetUserLoginsInPortal(ctx, portal.PortalKey) + if err != nil { + log.Err(err).Msg("Failed to get user logins in portal to sync info") + return + } + var preferredLogin *UserLogin + for _, login := range logins { + if !login.Client.IsLoggedIn() { + continue + } else if preferredLogin == nil { + preferredLogin = login + } else if senderUser != nil && login.User == senderUser { + preferredLogin = login + } + } + if preferredLogin == nil { + log.Warn().Msg("No logins found to sync info") + return + } + info, err := preferredLogin.Client.GetChatInfo(ctx, portal) + if err != nil { + log.Err(err).Msg("Failed to get chat info") + return + } + log.Info(). + Str("info_source_login", string(preferredLogin.ID)). + Msg("Fetched info to update portal after tombstone") + portal.UpdateInfo(ctx, info, preferredLogin, nil, time.Time{}) +} + +func (portal *Portal) handleMatrixRedaction( + ctx context.Context, sender *UserLogin, origSender *OrigSender, evt *event.Event, +) EventHandlingResult { log := zerolog.Ctx(ctx) content, ok := evt.Content.Parsed.(*event.RedactionEventContent) if !ok { log.Error().Type("content_type", evt.Content.Parsed).Msg("Unexpected parsed content type") - portal.sendErrorStatus(ctx, evt, fmt.Errorf("%w: %T", ErrUnexpectedParsedContentType, evt.Content.Parsed)) - return + return EventHandlingResultFailed.WithMSSError(fmt.Errorf("%w: %T", ErrUnexpectedParsedContentType, evt.Content.Parsed)) } if evt.Redacts != "" && content.Redacts != evt.Redacts { content.Redacts = evt.Redacts @@ -1540,20 +2450,17 @@ func (portal *Portal) handleMatrixRedaction(ctx context.Context, sender *UserLog reactingAPI, reactOK := sender.Client.(ReactionHandlingNetworkAPI) if !deleteOK && !reactOK { log.Debug().Msg("Ignoring redaction without checking target as network connector doesn't implement RedactionHandlingNetworkAPI nor ReactionHandlingNetworkAPI") - portal.sendErrorStatus(ctx, evt, ErrRedactionsNotSupported) - return + return EventHandlingResultIgnored.WithMSSError(ErrRedactionsNotSupported) } var redactionTargetReaction *database.Reaction redactionTargetMsg, err := portal.Bridge.DB.Message.GetPartByMXID(ctx, content.Redacts) if err != nil { log.Err(err).Msg("Failed to get redaction target message from database") - portal.sendErrorStatus(ctx, evt, fmt.Errorf("%w: failed to get redaction target message: %w", ErrDatabaseError, err)) - return + return EventHandlingResultFailed.WithMSSError(fmt.Errorf("%w: failed to get redaction target message: %w", ErrDatabaseError, err)) } else if redactionTargetMsg != nil { if !deleteOK { log.Debug().Msg("Ignoring message redaction event as network connector doesn't implement RedactionHandlingNetworkAPI") - portal.sendErrorStatus(ctx, evt, ErrRedactionsNotSupported) - return + return EventHandlingResultIgnored.WithMSSError(ErrRedactionsNotSupported) } err = deletingAPI.HandleMatrixMessageRemove(ctx, &MatrixMessageRemove{ MatrixEventBase: MatrixEventBase[*event.RedactionEventContent]{ @@ -1561,18 +2468,18 @@ func (portal *Portal) handleMatrixRedaction(ctx context.Context, sender *UserLog Content: content, Portal: portal, OrigSender: origSender, + + InputTransactionID: portal.parseInputTransactionID(origSender, evt), }, TargetMessage: redactionTargetMsg, }) } else if redactionTargetReaction, err = portal.Bridge.DB.Reaction.GetByMXID(ctx, content.Redacts); err != nil { log.Err(err).Msg("Failed to get redaction target reaction from database") - portal.sendErrorStatus(ctx, evt, fmt.Errorf("%w: failed to get redaction target message reaction: %w", ErrDatabaseError, err)) - return + return EventHandlingResultFailed.WithMSSError(fmt.Errorf("%w: failed to get redaction target message reaction: %w", ErrDatabaseError, err)) } else if redactionTargetReaction != nil { if !reactOK { log.Debug().Msg("Ignoring reaction redaction event as network connector doesn't implement ReactionHandlingNetworkAPI") - portal.sendErrorStatus(ctx, evt, ErrReactionsNotSupported) - return + return EventHandlingResultIgnored.WithMSSError(ErrReactionsNotSupported) } // TODO ignore if sender doesn't match? err = reactingAPI.HandleMatrixReactionRemove(ctx, &MatrixReactionRemove{ @@ -1581,30 +2488,30 @@ func (portal *Portal) handleMatrixRedaction(ctx context.Context, sender *UserLog Content: content, Portal: portal, OrigSender: origSender, + + InputTransactionID: portal.parseInputTransactionID(origSender, evt), }, TargetReaction: redactionTargetReaction, }) } else { log.Debug().Msg("Redaction target message not found in database") - portal.sendErrorStatus(ctx, evt, fmt.Errorf("redaction %w", ErrTargetMessageNotFound)) - return + return EventHandlingResultIgnored.WithMSSError(fmt.Errorf("redaction %w", ErrTargetMessageNotFound)) } if err != nil { log.Err(err).Msg("Failed to handle Matrix redaction") - portal.sendErrorStatus(ctx, evt, err) - return + return EventHandlingResultFailed.WithMSSError(err) } // TODO delete msg/reaction db row - portal.sendSuccessStatus(ctx, evt, 0, "") + return EventHandlingResultSuccess.WithMSS() } -func (portal *Portal) handleRemoteEvent(ctx context.Context, source *UserLogin, evtType RemoteEventType, evt RemoteEvent) { +func (portal *Portal) handleRemoteEvent(ctx context.Context, source *UserLogin, evtType RemoteEventType, evt RemoteEvent) (res EventHandlingResult) { log := zerolog.Ctx(ctx) if portal.MXID == "" { mcp, ok := evt.(RemoteEventThatMayCreatePortal) if !ok || !mcp.ShouldCreatePortal() { log.Debug().Msg("Dropping event as portal doesn't exist") - return + return EventHandlingResultIgnored } infoProvider, ok := mcp.(RemoteChatResyncWithInfo) var info *ChatInfo @@ -1623,8 +2530,7 @@ func (portal *Portal) handleRemoteEvent(ctx context.Context, source *UserLogin, err = portal.createMatrixRoomInLoop(ctx, source, info, bundle) if err != nil { log.Err(err).Msg("Failed to create portal to handle event") - // TODO error - return + return EventHandlingResultFailed.WithError(err) } if evtType == RemoteEventChatResync { log.Debug().Msg("Not handling chat resync event further as portal was created by it") @@ -1632,7 +2538,7 @@ func (portal *Portal) handleRemoteEvent(ctx context.Context, source *UserLogin, if ok { postHandler.PostHandle(ctx, portal) } - return + return EventHandlingResultSuccess } } preHandler, ok := evt.(RemotePreHandler) @@ -1643,34 +2549,35 @@ func (portal *Portal) handleRemoteEvent(ctx context.Context, source *UserLogin, switch evtType { case RemoteEventUnknown: log.Debug().Msg("Ignoring remote event with type unknown") + res = EventHandlingResultIgnored case RemoteEventMessage, RemoteEventMessageUpsert: - portal.handleRemoteMessage(ctx, source, evt.(RemoteMessage)) + res = portal.handleRemoteMessage(ctx, source, evt.(RemoteMessage)) case RemoteEventEdit: - portal.handleRemoteEdit(ctx, source, evt.(RemoteEdit)) + res = portal.handleRemoteEdit(ctx, source, evt.(RemoteEdit)) case RemoteEventReaction: - portal.handleRemoteReaction(ctx, source, evt.(RemoteReaction)) + res = portal.handleRemoteReaction(ctx, source, evt.(RemoteReaction)) case RemoteEventReactionRemove: - portal.handleRemoteReactionRemove(ctx, source, evt.(RemoteReactionRemove)) + res = portal.handleRemoteReactionRemove(ctx, source, evt.(RemoteReactionRemove)) case RemoteEventReactionSync: - portal.handleRemoteReactionSync(ctx, source, evt.(RemoteReactionSync)) + res = portal.handleRemoteReactionSync(ctx, source, evt.(RemoteReactionSync)) case RemoteEventMessageRemove: - portal.handleRemoteMessageRemove(ctx, source, evt.(RemoteMessageRemove)) + res = portal.handleRemoteMessageRemove(ctx, source, evt.(RemoteMessageRemove)) case RemoteEventReadReceipt: - portal.handleRemoteReadReceipt(ctx, source, evt.(RemoteReadReceipt)) + res = portal.handleRemoteReadReceipt(ctx, source, evt.(RemoteReadReceipt)) case RemoteEventMarkUnread: - portal.handleRemoteMarkUnread(ctx, source, evt.(RemoteMarkUnread)) + res = portal.handleRemoteMarkUnread(ctx, source, evt.(RemoteMarkUnread)) case RemoteEventDeliveryReceipt: - portal.handleRemoteDeliveryReceipt(ctx, source, evt.(RemoteDeliveryReceipt)) + res = portal.handleRemoteDeliveryReceipt(ctx, source, evt.(RemoteDeliveryReceipt)) case RemoteEventTyping: - portal.handleRemoteTyping(ctx, source, evt.(RemoteTyping)) + res = portal.handleRemoteTyping(ctx, source, evt.(RemoteTyping)) case RemoteEventChatInfoChange: - portal.handleRemoteChatInfoChange(ctx, source, evt.(RemoteChatInfoChange)) + res = portal.handleRemoteChatInfoChange(ctx, source, evt.(RemoteChatInfoChange)) case RemoteEventChatResync: - portal.handleRemoteChatResync(ctx, source, evt.(RemoteChatResync)) + res = portal.handleRemoteChatResync(ctx, source, evt.(RemoteChatResync)) case RemoteEventChatDelete: - portal.handleRemoteChatDelete(ctx, source, evt.(RemoteChatDelete)) + res = portal.handleRemoteChatDelete(ctx, source, evt.(RemoteChatDelete)) case RemoteEventBackfill: - portal.handleRemoteBackfill(ctx, source, evt.(RemoteBackfill)) + res = portal.handleRemoteBackfill(ctx, source, evt.(RemoteBackfill)) default: log.Warn().Msg("Got remote event with unknown type") } @@ -1678,9 +2585,50 @@ func (portal *Portal) handleRemoteEvent(ctx context.Context, source *UserLogin, if ok { postHandler.PostHandle(ctx, portal) } + return } -func (portal *Portal) getIntentAndUserMXIDFor(ctx context.Context, sender EventSender, source *UserLogin, otherLogins []*UserLogin, evtType RemoteEventType) (intent MatrixAPI, extraUserID id.UserID) { +func (portal *Portal) ensureFunctionalMember(ctx context.Context, ghost *Ghost) { + if !ghost.IsBot || portal.RoomType != database.RoomTypeDM || portal.OtherUserID == ghost.ID || portal.MXID == "" { + return + } + ars, ok := portal.Bridge.Matrix.(MatrixConnectorWithArbitraryRoomState) + if !ok { + return + } + portal.functionalMembersLock.Lock() + defer portal.functionalMembersLock.Unlock() + var functionalMembers *event.ElementFunctionalMembersContent + if portal.functionalMembersCache != nil { + functionalMembers = portal.functionalMembersCache + } else { + evt, err := ars.GetStateEvent(ctx, portal.MXID, event.StateElementFunctionalMembers, "") + if err != nil && !errors.Is(err, mautrix.MNotFound) { + zerolog.Ctx(ctx).Err(err).Msg("Failed to get functional members state event") + return + } + functionalMembers = &event.ElementFunctionalMembersContent{} + if evt != nil { + evtContent, ok := evt.Content.Parsed.(*event.ElementFunctionalMembersContent) + if ok && evtContent != nil { + functionalMembers = evtContent + } + } + } + // TODO what about non-double-puppeted user ghosts? + functionalMembers.Add(portal.Bridge.Bot.GetMXID()) + if functionalMembers.Add(ghost.Intent.GetMXID()) { + _, err := portal.Bridge.Bot.SendState(ctx, portal.MXID, event.StateElementFunctionalMembers, "", &event.Content{ + Parsed: functionalMembers, + }, time.Time{}) + if err != nil { + zerolog.Ctx(ctx).Err(err).Msg("Failed to update functional members state event") + return + } + } +} + +func (portal *Portal) getIntentAndUserMXIDFor(ctx context.Context, sender EventSender, source *UserLogin, otherLogins []*UserLogin, evtType RemoteEventType) (intent MatrixAPI, extraUserID id.UserID, err error) { var ghost *Ghost if !sender.IsFromMe && sender.ForceDMUser && portal.OtherUserID != "" && sender.Sender != portal.OtherUserID { zerolog.Ctx(ctx).Warn(). @@ -1689,21 +2637,21 @@ func (portal *Portal) getIntentAndUserMXIDFor(ctx context.Context, sender EventS Msg("Overriding event sender with primary other user in DM portal") // Ensure the ghost row exists anyway to prevent foreign key errors when saving messages // TODO it'd probably be better to override the sender in the saved message, but that's more effort - _, err := portal.Bridge.GetGhostByID(ctx, sender.Sender) + _, err = portal.Bridge.GetGhostByID(ctx, sender.Sender) if err != nil { zerolog.Ctx(ctx).Warn().Err(err).Msg("Failed to get ghost with original user ID") + return } sender.Sender = portal.OtherUserID } if sender.Sender != "" { - var err error ghost, err = portal.Bridge.GetGhostByID(ctx, sender.Sender) if err != nil { zerolog.Ctx(ctx).Err(err).Msg("Failed to get ghost for message sender") return - } else { - ghost.UpdateInfoIfNecessary(ctx, source, evtType) } + ghost.UpdateInfoIfNecessary(ctx, source, evtType) + portal.ensureFunctionalMember(ctx, ghost) } if sender.IsFromMe { intent = source.User.DoublePuppet(ctx) @@ -1738,58 +2686,90 @@ func (portal *Portal) getIntentAndUserMXIDFor(ctx context.Context, sender EventS return } -func (portal *Portal) GetIntentFor(ctx context.Context, sender EventSender, source *UserLogin, evtType RemoteEventType) MatrixAPI { - intent, _ := portal.getIntentAndUserMXIDFor(ctx, sender, source, nil, evtType) +func (portal *Portal) GetIntentFor(ctx context.Context, sender EventSender, source *UserLogin, evtType RemoteEventType) (MatrixAPI, bool) { + intent, _, err := portal.getIntentAndUserMXIDFor(ctx, sender, source, nil, evtType) + if err != nil { + return nil, false + } if intent == nil { // TODO this is very hacky - we should either insert an empty ghost row automatically // (and not fetch it at runtime) or make the message sender column nullable. portal.Bridge.GetGhostByID(ctx, "") intent = portal.Bridge.Bot + if intent == nil { + panic(fmt.Errorf("bridge bot is nil")) + } } - return intent + return intent, true } -func (portal *Portal) getRelationMeta(ctx context.Context, currentMsg networkid.MessageID, replyToPtr *networkid.MessageOptionalPartID, threadRootPtr *networkid.MessageID, isBatchSend bool) (replyTo, threadRoot, prevThreadEvent *database.Message) { +func (portal *Portal) getRelationMeta( + ctx context.Context, + currentMsgID networkid.MessageID, + currentMsg *ConvertedMessage, + isBatchSend bool, +) (replyTo, threadRoot, prevThreadEvent *database.Message) { log := zerolog.Ctx(ctx) var err error - if replyToPtr != nil { - replyTo, err = portal.Bridge.DB.Message.GetFirstOrSpecificPartByID(ctx, portal.Receiver, *replyToPtr) + if currentMsg.ReplyTo != nil { + replyTo, err = portal.Bridge.DB.Message.GetFirstOrSpecificPartByID(ctx, portal.Receiver, *currentMsg.ReplyTo) if err != nil { log.Err(err).Msg("Failed to get reply target message from database") } else if replyTo == nil { - if isBatchSend { + if isBatchSend || portal.Bridge.Config.OutgoingMessageReID { // This is somewhat evil replyTo = &database.Message{ - MXID: portal.Bridge.Matrix.GenerateDeterministicEventID(portal.MXID, portal.PortalKey, replyToPtr.MessageID, ptr.Val(replyToPtr.PartID)), + MXID: portal.Bridge.Matrix.GenerateDeterministicEventID(portal.MXID, portal.PortalKey, currentMsg.ReplyTo.MessageID, ptr.Val(currentMsg.ReplyTo.PartID)), + Room: currentMsg.ReplyToRoom, + SenderID: currentMsg.ReplyToUser, + } + if currentMsg.ReplyToLogin != "" && (portal.Receiver == "" || portal.Receiver == currentMsg.ReplyToLogin) { + userLogin, err := portal.Bridge.GetExistingUserLoginByID(ctx, currentMsg.ReplyToLogin) + if err != nil { + log.Err(err). + Str("reply_to_login", string(currentMsg.ReplyToLogin)). + Msg("Failed to get reply target user login") + } else if userLogin != nil { + replyTo.SenderMXID = userLogin.UserMXID + } + } else { + ghost, err := portal.Bridge.GetGhostByID(ctx, currentMsg.ReplyToUser) + if err != nil { + log.Err(err). + Str("reply_to_user_id", string(currentMsg.ReplyToUser)). + Msg("Failed to get reply target ghost") + } else { + replyTo.SenderMXID = ghost.Intent.GetMXID() + } } } else { - log.Warn().Any("reply_to", *replyToPtr).Msg("Reply target message not found in database") + log.Warn().Any("reply_to", *currentMsg.ReplyTo).Msg("Reply target message not found in database") } } } - if threadRootPtr != nil && *threadRootPtr != currentMsg { - threadRoot, err = portal.Bridge.DB.Message.GetFirstThreadMessage(ctx, portal.PortalKey, *threadRootPtr) + if currentMsg.ThreadRoot != nil && *currentMsg.ThreadRoot != currentMsgID { + threadRoot, err = portal.Bridge.DB.Message.GetFirstThreadMessage(ctx, portal.PortalKey, *currentMsg.ThreadRoot) if err != nil { log.Err(err).Msg("Failed to get thread root message from database") } else if threadRoot == nil { - if isBatchSend { + if isBatchSend || portal.Bridge.Config.OutgoingMessageReID { threadRoot = &database.Message{ - MXID: portal.Bridge.Matrix.GenerateDeterministicEventID(portal.MXID, portal.PortalKey, *threadRootPtr, ""), + MXID: portal.Bridge.Matrix.GenerateDeterministicEventID(portal.MXID, portal.PortalKey, *currentMsg.ThreadRoot, ""), } } else { - log.Warn().Str("thread_root", string(*threadRootPtr)).Msg("Thread root message not found in database") + log.Warn().Str("thread_root", string(*currentMsg.ThreadRoot)).Msg("Thread root message not found in database") } - } else if prevThreadEvent, err = portal.Bridge.DB.Message.GetLastThreadMessage(ctx, portal.PortalKey, *threadRootPtr); err != nil { + } else if prevThreadEvent, err = portal.Bridge.DB.Message.GetLastThreadMessage(ctx, portal.PortalKey, *currentMsg.ThreadRoot); err != nil { log.Err(err).Msg("Failed to get last thread message from database") } if prevThreadEvent == nil { - prevThreadEvent = threadRoot + prevThreadEvent = ptr.Clone(threadRoot) } } return } -func (portal *Portal) applyRelationMeta(content *event.MessageEventContent, replyTo, threadRoot, prevThreadEvent *database.Message) { +func (portal *Portal) applyRelationMeta(ctx context.Context, content *event.MessageEventContent, replyTo, threadRoot, prevThreadEvent *database.Message) { if content.Mentions == nil { content.Mentions = &event.Mentions{} } @@ -1797,7 +2777,24 @@ func (portal *Portal) applyRelationMeta(content *event.MessageEventContent, repl content.GetRelatesTo().SetThread(threadRoot.MXID, prevThreadEvent.MXID) } if replyTo != nil { - content.GetRelatesTo().SetReplyTo(replyTo.MXID) + crossRoom := !replyTo.Room.IsEmpty() && replyTo.Room != portal.PortalKey + if !crossRoom || portal.Bridge.Config.CrossRoomReplies { + content.GetRelatesTo().SetReplyTo(replyTo.MXID) + } + if crossRoom && portal.Bridge.Config.CrossRoomReplies { + targetPortal, err := portal.Bridge.GetExistingPortalByKey(ctx, replyTo.Room) + if err != nil { + zerolog.Ctx(ctx).Err(err). + Object("target_portal_key", replyTo.Room). + Msg("Failed to get cross-room reply portal") + } else if targetPortal == nil || targetPortal.MXID == "" { + zerolog.Ctx(ctx).Warn(). + Object("target_portal_key", replyTo.Room). + Msg("Cross-room reply portal not found") + } else { + content.RelatesTo.InReplyTo.UnstableRoomID = targetPortal.MXID + } + } content.Mentions.Add(replyTo.SenderMXID) } } @@ -1811,27 +2808,32 @@ func (portal *Portal) sendConvertedMessage( ts time.Time, streamOrder int64, logContext func(*zerolog.Event) *zerolog.Event, -) []*database.Message { +) ([]*database.Message, EventHandlingResult) { if logContext == nil { logContext = func(e *zerolog.Event) *zerolog.Event { return e } } log := zerolog.Ctx(ctx) - replyTo, threadRoot, prevThreadEvent := portal.getRelationMeta(ctx, id, converted.ReplyTo, converted.ThreadRoot, false) + replyTo, threadRoot, prevThreadEvent := portal.getRelationMeta( + ctx, id, converted, false, + ) output := make([]*database.Message, 0, len(converted.Parts)) + allSuccess := true for i, part := range converted.Parts { - portal.applyRelationMeta(part.Content, replyTo, threadRoot, prevThreadEvent) + portal.applyRelationMeta(ctx, part.Content, replyTo, threadRoot, prevThreadEvent) + part.Content.BeeperDisappearingTimer = converted.Disappear.ToEventContent() dbMessage := &database.Message{ - ID: id, - PartID: part.ID, - Room: portal.PortalKey, - SenderID: senderID, - SenderMXID: intent.GetMXID(), - Timestamp: ts, - ThreadRoot: ptr.Val(converted.ThreadRoot), - ReplyTo: ptr.Val(converted.ReplyTo), - Metadata: part.DBMetadata, + ID: id, + PartID: part.ID, + Room: portal.PortalKey, + SenderID: senderID, + SenderMXID: intent.GetMXID(), + Timestamp: ts, + ThreadRoot: ptr.Val(converted.ThreadRoot), + ReplyTo: ptr.Val(converted.ReplyTo), + Metadata: part.DBMetadata, + IsDoublePuppeted: intent.IsDoublePuppet(), } if part.DontBridge { dbMessage.SetFakeMXID() @@ -1851,6 +2853,7 @@ func (portal *Portal) sendConvertedMessage( }) if err != nil { logContext(log.Err(err)).Str("part_id", string(part.ID)).Msg("Failed to send message part to Matrix") + allSuccess = false continue } logContext(log.Debug()). @@ -1862,14 +2865,16 @@ func (portal *Portal) sendConvertedMessage( err := portal.Bridge.DB.Message.Insert(ctx, dbMessage) if err != nil { logContext(log.Err(err)).Str("part_id", string(part.ID)).Msg("Failed to save message part to database") + allSuccess = false } - if converted.Disappear.Type != database.DisappearingTypeNone && !dbMessage.HasFakeMXID() { - if converted.Disappear.Type == database.DisappearingTypeAfterSend && converted.Disappear.DisappearAt.IsZero() { + if converted.Disappear.Type != event.DisappearingTypeNone && !dbMessage.HasFakeMXID() { + if converted.Disappear.Type == event.DisappearingTypeAfterSend && converted.Disappear.DisappearAt.IsZero() { converted.Disappear.DisappearAt = dbMessage.Timestamp.Add(converted.Disappear.Timer) } - go portal.Bridge.DisappearLoop.Add(ctx, &database.DisappearingMessage{ + portal.Bridge.DisappearLoop.Add(ctx, &database.DisappearingMessage{ RoomID: portal.MXID, EventID: dbMessage.MXID, + Timestamp: dbMessage.Timestamp, DisappearingSetting: converted.Disappear, }) } @@ -1878,7 +2883,10 @@ func (portal *Portal) sendConvertedMessage( } output = append(output, dbMessage) } - return output + if !allSuccess { + return output, EventHandlingResultFailed + } + return output, EventHandlingResultSuccess } func (portal *Portal) checkPendingMessage(ctx context.Context, evt RemoteMessage) (bool, *database.Message) { @@ -1938,21 +2946,24 @@ func (portal *Portal) checkPendingMessage(ctx context.Context, evt RemoteMessage return true, pending.db } -func (portal *Portal) handleRemoteUpsert(ctx context.Context, source *UserLogin, evt RemoteMessageUpsert, existing []*database.Message) bool { +func (portal *Portal) handleRemoteUpsert(ctx context.Context, source *UserLogin, evt RemoteMessageUpsert, existing []*database.Message) (handleRes EventHandlingResult, continueHandling bool) { log := zerolog.Ctx(ctx) - intent := portal.GetIntentFor(ctx, evt.GetSender(), source, RemoteEventMessageUpsert) - if intent == nil { - return false + intent, ok := portal.GetIntentFor(ctx, evt.GetSender(), source, RemoteEventMessageUpsert) + if !ok { + return } res, err := evt.HandleExisting(ctx, portal, intent, existing) if err != nil { log.Err(err).Msg("Failed to handle existing message in upsert event after receiving remote echo") + } else { + handleRes = EventHandlingResultSuccess } if res.SaveParts { for _, part := range existing { err = portal.Bridge.DB.Message.Update(ctx, part) if err != nil { log.Err(err).Str("part_id", string(part.PartID)).Msg("Failed to update message part in database") + handleRes = EventHandlingResultFailed.WithError(err) } } } @@ -1964,19 +2975,25 @@ func (portal *Portal) handleRemoteUpsert(ctx context.Context, source *UserLogin, Str("action", "handle remote subevent"). Stringer("bridge_evt_type", subType). Logger() - portal.handleRemoteEvent(log.WithContext(ctx), source, subType, subEvt) + subRes := portal.handleRemoteEvent(log.WithContext(ctx), source, subType, subEvt) + if !subRes.Success { + handleRes.Success = false + } } } - return res.ContinueMessageHandling + continueHandling = res.ContinueMessageHandling + return } -func (portal *Portal) handleRemoteMessage(ctx context.Context, source *UserLogin, evt RemoteMessage) { +func (portal *Portal) handleRemoteMessage(ctx context.Context, source *UserLogin, evt RemoteMessage) (res EventHandlingResult) { log := zerolog.Ctx(ctx) upsertEvt, isUpsert := evt.(RemoteMessageUpsert) isUpsert = isUpsert && evt.GetType() == RemoteEventMessageUpsert if wasPending, dbMessage := portal.checkPendingMessage(ctx, evt); wasPending { if isUpsert && dbMessage != nil { - portal.handleRemoteUpsert(ctx, source, upsertEvt, []*database.Message{dbMessage}) + res, _ = portal.handleRemoteUpsert(ctx, source, upsertEvt, []*database.Message{dbMessage}) + } else { + res = EventHandlingResultIgnored } return } @@ -1985,32 +3002,42 @@ func (portal *Portal) handleRemoteMessage(ctx context.Context, source *UserLogin log.Err(err).Msg("Failed to check if message is a duplicate") } else if len(existing) > 0 { if isUpsert { - if portal.handleRemoteUpsert(ctx, source, upsertEvt, existing) { + var continueHandling bool + res, continueHandling = portal.handleRemoteUpsert(ctx, source, upsertEvt, existing) + if continueHandling { log.Debug().Msg("Upsert handler said to continue message handling normally") } else { - return + return res } } else { log.Debug().Stringer("existing_mxid", existing[0].MXID).Msg("Ignoring duplicate message") - return + return EventHandlingResultIgnored } } - intent := portal.GetIntentFor(ctx, evt.GetSender(), source, RemoteEventMessage) - if intent == nil { - return + intent, ok := portal.GetIntentFor(ctx, evt.GetSender(), source, RemoteEventMessage) + if !ok { + return EventHandlingResultFailed } ts := getEventTS(evt) converted, err := evt.ConvertMessage(ctx, portal, intent) if err != nil { if errors.Is(err, ErrIgnoringRemoteEvent) { log.Debug().Err(err).Msg("Remote message handling was cancelled by convert function") + return EventHandlingResultIgnored } else { log.Err(err).Msg("Failed to convert remote message") portal.sendRemoteErrorNotice(ctx, intent, err, ts, "message") + return EventHandlingResultFailed.WithError(err) } - return } - portal.sendConvertedMessage(ctx, evt.GetID(), intent, evt.GetSender().Sender, converted, ts, getStreamOrder(evt), nil) + _, res = portal.sendConvertedMessage(ctx, evt.GetID(), intent, evt.GetSender().Sender, converted, ts, getStreamOrder(evt), nil) + if portal.currentlyTypingGhosts.Pop(intent.GetMXID()) { + err = intent.MarkTyping(ctx, portal.MXID, TypingTypeText, 0) + if err != nil { + log.Warn().Err(err).Msg("Failed to send stop typing event after bridging message") + } + } + return } func (portal *Portal) sendRemoteErrorNotice(ctx context.Context, intent MatrixAPI, err error, ts time.Time, evtTypeName string) { @@ -2033,7 +3060,7 @@ func (portal *Portal) sendRemoteErrorNotice(ctx context.Context, intent MatrixAP } } -func (portal *Portal) handleRemoteEdit(ctx context.Context, source *UserLogin, evt RemoteEdit) { +func (portal *Portal) handleRemoteEdit(ctx context.Context, source *UserLogin, evt RemoteEdit) EventHandlingResult { log := zerolog.Ctx(ctx) var existing []*database.Message if bundledEvt, ok := evt.(RemoteEventWithBundledParts); ok { @@ -2045,28 +3072,41 @@ func (portal *Portal) handleRemoteEdit(ctx context.Context, source *UserLogin, e existing, err = portal.Bridge.DB.Message.GetAllPartsByID(ctx, portal.Receiver, targetID) if err != nil { log.Err(err).Msg("Failed to get edit target message") - return + return EventHandlingResultFailed.WithError(err) } } if existing == nil { log.Warn().Msg("Edit target message not found") - return + return EventHandlingResultIgnored } - intent := portal.GetIntentFor(ctx, evt.GetSender(), source, RemoteEventEdit) - if intent == nil { - return + intent, ok := portal.GetIntentFor(ctx, evt.GetSender(), source, RemoteEventEdit) + if !ok { + return EventHandlingResultFailed + } else if intent.GetMXID() != existing[0].SenderMXID { + log.Warn(). + Stringer("edit_sender_mxid", intent.GetMXID()). + Stringer("original_sender_mxid", existing[0].SenderMXID). + Msg("Not bridging edit: sender doesn't match original message sender") + return EventHandlingResultIgnored } ts := getEventTS(evt) converted, err := evt.ConvertEdit(ctx, portal, intent, existing) if errors.Is(err, ErrIgnoringRemoteEvent) { log.Debug().Err(err).Msg("Remote edit handling was cancelled by convert function") - return + return EventHandlingResultIgnored } else if err != nil { log.Err(err).Msg("Failed to convert remote edit") portal.sendRemoteErrorNotice(ctx, intent, err, ts, "edit") - return + return EventHandlingResultFailed.WithError(err) } - portal.sendConvertedEdit(ctx, existing[0].ID, evt.GetSender().Sender, converted, intent, ts, getStreamOrder(evt)) + res := portal.sendConvertedEdit(ctx, existing[0].ID, evt.GetSender().Sender, converted, intent, ts, getStreamOrder(evt)) + if portal.currentlyTypingGhosts.Pop(intent.GetMXID()) { + err = intent.MarkTyping(ctx, portal.MXID, TypingTypeText, 0) + if err != nil { + log.Warn().Err(err).Msg("Failed to send stop typing event after bridging edit") + } + } + return res } func (portal *Portal) sendConvertedEdit( @@ -2077,8 +3117,9 @@ func (portal *Portal) sendConvertedEdit( intent MatrixAPI, ts time.Time, streamOrder int64, -) { +) EventHandlingResult { log := zerolog.Ctx(ctx) + allSuccess := true for i, part := range converted.ModifiedParts { if part.Content.Mentions == nil { part.Content.Mentions = &event.Mentions{} @@ -2114,6 +3155,7 @@ func (portal *Portal) sendConvertedEdit( }) if err != nil { log.Err(err).Stringer("part_mxid", part.Part.MXID).Msg("Failed to edit message part") + allSuccess = false continue } else { log.Debug(). @@ -2128,6 +3170,7 @@ func (portal *Portal) sendConvertedEdit( err := portal.Bridge.DB.Message.Update(ctx, part.Part) if err != nil { log.Err(err).Int64("part_rowid", part.Part.RowID).Msg("Failed to update message part in database") + allSuccess = false } } for _, part := range converted.DeletedParts { @@ -2141,6 +3184,7 @@ func (portal *Portal) sendConvertedEdit( }) if err != nil { log.Err(err).Stringer("part_mxid", part.MXID).Msg("Failed to redact message part deleted in edit") + allSuccess = false } else { log.Debug(). Stringer("redaction_event_id", resp.EventID). @@ -2151,11 +3195,19 @@ func (portal *Portal) sendConvertedEdit( err = portal.Bridge.DB.Message.Delete(ctx, part.RowID) if err != nil { log.Err(err).Int64("part_rowid", part.RowID).Msg("Failed to delete message part from database") + allSuccess = false } } if converted.AddedParts != nil { - portal.sendConvertedMessage(ctx, targetID, intent, senderID, converted.AddedParts, ts, streamOrder, nil) + _, res := portal.sendConvertedMessage(ctx, targetID, intent, senderID, converted.AddedParts, ts, streamOrder, nil) + if !res.Success { + allSuccess = false + } } + if !allSuccess { + return EventHandlingResultFailed + } + return EventHandlingResultSuccess } func (portal *Portal) getTargetMessagePart(ctx context.Context, evt RemoteEventWithTargetMessage) (*database.Message, error) { @@ -2188,17 +3240,17 @@ func getStreamOrder(evt RemoteEvent) int64 { return 0 } -func (portal *Portal) handleRemoteReactionSync(ctx context.Context, source *UserLogin, evt RemoteReactionSync) { +func (portal *Portal) handleRemoteReactionSync(ctx context.Context, source *UserLogin, evt RemoteReactionSync) EventHandlingResult { log := zerolog.Ctx(ctx) eventTS := getEventTS(evt) targetMessage, err := portal.getTargetMessagePart(ctx, evt) if err != nil { log.Err(err).Msg("Failed to get target message for reaction") - return + return EventHandlingResultFailed.WithError(err) } else if targetMessage == nil { // TODO use deterministic event ID as target if applicable? log.Warn().Msg("Target message for reaction not found") - return + return EventHandlingResultIgnored } var existingReactions []*database.Reaction if partTargeter, ok := evt.(RemoteEventWithTargetPart); ok { @@ -2206,6 +3258,10 @@ func (portal *Portal) handleRemoteReactionSync(ctx context.Context, source *User } else { existingReactions, err = portal.Bridge.DB.Reaction.GetAllToMessage(ctx, portal.Receiver, evt.GetTargetMessage()) } + if err != nil { + log.Err(err).Msg("Failed to get existing reactions for reaction sync") + return EventHandlingResultFailed.WithError(err) + } existing := make(map[networkid.UserID]map[networkid.EmojiID]*database.Reaction) for _, existingReaction := range existingReactions { if existing[existingReaction.SenderID] == nil { @@ -2214,8 +3270,14 @@ func (portal *Portal) handleRemoteReactionSync(ctx context.Context, source *User existing[existingReaction.SenderID][existingReaction.EmojiID] = existingReaction } - doAddReaction := func(new *BackfillReaction) MatrixAPI { - intent := portal.GetIntentFor(ctx, new.Sender, source, RemoteEventReactionSync) + doAddReaction := func(new *BackfillReaction, intent MatrixAPI) { + if intent == nil { + var ok bool + intent, ok = portal.GetIntentFor(ctx, new.Sender, source, RemoteEventReactionSync) + if !ok { + return + } + } portal.sendConvertedReaction( ctx, new.Sender.Sender, intent, targetMessage, new.EmojiID, new.Emoji, new.Timestamp, new.DBMetadata, new.ExtraContent, @@ -2225,7 +3287,6 @@ func (portal *Portal) handleRemoteReactionSync(ctx context.Context, source *User Time("reaction_ts", new.Timestamp) }, ) - return intent } doRemoveReaction := func(old *database.Reaction, intent MatrixAPI, deleteRow bool) { if intent == nil && old.SenderMXID != "" { @@ -2259,8 +3320,12 @@ func (portal *Portal) handleRemoteReactionSync(ctx context.Context, source *User } } doOverwriteReaction := func(new *BackfillReaction, old *database.Reaction) { - intent := doAddReaction(new) + intent, ok := portal.GetIntentFor(ctx, new.Sender, source, RemoteEventReactionSync) + if !ok { + return + } doRemoveReaction(old, intent, false) + doAddReaction(new, intent) } newData := evt.GetReactions() @@ -2274,12 +3339,12 @@ func (portal *Portal) handleRemoteReactionSync(ctx context.Context, source *User existingReaction, ok := existingUserReactions[reaction.EmojiID] if ok { delete(existingUserReactions, reaction.EmojiID) - if reaction.EmojiID != "" { + if reaction.EmojiID != "" || reaction.Emoji == existingReaction.Emoji { continue } doOverwriteReaction(reaction, existingReaction) } else { - doAddReaction(reaction) + doAddReaction(reaction, nil) } } totalReactionCount := len(existingUserReactions) + len(reactions.Reactions) @@ -2309,30 +3374,34 @@ func (portal *Portal) handleRemoteReactionSync(ctx context.Context, source *User } } } + return EventHandlingResultSuccess } -func (portal *Portal) handleRemoteReaction(ctx context.Context, source *UserLogin, evt RemoteReaction) { +func (portal *Portal) handleRemoteReaction(ctx context.Context, source *UserLogin, evt RemoteReaction) EventHandlingResult { log := zerolog.Ctx(ctx) targetMessage, err := portal.getTargetMessagePart(ctx, evt) if err != nil { log.Err(err).Msg("Failed to get target message for reaction") - return + return EventHandlingResultFailed.WithError(err) } else if targetMessage == nil { // TODO use deterministic event ID as target if applicable? log.Warn().Msg("Target message for reaction not found") - return + return EventHandlingResultIgnored } emoji, emojiID := evt.GetReactionEmoji() existingReaction, err := portal.Bridge.DB.Reaction.GetByID(ctx, portal.Receiver, targetMessage.ID, targetMessage.PartID, evt.GetSender().Sender, emojiID) if err != nil { log.Err(err).Msg("Failed to check if reaction is a duplicate") - return + return EventHandlingResultFailed.WithError(err) } else if existingReaction != nil && (emojiID != "" || existingReaction.Emoji == emoji) { log.Debug().Msg("Ignoring duplicate reaction") - return + return EventHandlingResultIgnored } ts := getEventTS(evt) - intent := portal.GetIntentFor(ctx, evt.GetSender(), source, RemoteEventReaction) + intent, ok := portal.GetIntentFor(ctx, evt.GetSender(), source, RemoteEventReaction) + if !ok { + return EventHandlingResultFailed + } var extra map[string]any if extraContentProvider, ok := evt.(RemoteReactionWithExtraContent); ok { extra = extraContentProvider.GetReactionExtraContent() @@ -2341,7 +3410,6 @@ func (portal *Portal) handleRemoteReaction(ctx context.Context, source *UserLogi if metaProvider, ok := evt.(RemoteReactionWithMeta); ok { dbMetadata = metaProvider.GetReactionDBMetadata() } - portal.sendConvertedReaction(ctx, evt.GetSender().Sender, intent, targetMessage, emojiID, emoji, ts, dbMetadata, extra, nil) if existingReaction != nil { _, err = intent.SendMessage(ctx, portal.MXID, event.EventRedaction, &event.Content{ Parsed: &event.RedactionEventContent{ @@ -2352,13 +3420,14 @@ func (portal *Portal) handleRemoteReaction(ctx context.Context, source *UserLogi log.Err(err).Msg("Failed to redact old reaction") } } + return portal.sendConvertedReaction(ctx, evt.GetSender().Sender, intent, targetMessage, emojiID, emoji, ts, dbMetadata, extra, nil) } func (portal *Portal) sendConvertedReaction( ctx context.Context, senderID networkid.UserID, intent MatrixAPI, targetMessage *database.Message, emojiID networkid.EmojiID, emoji string, ts time.Time, dbMetadata any, extraContent map[string]any, logContext func(*zerolog.Event) *zerolog.Event, -) { +) EventHandlingResult { if logContext == nil { logContext = func(e *zerolog.Event) *zerolog.Event { return e @@ -2393,7 +3462,7 @@ func (portal *Portal) sendConvertedReaction( }) if err != nil { logContext(log.Err(err)).Msg("Failed to send reaction to Matrix") - return + return EventHandlingResultFailed.WithError(err) } logContext(log.Debug()). Stringer("event_id", resp.EventID). @@ -2402,7 +3471,9 @@ func (portal *Portal) sendConvertedReaction( err = portal.Bridge.DB.Reaction.Upsert(ctx, dbReaction) if err != nil { logContext(log.Err(err)).Msg("Failed to save reaction to database") + return EventHandlingResultFailed.WithError(err) } + return EventHandlingResultSuccess } func (portal *Portal) getIntentForMXID(ctx context.Context, userID id.UserID) (MatrixAPI, error) { @@ -2421,22 +3492,26 @@ func (portal *Portal) getIntentForMXID(ctx context.Context, userID id.UserID) (M } } -func (portal *Portal) handleRemoteReactionRemove(ctx context.Context, source *UserLogin, evt RemoteReactionRemove) { +func (portal *Portal) handleRemoteReactionRemove(ctx context.Context, source *UserLogin, evt RemoteReactionRemove) EventHandlingResult { log := zerolog.Ctx(ctx) targetReaction, err := portal.getTargetReaction(ctx, evt) if err != nil { log.Err(err).Msg("Failed to get target reaction for removal") - return + return EventHandlingResultFailed.WithError(err) } else if targetReaction == nil { log.Warn().Msg("Target reaction not found") - return + return EventHandlingResultIgnored } intent, err := portal.getIntentForMXID(ctx, targetReaction.SenderMXID) if err != nil { log.Err(err).Stringer("sender_mxid", targetReaction.SenderMXID).Msg("Failed to get intent for removing reaction") } if intent == nil { - intent = portal.GetIntentFor(ctx, evt.GetSender(), source, RemoteEventReactionRemove) + var ok bool + intent, ok = portal.GetIntentFor(ctx, evt.GetSender(), source, RemoteEventReactionRemove) + if !ok { + return EventHandlingResultFailed + } } ts := getEventTS(evt) _, err = intent.SendMessage(ctx, portal.MXID, event.EventRedaction, &event.Content{ @@ -2446,30 +3521,42 @@ func (portal *Portal) handleRemoteReactionRemove(ctx context.Context, source *Us }, &MatrixSendExtra{Timestamp: ts, ReactionMeta: targetReaction}) if err != nil { log.Err(err).Stringer("reaction_mxid", targetReaction.MXID).Msg("Failed to redact reaction") + return EventHandlingResultFailed.WithError(err) } err = portal.Bridge.DB.Reaction.Delete(ctx, targetReaction) if err != nil { log.Err(err).Msg("Failed to delete target reaction from database") } + return EventHandlingResultSuccess } -func (portal *Portal) handleRemoteMessageRemove(ctx context.Context, source *UserLogin, evt RemoteMessageRemove) { +func (portal *Portal) handleRemoteMessageRemove(ctx context.Context, source *UserLogin, evt RemoteMessageRemove) EventHandlingResult { log := zerolog.Ctx(ctx) targetParts, err := portal.Bridge.DB.Message.GetAllPartsByID(ctx, portal.Receiver, evt.GetTargetMessage()) if err != nil { log.Err(err).Msg("Failed to get target message for removal") - return + return EventHandlingResultFailed.WithError(err) } else if len(targetParts) == 0 { log.Debug().Msg("Target message not found") - return + return EventHandlingResultIgnored } onlyForMeProvider, ok := evt.(RemoteDeleteOnlyForMe) onlyForMe := ok && onlyForMeProvider.DeleteOnlyForMe() if onlyForMe && portal.Receiver == "" { - // TODO check if there are other user logins before deleting + _, others, err := portal.findOtherLogins(ctx, source) + if err != nil { + log.Err(err).Msg("Failed to check if portal has other logins") + return EventHandlingResultFailed.WithError(err) + } else if len(others) > 0 { + log.Debug().Msg("Ignoring delete for me event in portal with multiple logins") + return EventHandlingResultIgnored + } } - intent := portal.GetIntentFor(ctx, evt.GetSender(), source, RemoteEventMessageRemove) + intent, ok := portal.GetIntentFor(ctx, evt.GetSender(), source, RemoteEventMessageRemove) + if !ok { + return EventHandlingResultFailed + } if intent == portal.Bridge.Bot && len(targetParts) > 0 { senderIntent, err := portal.getIntentForMXID(ctx, targetParts[0].SenderMXID) if err != nil { @@ -2478,15 +3565,17 @@ func (portal *Portal) handleRemoteMessageRemove(ctx context.Context, source *Use intent = senderIntent } } - portal.redactMessageParts(ctx, targetParts, intent, getEventTS(evt)) + res := portal.redactMessageParts(ctx, targetParts, intent, getEventTS(evt)) err = portal.Bridge.DB.Message.DeleteAllParts(ctx, portal.Receiver, evt.GetTargetMessage()) if err != nil { log.Err(err).Msg("Failed to delete target message from database") } + return res } -func (portal *Portal) redactMessageParts(ctx context.Context, parts []*database.Message, intent MatrixAPI, ts time.Time) { +func (portal *Portal) redactMessageParts(ctx context.Context, parts []*database.Message, intent MatrixAPI, ts time.Time) EventHandlingResult { log := zerolog.Ctx(ctx) + var anyFailed bool for _, part := range parts { if part.HasFakeMXID() { continue @@ -2498,6 +3587,7 @@ func (portal *Portal) redactMessageParts(ctx context.Context, parts []*database. }, &MatrixSendExtra{Timestamp: ts, MessageMeta: part}) if err != nil { log.Err(err).Stringer("part_mxid", part.MXID).Msg("Failed to redact message part") + anyFailed = true } else { log.Debug(). Stringer("redaction_event_id", resp.EventID). @@ -2506,22 +3596,33 @@ func (portal *Portal) redactMessageParts(ctx context.Context, parts []*database. Msg("Sent redaction of message part to Matrix") } } + if anyFailed { + return EventHandlingResultFailed + } + return EventHandlingResultSuccess } -func (portal *Portal) handleRemoteReadReceipt(ctx context.Context, source *UserLogin, evt RemoteReadReceipt) { - // TODO exclude fake mxids +func (portal *Portal) handleRemoteReadReceipt(ctx context.Context, source *UserLogin, evt RemoteReadReceipt) EventHandlingResult { log := zerolog.Ctx(ctx) var err error var lastTarget *database.Message + readUpTo := evt.GetReadUpTo() if lastTargetID := evt.GetLastReceiptTarget(); lastTargetID != "" { lastTarget, err = portal.Bridge.DB.Message.GetLastPartByID(ctx, portal.Receiver, lastTargetID) if err != nil { log.Err(err).Str("last_target_id", string(lastTargetID)). Msg("Failed to get last target message for read receipt") - return + return EventHandlingResultFailed.WithError(err) } else if lastTarget == nil { log.Debug().Str("last_target_id", string(lastTargetID)). Msg("Last target message not found") + } else if lastTarget.HasFakeMXID() { + log.Debug().Str("last_target_id", string(lastTargetID)). + Msg("Last target message is fake") + if readUpTo.IsZero() { + readUpTo = lastTarget.Timestamp + } + lastTarget = nil } } if lastTarget == nil { @@ -2530,62 +3631,89 @@ func (portal *Portal) handleRemoteReadReceipt(ctx context.Context, source *UserL if err != nil { log.Err(err).Str("target_id", string(targetID)). Msg("Failed to get target message for read receipt") - return - } else if target != nil && (lastTarget == nil || target.Timestamp.After(lastTarget.Timestamp)) { + return EventHandlingResultFailed.WithError(err) + } else if target != nil && !target.HasFakeMXID() && (lastTarget == nil || target.Timestamp.After(lastTarget.Timestamp)) { lastTarget = target } } } - readUpTo := evt.GetReadUpTo() if lastTarget == nil && !readUpTo.IsZero() { - lastTarget, err = portal.Bridge.DB.Message.GetLastPartAtOrBeforeTime(ctx, portal.PortalKey, readUpTo) + lastTarget, err = portal.Bridge.DB.Message.GetLastNonFakePartAtOrBeforeTime(ctx, portal.PortalKey, readUpTo) if err != nil { log.Err(err).Time("read_up_to", readUpTo).Msg("Failed to get target message for read receipt") } } - if lastTarget == nil { - log.Warn().Msg("No target message found for read receipt") - return - } sender := evt.GetSender() - intent := portal.GetIntentFor(ctx, sender, source, RemoteEventReadReceipt) - err = intent.MarkRead(ctx, portal.MXID, lastTarget.MXID, getEventTS(evt)) - if err != nil { - log.Err(err).Stringer("target_mxid", lastTarget.MXID).Msg("Failed to bridge read receipt") + intent, ok := portal.GetIntentFor(ctx, sender, source, RemoteEventReadReceipt) + if !ok { + return EventHandlingResultFailed + } + var addTargetLog func(evt *zerolog.Event) *zerolog.Event + if lastTarget == nil { + sevt, evtOK := evt.(RemoteReadReceiptWithStreamOrder) + soIntent, soIntentOK := intent.(StreamOrderReadingMatrixAPI) + if !evtOK || !soIntentOK || sevt.GetReadUpToStreamOrder() == 0 { + log.Warn().Msg("No target message found for read receipt") + return EventHandlingResultIgnored + } + targetStreamOrder := sevt.GetReadUpToStreamOrder() + addTargetLog = func(evt *zerolog.Event) *zerolog.Event { + return evt.Int64("target_stream_order", targetStreamOrder) + } + err = soIntent.MarkStreamOrderRead(ctx, portal.MXID, targetStreamOrder, getEventTS(evt)) + if readUpTo.IsZero() { + readUpTo = getEventTS(evt) + } } else { - log.Debug().Stringer("target_mxid", lastTarget.MXID).Msg("Bridged read receipt") + addTargetLog = func(evt *zerolog.Event) *zerolog.Event { + return evt.Stringer("target_mxid", lastTarget.MXID) + } + err = intent.MarkRead(ctx, portal.MXID, lastTarget.MXID, getEventTS(evt)) + readUpTo = lastTarget.Timestamp + } + if err != nil { + addTargetLog(log.Err(err)).Msg("Failed to bridge read receipt") + return EventHandlingResultFailed.WithError(err) + } else { + addTargetLog(log.Debug()).Msg("Bridged read receipt") } if sender.IsFromMe { - portal.Bridge.DisappearLoop.StartAll(ctx, portal.MXID) + portal.Bridge.DisappearLoop.StartAllBefore(ctx, portal.MXID, readUpTo) } + return EventHandlingResultSuccess } -func (portal *Portal) handleRemoteMarkUnread(ctx context.Context, source *UserLogin, evt RemoteMarkUnread) { +func (portal *Portal) handleRemoteMarkUnread(ctx context.Context, source *UserLogin, evt RemoteMarkUnread) EventHandlingResult { if !evt.GetSender().IsFromMe { zerolog.Ctx(ctx).Warn().Msg("Ignoring mark unread event from non-self user") - return + return EventHandlingResultIgnored } dp := source.User.DoublePuppet(ctx) if dp == nil { - return + return EventHandlingResultIgnored } err := dp.MarkUnread(ctx, portal.MXID, evt.GetUnread()) if err != nil { zerolog.Ctx(ctx).Err(err).Msg("Failed to bridge mark unread event") + return EventHandlingResultFailed.WithError(err) } + return EventHandlingResultSuccess } -func (portal *Portal) handleRemoteDeliveryReceipt(ctx context.Context, source *UserLogin, evt RemoteDeliveryReceipt) { - if portal.RoomType != database.RoomTypeDM || evt.GetSender().Sender != portal.OtherUserID { - return +func (portal *Portal) handleRemoteDeliveryReceipt(ctx context.Context, source *UserLogin, evt RemoteDeliveryReceipt) EventHandlingResult { + if portal.RoomType != database.RoomTypeDM || (evt.GetSender().Sender != portal.OtherUserID && portal.OtherUserID != "") { + return EventHandlingResultIgnored + } + intent, ok := portal.GetIntentFor(ctx, evt.GetSender(), source, RemoteEventDeliveryReceipt) + if !ok { + return EventHandlingResultFailed } - intent := portal.GetIntentFor(ctx, evt.GetSender(), source, RemoteEventDeliveryReceipt) log := zerolog.Ctx(ctx) for _, target := range evt.GetReceiptTargets() { targetParts, err := portal.Bridge.DB.Message.GetAllPartsByID(ctx, portal.Receiver, target) if err != nil { log.Err(err).Str("target_id", string(target)).Msg("Failed to get target message for delivery receipt") - continue + return EventHandlingResultFailed.WithError(err) } else if len(targetParts) == 0 { continue } else if _, sentByGhost := portal.Bridge.Matrix.ParseGhostMXID(targetParts[0].SenderMXID); sentByGhost { @@ -2599,33 +3727,48 @@ func (portal *Portal) handleRemoteDeliveryReceipt(ctx context.Context, source *U RoomID: portal.MXID, SourceEventID: part.MXID, Sender: part.SenderMXID, + + IsSourceEventDoublePuppeted: part.IsDoublePuppeted, }) } } + return EventHandlingResultSuccess } -func (portal *Portal) handleRemoteTyping(ctx context.Context, source *UserLogin, evt RemoteTyping) { +func (portal *Portal) handleRemoteTyping(ctx context.Context, source *UserLogin, evt RemoteTyping) EventHandlingResult { var typingType TypingType if typedEvt, ok := evt.(RemoteTypingWithType); ok { typingType = typedEvt.GetTypingType() } - intent := portal.GetIntentFor(ctx, evt.GetSender(), source, RemoteEventTyping) - err := intent.MarkTyping(ctx, portal.MXID, typingType, evt.GetTimeout()) + intent, ok := portal.GetIntentFor(ctx, evt.GetSender(), source, RemoteEventTyping) + if !ok { + return EventHandlingResultFailed + } + timeout := evt.GetTimeout() + err := intent.MarkTyping(ctx, portal.MXID, typingType, timeout) if err != nil { zerolog.Ctx(ctx).Err(err).Msg("Failed to bridge typing event") + return EventHandlingResultFailed.WithError(err) } + if timeout == 0 { + portal.currentlyTypingGhosts.Remove(intent.GetMXID()) + } else { + portal.currentlyTypingGhosts.Add(intent.GetMXID()) + } + return EventHandlingResultSuccess } -func (portal *Portal) handleRemoteChatInfoChange(ctx context.Context, source *UserLogin, evt RemoteChatInfoChange) { +func (portal *Portal) handleRemoteChatInfoChange(ctx context.Context, source *UserLogin, evt RemoteChatInfoChange) EventHandlingResult { info, err := evt.GetChatInfoChange(ctx) if err != nil { zerolog.Ctx(ctx).Err(err).Msg("Failed to get chat info change") - return + return EventHandlingResultFailed.WithError(err) } portal.ProcessChatInfoChange(ctx, evt.GetSender(), source, info, getEventTS(evt)) + return EventHandlingResultSuccess } -func (portal *Portal) handleRemoteChatResync(ctx context.Context, source *UserLogin, evt RemoteChatResync) { +func (portal *Portal) handleRemoteChatResync(ctx context.Context, source *UserLogin, evt RemoteChatResync) EventHandlingResult { log := zerolog.Ctx(ctx) infoProvider, ok := evt.(RemoteChatResyncWithInfo) if ok { @@ -2654,29 +3797,120 @@ func (portal *Portal) handleRemoteChatResync(ctx context.Context, source *UserLo portal.doForwardBackfill(ctx, source, latestMessage, bundle) } } + return EventHandlingResultSuccess } -func (portal *Portal) handleRemoteChatDelete(ctx context.Context, source *UserLogin, evt RemoteChatDelete) { +func (portal *Portal) findOtherLogins(ctx context.Context, source *UserLogin) (ownUP *database.UserPortal, others []*database.UserPortal, err error) { + others, err = portal.Bridge.DB.UserPortal.GetAllInPortal(ctx, portal.PortalKey) + if err != nil { + return + } + others = slices.DeleteFunc(others, func(up *database.UserPortal) bool { + if up.LoginID == source.ID { + ownUP = up + return true + } + return false + }) + return +} + +type childDeleteProxy struct { + RemoteChatDeleteWithChildren + child networkid.PortalKey + done func() +} + +func (cdp *childDeleteProxy) AddLogContext(c zerolog.Context) zerolog.Context { + return cdp.RemoteChatDeleteWithChildren.AddLogContext(c).Str("subaction", "delete children") +} +func (cdp *childDeleteProxy) GetPortalKey() networkid.PortalKey { return cdp.child } +func (cdp *childDeleteProxy) ShouldCreatePortal() bool { return false } +func (cdp *childDeleteProxy) PreHandle(ctx context.Context, portal *Portal) {} +func (cdp *childDeleteProxy) PostHandle(ctx context.Context, portal *Portal) { cdp.done() } + +func (portal *Portal) handleRemoteChatDelete(ctx context.Context, source *UserLogin, evt RemoteChatDelete) EventHandlingResult { + log := zerolog.Ctx(ctx) if portal.Receiver == "" && evt.DeleteOnlyForMe() { - // TODO check if there are other users + ownUP, logins, err := portal.findOtherLogins(ctx, source) + if err != nil { + log.Err(err).Msg("Failed to check if portal has other logins") + return EventHandlingResultFailed.WithError(err) + } + if len(logins) > 0 { + log.Debug().Msg("Not deleting portal with other logins in remote chat delete event") + if ownUP != nil { + err = portal.Bridge.DB.UserPortal.Delete(ctx, ownUP) + if err != nil { + log.Err(err).Msg("Failed to delete own user portal row from database") + } else { + log.Debug().Msg("Deleted own user portal row from database") + } + } + _, err = portal.sendStateWithIntentOrBot( + ctx, + source.User.DoublePuppet(ctx), + event.StateMember, + source.UserMXID.String(), + &event.Content{Parsed: &event.MemberEventContent{Membership: event.MembershipLeave}}, + getEventTS(evt), + ) + if err != nil { + log.Err(err).Msg("Failed to send leave state event for user after remote chat delete") + return EventHandlingResultFailed.WithError(err) + } else { + log.Debug().Msg("Sent leave state event for user after remote chat delete") + return EventHandlingResultSuccess + } + } + } + if childDeleter, ok := evt.(RemoteChatDeleteWithChildren); ok && childDeleter.DeleteChildren() && portal.RoomType == database.RoomTypeSpace { + children, err := portal.Bridge.GetChildPortals(ctx, portal.PortalKey) + if err != nil { + log.Err(err).Msg("Failed to fetch children to delete") + return EventHandlingResultFailed.WithError(err) + } + log.Debug(). + Int("portal_count", len(children)). + Msg("Deleting child portals before remote chat delete") + var wg sync.WaitGroup + wg.Add(len(children)) + for _, child := range children { + child.queueEvent(ctx, &portalRemoteEvent{ + evt: &childDeleteProxy{ + RemoteChatDeleteWithChildren: childDeleter, + child: child.PortalKey, + done: wg.Done, + }, + source: source, + evtType: RemoteEventChatDelete, + }) + } + wg.Wait() + log.Debug().Msg("Finished deleting child portals") } err := portal.Delete(ctx) if err != nil { - zerolog.Ctx(ctx).Err(err).Msg("Failed to delete portal from database") - return + log.Err(err).Msg("Failed to delete portal from database") + return EventHandlingResultFailed.WithError(err) } err = portal.Bridge.Bot.DeleteRoom(ctx, portal.MXID, false) if err != nil { - zerolog.Ctx(ctx).Err(err).Msg("Failed to delete Matrix room") + log.Err(err).Msg("Failed to delete Matrix room") + return EventHandlingResultFailed.WithError(err) + } else { + log.Info().Msg("Deleted room after remote chat delete event") + return EventHandlingResultSuccess } } -func (portal *Portal) handleRemoteBackfill(ctx context.Context, source *UserLogin, backfill RemoteBackfill) { +func (portal *Portal) handleRemoteBackfill(ctx context.Context, source *UserLogin, backfill RemoteBackfill) (res EventHandlingResult) { //data, err := backfill.GetBackfillData(ctx, portal) //if err != nil { // zerolog.Ctx(ctx).Err(err).Msg("Failed to get backfill data") // return //} + return } type ChatInfoChange struct { @@ -2689,7 +3923,10 @@ type ChatInfoChange struct { } func (portal *Portal) ProcessChatInfoChange(ctx context.Context, sender EventSender, source *UserLogin, change *ChatInfoChange, ts time.Time) { - intent := portal.GetIntentFor(ctx, sender, source, RemoteEventChatInfoChange) + intent, ok := portal.GetIntentFor(ctx, sender, source, RemoteEventChatInfoChange) + if !ok { + return + } if change.ChatInfo != nil { portal.UpdateInfo(ctx, change.ChatInfo, source, intent, ts) } @@ -2707,13 +3944,45 @@ type PortalInfo = ChatInfo type ChatMember struct { EventSender Membership event.Membership - Nickname *string + // Per-room nickname for the user. Not yet used. + Nickname *string + // The power level to set for the user when syncing power levels. PowerLevel *int - UserInfo *UserInfo - + // Optional user info to sync the ghost user while updating membership. + UserInfo *UserInfo + // The user who sent the membership change (user who invited/kicked/banned this user). + // Not yet used. Not applicable if Membership is join or knock. + MemberSender EventSender + // Extra fields to include in the member event. + MemberEventExtra map[string]any + // The expected previous membership. If this doesn't match, the change is ignored. PrevMembership event.Membership } +type ChatMemberMap map[networkid.UserID]ChatMember + +// Set adds the given entry to this map, overwriting any existing entry with the same Sender field. +func (cmm ChatMemberMap) Set(member ChatMember) ChatMemberMap { + if member.Sender == "" && member.SenderLogin == "" && !member.IsFromMe { + return cmm + } + cmm[member.Sender] = member + return cmm +} + +// Add adds the given entry to this map, but will ignore it if an entry with the same Sender field already exists. +// It returns true if the entry was added, false otherwise. +func (cmm ChatMemberMap) Add(member ChatMember) bool { + if member.Sender == "" && member.SenderLogin == "" && !member.IsFromMe { + return false + } + if _, exists := cmm[member.Sender]; exists { + return false + } + cmm[member.Sender] = member + return true +} + type ChatMemberList struct { // Whether this is the full member list. // If true, any extra members not listed here will be removed from the portal. @@ -2721,6 +3990,10 @@ type ChatMemberList struct { // Should the bridge call IsThisUser for every member in the list? // This should be used when SenderLogin can't be filled accurately. CheckAllLogins bool + // Should any changes have the `com.beeper.exclude_from_timeline` flag set by default? + // This is recommended for syncs with non-real-time changes. + // Real-time changes (e.g. a user joining) should not set this flag set. + ExcludeChangesFromTimeline bool // The total number of members in the chat, regardless of how many of those members are included in MemberMap. TotalMemberCount int @@ -2731,7 +4004,7 @@ type ChatMemberList struct { // Deprecated: Use MemberMap instead to avoid duplicate entries Members []ChatMember - MemberMap map[networkid.UserID]ChatMember + MemberMap ChatMemberMap PowerLevels *PowerLevelOverrides } @@ -2833,9 +4106,11 @@ type ChatInfo struct { Disappear *database.DisappearingSetting ParentID *networkid.PortalID - UserLocal *UserLocalPortalInfo + UserLocal *UserLocalPortalInfo + MessageRequest *bool + CanBackfill bool - CanBackfill bool + ExcludeChangesFromTimeline bool ExtraUpdates ExtraUpdater[*Portal] } @@ -2869,26 +4144,36 @@ type UserLocalPortalInfo struct { Tag *event.RoomTag } -func (portal *Portal) updateName(ctx context.Context, name string, sender MatrixAPI, ts time.Time) bool { +func (portal *Portal) updateName( + ctx context.Context, name string, sender MatrixAPI, ts time.Time, excludeFromTimeline bool, +) bool { if portal.Name == name && (portal.NameSet || portal.MXID == "") { return false } portal.Name = name - portal.NameSet = portal.sendRoomMeta(ctx, sender, ts, event.StateRoomName, "", &event.RoomNameEventContent{Name: name}) + portal.NameSet = portal.sendRoomMeta( + ctx, sender, ts, event.StateRoomName, "", &event.RoomNameEventContent{Name: name}, excludeFromTimeline, nil, + ) return true } -func (portal *Portal) updateTopic(ctx context.Context, topic string, sender MatrixAPI, ts time.Time) bool { +func (portal *Portal) updateTopic( + ctx context.Context, topic string, sender MatrixAPI, ts time.Time, excludeFromTimeline bool, +) bool { if portal.Topic == topic && (portal.TopicSet || portal.MXID == "") { return false } portal.Topic = topic - portal.TopicSet = portal.sendRoomMeta(ctx, sender, ts, event.StateTopic, "", &event.TopicEventContent{Topic: topic}) + portal.TopicSet = portal.sendRoomMeta( + ctx, sender, ts, event.StateTopic, "", &event.TopicEventContent{Topic: topic}, excludeFromTimeline, nil, + ) return true } -func (portal *Portal) updateAvatar(ctx context.Context, avatar *Avatar, sender MatrixAPI, ts time.Time) bool { - if portal.AvatarID == avatar.ID && (portal.AvatarSet || portal.MXID == "") { +func (portal *Portal) updateAvatar( + ctx context.Context, avatar *Avatar, sender MatrixAPI, ts time.Time, excludeFromTimeline bool, +) bool { + if portal.AvatarID == avatar.ID && (avatar.Remove || portal.AvatarMXC != "") && (portal.AvatarSet || portal.MXID == "") { return false } portal.AvatarID = avatar.ID @@ -2904,13 +4189,15 @@ func (portal *Portal) updateAvatar(ctx context.Context, avatar *Avatar, sender M portal.AvatarSet = false zerolog.Ctx(ctx).Err(err).Msg("Failed to reupload room avatar") return true - } else if newHash == portal.AvatarHash && portal.AvatarSet { + } else if newHash == portal.AvatarHash && portal.AvatarMXC != "" && portal.AvatarSet { return true } portal.AvatarMXC = newMXC portal.AvatarHash = newHash } - portal.AvatarSet = portal.sendRoomMeta(ctx, sender, ts, event.StateRoomAvatar, "", &event.RoomAvatarEventContent{URL: portal.AvatarMXC}) + portal.AvatarSet = portal.sendRoomMeta( + ctx, sender, ts, event.StateRoomAvatar, "", &event.RoomAvatarEventContent{URL: portal.AvatarMXC}, excludeFromTimeline, nil, + ) return true } @@ -2924,16 +4211,28 @@ func (portal *Portal) GetTopLevelParent() *Portal { return portal.Parent.GetTopLevelParent() } +func (portal *Portal) getBridgeInfoStateKey() string { + if portal.Bridge.Config.NoBridgeInfoStateKey { + return "" + } + idProvider, ok := portal.Bridge.Matrix.(MatrixConnectorWithBridgeIdentifier) + if ok { + return idProvider.GetUniqueBridgeID() + } + return string(portal.BridgeID) +} + func (portal *Portal) getBridgeInfo() (string, event.BridgeEventContent) { bridgeInfo := event.BridgeEventContent{ BridgeBot: portal.Bridge.Bot.GetMXID(), Creator: portal.Bridge.Bot.GetMXID(), Protocol: portal.Bridge.Network.GetName().AsBridgeInfoSection(), Channel: event.BridgeInfoSection{ - ID: string(portal.ID), - DisplayName: portal.Name, - AvatarURL: portal.AvatarMXC, - Receiver: string(portal.Receiver), + ID: string(portal.ID), + DisplayName: portal.Name, + AvatarURL: portal.AvatarMXC, + Receiver: string(portal.Receiver), + MessageRequest: portal.MessageRequest, // TODO external URL? }, BeeperRoomTypeV2: string(portal.RoomType), @@ -2941,6 +4240,10 @@ func (portal *Portal) getBridgeInfo() (string, event.BridgeEventContent) { if portal.RoomType == database.RoomTypeDM || portal.RoomType == database.RoomTypeGroupDM { bridgeInfo.BeeperRoomType = "dm" } + if bridgeInfo.Protocol.ID == "slackgo" { + bridgeInfo.TempSlackRemoteIDMigratedFlag = true + bridgeInfo.TempSlackRemoteIDMigratedFlag2 = true + } parent := portal.GetTopLevelParent() if parent != nil { bridgeInfo.Network = &event.BridgeInfoSection{ @@ -2954,10 +4257,7 @@ func (portal *Portal) getBridgeInfo() (string, event.BridgeEventContent) { if ok { filler.FillPortalBridgeInfo(portal, &bridgeInfo) } - // TODO use something globally unique instead of bridge ID? - // maybe ask the matrix connector to use serverName+appserviceID+bridgeID - stateKey := string(portal.BridgeID) - return stateKey, bridgeInfo + return portal.getBridgeInfoStateKey(), bridgeInfo } func (portal *Portal) UpdateBridgeInfo(ctx context.Context) { @@ -2965,8 +4265,54 @@ func (portal *Portal) UpdateBridgeInfo(ctx context.Context) { return } stateKey, bridgeInfo := portal.getBridgeInfo() - portal.sendRoomMeta(ctx, nil, time.Now(), event.StateBridge, stateKey, &bridgeInfo) - portal.sendRoomMeta(ctx, nil, time.Now(), event.StateHalfShotBridge, stateKey, &bridgeInfo) + portal.sendRoomMeta(ctx, nil, time.Now(), event.StateBridge, stateKey, &bridgeInfo, false, nil) + portal.sendRoomMeta(ctx, nil, time.Now(), event.StateHalfShotBridge, stateKey, &bridgeInfo, false, nil) +} + +func (portal *Portal) UpdateCapabilities(ctx context.Context, source *UserLogin, implicit bool) bool { + if portal.MXID == "" { + return false + } else if !implicit && time.Since(portal.lastCapUpdate) < 24*time.Hour { + return false + } else if portal.CapState.ID != "" && source.ID != portal.CapState.Source && source.ID != portal.Receiver { + // TODO allow capability state source to change if the old user login is removed from the portal + return false + } + caps := source.Client.GetCapabilities(ctx, portal) + capID := caps.GetID() + if capID == portal.CapState.ID { + return false + } + zerolog.Ctx(ctx).Debug(). + Str("user_login_id", string(source.ID)). + Str("old_id", portal.CapState.ID). + Str("new_id", capID). + Msg("Sending new room capability event") + success := portal.sendRoomMeta(ctx, nil, time.Now(), event.StateBeeperRoomFeatures, portal.getBridgeInfoStateKey(), caps, false, nil) + if !success { + return false + } + portal.CapState = database.CapabilityState{ + Source: source.ID, + ID: capID, + Flags: portal.CapState.Flags, + } + if caps.DisappearingTimer != nil && !portal.CapState.Flags.Has(database.CapStateFlagDisappearingTimerSet) { + zerolog.Ctx(ctx).Debug().Msg("Disappearing timer capability was added, sending disappearing timer state event") + success = portal.sendRoomMeta(ctx, nil, time.Now(), event.StateBeeperDisappearingTimer, "", portal.Disappear.ToEventContent(), true, nil) + if !success { + return false + } + portal.CapState.Flags |= database.CapStateFlagDisappearingTimerSet + } + portal.lastCapUpdate = time.Now() + if implicit { + err := portal.Save(ctx) + if err != nil { + zerolog.Ctx(ctx).Err(err).Msg("Failed to save portal capability state after sending state event") + } + } + return true } func (portal *Portal) sendStateWithIntentOrBot(ctx context.Context, sender MatrixAPI, eventType event.Type, stateKey string, content *event.Content, ts time.Time) (resp *mautrix.RespSendEvent, err error) { @@ -2984,15 +4330,27 @@ func (portal *Portal) sendStateWithIntentOrBot(ctx context.Context, sender Matri return } -func (portal *Portal) sendRoomMeta(ctx context.Context, sender MatrixAPI, ts time.Time, eventType event.Type, stateKey string, content any) bool { +func (portal *Portal) sendRoomMeta( + ctx context.Context, + sender MatrixAPI, + ts time.Time, + eventType event.Type, + stateKey string, + content any, + excludeFromTimeline bool, + extra map[string]any, +) bool { if portal.MXID == "" { return false } - var extra map[string]any + if extra == nil { + extra = make(map[string]any) + } + if excludeFromTimeline { + extra["com.beeper.exclude_from_timeline"] = true + } if !portal.NameIsCustom && (eventType == event.StateRoomName || eventType == event.StateRoomAvatar) { - extra = map[string]any{ - "fi.mau.implicit_name": true, - } + extra["fi.mau.implicit_name"] = true } _, err := portal.sendStateWithIntentOrBot(ctx, sender, eventType, stateKey, &event.Content{ Parsed: content, @@ -3004,9 +4362,55 @@ func (portal *Portal) sendRoomMeta(ctx context.Context, sender MatrixAPI, ts tim Msg("Failed to set room metadata") return false } + if eventType == event.StateBeeperDisappearingTimer { + // TODO remove this debug log at some point + zerolog.Ctx(ctx).Debug(). + Any("content", content). + Msg("Sent new disappearing timer event") + } return true } +func (portal *Portal) revertRoomMeta(ctx context.Context, evt *event.Event) { + if !portal.Bridge.Config.RevertFailedStateChanges { + return + } + if evt.GetStateKey() != "" && evt.Type != event.StateMember { + return + } + switch evt.Type { + case event.StateRoomName: + portal.sendRoomMeta(ctx, nil, time.Time{}, event.StateRoomName, "", &event.RoomNameEventContent{Name: portal.Name}, true, nil) + case event.StateRoomAvatar: + portal.sendRoomMeta(ctx, nil, time.Time{}, event.StateRoomAvatar, "", &event.RoomAvatarEventContent{URL: portal.AvatarMXC}, true, nil) + case event.StateTopic: + portal.sendRoomMeta(ctx, nil, time.Time{}, event.StateTopic, "", &event.TopicEventContent{Topic: portal.Topic}, true, nil) + case event.StateBeeperDisappearingTimer: + portal.sendRoomMeta(ctx, nil, time.Time{}, event.StateBeeperDisappearingTimer, "", portal.Disappear.ToEventContent(), true, nil) + case event.StateMember: + var prevContent *event.MemberEventContent + var extra map[string]any + if evt.Unsigned.PrevContent != nil { + _ = evt.Unsigned.PrevContent.ParseRaw(evt.Type) + prevContent = evt.Unsigned.PrevContent.AsMember() + newContent := evt.Content.AsMember() + if prevContent.Membership == newContent.Membership { + return + } + extra = evt.Unsigned.PrevContent.Raw + } else { + prevContent = &event.MemberEventContent{Membership: event.MembershipLeave} + } + if portal.Bridge.Matrix.GetCapabilities().ArbitraryMemberChange { + if extra == nil { + extra = make(map[string]any) + } + extra["com.beeper.member_rollback"] = true + portal.sendRoomMeta(ctx, nil, time.Time{}, event.StateMember, evt.GetStateKey(), prevContent, true, extra) + } + } +} + func (portal *Portal) getInitialMemberList(ctx context.Context, members *ChatMemberList, source *UserLogin, pl *event.PowerLevelsEventContent) (invite, functional []id.UserID, err error) { if members == nil { invite = []id.UserID{source.UserMXID} @@ -3023,6 +4427,10 @@ func (portal *Portal) getInitialMemberList(ctx context.Context, members *ChatMem members.PowerLevels.Apply("", pl) members.memberListToMap(ctx) for _, member := range members.MemberMap { + if ctx.Err() != nil { + err = ctx.Err() + return + } if member.Membership != event.MembershipJoin && member.Membership != "" { continue } @@ -3034,7 +4442,10 @@ func (portal *Portal) getInitialMemberList(ctx context.Context, members *ChatMem ghost.UpdateInfo(ctx, member.UserInfo) } } - intent, extraUserID := portal.getIntentAndUserMXIDFor(ctx, member.EventSender, source, loginsInPortal, 0) + intent, extraUserID, err := portal.getIntentAndUserMXIDFor(ctx, member.EventSender, source, loginsInPortal, 0) + if err != nil { + return nil, nil, err + } if extraUserID != "" { invite = append(invite, extraUserID) if member.PowerLevel != nil { @@ -3083,7 +4494,46 @@ func (portal *Portal) updateOtherUser(ctx context.Context, members *ChatMemberLi return false } -func (portal *Portal) syncParticipants(ctx context.Context, members *ChatMemberList, source *UserLogin, sender MatrixAPI, ts time.Time) error { +func looksDirectlyJoinable(rule *event.JoinRulesEventContent) bool { + switch rule.JoinRule { + case event.JoinRulePublic: + return true + case event.JoinRuleKnockRestricted, event.JoinRuleRestricted: + for _, allow := range rule.Allow { + if allow.Type == "fi.mau.spam_checker" { + return true + } + } + } + return false +} + +func (portal *Portal) roomIsPublic(ctx context.Context) bool { + mx, ok := portal.Bridge.Matrix.(MatrixConnectorWithArbitraryRoomState) + if !ok { + return false + } + evt, err := mx.GetStateEvent(ctx, portal.MXID, event.StateJoinRules, "") + if err != nil { + zerolog.Ctx(ctx).Warn().Err(err).Msg("Failed to get join rules to check if room is public") + return false + } else if evt == nil { + return false + } + content, ok := evt.Content.Parsed.(*event.JoinRulesEventContent) + if !ok { + return false + } + return looksDirectlyJoinable(content) +} + +func (portal *Portal) syncParticipants( + ctx context.Context, + members *ChatMemberList, + source *UserLogin, + sender MatrixAPI, + ts time.Time, +) error { members.memberListToMap(ctx) var loginsInPortal []*UserLogin var err error @@ -3107,7 +4557,13 @@ func (portal *Portal) syncParticipants(ctx context.Context, members *ChatMemberL } delete(currentMembers, portal.Bridge.Bot.GetMXID()) powerChanged := members.PowerLevels.Apply(portal.Bridge.Bot.GetMXID(), currentPower) - syncUser := func(extraUserID id.UserID, member ChatMember, hasIntent bool) bool { + addExcludeFromTimeline := func(raw map[string]any) { + _, hasKey := raw["com.beeper.exclude_from_timeline"] + if !hasKey && members.ExcludeChangesFromTimeline { + raw["com.beeper.exclude_from_timeline"] = true + } + } + syncUser := func(extraUserID id.UserID, member ChatMember, intent MatrixAPI) bool { if member.Membership == "" { member.Membership = event.MembershipJoin } @@ -3136,58 +4592,74 @@ func (portal *Portal) syncParticipants(ctx context.Context, members *ChatMemberL Displayname: currentMember.Displayname, AvatarURL: currentMember.AvatarURL, } - wrappedContent := &event.Content{Parsed: content, Raw: make(map[string]any)} + wrappedContent := &event.Content{Parsed: content, Raw: exmaps.NonNilClone(member.MemberEventExtra)} + addExcludeFromTimeline(wrappedContent.Raw) thisEvtSender := sender - if member.Membership == event.MembershipJoin { + if member.Membership == event.MembershipJoin && (intent == nil || !portal.roomIsPublic(ctx)) { content.Membership = event.MembershipInvite - if hasIntent { + if intent != nil { wrappedContent.Raw["fi.mau.will_auto_accept"] = true } if thisEvtSender.GetMXID() == extraUserID { thisEvtSender = portal.Bridge.Bot } } + addLogContext := func(e *zerolog.Event) *zerolog.Event { + return e.Stringer("target_user_id", extraUserID). + Stringer("sender_user_id", thisEvtSender.GetMXID()). + Str("prev_membership", string(currentMember.Membership)) + } if currentMember != nil && currentMember.Membership == event.MembershipBan && member.Membership != event.MembershipLeave { unbanContent := *content unbanContent.Membership = event.MembershipLeave wrappedUnbanContent := &event.Content{Parsed: &unbanContent} _, err = portal.sendStateWithIntentOrBot(ctx, thisEvtSender, event.StateMember, extraUserID.String(), wrappedUnbanContent, ts) if err != nil { - log.Err(err). - Stringer("target_user_id", extraUserID). - Stringer("sender_user_id", thisEvtSender.GetMXID()). - Str("prev_membership", string(currentMember.Membership)). - Str("membership", string(member.Membership)). + addLogContext(log.Err(err)). + Str("new_membership", string(unbanContent.Membership)). Msg("Failed to unban user to update membership") } else { - log.Trace(). - Stringer("target_user_id", extraUserID). - Stringer("sender_user_id", thisEvtSender.GetMXID()). - Str("prev_membership", string(currentMember.Membership)). - Str("membership", string(member.Membership)). + addLogContext(log.Trace()). + Str("new_membership", string(unbanContent.Membership)). Msg("Unbanned user to update membership") + currentMember.Membership = event.MembershipLeave } } - _, err = portal.sendStateWithIntentOrBot(ctx, thisEvtSender, event.StateMember, extraUserID.String(), wrappedContent, ts) + if content.Membership == event.MembershipJoin && intent != nil && intent.GetMXID() == extraUserID { + _, err = intent.SendState(ctx, portal.MXID, event.StateMember, extraUserID.String(), wrappedContent, ts) + } else { + _, err = portal.sendStateWithIntentOrBot(ctx, thisEvtSender, event.StateMember, extraUserID.String(), wrappedContent, ts) + } if err != nil { - log.Err(err). - Stringer("target_user_id", extraUserID). - Stringer("sender_user_id", thisEvtSender.GetMXID()). - Str("prev_membership", string(currentMember.Membership)). - Str("membership", string(member.Membership)). + addLogContext(log.Err(err)). + Str("new_membership", string(content.Membership)). Msg("Failed to update user membership") } else { - log.Trace(). - Stringer("target_user_id", extraUserID). - Stringer("sender_user_id", thisEvtSender.GetMXID()). - Str("prev_membership", string(currentMember.Membership)). - Str("membership", string(member.Membership)). - Msg("Updating membership in room") + addLogContext(log.Trace()). + Str("new_membership", string(content.Membership)). + Msg("Updated membership in room") + currentMember.Membership = content.Membership + + if intent != nil && content.Membership == event.MembershipInvite && member.Membership == event.MembershipJoin { + content.Membership = event.MembershipJoin + wrappedJoinContent := &event.Content{Parsed: content, Raw: exmaps.NonNilClone(member.MemberEventExtra)} + addExcludeFromTimeline(wrappedContent.Raw) + _, err = intent.SendState(ctx, portal.MXID, event.StateMember, intent.GetMXID().String(), wrappedJoinContent, ts) + if err != nil { + addLogContext(log.Err(err)). + Str("new_membership", string(content.Membership)). + Msg("Failed to join with intent") + } else { + addLogContext(log.Trace()). + Str("new_membership", string(content.Membership)). + Msg("Joined room with intent") + } + } } return true } syncIntent := func(intent MatrixAPI, member ChatMember) { - if !syncUser(intent.GetMXID(), member, true) { + if !syncUser(intent.GetMXID(), member, intent) { return } if member.Membership == event.MembershipJoin || member.Membership == "" { @@ -3200,6 +4672,9 @@ func (portal *Portal) syncParticipants(ctx context.Context, members *ChatMemberL } } for _, member := range members.MemberMap { + if ctx.Err() != nil { + return ctx.Err() + } if member.Sender != "" && member.UserInfo != nil { ghost, err := portal.Bridge.GetGhostByID(ctx, member.Sender) if err != nil { @@ -3208,12 +4683,15 @@ func (portal *Portal) syncParticipants(ctx context.Context, members *ChatMemberL ghost.UpdateInfo(ctx, member.UserInfo) } } - intent, extraUserID := portal.getIntentAndUserMXIDFor(ctx, member.EventSender, source, loginsInPortal, 0) + intent, extraUserID, err := portal.getIntentAndUserMXIDFor(ctx, member.EventSender, source, loginsInPortal, 0) + if err != nil { + return err + } if intent != nil { syncIntent(intent, member) } if extraUserID != "" { - syncUser(extraUserID, member, false) + syncUser(extraUserID, member, nil) } } if powerChanged { @@ -3228,7 +4706,7 @@ func (portal *Portal) syncParticipants(ctx context.Context, members *ChatMemberL if memberEvt.Membership == event.MembershipLeave || memberEvt.Membership == event.MembershipBan { continue } - if !portal.Bridge.IsGhostMXID(extraMember) && portal.Relay != nil { + if !portal.Bridge.IsGhostMXID(extraMember) && (portal.Relay != nil || !portal.Bridge.Config.KickMatrixUsers) { continue } _, err = portal.Bridge.Bot.SendState(ctx, portal.MXID, event.StateMember, extraMember.String(), &event.Content{ @@ -3238,6 +4716,9 @@ func (portal *Portal) syncParticipants(ctx context.Context, members *ChatMemberL Displayname: memberEvt.Displayname, Reason: "User is not in remote chat", }, + Raw: map[string]any{ + "com.beeper.exclude_from_timeline": members.ExcludeChangesFromTimeline, + }, }, time.Now()) if err != nil { zerolog.Ctx(ctx).Err(err). @@ -3290,27 +4771,44 @@ func (portal *Portal) updateUserLocalInfo(ctx context.Context, info *UserLocalPo func DisappearingMessageNotice(expiration time.Duration, implicit bool) *event.MessageEventContent { formattedDuration := exfmt.DurationCustom(expiration, nil, exfmt.Day, time.Hour, time.Minute, time.Second) content := &event.MessageEventContent{ - MsgType: event.MsgNotice, - Body: fmt.Sprintf("Set the disappearing message timer to %s", formattedDuration), + MsgType: event.MsgNotice, + Body: fmt.Sprintf("Set the disappearing message timer to %s", formattedDuration), + Mentions: &event.Mentions{}, } - if implicit { + if expiration == 0 { + if implicit { + content.Body = "Automatically turned off disappearing messages because incoming message is not disappearing" + } else { + content.Body = "Turned off disappearing messages" + } + } else if implicit { content.Body = fmt.Sprintf("Automatically enabled disappearing message timer (%s) because incoming message is disappearing", formattedDuration) - } else if expiration == 0 { - content.Body = "Turned off disappearing messages" } return content } -func (portal *Portal) UpdateDisappearingSetting(ctx context.Context, setting database.DisappearingSetting, sender MatrixAPI, ts time.Time, implicit, save bool) bool { - if setting.Timer == 0 { - setting.Type = "" - } +type UpdateDisappearingSettingOpts struct { + Sender MatrixAPI + Timestamp time.Time + Implicit bool + Save bool + SendNotice bool + + ExcludeFromTimeline bool +} + +func (portal *Portal) UpdateDisappearingSetting( + ctx context.Context, + setting database.DisappearingSetting, + opts UpdateDisappearingSettingOpts, +) bool { + setting = setting.Normalize() if portal.Disappear.Timer == setting.Timer && portal.Disappear.Type == setting.Type { return false } portal.Disappear.Type = setting.Type portal.Disappear.Timer = setting.Timer - if save { + if opts.Save { err := portal.Save(ctx) if err != nil { zerolog.Ctx(ctx).Err(err).Msg("Failed to save portal to database after updating disappearing setting") @@ -3319,19 +4817,45 @@ func (portal *Portal) UpdateDisappearingSetting(ctx context.Context, setting dat if portal.MXID == "" { return true } - content := DisappearingMessageNotice(setting.Timer, implicit) - if sender == nil { - sender = portal.Bridge.Bot + + if opts.Sender == nil { + opts.Sender = portal.Bridge.Bot } - _, err := sender.SendMessage(ctx, portal.MXID, event.EventMessage, &event.Content{ + if opts.Timestamp.IsZero() { + opts.Timestamp = time.Now() + } + portal.sendRoomMeta( + ctx, + opts.Sender, + opts.Timestamp, + event.StateBeeperDisappearingTimer, + "", + setting.ToEventContent(), + opts.ExcludeFromTimeline, + nil, + ) + + if !opts.SendNotice { + return true + } + content := DisappearingMessageNotice(setting.Timer, opts.Implicit) + _, err := opts.Sender.SendMessage(ctx, portal.MXID, event.EventMessage, &event.Content{ Parsed: content, - }, &MatrixSendExtra{Timestamp: ts}) + Raw: map[string]any{ + "com.beeper.action_message": map[string]any{ + "type": "disappearing_timer", + "timer": setting.Timer.Milliseconds(), + "timer_type": setting.Type, + "implicit": opts.Implicit, + }, + }, + }, &MatrixSendExtra{Timestamp: opts.Timestamp}) if err != nil { zerolog.Ctx(ctx).Err(err).Msg("Failed to send disappearing messages notice") } else { zerolog.Ctx(ctx).Debug(). Dur("new_timer", portal.Disappear.Timer). - Bool("implicit", implicit). + Bool("implicit", opts.Implicit). Msg("Sent disappearing messages notice") } return true @@ -3393,13 +4917,13 @@ func (portal *Portal) UpdateInfoFromGhost(ctx context.Context, ghost *Ghost) (ch return } } - changed = portal.updateName(ctx, ghost.Name, nil, time.Time{}) || changed + changed = portal.updateName(ctx, ghost.Name, nil, time.Time{}, false) || changed changed = portal.updateAvatar(ctx, &Avatar{ ID: ghost.AvatarID, MXC: ghost.AvatarMXC, Hash: ghost.AvatarHash, Remove: ghost.AvatarID == "", - }, nil, time.Time{}) || changed + }, nil, time.Time{}, false) || changed return } @@ -3408,28 +4932,36 @@ func (portal *Portal) UpdateInfo(ctx context.Context, info *ChatInfo, source *Us if info.Name == DefaultChatName { if portal.NameIsCustom { portal.NameIsCustom = false - changed = portal.updateName(ctx, "", sender, ts) || changed + changed = portal.updateName(ctx, "", sender, ts, info.ExcludeChangesFromTimeline) || changed } } else if info.Name != nil { portal.NameIsCustom = true - changed = portal.updateName(ctx, *info.Name, sender, ts) || changed + changed = portal.updateName(ctx, *info.Name, sender, ts, info.ExcludeChangesFromTimeline) || changed } if info.Topic != nil { - changed = portal.updateTopic(ctx, *info.Topic, sender, ts) || changed + changed = portal.updateTopic(ctx, *info.Topic, sender, ts, info.ExcludeChangesFromTimeline) || changed } if info.Avatar != nil { portal.NameIsCustom = true - changed = portal.updateAvatar(ctx, info.Avatar, sender, ts) || changed + changed = portal.updateAvatar(ctx, info.Avatar, sender, ts, info.ExcludeChangesFromTimeline) || changed } if info.Disappear != nil { - changed = portal.UpdateDisappearingSetting(ctx, *info.Disappear, sender, ts, false, false) || changed + changed = portal.UpdateDisappearingSetting(ctx, *info.Disappear, UpdateDisappearingSettingOpts{ + Sender: sender, + Timestamp: ts, + Implicit: false, + Save: false, + + SendNotice: !info.ExcludeChangesFromTimeline, + ExcludeFromTimeline: info.ExcludeChangesFromTimeline, + }) || changed } if info.ParentID != nil { changed = portal.updateParent(ctx, *info.ParentID, source) || changed } if info.JoinRule != nil { // TODO change detection instead of spamming this every time? - portal.sendRoomMeta(ctx, sender, ts, event.StateJoinRules, "", info.JoinRule) + portal.sendRoomMeta(ctx, sender, ts, event.StateJoinRules, "", info.JoinRule, info.ExcludeChangesFromTimeline, nil) } if info.Type != nil && portal.RoomType != *info.Type { if portal.MXID != "" && (*info.Type == database.RoomTypeSpace || portal.RoomType == database.RoomTypeSpace) { @@ -3442,6 +4974,10 @@ func (portal *Portal) UpdateInfo(ctx context.Context, info *ChatInfo, source *Us portal.RoomType = *info.Type } } + if info.MessageRequest != nil && *info.MessageRequest != portal.MessageRequest { + changed = true + portal.MessageRequest = *info.MessageRequest + } if info.Members != nil && portal.MXID != "" && source != nil { err := portal.syncParticipants(ctx, info.Members, source, nil, time.Time{}) if err != nil { @@ -3455,6 +4991,7 @@ func (portal *Portal) UpdateInfo(ctx context.Context, info *ChatInfo, source *Us if source != nil { source.MarkInPortal(ctx, portal) portal.updateUserLocalInfo(ctx, info.UserLocal, source, false) + changed = portal.UpdateCapabilities(ctx, source, false) || changed } if info.CanBackfill && source != nil && portal.MXID != "" { err := portal.Bridge.DB.BackfillTask.EnsureExists(ctx, portal.PortalKey, source.ID) @@ -3482,9 +5019,12 @@ func (portal *Portal) CreateMatrixRoom(ctx context.Context, source *UserLogin, i } return nil } + if portal.deleted.IsSet() { + return ErrPortalIsDeleted + } waiter := make(chan struct{}) closed := false - portal.events <- &portalCreateEvent{ + evt := &portalCreateEvent{ ctx: ctx, source: source, info: info, @@ -3496,6 +5036,15 @@ func (portal *Portal) CreateMatrixRoom(ctx context.Context, source *UserLogin, i } }, } + if PortalEventBuffer == 0 { + go portal.queueEvent(ctx, evt) + } else { + select { + case portal.events <- evt: + case <-portal.deleted.GetChan(): + return ErrPortalIsDeleted + } + } select { case <-ctx.Done(): return ctx.Err() @@ -3505,7 +5054,11 @@ func (portal *Portal) CreateMatrixRoom(ctx context.Context, source *UserLogin, i } func (portal *Portal) createMatrixRoomInLoop(ctx context.Context, source *UserLogin, info *ChatInfo, backfillBundle any) error { + cancellableCtx, cancel := context.WithCancel(ctx) + defer cancel() + portal.cancelRoomCreate.CompareAndSwap(nil, &cancel) portal.roomCreateLock.Lock() + portal.cancelRoomCreate.Store(&cancel) defer portal.roomCreateLock.Unlock() if portal.MXID != "" { if source != nil { @@ -3516,6 +5069,7 @@ func (portal *Portal) createMatrixRoomInLoop(ctx context.Context, source *UserLo log := zerolog.Ctx(ctx).With(). Str("action", "create matrix room"). Logger() + cancellableCtx = log.WithContext(cancellableCtx) ctx = log.WithContext(ctx) log.Info().Msg("Creating Matrix room") @@ -3524,14 +5078,17 @@ func (portal *Portal) createMatrixRoomInLoop(ctx context.Context, source *UserLo if info != nil { log.Warn().Msg("CreateMatrixRoom got info without members. Refetching info") } - info, err = source.Client.GetChatInfo(ctx, portal) + info, err = source.Client.GetChatInfo(cancellableCtx, portal) if err != nil { log.Err(err).Msg("Failed to update portal info for creation") return err } } - portal.UpdateInfo(ctx, info, source, nil, time.Time{}) + portal.UpdateInfo(cancellableCtx, info, source, nil, time.Time{}) + if cancellableCtx.Err() != nil { + return cancellableCtx.Err() + } powerLevels := &event.PowerLevelsEventContent{ Events: map[string]int{ @@ -3543,7 +5100,7 @@ func (portal *Portal) createMatrixRoomInLoop(ctx context.Context, source *UserLo portal.Bridge.Bot.GetMXID(): 9001, }, } - initialMembers, extraFunctionalMembers, err := portal.getInitialMemberList(ctx, info.Members, source, powerLevels) + initialMembers, extraFunctionalMembers, err := portal.getInitialMemberList(cancellableCtx, info.Members, source, powerLevels) if err != nil { log.Err(err).Msg("Failed to process participant list for portal creation") return err @@ -3552,14 +5109,12 @@ func (portal *Portal) createMatrixRoomInLoop(ctx context.Context, source *UserLo req := mautrix.ReqCreateRoom{ Visibility: "private", - Name: portal.Name, - Topic: portal.Topic, CreationContent: make(map[string]any), InitialState: make([]*event.Event, 0, 6), Preset: "private_chat", IsDirect: portal.RoomType == database.RoomTypeDM, PowerLevelOverride: powerLevels, - BeeperLocalRoomID: id.RoomID(fmt.Sprintf("!%s.%s:%s", portal.ID, portal.Receiver, portal.Bridge.Matrix.ServerName())), + BeeperLocalRoomID: portal.Bridge.Matrix.GenerateDeterministicRoomID(portal.PortalKey), } autoJoinInvites := portal.Bridge.Matrix.GetCapabilities().AutoJoinInvites if autoJoinInvites { @@ -3572,6 +5127,11 @@ func (portal *Portal) createMatrixRoomInLoop(ctx context.Context, source *UserLo req.CreationContent["type"] = event.RoomTypeSpace } bridgeInfoStateKey, bridgeInfo := portal.getBridgeInfo() + roomFeatures := source.Client.GetCapabilities(cancellableCtx, portal) + portal.CapState = database.CapabilityState{ + Source: source.ID, + ID: roomFeatures.GetID(), + } req.InitialState = append(req.InitialState, &event.Event{ Type: event.StateElementFunctionalMembers, @@ -3586,19 +5146,51 @@ func (portal *Portal) createMatrixRoomInLoop(ctx context.Context, source *UserLo StateKey: &bridgeInfoStateKey, Type: event.StateBridge, Content: event.Content{Parsed: &bridgeInfo}, + }, &event.Event{ + StateKey: &bridgeInfoStateKey, + Type: event.StateBeeperRoomFeatures, + Content: event.Content{Parsed: roomFeatures}, + }, &event.Event{ + Type: event.StateTopic, + Content: event.Content{ + Parsed: &event.TopicEventContent{Topic: portal.Topic}, + Raw: map[string]any{ + "com.beeper.exclude_from_timeline": true, + }, + }, }) - if req.Topic == "" { - // Add explicit topic event if topic is empty to ensure the event is set. - // This ensures that there won't be an extra event later if PUT /state/... is called. + if roomFeatures.DisappearingTimer != nil { req.InitialState = append(req.InitialState, &event.Event{ - Type: event.StateTopic, - Content: event.Content{Parsed: &event.TopicEventContent{Topic: ""}}, + Type: event.StateBeeperDisappearingTimer, + Content: event.Content{ + Parsed: portal.Disappear.ToEventContent(), + Raw: map[string]any{ + "com.beeper.exclude_from_timeline": true, + }, + }, + }) + portal.CapState.Flags |= database.CapStateFlagDisappearingTimerSet + } + if portal.Name != "" { + req.InitialState = append(req.InitialState, &event.Event{ + Type: event.StateRoomName, + Content: event.Content{ + Parsed: &event.RoomNameEventContent{Name: portal.Name}, + Raw: map[string]any{ + "com.beeper.exclude_from_timeline": true, + }, + }, }) } if portal.AvatarMXC != "" { req.InitialState = append(req.InitialState, &event.Event{ - Type: event.StateRoomAvatar, - Content: event.Content{Parsed: &event.RoomAvatarEventContent{URL: portal.AvatarMXC}}, + Type: event.StateRoomAvatar, + Content: event.Content{ + Parsed: &event.RoomAvatarEventContent{URL: portal.AvatarMXC}, + Raw: map[string]any{ + "com.beeper.exclude_from_timeline": true, + }, + }, }) } if portal.Parent != nil && portal.Parent.MXID != "" { @@ -3617,6 +5209,9 @@ func (portal *Portal) createMatrixRoomInLoop(ctx context.Context, source *UserLo Content: event.Content{Parsed: info.JoinRule}, }) } + if cancellableCtx.Err() != nil { + return cancellableCtx.Err() + } roomID, err := portal.Bridge.Bot.CreateRoom(ctx, &req) if err != nil { log.Err(err).Msg("Failed to create Matrix room") @@ -3627,6 +5222,7 @@ func (portal *Portal) createMatrixRoomInLoop(ctx context.Context, source *UserLo portal.TopicSet = true portal.NameSet = true portal.MXID = roomID + portal.RoomCreated.Set() portal.Bridge.cacheLock.Lock() portal.Bridge.portalsByMXID[roomID] = portal portal.Bridge.cacheLock.Unlock() @@ -3647,7 +5243,7 @@ func (portal *Portal) createMatrixRoomInLoop(ctx context.Context, source *UserLo } portal.Bridge.WakeupBackfillQueue() } - withoutCancelCtx := context.WithoutCancel(ctx) + withoutCancelCtx := zerolog.Ctx(ctx).WithContext(portal.Bridge.BackgroundCtx) if portal.Parent != nil { if portal.Parent.MXID != "" { portal.addToParentSpaceAndSave(ctx, true) @@ -3673,42 +5269,55 @@ func (portal *Portal) createMatrixRoomInLoop(ctx context.Context, source *UserLo } } } - if portal.Parent == nil { - if portal.Receiver != "" { - login := portal.Bridge.GetCachedUserLoginByID(portal.Receiver) - if login != nil { - up, err := portal.Bridge.DB.UserPortal.Get(ctx, login.UserLogin, portal.PortalKey) - if err != nil { - log.Err(err).Msg("Failed to get user portal to add portal to spaces") - } else { - login.inPortalCache.Remove(portal.PortalKey) - go login.tryAddPortalToSpace(withoutCancelCtx, portal, up.CopyWithoutValues()) - } - } - } else { - userPortals, err := portal.Bridge.DB.UserPortal.GetAllInPortal(ctx, portal.PortalKey) - if err != nil { - log.Err(err).Msg("Failed to get user logins in portal to add portal to spaces") - } else { - for _, up := range userPortals { - login := portal.Bridge.GetCachedUserLoginByID(up.LoginID) - if login != nil { - login.inPortalCache.Remove(portal.PortalKey) - go login.tryAddPortalToSpace(withoutCancelCtx, portal, up.CopyWithoutValues()) - } - } - } - } - } - if portal.Bridge.Config.Backfill.Enabled && portal.RoomType != database.RoomTypeSpace { + portal.addToUserSpaces(ctx) + if info.CanBackfill && + portal.Bridge.Config.Backfill.Enabled && + portal.RoomType != database.RoomTypeSpace && + !portal.Bridge.Background { portal.doForwardBackfill(ctx, source, nil, backfillBundle) } return nil } +func (portal *Portal) addToUserSpaces(ctx context.Context) { + if portal.Parent != nil { + return + } + log := zerolog.Ctx(ctx) + withoutCancelCtx := log.WithContext(portal.Bridge.BackgroundCtx) + if portal.Receiver != "" { + login := portal.Bridge.GetCachedUserLoginByID(portal.Receiver) + if login != nil { + up, err := portal.Bridge.DB.UserPortal.GetOrCreate(ctx, login.UserLogin, portal.PortalKey) + if err != nil { + log.Err(err).Msg("Failed to get user portal to add portal to spaces") + } else { + login.inPortalCache.Remove(portal.PortalKey) + go login.tryAddPortalToSpace(withoutCancelCtx, portal, up.CopyWithoutValues()) + } + } + } else { + userPortals, err := portal.Bridge.DB.UserPortal.GetAllInPortal(ctx, portal.PortalKey) + if err != nil { + log.Err(err).Msg("Failed to get user logins in portal to add portal to spaces") + } else { + for _, up := range userPortals { + login := portal.Bridge.GetCachedUserLoginByID(up.LoginID) + if login != nil { + login.inPortalCache.Remove(portal.PortalKey) + go login.tryAddPortalToSpace(withoutCancelCtx, portal, up.CopyWithoutValues()) + } + } + } + } +} + func (portal *Portal) Delete(ctx context.Context) error { + if portal.deleted.IsSet() { + return nil + } portal.removeInPortalCache(ctx) - err := portal.Bridge.DB.Portal.Delete(ctx, portal.PortalKey) + err := portal.safeDBDelete(ctx) if err != nil { return err } @@ -3718,11 +5327,21 @@ func (portal *Portal) Delete(ctx context.Context) error { return nil } +func (portal *Portal) safeDBDelete(ctx context.Context) error { + err := portal.Bridge.DB.Message.DeleteInChunks(ctx, portal.PortalKey) + if err != nil { + return fmt.Errorf("failed to delete messages in portal: %w", err) + } + // TODO delete child portals? + return portal.Bridge.DB.Portal.Delete(ctx, portal.PortalKey) +} + func (portal *Portal) RemoveMXID(ctx context.Context) error { if portal.MXID == "" { return nil } portal.MXID = "" + portal.RoomCreated.Clear() err := portal.Save(ctx) if err != nil { return err @@ -3755,8 +5374,10 @@ func (portal *Portal) removeInPortalCache(ctx context.Context) { } func (portal *Portal) unlockedDelete(ctx context.Context) error { - // TODO delete child portals? - err := portal.Bridge.DB.Portal.Delete(ctx, portal.PortalKey) + if portal.deleted.IsSet() { + return nil + } + err := portal.safeDBDelete(ctx) if err != nil { return err } @@ -3765,10 +5386,18 @@ func (portal *Portal) unlockedDelete(ctx context.Context) error { } func (portal *Portal) unlockedDeleteCache() { + if portal.deleted.IsSet() { + return + } delete(portal.Bridge.portalsByKey, portal.PortalKey) if portal.MXID != "" { delete(portal.Bridge.portalsByMXID, portal.MXID) } + portal.deleted.Set() + if portal.events != nil { + // TODO there's a small risk of this racing with a queueEvent call + close(portal.events) + } } func (portal *Portal) Save(ctx context.Context) error { @@ -3776,6 +5405,9 @@ func (portal *Portal) Save(ctx context.Context) error { } func (portal *Portal) SetRelay(ctx context.Context, relay *UserLogin) error { + if portal.Receiver != "" && relay.ID != portal.Receiver { + return fmt.Errorf("can't set non-receiver login as relay") + } portal.Relay = relay if relay == nil { portal.RelayLoginID = "" @@ -3788,3 +5420,17 @@ func (portal *Portal) SetRelay(ctx context.Context, relay *UserLogin) error { } return nil } + +func (portal *Portal) PerMessageProfileForSender(ctx context.Context, sender networkid.UserID) (profile event.BeeperPerMessageProfile, err error) { + var ghost *Ghost + ghost, err = portal.Bridge.GetGhostByID(ctx, sender) + if err != nil { + return + } + profile.ID = string(ghost.Intent.GetMXID()) + profile.Displayname = ghost.Name + if ghost.AvatarMXC != "" { + profile.AvatarURL = &ghost.AvatarMXC + } + return +} diff --git a/bridgev2/portalbackfill.go b/bridgev2/portalbackfill.go index 55225efc..879f07ae 100644 --- a/bridgev2/portalbackfill.go +++ b/bridgev2/portalbackfill.go @@ -61,6 +61,9 @@ func (portal *Portal) doForwardBackfill(ctx context.Context, source *UserLogin, return } else if len(resp.Messages) == 0 { log.Debug().Msg("No messages to backfill") + if resp.CompleteCallback != nil { + resp.CompleteCallback() + } return } log.Debug(). @@ -191,6 +194,9 @@ func (portal *Portal) doThreadBackfill(ctx context.Context, source *UserLogin, t if err != nil { log.Err(err).Msg("Failed to get last thread message") return + } else if anchorMessage == nil { + log.Warn().Msg("No messages found in thread?") + return } resp := portal.fetchThreadBackfill(ctx, source, anchorMessage) if resp != nil { @@ -320,8 +326,13 @@ func (portal *Portal) compileBatchMessage(ctx context.Context, source *UserLogin if len(msg.Parts) == 0 { return } - intent := portal.GetIntentFor(ctx, msg.Sender, source, RemoteEventMessage) - replyTo, threadRoot, prevThreadEvent := portal.getRelationMeta(ctx, msg.ID, msg.ReplyTo, msg.ThreadRoot, true) + intent, ok := portal.GetIntentFor(ctx, msg.Sender, source, RemoteEventMessage) + if !ok { + return + } + replyTo, threadRoot, prevThreadEvent := portal.getRelationMeta( + ctx, msg.ID, msg.ConvertedMessage, true, + ) if threadRoot != nil && out.PrevThreadEvents[*msg.ThreadRoot] != "" { prevThreadEvent.MXID = out.PrevThreadEvents[*msg.ThreadRoot] } @@ -330,19 +341,21 @@ func (portal *Portal) compileBatchMessage(ctx context.Context, source *UserLogin var firstPart *database.Message for i, part := range msg.Parts { partIDs = append(partIDs, part.ID) - portal.applyRelationMeta(part.Content, replyTo, threadRoot, prevThreadEvent) + portal.applyRelationMeta(ctx, part.Content, replyTo, threadRoot, prevThreadEvent) + part.Content.BeeperDisappearingTimer = msg.Disappear.ToEventContent() evtID := portal.Bridge.Matrix.GenerateDeterministicEventID(portal.MXID, portal.PortalKey, msg.ID, part.ID) dbMessage := &database.Message{ - ID: msg.ID, - PartID: part.ID, - MXID: evtID, - Room: portal.PortalKey, - SenderID: msg.Sender.Sender, - SenderMXID: intent.GetMXID(), - Timestamp: msg.Timestamp, - ThreadRoot: ptr.Val(msg.ThreadRoot), - ReplyTo: ptr.Val(msg.ReplyTo), - Metadata: part.DBMetadata, + ID: msg.ID, + PartID: part.ID, + MXID: evtID, + Room: portal.PortalKey, + SenderID: msg.Sender.Sender, + SenderMXID: intent.GetMXID(), + Timestamp: msg.Timestamp, + ThreadRoot: ptr.Val(msg.ThreadRoot), + ReplyTo: ptr.Val(msg.ReplyTo), + Metadata: part.DBMetadata, + IsDoublePuppeted: intent.IsDoublePuppet(), } if part.DontBridge { dbMessage.SetFakeMXID() @@ -370,26 +383,34 @@ func (portal *Portal) compileBatchMessage(ctx context.Context, source *UserLogin prevThreadEvent.MXID = evtID out.PrevThreadEvents[*msg.ThreadRoot] = evtID } - if msg.Disappear.Type != database.DisappearingTypeNone { - if msg.Disappear.Type == database.DisappearingTypeAfterSend && msg.Disappear.DisappearAt.IsZero() { + if msg.Disappear.Type != event.DisappearingTypeNone { + if msg.Disappear.Type == event.DisappearingTypeAfterSend && msg.Disappear.DisappearAt.IsZero() { msg.Disappear.DisappearAt = msg.Timestamp.Add(msg.Disappear.Timer) } out.Disappear = append(out.Disappear, &database.DisappearingMessage{ RoomID: portal.MXID, EventID: evtID, + Timestamp: msg.Timestamp, DisappearingSetting: msg.Disappear, }) } } slices.Sort(partIDs) for _, reaction := range msg.Reactions { - reactionIntent := portal.GetIntentFor(ctx, reaction.Sender, source, RemoteEventReactionRemove) + if reaction == nil { + continue + } + reactionIntent, ok := portal.GetIntentFor(ctx, reaction.Sender, source, RemoteEventReactionRemove) + if !ok { + continue + } if reaction.TargetPart == nil { reaction.TargetPart = &partIDs[0] } if reaction.Timestamp.IsZero() { reaction.Timestamp = msg.Timestamp.Add(10 * time.Millisecond) } + //lint:ignore SA4006 it's a todo targetPart, ok := partMap[*reaction.TargetPart] if !ok { // TODO warning log and/or skip reaction? @@ -509,8 +530,11 @@ func (portal *Portal) sendBatch(ctx context.Context, source *UserLogin, messages func (portal *Portal) sendLegacyBackfill(ctx context.Context, source *UserLogin, messages []*BackfillMessage, markRead bool) { var lastPart id.EventID for _, msg := range messages { - intent := portal.GetIntentFor(ctx, msg.Sender, source, RemoteEventMessage) - dbMessages := portal.sendConvertedMessage(ctx, msg.ID, intent, msg.Sender.Sender, msg.ConvertedMessage, msg.Timestamp, msg.StreamOrder, func(z *zerolog.Event) *zerolog.Event { + intent, ok := portal.GetIntentFor(ctx, msg.Sender, source, RemoteEventMessage) + if !ok { + continue + } + dbMessages, _ := portal.sendConvertedMessage(ctx, msg.ID, intent, msg.Sender.Sender, msg.ConvertedMessage, msg.Timestamp, msg.StreamOrder, func(z *zerolog.Event) *zerolog.Event { return z. Str("message_id", string(msg.ID)). Any("sender_id", msg.Sender). @@ -519,7 +543,10 @@ func (portal *Portal) sendLegacyBackfill(ctx context.Context, source *UserLogin, if len(dbMessages) > 0 { lastPart = dbMessages[len(dbMessages)-1].MXID for _, reaction := range msg.Reactions { - reactionIntent := portal.GetIntentFor(ctx, reaction.Sender, source, RemoteEventReaction) + reactionIntent, ok := portal.GetIntentFor(ctx, reaction.Sender, source, RemoteEventReaction) + if !ok { + continue + } targetPart := dbMessages[0] if reaction.TargetPart != nil { targetPartIdx := slices.IndexFunc(dbMessages, func(dbMsg *database.Message) bool { diff --git a/bridgev2/portalinternal.go b/bridgev2/portalinternal.go index a5da077b..4c7e2447 100644 --- a/bridgev2/portalinternal.go +++ b/bridgev2/portalinternal.go @@ -29,26 +29,30 @@ func (portal *PortalInternals) UpdateLogger() { (*Portal)(portal).updateLogger() } -func (portal *PortalInternals) QueueEvent(ctx context.Context, evt portalEvent) { - (*Portal)(portal).queueEvent(ctx, evt) +func (portal *PortalInternals) QueueEvent(ctx context.Context, evt portalEvent) EventHandlingResult { + return (*Portal)(portal).queueEvent(ctx, evt) } func (portal *PortalInternals) EventLoop() { (*Portal)(portal).eventLoop() } -func (portal *PortalInternals) HandleSingleEventAsync(idx int, rawEvt any) { - (*Portal)(portal).handleSingleEventAsync(idx, rawEvt) +func (portal *PortalInternals) HandleSingleEventWithDelayLogging(idx int, rawEvt any) (outerRes EventHandlingResult) { + return (*Portal)(portal).handleSingleEventWithDelayLogging(idx, rawEvt) } func (portal *PortalInternals) GetEventCtxWithLog(rawEvt any, idx int) context.Context { return (*Portal)(portal).getEventCtxWithLog(rawEvt, idx) } -func (portal *PortalInternals) HandleSingleEvent(ctx context.Context, rawEvt any, doneCallback func()) { +func (portal *PortalInternals) HandleSingleEvent(ctx context.Context, rawEvt any, doneCallback func(EventHandlingResult)) { (*Portal)(portal).handleSingleEvent(ctx, rawEvt, doneCallback) } +func (portal *PortalInternals) UnwrapBeeperSendState(ctx context.Context, evt *event.Event) error { + return (*Portal)(portal).unwrapBeeperSendState(ctx, evt) +} + func (portal *PortalInternals) SendSuccessStatus(ctx context.Context, evt *event.Event, streamOrder int64, newEventID id.EventID) { (*Portal)(portal).sendSuccessStatus(ctx, evt, streamOrder, newEventID) } @@ -61,20 +65,24 @@ func (portal *PortalInternals) CheckConfusableName(ctx context.Context, userID i return (*Portal)(portal).checkConfusableName(ctx, userID, name) } -func (portal *PortalInternals) HandleMatrixEvent(ctx context.Context, sender *User, evt *event.Event) { - (*Portal)(portal).handleMatrixEvent(ctx, sender, evt) +func (portal *PortalInternals) HandleMatrixEvent(ctx context.Context, sender *User, evt *event.Event, isStateRequest bool) EventHandlingResult { + return (*Portal)(portal).handleMatrixEvent(ctx, sender, evt, isStateRequest) } -func (portal *PortalInternals) HandleMatrixReceipts(ctx context.Context, evt *event.Event) { - (*Portal)(portal).handleMatrixReceipts(ctx, evt) +func (portal *PortalInternals) HandleMatrixReceipts(ctx context.Context, evt *event.Event) EventHandlingResult { + return (*Portal)(portal).handleMatrixReceipts(ctx, evt) } func (portal *PortalInternals) HandleMatrixReadReceipt(ctx context.Context, user *User, eventID id.EventID, receipt event.ReadReceipt) { (*Portal)(portal).handleMatrixReadReceipt(ctx, user, eventID, receipt) } -func (portal *PortalInternals) HandleMatrixTyping(ctx context.Context, evt *event.Event) { - (*Portal)(portal).handleMatrixTyping(ctx, evt) +func (portal *PortalInternals) CallReadReceiptHandler(ctx context.Context, login *UserLogin, rrClient ReadReceiptHandlingNetworkAPI, evt *MatrixReadReceipt, userPortal *database.UserPortal) { + (*Portal)(portal).callReadReceiptHandler(ctx, login, rrClient, evt, userPortal) +} + +func (portal *PortalInternals) HandleMatrixTyping(ctx context.Context, evt *event.Event) EventHandlingResult { + return (*Portal)(portal).handleMatrixTyping(ctx, evt) } func (portal *PortalInternals) SendTypings(ctx context.Context, userIDs []id.UserID, typing bool) { @@ -85,55 +93,83 @@ func (portal *PortalInternals) PeriodicTypingUpdater() { (*Portal)(portal).periodicTypingUpdater() } -func (portal *PortalInternals) CheckMessageContentCaps(ctx context.Context, caps *NetworkRoomCapabilities, content *event.MessageEventContent, evt *event.Event) bool { - return (*Portal)(portal).checkMessageContentCaps(ctx, caps, content, evt) +func (portal *PortalInternals) CheckMessageContentCaps(caps *event.RoomFeatures, content *event.MessageEventContent) error { + return (*Portal)(portal).checkMessageContentCaps(caps, content) } -func (portal *PortalInternals) HandleMatrixMessage(ctx context.Context, sender *UserLogin, origSender *OrigSender, evt *event.Event) { - (*Portal)(portal).handleMatrixMessage(ctx, sender, origSender, evt) +func (portal *PortalInternals) ParseInputTransactionID(origSender *OrigSender, evt *event.Event) networkid.RawTransactionID { + return (*Portal)(portal).parseInputTransactionID(origSender, evt) } -func (portal *PortalInternals) HandleMatrixEdit(ctx context.Context, sender *UserLogin, origSender *OrigSender, evt *event.Event, content *event.MessageEventContent, caps *NetworkRoomCapabilities) { - (*Portal)(portal).handleMatrixEdit(ctx, sender, origSender, evt, content, caps) +func (portal *PortalInternals) HandleMatrixMessage(ctx context.Context, sender *UserLogin, origSender *OrigSender, evt *event.Event) EventHandlingResult { + return (*Portal)(portal).handleMatrixMessage(ctx, sender, origSender, evt) } -func (portal *PortalInternals) HandleMatrixReaction(ctx context.Context, sender *UserLogin, evt *event.Event) { - (*Portal)(portal).handleMatrixReaction(ctx, sender, evt) +func (portal *PortalInternals) PendingMessageTimeoutLoop(ctx context.Context, cfg *OutgoingTimeoutConfig) { + (*Portal)(portal).pendingMessageTimeoutLoop(ctx, cfg) +} + +func (portal *PortalInternals) CheckPendingMessages(ctx context.Context, cfg *OutgoingTimeoutConfig) { + (*Portal)(portal).checkPendingMessages(ctx, cfg) +} + +func (portal *PortalInternals) HandleMatrixEdit(ctx context.Context, sender *UserLogin, origSender *OrigSender, evt *event.Event, content *event.MessageEventContent, caps *event.RoomFeatures) EventHandlingResult { + return (*Portal)(portal).handleMatrixEdit(ctx, sender, origSender, evt, content, caps) +} + +func (portal *PortalInternals) HandleMatrixReaction(ctx context.Context, sender *UserLogin, evt *event.Event) EventHandlingResult { + return (*Portal)(portal).handleMatrixReaction(ctx, sender, evt) } func (portal *PortalInternals) GetTargetUser(ctx context.Context, userID id.UserID) (GhostOrUserLogin, error) { return (*Portal)(portal).getTargetUser(ctx, userID) } -func (portal *PortalInternals) HandleMatrixMembership(ctx context.Context, sender *UserLogin, origSender *OrigSender, evt *event.Event) { - (*Portal)(portal).handleMatrixMembership(ctx, sender, origSender, evt) +func (portal *PortalInternals) HandleMatrixDeleteChat(ctx context.Context, sender *UserLogin, origSender *OrigSender, evt *event.Event) EventHandlingResult { + return (*Portal)(portal).handleMatrixDeleteChat(ctx, sender, origSender, evt) } -func (portal *PortalInternals) HandleMatrixPowerLevels(ctx context.Context, sender *UserLogin, origSender *OrigSender, evt *event.Event) { - (*Portal)(portal).handleMatrixPowerLevels(ctx, sender, origSender, evt) +func (portal *PortalInternals) HandleMatrixMembership(ctx context.Context, sender *UserLogin, origSender *OrigSender, evt *event.Event, isStateRequest bool) EventHandlingResult { + return (*Portal)(portal).handleMatrixMembership(ctx, sender, origSender, evt, isStateRequest) } -func (portal *PortalInternals) HandleMatrixRedaction(ctx context.Context, sender *UserLogin, origSender *OrigSender, evt *event.Event) { - (*Portal)(portal).handleMatrixRedaction(ctx, sender, origSender, evt) +func (portal *PortalInternals) HandleMatrixPowerLevels(ctx context.Context, sender *UserLogin, origSender *OrigSender, evt *event.Event, isStateRequest bool) EventHandlingResult { + return (*Portal)(portal).handleMatrixPowerLevels(ctx, sender, origSender, evt, isStateRequest) } -func (portal *PortalInternals) HandleRemoteEvent(ctx context.Context, source *UserLogin, evtType RemoteEventType, evt RemoteEvent) { - (*Portal)(portal).handleRemoteEvent(ctx, source, evtType, evt) +func (portal *PortalInternals) HandleMatrixTombstone(ctx context.Context, evt *event.Event) EventHandlingResult { + return (*Portal)(portal).handleMatrixTombstone(ctx, evt) } -func (portal *PortalInternals) GetIntentAndUserMXIDFor(ctx context.Context, sender EventSender, source *UserLogin, otherLogins []*UserLogin, evtType RemoteEventType) (intent MatrixAPI, extraUserID id.UserID) { +func (portal *PortalInternals) UpdateInfoAfterTombstone(ctx context.Context, senderUser *User) { + (*Portal)(portal).updateInfoAfterTombstone(ctx, senderUser) +} + +func (portal *PortalInternals) HandleMatrixRedaction(ctx context.Context, sender *UserLogin, origSender *OrigSender, evt *event.Event) EventHandlingResult { + return (*Portal)(portal).handleMatrixRedaction(ctx, sender, origSender, evt) +} + +func (portal *PortalInternals) HandleRemoteEvent(ctx context.Context, source *UserLogin, evtType RemoteEventType, evt RemoteEvent) (res EventHandlingResult) { + return (*Portal)(portal).handleRemoteEvent(ctx, source, evtType, evt) +} + +func (portal *PortalInternals) EnsureFunctionalMember(ctx context.Context, ghost *Ghost) { + (*Portal)(portal).ensureFunctionalMember(ctx, ghost) +} + +func (portal *PortalInternals) GetIntentAndUserMXIDFor(ctx context.Context, sender EventSender, source *UserLogin, otherLogins []*UserLogin, evtType RemoteEventType) (intent MatrixAPI, extraUserID id.UserID, err error) { return (*Portal)(portal).getIntentAndUserMXIDFor(ctx, sender, source, otherLogins, evtType) } -func (portal *PortalInternals) GetRelationMeta(ctx context.Context, currentMsg networkid.MessageID, replyToPtr *networkid.MessageOptionalPartID, threadRootPtr *networkid.MessageID, isBatchSend bool) (replyTo, threadRoot, prevThreadEvent *database.Message) { - return (*Portal)(portal).getRelationMeta(ctx, currentMsg, replyToPtr, threadRootPtr, isBatchSend) +func (portal *PortalInternals) GetRelationMeta(ctx context.Context, currentMsgID networkid.MessageID, currentMsg *ConvertedMessage, isBatchSend bool) (replyTo, threadRoot, prevThreadEvent *database.Message) { + return (*Portal)(portal).getRelationMeta(ctx, currentMsgID, currentMsg, isBatchSend) } -func (portal *PortalInternals) ApplyRelationMeta(content *event.MessageEventContent, replyTo, threadRoot, prevThreadEvent *database.Message) { - (*Portal)(portal).applyRelationMeta(content, replyTo, threadRoot, prevThreadEvent) +func (portal *PortalInternals) ApplyRelationMeta(ctx context.Context, content *event.MessageEventContent, replyTo, threadRoot, prevThreadEvent *database.Message) { + (*Portal)(portal).applyRelationMeta(ctx, content, replyTo, threadRoot, prevThreadEvent) } -func (portal *PortalInternals) SendConvertedMessage(ctx context.Context, id networkid.MessageID, intent MatrixAPI, senderID networkid.UserID, converted *ConvertedMessage, ts time.Time, streamOrder int64, logContext func(*zerolog.Event) *zerolog.Event) []*database.Message { +func (portal *PortalInternals) SendConvertedMessage(ctx context.Context, id networkid.MessageID, intent MatrixAPI, senderID networkid.UserID, converted *ConvertedMessage, ts time.Time, streamOrder int64, logContext func(*zerolog.Event) *zerolog.Event) ([]*database.Message, EventHandlingResult) { return (*Portal)(portal).sendConvertedMessage(ctx, id, intent, senderID, converted, ts, streamOrder, logContext) } @@ -141,24 +177,24 @@ func (portal *PortalInternals) CheckPendingMessage(ctx context.Context, evt Remo return (*Portal)(portal).checkPendingMessage(ctx, evt) } -func (portal *PortalInternals) HandleRemoteUpsert(ctx context.Context, source *UserLogin, evt RemoteMessageUpsert, existing []*database.Message) bool { +func (portal *PortalInternals) HandleRemoteUpsert(ctx context.Context, source *UserLogin, evt RemoteMessageUpsert, existing []*database.Message) (handleRes EventHandlingResult, continueHandling bool) { return (*Portal)(portal).handleRemoteUpsert(ctx, source, evt, existing) } -func (portal *PortalInternals) HandleRemoteMessage(ctx context.Context, source *UserLogin, evt RemoteMessage) { - (*Portal)(portal).handleRemoteMessage(ctx, source, evt) +func (portal *PortalInternals) HandleRemoteMessage(ctx context.Context, source *UserLogin, evt RemoteMessage) (res EventHandlingResult) { + return (*Portal)(portal).handleRemoteMessage(ctx, source, evt) } func (portal *PortalInternals) SendRemoteErrorNotice(ctx context.Context, intent MatrixAPI, err error, ts time.Time, evtTypeName string) { (*Portal)(portal).sendRemoteErrorNotice(ctx, intent, err, ts, evtTypeName) } -func (portal *PortalInternals) HandleRemoteEdit(ctx context.Context, source *UserLogin, evt RemoteEdit) { - (*Portal)(portal).handleRemoteEdit(ctx, source, evt) +func (portal *PortalInternals) HandleRemoteEdit(ctx context.Context, source *UserLogin, evt RemoteEdit) EventHandlingResult { + return (*Portal)(portal).handleRemoteEdit(ctx, source, evt) } -func (portal *PortalInternals) SendConvertedEdit(ctx context.Context, targetID networkid.MessageID, senderID networkid.UserID, converted *ConvertedEdit, intent MatrixAPI, ts time.Time, streamOrder int64) { - (*Portal)(portal).sendConvertedEdit(ctx, targetID, senderID, converted, intent, ts, streamOrder) +func (portal *PortalInternals) SendConvertedEdit(ctx context.Context, targetID networkid.MessageID, senderID networkid.UserID, converted *ConvertedEdit, intent MatrixAPI, ts time.Time, streamOrder int64) EventHandlingResult { + return (*Portal)(portal).sendConvertedEdit(ctx, targetID, senderID, converted, intent, ts, streamOrder) } func (portal *PortalInternals) GetTargetMessagePart(ctx context.Context, evt RemoteEventWithTargetMessage) (*database.Message, error) { @@ -169,76 +205,84 @@ func (portal *PortalInternals) GetTargetReaction(ctx context.Context, evt Remote return (*Portal)(portal).getTargetReaction(ctx, evt) } -func (portal *PortalInternals) HandleRemoteReactionSync(ctx context.Context, source *UserLogin, evt RemoteReactionSync) { - (*Portal)(portal).handleRemoteReactionSync(ctx, source, evt) +func (portal *PortalInternals) HandleRemoteReactionSync(ctx context.Context, source *UserLogin, evt RemoteReactionSync) EventHandlingResult { + return (*Portal)(portal).handleRemoteReactionSync(ctx, source, evt) } -func (portal *PortalInternals) HandleRemoteReaction(ctx context.Context, source *UserLogin, evt RemoteReaction) { - (*Portal)(portal).handleRemoteReaction(ctx, source, evt) +func (portal *PortalInternals) HandleRemoteReaction(ctx context.Context, source *UserLogin, evt RemoteReaction) EventHandlingResult { + return (*Portal)(portal).handleRemoteReaction(ctx, source, evt) } -func (portal *PortalInternals) SendConvertedReaction(ctx context.Context, senderID networkid.UserID, intent MatrixAPI, targetMessage *database.Message, emojiID networkid.EmojiID, emoji string, ts time.Time, dbMetadata any, extraContent map[string]any, logContext func(*zerolog.Event) *zerolog.Event) { - (*Portal)(portal).sendConvertedReaction(ctx, senderID, intent, targetMessage, emojiID, emoji, ts, dbMetadata, extraContent, logContext) +func (portal *PortalInternals) SendConvertedReaction(ctx context.Context, senderID networkid.UserID, intent MatrixAPI, targetMessage *database.Message, emojiID networkid.EmojiID, emoji string, ts time.Time, dbMetadata any, extraContent map[string]any, logContext func(*zerolog.Event) *zerolog.Event) EventHandlingResult { + return (*Portal)(portal).sendConvertedReaction(ctx, senderID, intent, targetMessage, emojiID, emoji, ts, dbMetadata, extraContent, logContext) } func (portal *PortalInternals) GetIntentForMXID(ctx context.Context, userID id.UserID) (MatrixAPI, error) { return (*Portal)(portal).getIntentForMXID(ctx, userID) } -func (portal *PortalInternals) HandleRemoteReactionRemove(ctx context.Context, source *UserLogin, evt RemoteReactionRemove) { - (*Portal)(portal).handleRemoteReactionRemove(ctx, source, evt) +func (portal *PortalInternals) HandleRemoteReactionRemove(ctx context.Context, source *UserLogin, evt RemoteReactionRemove) EventHandlingResult { + return (*Portal)(portal).handleRemoteReactionRemove(ctx, source, evt) } -func (portal *PortalInternals) HandleRemoteMessageRemove(ctx context.Context, source *UserLogin, evt RemoteMessageRemove) { - (*Portal)(portal).handleRemoteMessageRemove(ctx, source, evt) +func (portal *PortalInternals) HandleRemoteMessageRemove(ctx context.Context, source *UserLogin, evt RemoteMessageRemove) EventHandlingResult { + return (*Portal)(portal).handleRemoteMessageRemove(ctx, source, evt) } -func (portal *PortalInternals) RedactMessageParts(ctx context.Context, parts []*database.Message, intent MatrixAPI, ts time.Time) { - (*Portal)(portal).redactMessageParts(ctx, parts, intent, ts) +func (portal *PortalInternals) RedactMessageParts(ctx context.Context, parts []*database.Message, intent MatrixAPI, ts time.Time) EventHandlingResult { + return (*Portal)(portal).redactMessageParts(ctx, parts, intent, ts) } -func (portal *PortalInternals) HandleRemoteReadReceipt(ctx context.Context, source *UserLogin, evt RemoteReadReceipt) { - (*Portal)(portal).handleRemoteReadReceipt(ctx, source, evt) +func (portal *PortalInternals) HandleRemoteReadReceipt(ctx context.Context, source *UserLogin, evt RemoteReadReceipt) EventHandlingResult { + return (*Portal)(portal).handleRemoteReadReceipt(ctx, source, evt) } -func (portal *PortalInternals) HandleRemoteMarkUnread(ctx context.Context, source *UserLogin, evt RemoteMarkUnread) { - (*Portal)(portal).handleRemoteMarkUnread(ctx, source, evt) +func (portal *PortalInternals) HandleRemoteMarkUnread(ctx context.Context, source *UserLogin, evt RemoteMarkUnread) EventHandlingResult { + return (*Portal)(portal).handleRemoteMarkUnread(ctx, source, evt) } -func (portal *PortalInternals) HandleRemoteDeliveryReceipt(ctx context.Context, source *UserLogin, evt RemoteDeliveryReceipt) { - (*Portal)(portal).handleRemoteDeliveryReceipt(ctx, source, evt) +func (portal *PortalInternals) HandleRemoteDeliveryReceipt(ctx context.Context, source *UserLogin, evt RemoteDeliveryReceipt) EventHandlingResult { + return (*Portal)(portal).handleRemoteDeliveryReceipt(ctx, source, evt) } -func (portal *PortalInternals) HandleRemoteTyping(ctx context.Context, source *UserLogin, evt RemoteTyping) { - (*Portal)(portal).handleRemoteTyping(ctx, source, evt) +func (portal *PortalInternals) HandleRemoteTyping(ctx context.Context, source *UserLogin, evt RemoteTyping) EventHandlingResult { + return (*Portal)(portal).handleRemoteTyping(ctx, source, evt) } -func (portal *PortalInternals) HandleRemoteChatInfoChange(ctx context.Context, source *UserLogin, evt RemoteChatInfoChange) { - (*Portal)(portal).handleRemoteChatInfoChange(ctx, source, evt) +func (portal *PortalInternals) HandleRemoteChatInfoChange(ctx context.Context, source *UserLogin, evt RemoteChatInfoChange) EventHandlingResult { + return (*Portal)(portal).handleRemoteChatInfoChange(ctx, source, evt) } -func (portal *PortalInternals) HandleRemoteChatResync(ctx context.Context, source *UserLogin, evt RemoteChatResync) { - (*Portal)(portal).handleRemoteChatResync(ctx, source, evt) +func (portal *PortalInternals) HandleRemoteChatResync(ctx context.Context, source *UserLogin, evt RemoteChatResync) EventHandlingResult { + return (*Portal)(portal).handleRemoteChatResync(ctx, source, evt) } -func (portal *PortalInternals) HandleRemoteChatDelete(ctx context.Context, source *UserLogin, evt RemoteChatDelete) { - (*Portal)(portal).handleRemoteChatDelete(ctx, source, evt) +func (portal *PortalInternals) FindOtherLogins(ctx context.Context, source *UserLogin) (ownUP *database.UserPortal, others []*database.UserPortal, err error) { + return (*Portal)(portal).findOtherLogins(ctx, source) } -func (portal *PortalInternals) HandleRemoteBackfill(ctx context.Context, source *UserLogin, backfill RemoteBackfill) { - (*Portal)(portal).handleRemoteBackfill(ctx, source, backfill) +func (portal *PortalInternals) HandleRemoteChatDelete(ctx context.Context, source *UserLogin, evt RemoteChatDelete) EventHandlingResult { + return (*Portal)(portal).handleRemoteChatDelete(ctx, source, evt) } -func (portal *PortalInternals) UpdateName(ctx context.Context, name string, sender MatrixAPI, ts time.Time) bool { - return (*Portal)(portal).updateName(ctx, name, sender, ts) +func (portal *PortalInternals) HandleRemoteBackfill(ctx context.Context, source *UserLogin, backfill RemoteBackfill) (res EventHandlingResult) { + return (*Portal)(portal).handleRemoteBackfill(ctx, source, backfill) } -func (portal *PortalInternals) UpdateTopic(ctx context.Context, topic string, sender MatrixAPI, ts time.Time) bool { - return (*Portal)(portal).updateTopic(ctx, topic, sender, ts) +func (portal *PortalInternals) UpdateName(ctx context.Context, name string, sender MatrixAPI, ts time.Time, excludeFromTimeline bool) bool { + return (*Portal)(portal).updateName(ctx, name, sender, ts, excludeFromTimeline) } -func (portal *PortalInternals) UpdateAvatar(ctx context.Context, avatar *Avatar, sender MatrixAPI, ts time.Time) bool { - return (*Portal)(portal).updateAvatar(ctx, avatar, sender, ts) +func (portal *PortalInternals) UpdateTopic(ctx context.Context, topic string, sender MatrixAPI, ts time.Time, excludeFromTimeline bool) bool { + return (*Portal)(portal).updateTopic(ctx, topic, sender, ts, excludeFromTimeline) +} + +func (portal *PortalInternals) UpdateAvatar(ctx context.Context, avatar *Avatar, sender MatrixAPI, ts time.Time, excludeFromTimeline bool) bool { + return (*Portal)(portal).updateAvatar(ctx, avatar, sender, ts, excludeFromTimeline) +} + +func (portal *PortalInternals) GetBridgeInfoStateKey() string { + return (*Portal)(portal).getBridgeInfoStateKey() } func (portal *PortalInternals) GetBridgeInfo() (string, event.BridgeEventContent) { @@ -249,8 +293,12 @@ func (portal *PortalInternals) SendStateWithIntentOrBot(ctx context.Context, sen return (*Portal)(portal).sendStateWithIntentOrBot(ctx, sender, eventType, stateKey, content, ts) } -func (portal *PortalInternals) SendRoomMeta(ctx context.Context, sender MatrixAPI, ts time.Time, eventType event.Type, stateKey string, content any) bool { - return (*Portal)(portal).sendRoomMeta(ctx, sender, ts, eventType, stateKey, content) +func (portal *PortalInternals) SendRoomMeta(ctx context.Context, sender MatrixAPI, ts time.Time, eventType event.Type, stateKey string, content any, excludeFromTimeline bool, extra map[string]any) bool { + return (*Portal)(portal).sendRoomMeta(ctx, sender, ts, eventType, stateKey, content, excludeFromTimeline, extra) +} + +func (portal *PortalInternals) RevertRoomMeta(ctx context.Context, evt *event.Event) { + (*Portal)(portal).revertRoomMeta(ctx, evt) } func (portal *PortalInternals) GetInitialMemberList(ctx context.Context, members *ChatMemberList, source *UserLogin, pl *event.PowerLevelsEventContent) (invite, functional []id.UserID, err error) { @@ -261,6 +309,10 @@ func (portal *PortalInternals) UpdateOtherUser(ctx context.Context, members *Cha return (*Portal)(portal).updateOtherUser(ctx, members) } +func (portal *PortalInternals) RoomIsPublic(ctx context.Context) bool { + return (*Portal)(portal).roomIsPublic(ctx) +} + func (portal *PortalInternals) SyncParticipants(ctx context.Context, members *ChatMemberList, source *UserLogin, sender MatrixAPI, ts time.Time) error { return (*Portal)(portal).syncParticipants(ctx, members, source, sender, ts) } @@ -281,6 +333,10 @@ func (portal *PortalInternals) CreateMatrixRoomInLoop(ctx context.Context, sourc return (*Portal)(portal).createMatrixRoomInLoop(ctx, source, info, backfillBundle) } +func (portal *PortalInternals) AddToUserSpaces(ctx context.Context) { + (*Portal)(portal).addToUserSpaces(ctx) +} + func (portal *PortalInternals) RemoveInPortalCache(ctx context.Context) { (*Portal)(portal).removeInPortalCache(ctx) } @@ -344,7 +400,3 @@ func (portal *PortalInternals) AddToParentSpaceAndSave(ctx context.Context, save func (portal *PortalInternals) ToggleSpace(ctx context.Context, spaceID id.RoomID, canonical, remove bool) error { return (*Portal)(portal).toggleSpace(ctx, spaceID, canonical, remove) } - -func (portal *PortalInternals) SetMXIDToExistingRoom(roomID id.RoomID) bool { - return (*Portal)(portal).setMXIDToExistingRoom(roomID) -} diff --git a/bridgev2/portalreid.go b/bridgev2/portalreid.go index a25fe820..c976d97c 100644 --- a/bridgev2/portalreid.go +++ b/bridgev2/portalreid.go @@ -32,21 +32,40 @@ func (br *Bridge) ReIDPortal(ctx context.Context, source, target networkid.Porta if source == target { return ReIDResultError, nil, fmt.Errorf("illegal re-ID call: source and target are the same") } - log := zerolog.Ctx(ctx) - log.Debug().Msg("Re-ID'ing portal") + log := zerolog.Ctx(ctx).With(). + Str("action", "re-id portal"). + Stringer("source_portal_key", source). + Stringer("target_portal_key", target). + Logger() + ctx = log.WithContext(ctx) defer func() { log.Debug().Msg("Finished handling portal re-ID") }() - br.cacheLock.Lock() - defer br.cacheLock.Unlock() - sourcePortal, err := br.UnlockedGetPortalByKey(ctx, source, true) + acquireCacheLock := func() { + if !br.cacheLock.TryLock() { + log.Debug().Msg("Waiting for global cache lock") + br.cacheLock.Lock() + log.Debug().Msg("Acquired global cache lock after waiting") + } else { + log.Trace().Msg("Acquired global cache lock without waiting") + } + } + log.Debug().Msg("Re-ID'ing portal") + sourcePortal, err := br.GetExistingPortalByKey(ctx, source) if err != nil { return ReIDResultError, nil, fmt.Errorf("failed to get source portal: %w", err) } else if sourcePortal == nil { log.Debug().Msg("Source portal not found, re-ID is no-op") return ReIDResultNoOp, nil, nil } - sourcePortal.roomCreateLock.Lock() + if !sourcePortal.roomCreateLock.TryLock() { + if cancelCreate := sourcePortal.cancelRoomCreate.Swap(nil); cancelCreate != nil { + (*cancelCreate)() + } + log.Debug().Msg("Waiting for source portal room creation lock") + sourcePortal.roomCreateLock.Lock() + log.Debug().Msg("Acquired source portal room creation lock after waiting") + } defer sourcePortal.roomCreateLock.Unlock() if sourcePortal.MXID == "" { log.Info().Msg("Source portal doesn't have Matrix room, deleting row") @@ -59,22 +78,37 @@ func (br *Bridge) ReIDPortal(ctx context.Context, source, target networkid.Porta log.UpdateContext(func(c zerolog.Context) zerolog.Context { return c.Stringer("source_portal_mxid", sourcePortal.MXID) }) + + acquireCacheLock() targetPortal, err := br.UnlockedGetPortalByKey(ctx, target, true) if err != nil { + br.cacheLock.Unlock() return ReIDResultError, nil, fmt.Errorf("failed to get target portal: %w", err) } if targetPortal == nil { log.Info().Msg("Target portal doesn't exist, re-ID'ing source portal") err = sourcePortal.unlockedReID(ctx, target) + br.cacheLock.Unlock() if err != nil { return ReIDResultError, nil, fmt.Errorf("failed to re-ID source portal: %w", err) } return ReIDResultSourceReIDd, sourcePortal, nil } - targetPortal.roomCreateLock.Lock() + br.cacheLock.Unlock() + + if !targetPortal.roomCreateLock.TryLock() { + if cancelCreate := targetPortal.cancelRoomCreate.Swap(nil); cancelCreate != nil { + (*cancelCreate)() + } + log.Debug().Msg("Waiting for target portal room creation lock") + targetPortal.roomCreateLock.Lock() + log.Debug().Msg("Acquired target portal room creation lock after waiting") + } defer targetPortal.roomCreateLock.Unlock() if targetPortal.MXID == "" { log.Info().Msg("Target portal row exists, but doesn't have a Matrix room. Deleting target portal row and re-ID'ing source portal") + acquireCacheLock() + defer br.cacheLock.Unlock() err = targetPortal.unlockedDelete(ctx) if err != nil { return ReIDResultError, nil, fmt.Errorf("failed to delete target portal: %w", err) @@ -89,6 +123,9 @@ func (br *Bridge) ReIDPortal(ctx context.Context, source, target networkid.Porta return c.Stringer("target_portal_mxid", targetPortal.MXID) }) log.Info().Msg("Both target and source portals have Matrix rooms, tombstoning source portal") + sourcePortal.removeInPortalCache(ctx) + acquireCacheLock() + defer br.cacheLock.Unlock() err = sourcePortal.unlockedDelete(ctx) if err != nil { return ReIDResultError, nil, fmt.Errorf("failed to delete source portal row: %w", err) @@ -96,7 +133,7 @@ func (br *Bridge) ReIDPortal(ctx context.Context, source, target networkid.Porta go func() { _, err := br.Bot.SendState(ctx, sourcePortal.MXID, event.StateTombstone, "", &event.Content{ Parsed: &event.TombstoneEventContent{ - Body: fmt.Sprintf("This room has been merged"), + Body: "This room has been merged", ReplacementRoom: targetPortal.MXID, }, }, time.Now()) diff --git a/bridgev2/provisionutil/creategroup.go b/bridgev2/provisionutil/creategroup.go new file mode 100644 index 00000000..72bacaff --- /dev/null +++ b/bridgev2/provisionutil/creategroup.go @@ -0,0 +1,149 @@ +// Copyright (c) 2025 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package provisionutil + +import ( + "context" + + "github.com/rs/zerolog" + "go.mau.fi/util/ptr" + + "maunium.net/go/mautrix" + "maunium.net/go/mautrix/bridgev2" + "maunium.net/go/mautrix/bridgev2/networkid" + "maunium.net/go/mautrix/event" + "maunium.net/go/mautrix/id" +) + +type RespCreateGroup struct { + ID networkid.PortalID `json:"id"` + MXID id.RoomID `json:"mxid"` + Portal *bridgev2.Portal `json:"-"` + + FailedParticipants map[networkid.UserID]*bridgev2.CreateChatFailedParticipant `json:"failed_participants,omitempty"` +} + +func CreateGroup(ctx context.Context, login *bridgev2.UserLogin, params *bridgev2.GroupCreateParams) (*RespCreateGroup, error) { + api, ok := login.Client.(bridgev2.GroupCreatingNetworkAPI) + if !ok { + return nil, bridgev2.RespError(mautrix.MUnrecognized.WithMessage("This bridge does not support creating groups")) + } + zerolog.Ctx(ctx).Debug(). + Any("create_params", params). + Msg("Creating group chat on remote network") + caps := login.Bridge.Network.GetCapabilities() + typeSpec, validType := caps.Provisioning.GroupCreation[params.Type] + if !validType { + return nil, bridgev2.RespError(mautrix.MUnrecognized.WithMessage("Unrecognized group type %s", params.Type)) + } + if len(params.Participants) < typeSpec.Participants.MinLength { + return nil, bridgev2.RespError(mautrix.MInvalidParam.WithMessage("Must have at least %d members", typeSpec.Participants.MinLength)) + } else if typeSpec.Participants.MaxLength > 0 && len(params.Participants) > typeSpec.Participants.MaxLength { + return nil, bridgev2.RespError(mautrix.MInvalidParam.WithMessage("Must have at most %d members", typeSpec.Participants.MaxLength)) + } + userIDValidatingNetwork, uidValOK := login.Bridge.Network.(bridgev2.IdentifierValidatingNetwork) + for i, participant := range params.Participants { + parsedParticipant, ok := login.Bridge.Matrix.ParseGhostMXID(id.UserID(participant)) + if ok { + participant = parsedParticipant + params.Participants[i] = participant + } + if !typeSpec.Participants.SkipIdentifierValidation { + if uidValOK && !userIDValidatingNetwork.ValidateUserID(participant) { + return nil, bridgev2.RespError(mautrix.MInvalidParam.WithMessage("User ID %q is not valid on this network", participant)) + } + } + if api.IsThisUser(ctx, participant) { + return nil, bridgev2.RespError(mautrix.MInvalidParam.WithMessage("You can't include yourself in the participants list", participant)) + } + } + if (params.Name == nil || params.Name.Name == "") && typeSpec.Name.Required { + return nil, bridgev2.RespError(mautrix.MInvalidParam.WithMessage("Name is required")) + } else if nameLen := len(ptr.Val(params.Name).Name); nameLen > 0 && nameLen < typeSpec.Name.MinLength { + return nil, bridgev2.RespError(mautrix.MInvalidParam.WithMessage("Name must be at least %d characters", typeSpec.Name.MinLength)) + } else if typeSpec.Name.MaxLength > 0 && nameLen > typeSpec.Name.MaxLength { + return nil, bridgev2.RespError(mautrix.MInvalidParam.WithMessage("Name must be at most %d characters", typeSpec.Name.MaxLength)) + } + if (params.Avatar == nil || params.Avatar.URL == "") && typeSpec.Avatar.Required { + return nil, bridgev2.RespError(mautrix.MInvalidParam.WithMessage("Avatar is required")) + } + if (params.Topic == nil || params.Topic.Topic == "") && typeSpec.Topic.Required { + return nil, bridgev2.RespError(mautrix.MInvalidParam.WithMessage("Topic is required")) + } else if topicLen := len(ptr.Val(params.Topic).Topic); topicLen > 0 && topicLen < typeSpec.Topic.MinLength { + return nil, bridgev2.RespError(mautrix.MInvalidParam.WithMessage("Topic must be at least %d characters", typeSpec.Topic.MinLength)) + } else if typeSpec.Topic.MaxLength > 0 && topicLen > typeSpec.Topic.MaxLength { + return nil, bridgev2.RespError(mautrix.MInvalidParam.WithMessage("Topic must be at most %d characters", typeSpec.Topic.MaxLength)) + } + if (params.Disappear == nil || params.Disappear.Timer.Duration == 0) && typeSpec.Disappear.Required { + return nil, bridgev2.RespError(mautrix.MInvalidParam.WithMessage("Disappearing timer is required")) + } else if !typeSpec.Disappear.DisappearSettings.Supports(params.Disappear) { + return nil, bridgev2.RespError(mautrix.MInvalidParam.WithMessage("Unsupported value for disappearing timer")) + } + if params.Username == "" && typeSpec.Username.Required { + return nil, bridgev2.RespError(mautrix.MInvalidParam.WithMessage("Username is required")) + } else if len(params.Username) > 0 && len(params.Username) < typeSpec.Username.MinLength { + return nil, bridgev2.RespError(mautrix.MInvalidParam.WithMessage("Username must be at least %d characters", typeSpec.Username.MinLength)) + } else if typeSpec.Username.MaxLength > 0 && len(params.Username) > typeSpec.Username.MaxLength { + return nil, bridgev2.RespError(mautrix.MInvalidParam.WithMessage("Username must be at most %d characters", typeSpec.Username.MaxLength)) + } + if params.Parent == nil && typeSpec.Parent.Required { + return nil, bridgev2.RespError(mautrix.MInvalidParam.WithMessage("Parent is required")) + } + resp, err := api.CreateGroup(ctx, params) + if err != nil { + zerolog.Ctx(ctx).Err(err).Msg("Failed to create group") + return nil, err + } + if resp.PortalKey.IsEmpty() { + return nil, ErrNoPortalKey + } + zerolog.Ctx(ctx).Debug(). + Object("portal_key", resp.PortalKey). + Msg("Successfully created group on remote network") + if resp.Portal == nil { + resp.Portal, err = login.Bridge.GetPortalByKey(ctx, resp.PortalKey) + if err != nil { + zerolog.Ctx(ctx).Err(err).Msg("Failed to get portal") + return nil, bridgev2.RespError(mautrix.MUnknown.WithMessage("Failed to get portal")) + } + } + if resp.Portal.MXID == "" { + err = resp.Portal.CreateMatrixRoom(ctx, login, resp.PortalInfo) + if err != nil { + zerolog.Ctx(ctx).Err(err).Msg("Failed to create portal room") + return nil, bridgev2.RespError(mautrix.MUnknown.WithMessage("Failed to create portal room")) + } + } + for key, fp := range resp.FailedParticipants { + if fp.InviteEventType == "" { + fp.InviteEventType = event.EventMessage.Type + } + if fp.UserMXID == "" { + ghost, err := login.Bridge.GetGhostByID(ctx, key) + if err != nil { + zerolog.Ctx(ctx).Err(err).Msg("Failed to get ghost for failed participant") + } else if ghost != nil { + fp.UserMXID = ghost.Intent.GetMXID() + } + } + if fp.DMRoomMXID == "" { + portal, err := login.Bridge.GetDMPortal(ctx, login.ID, key) + if err != nil { + zerolog.Ctx(ctx).Err(err).Msg("Failed to get DM portal for failed participant") + } else if portal != nil { + fp.DMRoomMXID = portal.MXID + } + } + } + return &RespCreateGroup{ + ID: resp.Portal.ID, + MXID: resp.Portal.MXID, + Portal: resp.Portal, + + FailedParticipants: resp.FailedParticipants, + }, nil +} diff --git a/bridgev2/provisionutil/listcontacts.go b/bridgev2/provisionutil/listcontacts.go new file mode 100644 index 00000000..ce163e67 --- /dev/null +++ b/bridgev2/provisionutil/listcontacts.go @@ -0,0 +1,98 @@ +// Copyright (c) 2025 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package provisionutil + +import ( + "context" + + "github.com/rs/zerolog" + + "maunium.net/go/mautrix" + "maunium.net/go/mautrix/bridgev2" +) + +type RespGetContactList struct { + Contacts []*RespResolveIdentifier `json:"contacts"` +} + +type RespSearchUsers struct { + Results []*RespResolveIdentifier `json:"results"` +} + +func GetContactList(ctx context.Context, login *bridgev2.UserLogin) (*RespGetContactList, error) { + api, ok := login.Client.(bridgev2.ContactListingNetworkAPI) + if !ok { + return nil, bridgev2.RespError(mautrix.MUnrecognized.WithMessage("This bridge does not support listing contacts")) + } + resp, err := api.GetContactList(ctx) + if err != nil { + zerolog.Ctx(ctx).Err(err).Msg("Failed to get contact list") + return nil, err + } + return &RespGetContactList{ + Contacts: processResolveIdentifiers(ctx, login.Bridge, resp, false), + }, nil +} + +func SearchUsers(ctx context.Context, login *bridgev2.UserLogin, query string) (*RespSearchUsers, error) { + api, ok := login.Client.(bridgev2.UserSearchingNetworkAPI) + if !ok { + return nil, bridgev2.RespError(mautrix.MUnrecognized.WithMessage("This bridge does not support searching for users")) + } + resp, err := api.SearchUsers(ctx, query) + if err != nil { + zerolog.Ctx(ctx).Err(err).Msg("Failed to get contact list") + return nil, err + } + return &RespSearchUsers{ + Results: processResolveIdentifiers(ctx, login.Bridge, resp, true), + }, nil +} + +func processResolveIdentifiers(ctx context.Context, br *bridgev2.Bridge, resp []*bridgev2.ResolveIdentifierResponse, syncInfo bool) (apiResp []*RespResolveIdentifier) { + apiResp = make([]*RespResolveIdentifier, len(resp)) + for i, contact := range resp { + apiContact := &RespResolveIdentifier{ + ID: contact.UserID, + } + apiResp[i] = apiContact + if contact.UserInfo != nil { + if contact.UserInfo.Name != nil { + apiContact.Name = *contact.UserInfo.Name + } + if contact.UserInfo.Identifiers != nil { + apiContact.Identifiers = contact.UserInfo.Identifiers + } + } + if contact.Ghost != nil { + if syncInfo && contact.UserInfo != nil { + contact.Ghost.UpdateInfo(ctx, contact.UserInfo) + } + if contact.Ghost.Name != "" { + apiContact.Name = contact.Ghost.Name + } + if len(contact.Ghost.Identifiers) >= len(apiContact.Identifiers) { + apiContact.Identifiers = contact.Ghost.Identifiers + } + apiContact.AvatarURL = contact.Ghost.AvatarMXC + apiContact.MXID = contact.Ghost.Intent.GetMXID() + } + if contact.Chat != nil { + if contact.Chat.Portal == nil { + var err error + contact.Chat.Portal, err = br.GetPortalByKey(ctx, contact.Chat.PortalKey) + if err != nil { + zerolog.Ctx(ctx).Err(err).Msg("Failed to get portal") + } + } + if contact.Chat.Portal != nil { + apiContact.DMRoomID = contact.Chat.Portal.MXID + } + } + } + return +} diff --git a/bridgev2/provisionutil/resolveidentifier.go b/bridgev2/provisionutil/resolveidentifier.go new file mode 100644 index 00000000..cfc388d0 --- /dev/null +++ b/bridgev2/provisionutil/resolveidentifier.go @@ -0,0 +1,125 @@ +// Copyright (c) 2025 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package provisionutil + +import ( + "context" + "errors" + + "github.com/rs/zerolog" + + "maunium.net/go/mautrix" + "maunium.net/go/mautrix/bridgev2" + "maunium.net/go/mautrix/bridgev2/networkid" + "maunium.net/go/mautrix/id" +) + +type RespResolveIdentifier struct { + ID networkid.UserID `json:"id"` + Name string `json:"name,omitempty"` + AvatarURL id.ContentURIString `json:"avatar_url,omitempty"` + Identifiers []string `json:"identifiers,omitempty"` + MXID id.UserID `json:"mxid,omitempty"` + DMRoomID id.RoomID `json:"dm_room_mxid,omitempty"` + + Portal *bridgev2.Portal `json:"-"` + Ghost *bridgev2.Ghost `json:"-"` + JustCreated bool `json:"-"` +} + +var ErrNoPortalKey = errors.New("network API didn't return portal key for createChat request") + +func ResolveIdentifier( + ctx context.Context, + login *bridgev2.UserLogin, + identifier string, + createChat bool, +) (*RespResolveIdentifier, error) { + api, ok := login.Client.(bridgev2.IdentifierResolvingNetworkAPI) + if !ok { + return nil, bridgev2.RespError(mautrix.MUnrecognized.WithMessage("This bridge does not support resolving identifiers")) + } + var resp *bridgev2.ResolveIdentifierResponse + parsedUserID, ok := login.Bridge.Matrix.ParseGhostMXID(id.UserID(identifier)) + validator, vOK := login.Bridge.Network.(bridgev2.IdentifierValidatingNetwork) + if ok && (!vOK || validator.ValidateUserID(parsedUserID)) { + ghost, err := login.Bridge.GetGhostByID(ctx, parsedUserID) + if err != nil { + zerolog.Ctx(ctx).Err(err).Msg("Failed to get ghost by ID") + return nil, err + } + resp = &bridgev2.ResolveIdentifierResponse{ + Ghost: ghost, + UserID: parsedUserID, + } + gdcAPI, ok := api.(bridgev2.GhostDMCreatingNetworkAPI) + if ok && createChat { + resp.Chat, err = gdcAPI.CreateChatWithGhost(ctx, ghost) + if err != nil { + zerolog.Ctx(ctx).Err(err).Msg("Failed to create chat") + return nil, err + } + } else if createChat || ghost.Name == "" { + zerolog.Ctx(ctx).Debug(). + Bool("create_chat", createChat). + Bool("has_name", ghost.Name != ""). + Msg("Falling back to resolving identifier") + resp = nil + identifier = string(parsedUserID) + } + } + if resp == nil { + var err error + resp, err = api.ResolveIdentifier(ctx, identifier, createChat) + if err != nil { + zerolog.Ctx(ctx).Err(err).Msg("Failed to resolve identifier") + return nil, err + } else if resp == nil { + return nil, nil + } + } + apiResp := &RespResolveIdentifier{ + ID: resp.UserID, + Ghost: resp.Ghost, + } + if resp.Ghost != nil { + if resp.UserInfo != nil { + resp.Ghost.UpdateInfo(ctx, resp.UserInfo) + } + apiResp.Name = resp.Ghost.Name + apiResp.AvatarURL = resp.Ghost.AvatarMXC + apiResp.Identifiers = resp.Ghost.Identifiers + apiResp.MXID = resp.Ghost.Intent.GetMXID() + } else if resp.UserInfo != nil && resp.UserInfo.Name != nil { + apiResp.Name = *resp.UserInfo.Name + } + if resp.Chat != nil { + if resp.Chat.PortalKey.IsEmpty() { + return nil, ErrNoPortalKey + } + if resp.Chat.Portal == nil { + var err error + resp.Chat.Portal, err = login.Bridge.GetPortalByKey(ctx, resp.Chat.PortalKey) + if err != nil { + zerolog.Ctx(ctx).Err(err).Msg("Failed to get portal") + return nil, bridgev2.RespError(mautrix.MUnknown.WithMessage("Failed to get portal")) + } + } + resp.Chat.Portal.CleanupOrphanedDM(ctx, login.UserMXID) + if createChat && resp.Chat.Portal.MXID == "" { + apiResp.JustCreated = true + err := resp.Chat.Portal.CreateMatrixRoom(ctx, login, resp.Chat.PortalInfo) + if err != nil { + zerolog.Ctx(ctx).Err(err).Msg("Failed to create portal room") + return nil, bridgev2.RespError(mautrix.MUnknown.WithMessage("Failed to create portal room")) + } + } + apiResp.Portal = resp.Chat.Portal + apiResp.DMRoomID = resp.Chat.Portal.MXID + } + return apiResp, nil +} diff --git a/bridgev2/queue.go b/bridgev2/queue.go index 38895953..3775c825 100644 --- a/bridgev2/queue.go +++ b/bridgev2/queue.go @@ -63,7 +63,14 @@ func (br *Bridge) rejectInviteOnNoPermission(ctx context.Context, evt *event.Eve return true } -func (br *Bridge) QueueMatrixEvent(ctx context.Context, evt *event.Event) { +var ( + ErrEventSenderUserNotFound = WrapErrorInStatus(errors.New("sender not found for event")).WithIsCertain(true).WithErrorAsMessage() + ErrNoPermissionToInteract = WrapErrorInStatus(errors.New("you don't have permission to send messages")).WithIsCertain(true).WithSendNotice(false).WithErrorAsMessage() + ErrNoPermissionForCommands = WrapErrorInStatus(WrapErrorInStatus(errors.New("you don't have permission to use commands")).WithIsCertain(true).WithSendNotice(false).WithErrorAsMessage()) + ErrCantRelayStateRequest = WrapErrorInStatus(errors.New("relayed users can't use beeper state requests")).WithIsCertain(true).WithErrorAsMessage() +) + +func (br *Bridge) QueueMatrixEvent(ctx context.Context, evt *event.Event) EventHandlingResult { // TODO maybe HandleMatrixEvent would be more appropriate as this also handles bot invites and commands log := zerolog.Ctx(ctx) @@ -75,37 +82,34 @@ func (br *Bridge) QueueMatrixEvent(ctx context.Context, evt *event.Event) { log.Err(err).Msg("Failed to get sender user for incoming Matrix event") status := WrapErrorInStatus(fmt.Errorf("%w: failed to get sender user: %w", ErrDatabaseError, err)) br.Matrix.SendMessageStatus(ctx, &status, StatusEventInfoFromEvent(evt)) - return + return EventHandlingResultFailed } else if sender == nil { log.Error().Msg("Couldn't get sender for incoming non-ephemeral Matrix event") - status := WrapErrorInStatus(errors.New("sender not found for event")).WithIsCertain(true).WithErrorAsMessage() - br.Matrix.SendMessageStatus(ctx, &status, StatusEventInfoFromEvent(evt)) - return + br.Matrix.SendMessageStatus(ctx, &ErrEventSenderUserNotFound, StatusEventInfoFromEvent(evt)) + return EventHandlingResultFailed } else if !sender.Permissions.SendEvents { if !br.rejectInviteOnNoPermission(ctx, evt, "interact with") { - status := WrapErrorInStatus(errors.New("you don't have permission to send messages")).WithIsCertain(true).WithSendNotice(false).WithErrorAsMessage() - br.Matrix.SendMessageStatus(ctx, &status, StatusEventInfoFromEvent(evt)) + br.Matrix.SendMessageStatus(ctx, &ErrNoPermissionToInteract, StatusEventInfoFromEvent(evt)) } - return + return EventHandlingResultIgnored } else if !sender.Permissions.Commands && br.rejectInviteOnNoPermission(ctx, evt, "send commands to") { - return + return EventHandlingResultIgnored } } else if evt.Type.Class != event.EphemeralEventType { log.Error().Msg("Missing sender for incoming non-ephemeral Matrix event") - status := WrapErrorInStatus(errors.New("sender not found for event")).WithIsCertain(true).WithErrorAsMessage() - br.Matrix.SendMessageStatus(ctx, &status, StatusEventInfoFromEvent(evt)) - return + br.Matrix.SendMessageStatus(ctx, &ErrEventSenderUserNotFound, StatusEventInfoFromEvent(evt)) + return EventHandlingResultIgnored } if evt.Type == event.EventMessage && sender != nil { msg := evt.Content.AsMessage() msg.RemoveReplyFallback() + msg.RemovePerMessageProfileFallback() if strings.HasPrefix(msg.Body, br.Config.CommandPrefix) || evt.RoomID == sender.ManagementRoom { if !sender.Permissions.Commands { - status := WrapErrorInStatus(errors.New("you don't have permission to use commands")).WithIsCertain(true).WithSendNotice(false).WithErrorAsMessage() - br.Matrix.SendMessageStatus(ctx, &status, StatusEventInfoFromEvent(evt)) - return + br.Matrix.SendMessageStatus(ctx, &ErrNoPermissionForCommands, StatusEventInfoFromEvent(evt)) + return EventHandlingResultIgnored } - br.Commands.Handle( + go br.Commands.Handle( ctx, evt.RoomID, evt.ID, @@ -113,39 +117,112 @@ func (br *Bridge) QueueMatrixEvent(ctx context.Context, evt *event.Event) { strings.TrimPrefix(msg.Body, br.Config.CommandPrefix+" "), msg.RelatesTo.GetReplyTo(), ) - return + return EventHandlingResultQueued } } if evt.Type == event.StateMember && evt.GetStateKey() == br.Bot.GetMXID().String() && evt.Content.AsMember().Membership == event.MembershipInvite && sender != nil { - br.handleBotInvite(ctx, evt, sender) - return + return br.handleBotInvite(ctx, evt, sender) + } else if sender != nil && evt.RoomID == sender.ManagementRoom { + if evt.Type == event.StateMember && evt.Content.AsMember().Membership == event.MembershipLeave && (evt.GetStateKey() == br.Bot.GetMXID().String() || evt.GetStateKey() == sender.MXID.String()) { + sender.ManagementRoom = "" + err := br.DB.User.Update(ctx, sender.User) + if err != nil { + log.Err(err).Msg("Failed to clear user's management room in database") + return EventHandlingResultFailed + } else { + log.Debug().Msg("Cleared user's management room due to leave event") + } + } + return EventHandlingResultSuccess } portal, err := br.GetPortalByMXID(ctx, evt.RoomID) if err != nil { log.Err(err).Msg("Failed to get portal for incoming Matrix event") status := WrapErrorInStatus(fmt.Errorf("%w: failed to get portal: %w", ErrDatabaseError, err)) br.Matrix.SendMessageStatus(ctx, &status, StatusEventInfoFromEvent(evt)) - return + return EventHandlingResultFailed } else if portal != nil { - portal.queueEvent(ctx, &portalMatrixEvent{ + return portal.queueEvent(ctx, &portalMatrixEvent{ evt: evt, sender: sender, }) } else if evt.Type == event.StateMember && br.IsGhostMXID(id.UserID(evt.GetStateKey())) && evt.Content.AsMember().Membership == event.MembershipInvite && evt.Content.AsMember().IsDirect { - br.handleGhostDMInvite(ctx, evt, sender) + return br.handleGhostDMInvite(ctx, evt, sender) } else { status := WrapErrorInStatus(ErrNoPortal) br.Matrix.SendMessageStatus(ctx, &status, StatusEventInfoFromEvent(evt)) + return EventHandlingResultIgnored } } -func (ul *UserLogin) QueueRemoteEvent(evt RemoteEvent) { - ul.Bridge.QueueRemoteEvent(ul, evt) +type EventHandlingResult struct { + Success bool + Ignored bool + Queued bool + + SkipStateEcho bool + + // Error is an optional reason for failure. It is not required, Success may be false even without a specific error. + Error error + // Whether the Error should be sent as a MSS event. + SendMSS bool + + // EventID from the network + EventID id.EventID + // Stream order from the network + StreamOrder int64 } -func (br *Bridge) QueueRemoteEvent(login *UserLogin, evt RemoteEvent) { +func (ehr EventHandlingResult) WithEventID(id id.EventID) EventHandlingResult { + ehr.EventID = id + return ehr +} + +func (ehr EventHandlingResult) WithStreamOrder(order int64) EventHandlingResult { + ehr.StreamOrder = order + return ehr +} + +func (ehr EventHandlingResult) WithError(err error) EventHandlingResult { + if err == nil { + return ehr + } + ehr.Error = err + ehr.Success = false + return ehr +} + +func (ehr EventHandlingResult) WithMSS() EventHandlingResult { + ehr.SendMSS = true + return ehr +} + +func (ehr EventHandlingResult) WithSkipStateEcho(skip bool) EventHandlingResult { + ehr.SkipStateEcho = skip + return ehr +} + +func (ehr EventHandlingResult) WithMSSError(err error) EventHandlingResult { + if err == nil { + return ehr + } + return ehr.WithError(err).WithMSS() +} + +var ( + EventHandlingResultFailed = EventHandlingResult{} + EventHandlingResultQueued = EventHandlingResult{Success: true, Queued: true} + EventHandlingResultSuccess = EventHandlingResult{Success: true} + EventHandlingResultIgnored = EventHandlingResult{Success: true, Ignored: true} +) + +func (ul *UserLogin) QueueRemoteEvent(evt RemoteEvent) EventHandlingResult { + return ul.Bridge.QueueRemoteEvent(ul, evt) +} + +func (br *Bridge) QueueRemoteEvent(login *UserLogin, evt RemoteEvent) EventHandlingResult { log := login.Log - ctx := log.WithContext(context.TODO()) + ctx := log.WithContext(br.BackgroundCtx) maybeUncertain, ok := evt.(RemoteEventWithUncertainPortalReceiver) isUncertain := ok && maybeUncertain.PortalReceiverIsUncertain() key := evt.GetPortalKey() @@ -159,18 +236,18 @@ func (br *Bridge) QueueRemoteEvent(login *UserLogin, evt RemoteEvent) { if err != nil { log.Err(err).Object("portal_key", key).Bool("uncertain_receiver", isUncertain). Msg("Failed to get portal to handle remote event") - return + return EventHandlingResultFailed.WithError(fmt.Errorf("failed to get portal: %w", err)) } else if portal == nil { log.Warn(). Stringer("event_type", evt.GetType()). Object("portal_key", key). Bool("uncertain_receiver", isUncertain). Msg("Portal not found to handle remote event") - return + return EventHandlingResultFailed.WithError(ErrPortalNotFoundInEventHandler) } // TODO put this in a better place, and maybe cache to avoid constant db queries login.MarkInPortal(ctx, portal) - portal.queueEvent(ctx, &portalRemoteEvent{ + return portal.queueEvent(ctx, &portalRemoteEvent{ evt: evt, source: login, }) diff --git a/bridgev2/simplevent/chat.go b/bridgev2/simplevent/chat.go index c725141b..56e3a6b1 100644 --- a/bridgev2/simplevent/chat.go +++ b/bridgev2/simplevent/chat.go @@ -65,14 +65,19 @@ func (evt *ChatResync) GetChatInfo(ctx context.Context, portal *bridgev2.Portal) type ChatDelete struct { EventMeta OnlyForMe bool + Children bool } -var _ bridgev2.RemoteChatDelete = (*ChatDelete)(nil) +var _ bridgev2.RemoteChatDeleteWithChildren = (*ChatDelete)(nil) func (evt *ChatDelete) DeleteOnlyForMe() bool { return evt.OnlyForMe } +func (evt *ChatDelete) DeleteChildren() bool { + return evt.Children +} + // ChatInfoChange is a simple implementation of [bridgev2.RemoteChatInfoChange]. type ChatInfoChange struct { EventMeta diff --git a/bridgev2/simplevent/message.go b/bridgev2/simplevent/message.go index f648ab12..f8f8d7e1 100644 --- a/bridgev2/simplevent/message.go +++ b/bridgev2/simplevent/message.go @@ -59,6 +59,41 @@ func (evt *Message[T]) GetTransactionID() networkid.TransactionID { return evt.TransactionID } +// PreConvertedMessage is a simple implementation of [bridgev2.RemoteMessage] with pre-converted data. +type PreConvertedMessage struct { + EventMeta + Data *bridgev2.ConvertedMessage + ID networkid.MessageID + TransactionID networkid.TransactionID + + HandleExistingFunc func(ctx context.Context, portal *bridgev2.Portal, intent bridgev2.MatrixAPI, existing []*database.Message) (bridgev2.UpsertResult, error) +} + +var ( + _ bridgev2.RemoteMessage = (*PreConvertedMessage)(nil) + _ bridgev2.RemoteMessageUpsert = (*PreConvertedMessage)(nil) + _ bridgev2.RemoteMessageWithTransactionID = (*PreConvertedMessage)(nil) +) + +func (evt *PreConvertedMessage) ConvertMessage(ctx context.Context, portal *bridgev2.Portal, intent bridgev2.MatrixAPI) (*bridgev2.ConvertedMessage, error) { + return evt.Data, nil +} + +func (evt *PreConvertedMessage) GetID() networkid.MessageID { + return evt.ID +} + +func (evt *PreConvertedMessage) GetTransactionID() networkid.TransactionID { + return evt.TransactionID +} + +func (evt *PreConvertedMessage) HandleExisting(ctx context.Context, portal *bridgev2.Portal, intent bridgev2.MatrixAPI, existing []*database.Message) (bridgev2.UpsertResult, error) { + if evt.HandleExistingFunc == nil { + return bridgev2.UpsertResult{}, nil + } + return evt.HandleExistingFunc(ctx, portal, intent, existing) +} + type MessageRemove struct { EventMeta diff --git a/bridgev2/simplevent/meta.go b/bridgev2/simplevent/meta.go index 8aa91866..449a8773 100644 --- a/bridgev2/simplevent/meta.go +++ b/bridgev2/simplevent/meta.go @@ -101,6 +101,18 @@ func (evt EventMeta) WithLogContext(f func(c zerolog.Context) zerolog.Context) E return evt } +func (evt EventMeta) WithMoreLogContext(f func(c zerolog.Context) zerolog.Context) EventMeta { + origFunc := evt.LogContext + if origFunc == nil { + evt.LogContext = f + return evt + } + evt.LogContext = func(c zerolog.Context) zerolog.Context { + return f(origFunc(c)) + } + return evt +} + func (evt EventMeta) WithPortalKey(p networkid.PortalKey) EventMeta { evt.PortalKey = p return evt diff --git a/bridgev2/simplevent/receipt.go b/bridgev2/simplevent/receipt.go index 3565986b..41614e40 100644 --- a/bridgev2/simplevent/receipt.go +++ b/bridgev2/simplevent/receipt.go @@ -19,6 +19,8 @@ type Receipt struct { LastTarget networkid.MessageID Targets []networkid.MessageID ReadUpTo time.Time + + ReadUpToStreamOrder int64 } var ( @@ -38,6 +40,10 @@ func (evt *Receipt) GetReadUpTo() time.Time { return evt.ReadUpTo } +func (evt *Receipt) GetReadUpToStreamOrder() int64 { + return evt.ReadUpToStreamOrder +} + type MarkUnread struct { EventMeta Unread bool diff --git a/bridgev2/space.go b/bridgev2/space.go index 17388f3e..2ca2bce3 100644 --- a/bridgev2/space.go +++ b/bridgev2/space.go @@ -43,7 +43,7 @@ func (ul *UserLogin) MarkInPortal(ctx context.Context, portal *Portal) { } } if ul.Bridge.Config.PersonalFilteringSpaces && (userPortal.InSpace == nil || !*userPortal.InSpace) { - go ul.tryAddPortalToSpace(ctx, portal, userPortal.CopyWithoutValues()) + go ul.tryAddPortalToSpace(context.WithoutCancel(ctx), portal, userPortal.CopyWithoutValues()) } } } @@ -171,6 +171,10 @@ func (ul *UserLogin) GetSpaceRoom(ctx context.Context) (id.RoomID, error) { // TODO remove this after initial_members is supported in hungryserv req.BeeperAutoJoinInvites = true } + pfc, ok := ul.Client.(PersonalFilteringCustomizingNetworkAPI) + if ok { + pfc.CustomizePersonalFilteringSpace(req) + } ul.SpaceRoom, err = ul.Bridge.Bot.CreateRoom(ctx, req) if err != nil { return "", fmt.Errorf("failed to create space room: %w", err) diff --git a/bridge/status/bridgestate.go b/bridgev2/status/bridgestate.go similarity index 75% rename from bridge/status/bridgestate.go rename to bridgev2/status/bridgestate.go index 1aa4bb1f..5925dd4f 100644 --- a/bridge/status/bridgestate.go +++ b/bridgev2/status/bridgestate.go @@ -12,16 +12,17 @@ import ( "encoding/json" "fmt" "io" + "maps" "net/http" "reflect" "time" "github.com/tidwall/sjson" "go.mau.fi/util/jsontime" - "go.mau.fi/util/ptr" - "golang.org/x/exp/maps" "maunium.net/go/mautrix" + "maunium.net/go/mautrix/bridgev2/networkid" + "maunium.net/go/mautrix/event" "maunium.net/go/mautrix/id" ) @@ -53,12 +54,42 @@ const ( StateLoggedOut BridgeStateEvent = "LOGGED_OUT" ) +func (e BridgeStateEvent) IsValid() bool { + switch e { + case + StateStarting, + StateUnconfigured, + StateRunning, + StateBridgeUnreachable, + StateConnecting, + StateBackfilling, + StateConnected, + StateTransientDisconnect, + StateBadCredentials, + StateUnknownError, + StateLoggedOut: + return true + default: + return false + } +} + +type BridgeStateUserAction string + +const ( + UserActionOpenNative BridgeStateUserAction = "OPEN_NATIVE" + UserActionRelogin BridgeStateUserAction = "RELOGIN" + UserActionRestart BridgeStateUserAction = "RESTART" +) + type RemoteProfile struct { Phone string `json:"phone,omitempty"` Email string `json:"email,omitempty"` Username string `json:"username,omitempty"` Name string `json:"name,omitempty"` Avatar id.ContentURIString `json:"avatar,omitempty"` + + AvatarFile *event.EncryptedFileInfo `json:"avatar_file,omitempty"` } func coalesce[T ~string](a, b T) T { @@ -74,11 +105,14 @@ func (rp *RemoteProfile) Merge(other RemoteProfile) RemoteProfile { other.Username = coalesce(rp.Username, other.Username) other.Name = coalesce(rp.Name, other.Name) other.Avatar = coalesce(rp.Avatar, other.Avatar) + if rp.AvatarFile != nil { + other.AvatarFile = rp.AvatarFile + } return other } -func (rp *RemoteProfile) IsEmpty() bool { - return rp == nil || (rp.Phone == "" && rp.Email == "" && rp.Username == "" && rp.Name == "" && rp.Avatar == "") +func (rp *RemoteProfile) IsZero() bool { + return rp == nil || (rp.Phone == "" && rp.Email == "" && rp.Username == "" && rp.Name == "" && rp.Avatar == "" && rp.AvatarFile == nil) } type BridgeState struct { @@ -90,10 +124,12 @@ type BridgeState struct { Error BridgeStateErrorCode `json:"error,omitempty"` Message string `json:"message,omitempty"` - UserID id.UserID `json:"user_id,omitempty"` - RemoteID string `json:"remote_id,omitempty"` - RemoteName string `json:"remote_name,omitempty"` - RemoteProfile *RemoteProfile `json:"remote_profile,omitempty"` + UserAction BridgeStateUserAction `json:"user_action,omitempty"` + + UserID id.UserID `json:"user_id,omitempty"` + RemoteID networkid.UserLoginID `json:"remote_id,omitempty"` + RemoteName string `json:"remote_name,omitempty"` + RemoteProfile RemoteProfile `json:"remote_profile,omitzero"` Reason string `json:"reason,omitempty"` Info map[string]interface{} `json:"info,omitempty"` @@ -105,31 +141,15 @@ type GlobalBridgeState struct { } type BridgeStateFiller interface { - GetMXID() id.UserID - GetRemoteID() string - GetRemoteName() string -} - -type StandaloneCustomBridgeStateFiller interface { FillBridgeState(BridgeState) BridgeState } -type CustomBridgeStateFiller interface { - BridgeStateFiller - StandaloneCustomBridgeStateFiller -} +// Deprecated: use BridgeStateFiller instead +type StandaloneCustomBridgeStateFiller = BridgeStateFiller -func (pong BridgeState) Fill(user any) BridgeState { +func (pong BridgeState) Fill(user BridgeStateFiller) BridgeState { if user != nil { - if std, ok := user.(BridgeStateFiller); ok { - pong.UserID = std.GetMXID() - pong.RemoteID = std.GetRemoteID() - pong.RemoteName = std.GetRemoteName() - } - - if custom, ok := user.(StandaloneCustomBridgeStateFiller); ok { - pong = custom.FillBridgeState(pong) - } + pong = user.FillBridgeState(pong) } pong.Timestamp = jsontime.UnixNow() @@ -188,7 +208,8 @@ func (pong *BridgeState) ShouldDeduplicate(newPong *BridgeState) bool { return pong != nil && pong.StateEvent == newPong.StateEvent && pong.RemoteName == newPong.RemoteName && - ptr.Val(pong.RemoteProfile) == ptr.Val(newPong.RemoteProfile) && + pong.UserAction == newPong.UserAction && + pong.RemoteProfile == newPong.RemoteProfile && pong.Error == newPong.Error && maps.EqualFunc(pong.Info, newPong.Info, reflect.DeepEqual) && pong.Timestamp.Add(time.Duration(pong.TTL)*time.Second).After(time.Now()) diff --git a/bridgev2/status/localbridgestate.go b/bridgev2/status/localbridgestate.go new file mode 100644 index 00000000..3ad66538 --- /dev/null +++ b/bridgev2/status/localbridgestate.go @@ -0,0 +1,23 @@ +package status + +type LocalBridgeAccountState string + +const ( + // LocalBridgeAccountStateSetup means the user wants this account to be setup and connected + LocalBridgeAccountStateSetup LocalBridgeAccountState = "SETUP" + // LocalBridgeAccountStateDeleted means the user wants this account to be deleted + LocalBridgeAccountStateDeleted LocalBridgeAccountState = "DELETED" +) + +type LocalBridgeDeviceState string + +const ( + // LocalBridgeDeviceStateSetup means this device is setup to be connected to this account + LocalBridgeDeviceStateSetup LocalBridgeDeviceState = "SETUP" + // LocalBridgeDeviceStateLoggedOut means the user has logged this particular device out while wanting their other devices to remain setup + LocalBridgeDeviceStateLoggedOut LocalBridgeDeviceState = "LOGGED_OUT" + // LocalBridgeDeviceStateError means this particular device has fallen into a persistent error state that may need user intervention to fix + LocalBridgeDeviceStateError LocalBridgeDeviceState = "ERROR" + // LocalBridgeDeviceStateDeleted means this particular device has cleaned up after the account as a whole was requested to be deleted + LocalBridgeDeviceStateDeleted LocalBridgeDeviceState = "DELETED" +) diff --git a/bridge/status/messagecheckpoint.go b/bridgev2/status/messagecheckpoint.go similarity index 96% rename from bridge/status/messagecheckpoint.go rename to bridgev2/status/messagecheckpoint.go index ea859b84..b3c05f4f 100644 --- a/bridge/status/messagecheckpoint.go +++ b/bridgev2/status/messagecheckpoint.go @@ -169,13 +169,13 @@ type CheckpointsJSON struct { Checkpoints []*MessageCheckpoint `json:"checkpoints"` } -func (cj *CheckpointsJSON) SendHTTP(endpoint string, token string) error { +func (cj *CheckpointsJSON) SendHTTP(ctx context.Context, cli *http.Client, endpoint string, token string) error { var body bytes.Buffer if err := json.NewEncoder(&body).Encode(cj); err != nil { return fmt.Errorf("failed to encode message checkpoint JSON: %w", err) } - ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + ctx, cancel := context.WithTimeout(ctx, 5*time.Second) defer cancel() req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, &body) if err != nil { @@ -186,7 +186,10 @@ func (cj *CheckpointsJSON) SendHTTP(endpoint string, token string) error { req.Header.Set("User-Agent", mautrix.DefaultUserAgent+" (checkpoint sender)") req.Header.Set("Content-Type", "application/json") - resp, err := http.DefaultClient.Do(req) + if cli == nil { + cli = http.DefaultClient + } + resp, err := cli.Do(req) if err != nil { return mautrix.HTTPError{ Request: req, diff --git a/bridgev2/user.go b/bridgev2/user.go index e6a5dd99..9a7896d6 100644 --- a/bridgev2/user.go +++ b/bridgev2/user.go @@ -176,6 +176,10 @@ func (user *User) GetUserLogins() []*UserLogin { return maps.Values(user.logins) } +func (user *User) HasTooManyLogins() bool { + return user.Permissions.MaxLogins > 0 && len(user.GetUserLoginIDs()) >= user.Permissions.MaxLogins +} + func (user *User) GetFormattedUserLogins() string { user.Bridge.cacheLock.Lock() logins := make([]string, len(user.logins)) diff --git a/bridgev2/userlogin.go b/bridgev2/userlogin.go index 142d67d4..d56dc4cc 100644 --- a/bridgev2/userlogin.go +++ b/bridgev2/userlogin.go @@ -10,6 +10,7 @@ import ( "cmp" "context" "fmt" + "maps" "slices" "sync" "time" @@ -17,10 +18,10 @@ import ( "github.com/rs/zerolog" "go.mau.fi/util/exsync" - "maunium.net/go/mautrix/bridge/status" "maunium.net/go/mautrix/bridgev2/bridgeconfig" "maunium.net/go/mautrix/bridgev2/database" "maunium.net/go/mautrix/bridgev2/networkid" + "maunium.net/go/mautrix/bridgev2/status" "maunium.net/go/mautrix/event" ) @@ -37,6 +38,7 @@ type UserLogin struct { spaceCreateLock sync.Mutex deleteLock sync.Mutex + disconnectOnce sync.Once } func (br *Bridge) loadUserLogin(ctx context.Context, user *User, dbUserLogin *database.UserLogin) (*UserLogin, error) { @@ -49,6 +51,8 @@ func (br *Bridge) loadUserLogin(ctx context.Context, user *User, dbUserLogin *da if err != nil { return nil, fmt.Errorf("failed to get user: %w", err) } + // TODO if loading the user caused the provided userlogin to be loaded, cancel here? + // Currently this will double-load it } userLogin := &UserLogin{ UserLogin: dbUserLogin, @@ -62,6 +66,9 @@ func (br *Bridge) loadUserLogin(ctx context.Context, user *User, dbUserLogin *da if err != nil { userLogin.Log.Err(err).Msg("Failed to load user login") return nil, nil + } else if userLogin.Client == nil { + userLogin.Log.Error().Msg("LoadUserLogin didn't fill Client") + return nil, nil } userLogin.BridgeState = br.NewBridgeStateQueue(userLogin) user.logins[userLogin.ID] = userLogin @@ -136,6 +143,12 @@ func (br *Bridge) GetCachedUserLoginByID(id networkid.UserLoginID) *UserLogin { return br.userLoginsByID[id] } +func (br *Bridge) GetAllCachedUserLogins() (logins []*UserLogin) { + br.cacheLock.Lock() + defer br.cacheLock.Unlock() + return slices.Collect(maps.Values(br.userLoginsByID)) +} + func (br *Bridge) GetCurrentBridgeStates() (states []status.BridgeState) { br.cacheLock.Lock() defer br.cacheLock.Unlock() @@ -224,19 +237,23 @@ func (user *User) NewLogin(ctx context.Context, data *database.UserLogin, params } ul.BridgeState = user.Bridge.NewBridgeStateQueue(ul) } - err = params.LoadUserLogin(ul.Log.WithContext(context.Background()), ul) + noCancelCtx := ul.Log.WithContext(user.Bridge.BackgroundCtx) + err = params.LoadUserLogin(noCancelCtx, ul) if err != nil { return nil, err + } else if ul.Client == nil { + ul.Log.Error().Msg("LoadUserLogin didn't fill Client in NewLogin") + return nil, fmt.Errorf("client not filled by LoadUserLogin") } if doInsert { - err = user.Bridge.DB.UserLogin.Insert(ctx, ul.UserLogin) + err = user.Bridge.DB.UserLogin.Insert(noCancelCtx, ul.UserLogin) if err != nil { return nil, err } user.Bridge.userLoginsByID[ul.ID] = ul user.logins[ul.ID] = ul } else { - err = ul.Save(ctx) + err = ul.Save(noCancelCtx) if err != nil { return nil, err } @@ -273,7 +290,8 @@ func (ul *UserLogin) Delete(ctx context.Context, state status.BridgeState, opts if opts.LogoutRemote { ul.Client.LogoutRemote(ctx) } else { - ul.Disconnect(nil) + // we probably shouldn't delete the login if disconnect isn't finished + ul.Disconnect() } var portals []*database.UserPortal var err error @@ -295,7 +313,7 @@ func (ul *UserLogin) Delete(ctx context.Context, state status.BridgeState, opts if !opts.unlocked { ul.Bridge.cacheLock.Unlock() } - backgroundCtx := context.WithoutCancel(ctx) + backgroundCtx := zerolog.Ctx(ctx).WithContext(ul.Bridge.BackgroundCtx) if !opts.BlockingCleanup { go ul.deleteSpace(backgroundCtx) } else { @@ -488,36 +506,66 @@ func (ul *UserLogin) MarkAsPreferredIn(ctx context.Context, portal *Portal) erro return ul.Bridge.DB.UserPortal.MarkAsPreferred(ctx, ul.UserLogin, portal.PortalKey) } -var _ status.StandaloneCustomBridgeStateFiller = (*UserLogin)(nil) +var _ status.BridgeStateFiller = (*UserLogin)(nil) func (ul *UserLogin) FillBridgeState(state status.BridgeState) status.BridgeState { state.UserID = ul.UserMXID - state.RemoteID = string(ul.ID) + state.RemoteID = ul.ID state.RemoteName = ul.RemoteName - state.RemoteProfile = &ul.RemoteProfile - filler, ok := ul.Client.(status.StandaloneCustomBridgeStateFiller) + state.RemoteProfile = ul.RemoteProfile + filler, ok := ul.Client.(status.BridgeStateFiller) if ok { return filler.FillBridgeState(state) } return state } -func (ul *UserLogin) Disconnect(done func()) { - if done != nil { - defer done() +func (ul *UserLogin) Disconnect() { + ul.DisconnectWithTimeout(0) +} + +func (ul *UserLogin) DisconnectWithTimeout(timeout time.Duration) { + ul.disconnectOnce.Do(func() { + ul.disconnectInternal(timeout) + }) +} + +func (ul *UserLogin) disconnectInternal(timeout time.Duration) { + ul.BridgeState.StopUnknownErrorReconnect() + disconnected := make(chan struct{}) + go func() { + ul.Client.Disconnect() + close(disconnected) + }() + + var timeoutC <-chan time.Time + if timeout > 0 { + timeoutC = time.After(timeout) } - client := ul.Client - if client != nil { - ul.Client = nil - disconnected := make(chan struct{}) - go func() { - client.Disconnect() - close(disconnected) - }() + for { select { case <-disconnected: - case <-time.After(5 * time.Second): - ul.Log.Warn().Msg("Client disconnection timed out") + return + case <-time.After(2 * time.Second): + ul.Log.Warn().Msg("Client disconnection taking long") + case <-timeoutC: + ul.Log.Error().Msg("Client disconnection timed out") + return } } } + +func (ul *UserLogin) recreateClient(ctx context.Context) error { + oldClient := ul.Client + err := ul.Bridge.Network.LoadUserLogin(ctx, ul) + if err != nil { + return err + } + if ul.Client == oldClient { + zerolog.Ctx(ctx).Warn().Msg("LoadUserLogin didn't update client") + } else { + zerolog.Ctx(ctx).Debug().Msg("Recreated user login client") + } + ul.disconnectOnce = sync.Once{} + return nil +} diff --git a/client.go b/client.go index e8689708..7062d9b9 100644 --- a/client.go +++ b/client.go @@ -13,6 +13,7 @@ import ( "net/http" "net/url" "os" + "runtime" "slices" "strconv" "strings" @@ -76,18 +77,19 @@ type VerificationHelper interface { // Client represents a Matrix client. type Client struct { - HomeserverURL *url.URL // The base homeserver URL - UserID id.UserID // The user ID of the client. Used for forming HTTP paths which use the client's user ID. - DeviceID id.DeviceID // The device ID of the client. - AccessToken string // The access_token for the client. - UserAgent string // The value for the User-Agent header - Client *http.Client // The underlying HTTP client which will be used to make HTTP requests. - Syncer Syncer // The thing which can process /sync responses - Store SyncStore // The thing which can store tokens/ids - StateStore StateStore - Crypto CryptoHelper - Verification VerificationHelper - SpecVersions *RespVersions + HomeserverURL *url.URL // The base homeserver URL + UserID id.UserID // The user ID of the client. Used for forming HTTP paths which use the client's user ID. + DeviceID id.DeviceID // The device ID of the client. + AccessToken string // The access_token for the client. + UserAgent string // The value for the User-Agent header + Client *http.Client // The underlying HTTP client which will be used to make HTTP requests. + Syncer Syncer // The thing which can process /sync responses + Store SyncStore // The thing which can store tokens/ids + StateStore StateStore + Crypto CryptoHelper + Verification VerificationHelper + SpecVersions *RespVersions + ExternalClient *http.Client // The HTTP client used for external (not matrix) media HTTP requests. Log zerolog.Logger @@ -109,6 +111,8 @@ type Client struct { // Set to true to disable automatically sleeping on 429 errors. IgnoreRateLimit bool + ResponseSizeLimit int64 + txnID int32 // Should the ?user_id= query parameter be set in requests? @@ -138,6 +142,12 @@ type IdentityServerInfo struct { // Use ParseUserID to extract the server name from a user ID. // https://spec.matrix.org/v1.2/client-server-api/#server-discovery func DiscoverClientAPI(ctx context.Context, serverName string) (*ClientWellKnown, error) { + return DiscoverClientAPIWithClient(ctx, &http.Client{Timeout: 30 * time.Second}, serverName) +} + +const WellKnownMaxSize = 64 * 1024 + +func DiscoverClientAPIWithClient(ctx context.Context, client *http.Client, serverName string) (*ClientWellKnown, error) { wellKnownURL := url.URL{ Scheme: "https", Host: serverName, @@ -149,10 +159,11 @@ func DiscoverClientAPI(ctx context.Context, serverName string) (*ClientWellKnown return nil, err } - req.Header.Set("Accept", "application/json") - req.Header.Set("User-Agent", DefaultUserAgent+" (.well-known fetcher)") + if runtime.GOOS != "js" { + req.Header.Set("Accept", "application/json") + req.Header.Set("User-Agent", DefaultUserAgent+" (.well-known fetcher)") + } - client := &http.Client{Timeout: 30 * time.Second} resp, err := client.Do(req) if err != nil { return nil, err @@ -161,11 +172,15 @@ func DiscoverClientAPI(ctx context.Context, serverName string) (*ClientWellKnown if resp.StatusCode == http.StatusNotFound { return nil, nil + } else if resp.ContentLength > WellKnownMaxSize { + return nil, errors.New(".well-known response too large") } - data, err := io.ReadAll(resp.Body) + data, err := io.ReadAll(io.LimitReader(resp.Body, WellKnownMaxSize)) if err != nil { return nil, err + } else if len(data) >= WellKnownMaxSize { + return nil, errors.New(".well-known response too large") } var wellKnown ClientWellKnown @@ -315,15 +330,28 @@ type contextKey int const ( LogBodyContextKey contextKey = iota LogRequestIDContextKey + MaxAttemptsContextKey + SyncTokenContextKey ) func (cli *Client) RequestStart(req *http.Request) { - if cli.RequestHook != nil { + if cli != nil && cli.RequestHook != nil { cli.RequestHook(req) } } +// WithMaxRetries updates the context to set the maximum number of retries for any HTTP requests made with the context. +// +// 0 means the request will only be attempted once and will not be retried. +// Negative values will remove the override and fallback to the defaults. +func WithMaxRetries(ctx context.Context, maxRetries int) context.Context { + return context.WithValue(ctx, MaxAttemptsContextKey, maxRetries+1) +} + func (cli *Client) LogRequestDone(req *http.Request, resp *http.Response, err error, handlerErr error, contentLength int, duration time.Duration) { + if cli == nil { + return + } var evt *zerolog.Event if errors.Is(err, context.Canceled) { evt = zerolog.Ctx(req.Context()).Warn() @@ -358,7 +386,14 @@ func (cli *Client) LogRequestDone(req *http.Request, resp *http.Response, err er } } if body := req.Context().Value(LogBodyContextKey); body != nil { - evt.Interface("req_body", body) + switch typedLogBody := body.(type) { + case json.RawMessage: + evt.RawJSON("req_body", typedLogBody) + case string: + evt.Str("req_body", typedLogBody) + default: + panic(fmt.Errorf("invalid type for LogBodyContextKey: %T", body)) + } } if errors.Is(err, context.Canceled) { evt.Msg("Request canceled") @@ -375,32 +410,43 @@ func (cli *Client) MakeRequest(ctx context.Context, method string, httpURL strin return cli.MakeFullRequest(ctx, FullRequest{Method: method, URL: httpURL, RequestJSON: reqBody, ResponseJSON: resBody}) } -type ClientResponseHandler = func(req *http.Request, res *http.Response, responseJSON interface{}) ([]byte, error) +type ClientResponseHandler = func(req *http.Request, res *http.Response, responseJSON any, sizeLimit int64) ([]byte, error) type FullRequest struct { - Method string - URL string - Headers http.Header - RequestJSON interface{} - RequestBytes []byte - RequestBody io.Reader - RequestLength int64 - ResponseJSON interface{} - MaxAttempts int - BackoffDuration time.Duration - SensitiveContent bool - Handler ClientResponseHandler - DontReadResponse bool - Logger *zerolog.Logger - Client *http.Client + Method string + URL string + Headers http.Header + RequestJSON interface{} + RequestBytes []byte + RequestBody io.Reader + RequestLength int64 + ResponseJSON interface{} + MaxAttempts int + BackoffDuration time.Duration + SensitiveContent bool + Handler ClientResponseHandler + DontReadResponse bool + ResponseSizeLimit int64 + Logger *zerolog.Logger + Client *http.Client } var requestID int32 var logSensitiveContent = os.Getenv("MAUTRIX_LOG_SENSITIVE_CONTENT") == "yes" func (params *FullRequest) compileRequest(ctx context.Context) (*http.Request, error) { + reqID := atomic.AddInt32(&requestID, 1) + logger := zerolog.Ctx(ctx) + if logger.GetLevel() == zerolog.Disabled || logger == zerolog.DefaultContextLogger { + logger = params.Logger + } + ctx = logger.With(). + Int32("req_id", reqID). + Logger().WithContext(ctx) + var logBody any - reqBody := params.RequestBody + var reqBody io.Reader + var reqLen int64 if params.RequestJSON != nil { jsonStr, err := json.Marshal(params.RequestJSON) if err != nil { @@ -411,33 +457,38 @@ func (params *FullRequest) compileRequest(ctx context.Context) (*http.Request, e } if params.SensitiveContent && !logSensitiveContent { logBody = "" + } else if len(jsonStr) > 32768 { + logBody = fmt.Sprintf("", len(jsonStr)) } else { - logBody = params.RequestJSON + logBody = json.RawMessage(jsonStr) } reqBody = bytes.NewReader(jsonStr) + reqLen = int64(len(jsonStr)) } else if params.RequestBytes != nil { logBody = fmt.Sprintf("<%d bytes>", len(params.RequestBytes)) reqBody = bytes.NewReader(params.RequestBytes) - params.RequestLength = int64(len(params.RequestBytes)) - } else if params.RequestLength > 0 && params.RequestBody != nil { - logBody = fmt.Sprintf("<%d bytes>", params.RequestLength) + reqLen = int64(len(params.RequestBytes)) + } else if params.RequestBody != nil { + logBody = "" + reqLen = -1 + if params.RequestLength > 0 { + logBody = fmt.Sprintf("<%d bytes>", params.RequestLength) + reqLen = params.RequestLength + } else if params.RequestLength == 0 { + zerolog.Ctx(ctx).Warn(). + Msg("RequestBody passed without specifying request length") + } + reqBody = params.RequestBody if rsc, ok := params.RequestBody.(io.ReadSeekCloser); ok { // Prevent HTTP from closing the request body, it might be needed for retries reqBody = nopCloseSeeker{rsc} } } else if params.Method != http.MethodGet && params.Method != http.MethodHead { params.RequestJSON = struct{}{} - logBody = params.RequestJSON + logBody = json.RawMessage("{}") reqBody = bytes.NewReader([]byte("{}")) + reqLen = 2 } - reqID := atomic.AddInt32(&requestID, 1) - logger := zerolog.Ctx(ctx) - if logger.GetLevel() == zerolog.Disabled || logger == zerolog.DefaultContextLogger { - logger = params.Logger - } - ctx = logger.With(). - Int32("req_id", reqID). - Logger().WithContext(ctx) ctx = context.WithValue(ctx, LogBodyContextKey, logBody) ctx = context.WithValue(ctx, LogRequestIDContextKey, int(reqID)) req, err := http.NewRequestWithContext(ctx, params.Method, params.URL, reqBody) @@ -453,9 +504,7 @@ func (params *FullRequest) compileRequest(ctx context.Context) (*http.Request, e if params.RequestJSON != nil { req.Header.Set("Content-Type", "application/json") } - if params.RequestLength > 0 && params.RequestBody != nil { - req.ContentLength = params.RequestLength - } + req.ContentLength = reqLen return req, nil } @@ -465,8 +514,19 @@ func (cli *Client) MakeFullRequest(ctx context.Context, params FullRequest) ([]b } func (cli *Client) MakeFullRequestWithResp(ctx context.Context, params FullRequest) ([]byte, *http.Response, error) { + if cli == nil { + return nil, nil, ErrClientIsNil + } + if cli.HomeserverURL == nil || cli.HomeserverURL.Scheme == "" { + return nil, nil, ErrClientHasNoHomeserver + } if params.MaxAttempts == 0 { - params.MaxAttempts = 1 + cli.DefaultHTTPRetries + maxAttempts, ok := ctx.Value(MaxAttemptsContextKey).(int) + if ok && maxAttempts > 0 { + params.MaxAttempts = maxAttempts + } else { + params.MaxAttempts = 1 + cli.DefaultHTTPRetries + } } if params.BackoffDuration == 0 { if cli.DefaultHTTPBackoff == 0 { @@ -489,14 +549,31 @@ func (cli *Client) MakeFullRequestWithResp(ctx context.Context, params FullReque params.Handler = handleNormalResponse } } - req.Header.Set("User-Agent", cli.UserAgent) + if cli.UserAgent != "" { + req.Header.Set("User-Agent", cli.UserAgent) + } if len(cli.AccessToken) > 0 { req.Header.Set("Authorization", "Bearer "+cli.AccessToken) } + if params.ResponseSizeLimit == 0 { + params.ResponseSizeLimit = cli.ResponseSizeLimit + } + if params.ResponseSizeLimit == 0 { + params.ResponseSizeLimit = DefaultResponseSizeLimit + } if params.Client == nil { params.Client = cli.Client } - return cli.executeCompiledRequest(req, params.MaxAttempts-1, params.BackoffDuration, params.ResponseJSON, params.Handler, params.DontReadResponse, params.Client) + return cli.executeCompiledRequest( + req, + params.MaxAttempts-1, + params.BackoffDuration, + params.ResponseJSON, + params.Handler, + params.DontReadResponse, + params.ResponseSizeLimit, + params.Client, + ) } func (cli *Client) cliOrContextLog(ctx context.Context) *zerolog.Logger { @@ -507,7 +584,17 @@ func (cli *Client) cliOrContextLog(ctx context.Context) *zerolog.Logger { return log } -func (cli *Client) doRetry(req *http.Request, cause error, retries int, backoff time.Duration, responseJSON any, handler ClientResponseHandler, dontReadResponse bool, client *http.Client) ([]byte, *http.Response, error) { +func (cli *Client) doRetry( + req *http.Request, + cause error, + retries int, + backoff time.Duration, + responseJSON any, + handler ClientResponseHandler, + dontReadResponse bool, + sizeLimit int64, + client *http.Client, +) ([]byte, *http.Response, error) { log := zerolog.Ctx(req.Context()) if req.Body != nil { var err error @@ -529,17 +616,37 @@ func (cli *Client) doRetry(req *http.Request, cause error, retries int, backoff } } log.Warn().Err(cause). + Str("method", req.Method). + Str("url", req.URL.String()). Int("retry_in_seconds", int(backoff.Seconds())). Msg("Request failed, retrying") - time.Sleep(backoff) + select { + case <-time.After(backoff): + case <-req.Context().Done(): + if !errors.Is(context.Cause(req.Context()), ErrContextCancelRetry) { + return nil, nil, req.Context().Err() + } + } if cli.UpdateRequestOnRetry != nil { req = cli.UpdateRequestOnRetry(req, cause) } - return cli.executeCompiledRequest(req, retries-1, backoff*2, responseJSON, handler, dontReadResponse, client) + return cli.executeCompiledRequest(req, retries-1, backoff*2, responseJSON, handler, dontReadResponse, sizeLimit, client) } -func readResponseBody(req *http.Request, res *http.Response) ([]byte, error) { - contents, err := io.ReadAll(res.Body) +func readResponseBody(req *http.Request, res *http.Response, limit int64) ([]byte, error) { + if res.ContentLength > limit { + return nil, HTTPError{ + Request: req, + Response: res, + + Message: "not reading response", + WrappedError: fmt.Errorf("%w (%.2f MiB)", ErrResponseTooLong, float64(res.ContentLength)/1024/1024), + } + } + contents, err := io.ReadAll(io.LimitReader(res.Body, limit+1)) + if err == nil && len(contents) > int(limit) { + err = ErrBodyReadReachedLimit + } if err != nil { return nil, HTTPError{ Request: req, @@ -560,17 +667,20 @@ func closeTemp(log *zerolog.Logger, file *os.File) { } } -func streamResponse(req *http.Request, res *http.Response, responseJSON interface{}) ([]byte, error) { +func streamResponse(req *http.Request, res *http.Response, responseJSON any, limit int64) ([]byte, error) { log := zerolog.Ctx(req.Context()) file, err := os.CreateTemp("", "mautrix-response-") if err != nil { log.Warn().Err(err).Msg("Failed to create temporary file for streaming response") - _, err = handleNormalResponse(req, res, responseJSON) + _, err = handleNormalResponse(req, res, responseJSON, limit) return nil, err } defer closeTemp(log, file) - if _, err = io.Copy(file, res.Body); err != nil { + var n int64 + if n, err = io.Copy(file, io.LimitReader(res.Body, limit+1)); err != nil { return nil, fmt.Errorf("failed to copy response to file: %w", err) + } else if n > limit { + return nil, ErrBodyReadReachedLimit } else if _, err = file.Seek(0, 0); err != nil { return nil, fmt.Errorf("failed to seek to beginning of response file: %w", err) } else if err = json.NewDecoder(file).Decode(responseJSON); err != nil { @@ -580,12 +690,12 @@ func streamResponse(req *http.Request, res *http.Response, responseJSON interfac } } -func noopHandleResponse(req *http.Request, res *http.Response, responseJSON interface{}) ([]byte, error) { +func noopHandleResponse(req *http.Request, res *http.Response, responseJSON any, limit int64) ([]byte, error) { return nil, nil } -func handleNormalResponse(req *http.Request, res *http.Response, responseJSON interface{}) ([]byte, error) { - if contents, err := readResponseBody(req, res); err != nil { +func handleNormalResponse(req *http.Request, res *http.Response, responseJSON any, limit int64) ([]byte, error) { + if contents, err := readResponseBody(req, res, limit); err != nil { return nil, err } else if responseJSON == nil { return contents, nil @@ -603,8 +713,13 @@ func handleNormalResponse(req *http.Request, res *http.Response, responseJSON in } } +const ErrorResponseSizeLimit = 512 * 1024 + +var DefaultResponseSizeLimit int64 = 512 * 1024 * 1024 + func ParseErrorResponse(req *http.Request, res *http.Response) ([]byte, error) { - contents, err := readResponseBody(req, res) + defer res.Body.Close() + contents, err := readResponseBody(req, res, ErrorResponseSizeLimit) if err != nil { return contents, err } @@ -623,17 +738,31 @@ func ParseErrorResponse(req *http.Request, res *http.Response) ([]byte, error) { } } -func (cli *Client) executeCompiledRequest(req *http.Request, retries int, backoff time.Duration, responseJSON any, handler ClientResponseHandler, dontReadResponse bool, client *http.Client) ([]byte, *http.Response, error) { +func (cli *Client) executeCompiledRequest( + req *http.Request, + retries int, + backoff time.Duration, + responseJSON any, + handler ClientResponseHandler, + dontReadResponse bool, + sizeLimit int64, + client *http.Client, +) ([]byte, *http.Response, error) { cli.RequestStart(req) startTime := time.Now() res, err := client.Do(req) - duration := time.Now().Sub(startTime) + duration := time.Since(startTime) if res != nil && !dontReadResponse { defer res.Body.Close() } if err != nil { - if retries > 0 && !errors.Is(err, context.Canceled) { - return cli.doRetry(req, err, retries, backoff, responseJSON, handler, dontReadResponse, client) + // Either error is *not* canceled or the underlying cause of cancelation explicitly asks to retry + canRetry := !errors.Is(err, context.Canceled) || + errors.Is(context.Cause(req.Context()), ErrContextCancelRetry) + if retries > 0 && canRetry { + return cli.doRetry( + req, err, retries, backoff, responseJSON, handler, dontReadResponse, sizeLimit, client, + ) } err = HTTPError{ Request: req, @@ -648,7 +777,9 @@ func (cli *Client) executeCompiledRequest(req *http.Request, retries int, backof if retries > 0 && retryafter.Should(res.StatusCode, !cli.IgnoreRateLimit) { backoff = retryafter.Parse(res.Header.Get("Retry-After"), backoff) - return cli.doRetry(req, fmt.Errorf("HTTP %d", res.StatusCode), retries, backoff, responseJSON, handler, dontReadResponse, client) + return cli.doRetry( + req, fmt.Errorf("HTTP %d", res.StatusCode), retries, backoff, responseJSON, handler, dontReadResponse, sizeLimit, client, + ) } var body []byte @@ -656,7 +787,7 @@ func (cli *Client) executeCompiledRequest(req *http.Request, retries int, backof body, err = ParseErrorResponse(req, res) cli.LogRequestDone(req, res, nil, nil, len(body), duration) } else { - body, err = handler(req, res, responseJSON) + body, err = handler(req, res, responseJSON, sizeLimit) cli.LogRequestDone(req, res, nil, err, len(body), duration) } return body, res, err @@ -664,7 +795,6 @@ func (cli *Client) executeCompiledRequest(req *http.Request, retries int, backof // Whoami gets the user ID of the current user. See https://spec.matrix.org/v1.2/client-server-api/#get_matrixclientv3accountwhoami func (cli *Client) Whoami(ctx context.Context) (resp *RespWhoami, err error) { - urlPath := cli.BuildClientURL("v3", "account", "whoami") _, err = cli.MakeRequest(ctx, http.MethodGet, urlPath, nil, &resp) return @@ -695,6 +825,7 @@ type ReqSync struct { FullState bool SetPresence event.Presence StreamResponse bool + UseStateAfter bool BeeperStreaming bool Client *http.Client } @@ -715,9 +846,10 @@ func (req *ReqSync) BuildQuery() map[string]string { if req.FullState { query["full_state"] = "true" } + if req.UseStateAfter { + query["use_state_after"] = "true" + } if req.BeeperStreaming { - // TODO remove this - query["streaming"] = "" query["com.beeper.streaming"] = "true" } return query @@ -739,7 +871,7 @@ func (cli *Client) FullSyncRequest(ctx context.Context, req ReqSync) (resp *Resp } start := time.Now() _, err = cli.MakeFullRequest(ctx, fullReq) - duration := time.Now().Sub(start) + duration := time.Since(start) timeout := time.Duration(req.Timeout) * time.Millisecond buffer := 10 * time.Second if req.Since == "" { @@ -786,7 +918,7 @@ func (cli *Client) RegisterAvailable(ctx context.Context, username string) (resp return } -func (cli *Client) register(ctx context.Context, url string, req *ReqRegister) (resp *RespRegister, uiaResp *RespUserInteractive, err error) { +func (cli *Client) register(ctx context.Context, url string, req *ReqRegister[any]) (resp *RespRegister, uiaResp *RespUserInteractive, err error) { var bodyBytes []byte bodyBytes, err = cli.MakeFullRequest(ctx, FullRequest{ Method: http.MethodPost, @@ -810,7 +942,7 @@ func (cli *Client) register(ctx context.Context, url string, req *ReqRegister) ( // Register makes an HTTP request according to https://spec.matrix.org/v1.2/client-server-api/#post_matrixclientv3register // // Registers with kind=user. For kind=guest, see RegisterGuest. -func (cli *Client) Register(ctx context.Context, req *ReqRegister) (*RespRegister, *RespUserInteractive, error) { +func (cli *Client) Register(ctx context.Context, req *ReqRegister[any]) (*RespRegister, *RespUserInteractive, error) { u := cli.BuildClientURL("v3", "register") return cli.register(ctx, u, req) } @@ -819,7 +951,7 @@ func (cli *Client) Register(ctx context.Context, req *ReqRegister) (*RespRegiste // with kind=guest. // // For kind=user, see Register. -func (cli *Client) RegisterGuest(ctx context.Context, req *ReqRegister) (*RespRegister, *RespUserInteractive, error) { +func (cli *Client) RegisterGuest(ctx context.Context, req *ReqRegister[any]) (*RespRegister, *RespUserInteractive, error) { query := map[string]string{ "kind": "guest", } @@ -842,8 +974,8 @@ func (cli *Client) RegisterGuest(ctx context.Context, req *ReqRegister) (*RespRe // panic(err) // } // token := res.AccessToken -func (cli *Client) RegisterDummy(ctx context.Context, req *ReqRegister) (*RespRegister, error) { - res, uia, err := cli.Register(ctx, req) +func (cli *Client) RegisterDummy(ctx context.Context, req *ReqRegister[any]) (*RespRegister, error) { + _, uia, err := cli.Register(ctx, req) if err != nil && uia == nil { return nil, err } else if uia == nil { @@ -852,7 +984,7 @@ func (cli *Client) RegisterDummy(ctx context.Context, req *ReqRegister) (*RespRe return nil, errors.New("server does not support m.login.dummy") } req.Auth = BaseAuthData{Type: AuthTypeDummy, Session: uia.Session} - res, _, err = cli.Register(ctx, req) + res, _, err := cli.Register(ctx, req) if err != nil { return nil, err } @@ -951,20 +1083,19 @@ func (cli *Client) Capabilities(ctx context.Context) (resp *RespCapabilities, er return } -// JoinRoom joins the client to a room ID or alias. See https://spec.matrix.org/v1.2/client-server-api/#post_matrixclientv3joinroomidoralias +// JoinRoom joins the client to a room ID or alias. See https://spec.matrix.org/v1.13/client-server-api/#post_matrixclientv3joinroomidoralias // -// If serverName is specified, this will be added as a query param to instruct the homeserver to join via that server. If content is specified, it will -// be JSON encoded and used as the request body. -func (cli *Client) JoinRoom(ctx context.Context, roomIDorAlias, serverName string, content interface{}) (resp *RespJoinRoom, err error) { - var urlPath string - if serverName != "" { - urlPath = cli.BuildURLWithQuery(ClientURLPath{"v3", "join", roomIDorAlias}, map[string]string{ - "via": serverName, - }) - } else { - urlPath = cli.BuildClientURL("v3", "join", roomIDorAlias) +// The last parameter contains optional extra fields and can be left nil. +func (cli *Client) JoinRoom(ctx context.Context, roomIDorAlias string, req *ReqJoinRoom) (resp *RespJoinRoom, err error) { + if req == nil { + req = &ReqJoinRoom{} } - _, err = cli.MakeRequest(ctx, http.MethodPost, urlPath, content, &resp) + urlPath := cli.BuildURLWithFullQuery(ClientURLPath{"v3", "join", roomIDorAlias}, func(q url.Values) { + if len(req.Via) > 0 { + q["via"] = req.Via + } + }) + _, err = cli.MakeRequest(ctx, http.MethodPost, urlPath, req, &resp) if err == nil && cli.StateStore != nil { err = cli.StateStore.SetMembership(ctx, resp.RoomID, cli.UserID, event.MembershipJoin) if err != nil { @@ -974,6 +1105,28 @@ func (cli *Client) JoinRoom(ctx context.Context, roomIDorAlias, serverName strin return } +// KnockRoom requests to join a room ID or alias. See https://spec.matrix.org/v1.13/client-server-api/#post_matrixclientv3knockroomidoralias +// +// The last parameter contains optional extra fields and can be left nil. +func (cli *Client) KnockRoom(ctx context.Context, roomIDorAlias string, req *ReqKnockRoom) (resp *RespKnockRoom, err error) { + if req == nil { + req = &ReqKnockRoom{} + } + urlPath := cli.BuildURLWithFullQuery(ClientURLPath{"v3", "knock", roomIDorAlias}, func(q url.Values) { + if len(req.Via) > 0 { + q["via"] = req.Via + } + }) + _, err = cli.MakeRequest(ctx, http.MethodPost, urlPath, req, &resp) + if err == nil && cli.StateStore != nil { + err = cli.StateStore.SetMembership(ctx, resp.RoomID, cli.UserID, event.MembershipKnock) + if err != nil { + err = fmt.Errorf("failed to update state store: %w", err) + } + } + return +} + // JoinRoomByID joins the client to a room ID. See https://spec.matrix.org/v1.2/client-server-api/#post_matrixclientv3roomsroomidjoin // // Unlike JoinRoom, this method can only be used to join rooms that the server already knows about. @@ -995,8 +1148,19 @@ func (cli *Client) GetProfile(ctx context.Context, mxid id.UserID) (resp *RespUs return } +func (cli *Client) SearchUserDirectory(ctx context.Context, query string, limit int) (resp *RespSearchUserDirectory, err error) { + urlPath := cli.BuildClientURL("v3", "user_directory", "search") + _, err = cli.MakeRequest(ctx, http.MethodPost, urlPath, &ReqSearchUserDirectory{ + SearchTerm: query, + Limit: limit, + }, &resp) + return +} + func (cli *Client) GetMutualRooms(ctx context.Context, otherUserID id.UserID, extras ...ReqMutualRooms) (resp *RespMutualRooms, err error) { - if cli.SpecVersions != nil && !cli.SpecVersions.Supports(FeatureMutualRooms) { + supportsStable := cli.SpecVersions.Supports(FeatureStableMutualRooms) + supportsUnstable := cli.SpecVersions.Supports(FeatureUnstableMutualRooms) + if cli.SpecVersions != nil && !supportsUnstable && !supportsStable { err = fmt.Errorf("server does not support fetching mutual rooms") return } @@ -1006,15 +1170,32 @@ func (cli *Client) GetMutualRooms(ctx context.Context, otherUserID id.UserID, ex if len(extras) > 0 { query["from"] = extras[0].From } - urlPath := cli.BuildURLWithQuery(ClientURLPath{"unstable", "uk.half-shot.msc2666", "user", "mutual_rooms"}, query) + urlPath := cli.BuildURLWithQuery(ClientURLPath{"v1", "user", "mutual_rooms"}, query) + if !supportsStable && supportsUnstable { + urlPath = cli.BuildURLWithQuery(ClientURLPath{"unstable", "uk.half-shot.msc2666", "user", "mutual_rooms"}, query) + } _, err = cli.MakeRequest(ctx, http.MethodGet, urlPath, nil, &resp) return } +func (cli *Client) GetRoomSummary(ctx context.Context, roomIDOrAlias string, via ...string) (resp *RespRoomSummary, err error) { + urlPath := ClientURLPath{"unstable", "im.nheko.summary", "summary", roomIDOrAlias} + if cli.SpecVersions.ContainsGreaterOrEqual(SpecV115) { + urlPath = ClientURLPath{"v1", "room_summary", roomIDOrAlias} + } + // TODO add version check after one is added to MSC3266 + fullURL := cli.BuildURLWithFullQuery(urlPath, func(q url.Values) { + if len(via) > 0 { + q["via"] = via + } + }) + _, err = cli.MakeRequest(ctx, http.MethodGet, fullURL, nil, &resp) + return +} + // GetDisplayName returns the display name of the user with the specified MXID. See https://spec.matrix.org/v1.2/client-server-api/#get_matrixclientv3profileuseriddisplayname func (cli *Client) GetDisplayName(ctx context.Context, mxid id.UserID) (resp *RespUserDisplayName, err error) { - urlPath := cli.BuildClientURL("v3", "profile", mxid, "displayname") - _, err = cli.MakeRequest(ctx, http.MethodGet, urlPath, nil, &resp) + err = cli.GetProfileField(ctx, mxid, "displayname", &resp) return } @@ -1025,25 +1206,47 @@ func (cli *Client) GetOwnDisplayName(ctx context.Context) (resp *RespUserDisplay // SetDisplayName sets the user's profile display name. See https://spec.matrix.org/v1.2/client-server-api/#put_matrixclientv3profileuseriddisplayname func (cli *Client) SetDisplayName(ctx context.Context, displayName string) (err error) { - urlPath := cli.BuildClientURL("v3", "profile", cli.UserID, "displayname") - s := struct { - DisplayName string `json:"displayname"` - }{displayName} - _, err = cli.MakeRequest(ctx, http.MethodPut, urlPath, &s, nil) + return cli.SetProfileField(ctx, "displayname", displayName) +} + +// SetProfileField sets an arbitrary profile field. See https://spec.matrix.org/v1.16/client-server-api/#put_matrixclientv3profileuseridkeyname +func (cli *Client) SetProfileField(ctx context.Context, key string, value any) (err error) { + urlPath := cli.BuildClientURL("v3", "profile", cli.UserID, key) + if key != "displayname" && key != "avatar_url" && !cli.SpecVersions.Supports(FeatureArbitraryProfileFields) && cli.SpecVersions.Supports(FeatureUnstableProfileFields) { + urlPath = cli.BuildClientURL("unstable", "uk.tcpip.msc4133", "profile", cli.UserID, key) + } + _, err = cli.MakeRequest(ctx, http.MethodPut, urlPath, map[string]any{ + key: value, + }, nil) + return +} + +// DeleteProfileField deletes an arbitrary profile field. See https://spec.matrix.org/v1.16/client-server-api/#put_matrixclientv3profileuseridkeyname +func (cli *Client) DeleteProfileField(ctx context.Context, key string) (err error) { + urlPath := cli.BuildClientURL("v3", "profile", cli.UserID, key) + if key != "displayname" && key != "avatar_url" && !cli.SpecVersions.Supports(FeatureArbitraryProfileFields) && cli.SpecVersions.Supports(FeatureUnstableProfileFields) { + urlPath = cli.BuildClientURL("unstable", "uk.tcpip.msc4133", "profile", cli.UserID, key) + } + _, err = cli.MakeRequest(ctx, http.MethodDelete, urlPath, nil, nil) + return +} + +// GetProfileField gets an arbitrary profile field and parses the response into the given struct. See https://spec.matrix.org/unstable/client-server-api/#get_matrixclientv3profileuseridkeyname +func (cli *Client) GetProfileField(ctx context.Context, userID id.UserID, key string, into any) (err error) { + urlPath := cli.BuildClientURL("v3", "profile", userID, key) + if key != "displayname" && key != "avatar_url" && !cli.SpecVersions.Supports(FeatureArbitraryProfileFields) && cli.SpecVersions.Supports(FeatureUnstableProfileFields) { + urlPath = cli.BuildClientURL("unstable", "uk.tcpip.msc4133", "profile", cli.UserID, key) + } + _, err = cli.MakeRequest(ctx, http.MethodGet, urlPath, nil, into) return } // GetAvatarURL gets the avatar URL of the user with the specified MXID. See https://spec.matrix.org/v1.2/client-server-api/#get_matrixclientv3profileuseridavatar_url func (cli *Client) GetAvatarURL(ctx context.Context, mxid id.UserID) (url id.ContentURI, err error) { - urlPath := cli.BuildClientURL("v3", "profile", mxid, "avatar_url") s := struct { AvatarURL id.ContentURI `json:"avatar_url"` }{} - - _, err = cli.MakeRequest(ctx, http.MethodGet, urlPath, nil, &s) - if err != nil { - return - } + err = cli.GetProfileField(ctx, mxid, "avatar_url", &s) url = s.AvatarURL return } @@ -1110,15 +1313,6 @@ func (cli *Client) SetRoomAccountData(ctx context.Context, roomID id.RoomID, nam return nil } -type ReqSendEvent struct { - Timestamp int64 - TransactionID string - - DontEncrypt bool - - MeowEventID id.EventID -} - // SendMessageEvent sends a message event into a room. See https://spec.matrix.org/v1.2/client-server-api/#put_matrixclientv3roomsroomidsendeventtypetxnid // contentJSON should be a pointer to something that can be encoded as JSON using json.Marshal. func (cli *Client) SendMessageEvent(ctx context.Context, roomID id.RoomID, eventType event.Type, contentJSON interface{}, extra ...ReqSendEvent) (resp *RespSendEvent, err error) { @@ -1141,8 +1335,14 @@ func (cli *Client) SendMessageEvent(ctx context.Context, roomID id.RoomID, event if req.MeowEventID != "" { queryParams["fi.mau.event_id"] = req.MeowEventID.String() } + if req.UnstableDelay > 0 { + queryParams["org.matrix.msc4140.delay"] = strconv.FormatInt(req.UnstableDelay.Milliseconds(), 10) + } + if req.UnstableStickyDuration > 0 { + queryParams["org.matrix.msc4354.sticky_duration_ms"] = strconv.FormatInt(req.UnstableStickyDuration.Milliseconds(), 10) + } - if !req.DontEncrypt && cli.Crypto != nil && eventType != event.EventReaction && eventType != event.EventEncrypted { + if !req.DontEncrypt && cli != nil && cli.Crypto != nil && eventType != event.EventReaction && eventType != event.EventEncrypted { var isEncrypted bool isEncrypted, err = cli.StateStore.IsEncrypted(ctx, roomID) if err != nil { @@ -1164,9 +1364,51 @@ func (cli *Client) SendMessageEvent(ctx context.Context, roomID id.RoomID, event return } -// SendStateEvent sends a state event into a room. See https://spec.matrix.org/v1.2/client-server-api/#put_matrixclientv3roomsroomidstateeventtypestatekey +// BeeperSendEphemeralEvent sends an ephemeral event into a room using Beeper's unstable endpoint. +// contentJSON should be a value that can be encoded as JSON using json.Marshal. +func (cli *Client) BeeperSendEphemeralEvent(ctx context.Context, roomID id.RoomID, eventType event.Type, contentJSON any, extra ...ReqSendEvent) (resp *RespSendEvent, err error) { + var req ReqSendEvent + if len(extra) > 0 { + req = extra[0] + } + + var txnID string + if len(req.TransactionID) > 0 { + txnID = req.TransactionID + } else { + txnID = cli.TxnID() + } + + queryParams := map[string]string{} + if req.Timestamp > 0 { + queryParams["ts"] = strconv.FormatInt(req.Timestamp, 10) + } + + if !req.DontEncrypt && cli != nil && cli.Crypto != nil && eventType != event.EventEncrypted { + var isEncrypted bool + isEncrypted, err = cli.StateStore.IsEncrypted(ctx, roomID) + if err != nil { + err = fmt.Errorf("failed to check if room is encrypted: %w", err) + return + } + if isEncrypted { + if contentJSON, err = cli.Crypto.Encrypt(ctx, roomID, eventType, contentJSON); err != nil { + err = fmt.Errorf("failed to encrypt event: %w", err) + return + } + eventType = event.EventEncrypted + } + } + + urlData := ClientURLPath{"unstable", "com.beeper.ephemeral", "rooms", roomID, "ephemeral", eventType.String(), txnID} + urlPath := cli.BuildURLWithQuery(urlData, queryParams) + _, err = cli.MakeRequest(ctx, http.MethodPut, urlPath, contentJSON, &resp) + return +} + +// SendStateEvent sends a state event into a room. See https://spec.matrix.org/v1.16/client-server-api/#put_matrixclientv3roomsroomidstateeventtypestatekey // contentJSON should be a pointer to something that can be encoded as JSON using json.Marshal. -func (cli *Client) SendStateEvent(ctx context.Context, roomID id.RoomID, eventType event.Type, stateKey string, contentJSON interface{}, extra ...ReqSendEvent) (resp *RespSendEvent, err error) { +func (cli *Client) SendStateEvent(ctx context.Context, roomID id.RoomID, eventType event.Type, stateKey string, contentJSON any, extra ...ReqSendEvent) (resp *RespSendEvent, err error) { var req ReqSendEvent if len(extra) > 0 { req = extra[0] @@ -1176,11 +1418,23 @@ func (cli *Client) SendStateEvent(ctx context.Context, roomID id.RoomID, eventTy if req.MeowEventID != "" { queryParams["fi.mau.event_id"] = req.MeowEventID.String() } + if req.TransactionID != "" { + queryParams["fi.mau.transaction_id"] = req.TransactionID + } + if req.UnstableDelay > 0 { + queryParams["org.matrix.msc4140.delay"] = strconv.FormatInt(req.UnstableDelay.Milliseconds(), 10) + } + if req.UnstableStickyDuration > 0 { + queryParams["org.matrix.msc4354.sticky_duration_ms"] = strconv.FormatInt(req.UnstableStickyDuration.Milliseconds(), 10) + } + if req.Timestamp > 0 { + queryParams["ts"] = strconv.FormatInt(req.Timestamp, 10) + } urlData := ClientURLPath{"v3", "rooms", roomID, "state", eventType.String(), stateKey} urlPath := cli.BuildURLWithQuery(urlData, queryParams) _, err = cli.MakeRequest(ctx, http.MethodPut, urlPath, contentJSON, &resp) - if err == nil && cli.StateStore != nil { + if err == nil && cli.StateStore != nil && req.UnstableDelay == 0 { cli.updateStoreWithOutgoingEvent(ctx, roomID, eventType, stateKey, contentJSON) } return @@ -1188,14 +1442,44 @@ func (cli *Client) SendStateEvent(ctx context.Context, roomID id.RoomID, eventTy // SendMassagedStateEvent sends a state event into a room with a custom timestamp. See https://spec.matrix.org/v1.2/client-server-api/#put_matrixclientv3roomsroomidstateeventtypestatekey // contentJSON should be a pointer to something that can be encoded as JSON using json.Marshal. +// +// Deprecated: SendStateEvent accepts a timestamp via ReqSendEvent and should be used instead. func (cli *Client) SendMassagedStateEvent(ctx context.Context, roomID id.RoomID, eventType event.Type, stateKey string, contentJSON interface{}, ts int64) (resp *RespSendEvent, err error) { - urlPath := cli.BuildURLWithQuery(ClientURLPath{"v3", "rooms", roomID, "state", eventType.String(), stateKey}, map[string]string{ - "ts": strconv.FormatInt(ts, 10), + resp, err = cli.SendStateEvent(ctx, roomID, eventType, stateKey, contentJSON, ReqSendEvent{ + Timestamp: ts, }) - _, err = cli.MakeRequest(ctx, http.MethodPut, urlPath, contentJSON, &resp) - if err == nil && cli.StateStore != nil { - cli.updateStoreWithOutgoingEvent(ctx, roomID, eventType, stateKey, contentJSON) + return +} + +func (cli *Client) DelayedEvents(ctx context.Context, req *ReqDelayedEvents) (resp *RespDelayedEvents, err error) { + query := map[string]string{} + if req.DelayID != "" { + query["delay_id"] = string(req.DelayID) } + if req.Status != "" { + query["status"] = string(req.Status) + } + if req.NextBatch != "" { + query["next_batch"] = req.NextBatch + } + + urlPath := cli.BuildURLWithQuery(ClientURLPath{"unstable", "org.matrix.msc4140", "delayed_events"}, query) + _, err = cli.MakeRequest(ctx, http.MethodGet, urlPath, req, &resp) + + // Migration: merge old keys with new ones + if resp != nil { + resp.Scheduled = append(resp.Scheduled, resp.DelayedEvents...) + resp.DelayedEvents = nil + resp.Finalised = append(resp.Finalised, resp.FinalisedEvents...) + resp.FinalisedEvents = nil + } + + return +} + +func (cli *Client) UpdateDelayedEvent(ctx context.Context, req *ReqUpdateDelayedEvent) (resp *RespUpdateDelayedEvent, err error) { + urlPath := cli.BuildClientURL("unstable", "org.matrix.msc4140", "delayed_events", req.DelayID) + _, err = cli.MakeRequest(ctx, http.MethodPost, urlPath, req, &resp) return } @@ -1250,6 +1534,19 @@ func (cli *Client) RedactEvent(ctx context.Context, roomID id.RoomID, eventID id return } +func (cli *Client) UnstableRedactUserEvents(ctx context.Context, roomID id.RoomID, userID id.UserID, req *ReqRedactUser) (resp *RespRedactUserEvents, err error) { + if req == nil { + req = &ReqRedactUser{} + } + query := map[string]string{} + if req.Limit > 0 { + query["limit"] = strconv.Itoa(req.Limit) + } + urlPath := cli.BuildURLWithQuery(ClientURLPath{"unstable", "org.matrix.msc4194", "rooms", roomID, "redact", "user", userID}, query) + _, err = cli.MakeRequest(ctx, http.MethodPost, urlPath, req, &resp) + return +} + // CreateRoom creates a new Matrix room. See https://spec.matrix.org/v1.2/client-server-api/#post_matrixclientv3createroom // // resp, err := cli.CreateRoom(&mautrix.ReqCreateRoom{ @@ -1267,6 +1564,10 @@ func (cli *Client) CreateRoom(ctx context.Context, req *ReqCreateRoom) (resp *Re Msg("Failed to update creator membership in state store after creating room") } for _, evt := range req.InitialState { + evt.RoomID = resp.RoomID + if evt.StateKey == nil { + evt.StateKey = ptr.Ptr("") + } UpdateStateStore(ctx, cli.StateStore, evt) } inviteMembership := event.MembershipInvite @@ -1281,9 +1582,6 @@ func (cli *Client) CreateRoom(ctx context.Context, req *ReqCreateRoom) (resp *Re Msg("Failed to update membership in state store after creating room") } } - for _, evt := range req.InitialState { - cli.updateStoreWithOutgoingEvent(ctx, resp.RoomID, evt.Type, evt.GetStateKey(), &evt.Content) - } } return } @@ -1394,15 +1692,14 @@ func (cli *Client) GetOwnPresence(ctx context.Context) (resp *RespPresence, err return cli.GetPresence(ctx, cli.UserID) } -func (cli *Client) SetPresence(ctx context.Context, status event.Presence) (err error) { - req := ReqPresence{Presence: status} +func (cli *Client) SetPresence(ctx context.Context, presence ReqPresence) (err error) { u := cli.BuildClientURL("v3", "presence", cli.UserID, "status") - _, err = cli.MakeRequest(ctx, http.MethodPut, u, req, nil) + _, err = cli.MakeRequest(ctx, http.MethodPut, u, presence, nil) return } func (cli *Client) updateStoreWithOutgoingEvent(ctx context.Context, roomID id.RoomID, eventType event.Type, stateKey string, contentJSON interface{}) { - if cli.StateStore == nil { + if cli == nil || cli.StateStore == nil { return } fakeEvt := &event.Event{ @@ -1434,8 +1731,8 @@ func (cli *Client) updateStoreWithOutgoingEvent(ctx context.Context, roomID id.R UpdateStateStore(ctx, cli.StateStore, fakeEvt) } -// StateEvent gets a single state event in a room. It will attempt to JSON unmarshal into the given "outContent" struct with -// the HTTP response body, or return an error. +// StateEvent gets the content of a single state event in a room. +// It will attempt to JSON unmarshal into the given "outContent" struct with the HTTP response body, or return an error. // See https://spec.matrix.org/v1.2/client-server-api/#get_matrixclientv3roomsroomidstateeventtypestatekey func (cli *Client) StateEvent(ctx context.Context, roomID id.RoomID, eventType event.Type, stateKey string, outContent interface{}) (err error) { u := cli.BuildClientURL("v3", "rooms", roomID, "state", eventType.String(), stateKey) @@ -1446,12 +1743,43 @@ func (cli *Client) StateEvent(ctx context.Context, roomID id.RoomID, eventType e return } +// FullStateEvent gets a single state event in a room. Unlike [StateEvent], this gets the entire event +// (including details like the sender and timestamp). +// This requires the server to support the ?format=event query parameter, which is currently missing from the spec. +// See https://github.com/matrix-org/matrix-spec/issues/1047 for more info +func (cli *Client) FullStateEvent(ctx context.Context, roomID id.RoomID, eventType event.Type, stateKey string) (evt *event.Event, err error) { + u := cli.BuildURLWithQuery(ClientURLPath{"v3", "rooms", roomID, "state", eventType.String(), stateKey}, map[string]string{ + "format": "event", + }) + _, err = cli.MakeRequest(ctx, http.MethodGet, u, nil, &evt) + if evt != nil { + evt.Type.Class = event.StateEventType + _ = evt.Content.ParseRaw(evt.Type) + if evt.RoomID == "" { + evt.RoomID = roomID + } + } + if err == nil && cli.StateStore != nil { + UpdateStateStore(ctx, cli.StateStore, evt) + } + return +} + // parseRoomStateArray parses a JSON array as a stream and stores the events inside it in a room state map. -func parseRoomStateArray(_ *http.Request, res *http.Response, responseJSON interface{}) ([]byte, error) { +func parseRoomStateArray(req *http.Request, res *http.Response, responseJSON any, limit int64) ([]byte, error) { + if res.ContentLength > limit { + return nil, HTTPError{ + Request: req, + Response: res, + + Message: "not reading response", + WrappedError: fmt.Errorf("%w (%.2f MiB)", ErrResponseTooLong, float64(res.ContentLength)/1024/1024), + } + } response := make(RoomStateMap) responsePtr := responseJSON.(*map[event.Type]map[string]*event.Event) *responsePtr = response - dec := json.NewDecoder(res.Body) + dec := json.NewDecoder(io.LimitReader(res.Body, limit)) arrayStart, err := dec.Token() if err != nil { @@ -1485,6 +1813,8 @@ func parseRoomStateArray(_ *http.Request, res *http.Response, responseJSON inter return nil, nil } +type RoomStateMap = map[event.Type]map[string]*event.Event + // State gets all state in a room. // See https://spec.matrix.org/v1.2/client-server-api/#get_matrixclientv3roomsroomidstate func (cli *Client) State(ctx context.Context, roomID id.RoomID) (stateMap RoomStateMap, err error) { @@ -1494,12 +1824,21 @@ func (cli *Client) State(ctx context.Context, roomID id.RoomID) (stateMap RoomSt ResponseJSON: &stateMap, Handler: parseRoomStateArray, }) + if stateMap != nil { + pls, ok := stateMap[event.StatePowerLevels][""] + if ok { + pls.Content.AsPowerLevels().CreateEvent = stateMap[event.StateCreate][""] + } + } if err == nil && cli.StateStore != nil { for evtType, evts := range stateMap { if evtType == event.StateMember { continue } for _, evt := range evts { + if evt.RoomID == "" { + evt.RoomID = roomID + } UpdateStateStore(ctx, cli.StateStore, evt) } } @@ -1531,8 +1870,17 @@ func (cli *Client) GetMediaConfig(ctx context.Context) (resp *RespMediaConfig, e return } +func (cli *Client) RequestOpenIDToken(ctx context.Context) (resp *RespOpenIDToken, err error) { + _, err = cli.MakeRequest(ctx, http.MethodPost, cli.BuildClientURL("v3", "user", cli.UserID, "openid", "request_token"), nil, &resp) + return +} + // UploadLink uploads an HTTP URL and then returns an MXC URI. func (cli *Client) UploadLink(ctx context.Context, link string) (*RespMediaUpload, error) { + if cli == nil { + return nil, ErrClientIsNil + } + req, err := http.NewRequestWithContext(ctx, http.MethodGet, link, nil) if err != nil { return nil, err @@ -1549,6 +1897,9 @@ func (cli *Client) UploadLink(ctx context.Context, link string) (*RespMediaUploa } func (cli *Client) Download(ctx context.Context, mxcURL id.ContentURI) (*http.Response, error) { + if mxcURL.IsEmpty() { + return nil, fmt.Errorf("empty mxc uri provided to Download") + } _, resp, err := cli.MakeFullRequestWithResp(ctx, FullRequest{ Method: http.MethodGet, URL: cli.BuildClientURL("v1", "media", "download", mxcURL.Homeserver, mxcURL.FileID), @@ -1557,6 +1908,41 @@ func (cli *Client) Download(ctx context.Context, mxcURL id.ContentURI) (*http.Re return resp, err } +type DownloadThumbnailExtra struct { + Method string + Animated bool +} + +func (cli *Client) DownloadThumbnail(ctx context.Context, mxcURL id.ContentURI, height, width int, extras ...DownloadThumbnailExtra) (*http.Response, error) { + if mxcURL.IsEmpty() { + return nil, fmt.Errorf("empty mxc uri provided to DownloadThumbnail") + } + if len(extras) > 1 { + panic(fmt.Errorf("invalid number of arguments to DownloadThumbnail: %d", len(extras))) + } + var extra DownloadThumbnailExtra + if len(extras) == 1 { + extra = extras[0] + } + path := ClientURLPath{"v1", "media", "thumbnail", mxcURL.Homeserver, mxcURL.FileID} + query := map[string]string{ + "height": strconv.Itoa(height), + "width": strconv.Itoa(width), + } + if extra.Method != "" { + query["method"] = extra.Method + } + if extra.Animated { + query["animated"] = "true" + } + _, resp, err := cli.MakeFullRequestWithResp(ctx, FullRequest{ + Method: http.MethodGet, + URL: cli.BuildURLWithQuery(path, query), + DontReadResponse: true, + }) + return resp, err +} + func (cli *Client) DownloadBytes(ctx context.Context, mxcURL id.ContentURI) ([]byte, error) { resp, err := cli.Download(ctx, mxcURL) if err != nil { @@ -1603,10 +1989,15 @@ func (cli *Client) UploadAsync(ctx context.Context, req ReqUploadMedia) (*RespCr } req.MXC = resp.ContentURI req.UnstableUploadURL = resp.UnstableUploadURL + if req.AsyncContext == nil { + req.AsyncContext = cli.cliOrContextLog(ctx).WithContext(context.Background()) + } go func() { - _, err = cli.UploadMedia(ctx, req) + _, err = cli.UploadMedia(req.AsyncContext, req) if err != nil { - cli.Log.Error().Str("mxc", req.MXC.String()).Err(err).Msg("Async upload of media failed") + zerolog.Ctx(req.AsyncContext).Err(err). + Stringer("mxc", req.MXC). + Msg("Async upload of media failed") } }() return resp, nil @@ -1642,6 +2033,7 @@ type ReqUploadMedia struct { ContentType string FileName string + AsyncContext context.Context DoneCallback func() // MXC specifies an existing MXC URI which doesn't have content yet to upload into. @@ -1654,16 +2046,25 @@ type ReqUploadMedia struct { } func (cli *Client) tryUploadMediaToURL(ctx context.Context, url, contentType string, content io.Reader, contentLength int64) (*http.Response, error) { - cli.Log.Debug().Str("url", url).Msg("Uploading media to external URL") + cli.Log.Debug(). + Str("url", url). + Int64("content_length", contentLength). + Msg("Uploading media to external URL") req, err := http.NewRequestWithContext(ctx, http.MethodPut, url, content) if err != nil { return nil, err } req.ContentLength = contentLength req.Header.Set("Content-Type", contentType) - req.Header.Set("User-Agent", cli.UserAgent+" (external media uploader)") + if cli.UserAgent != "" { + req.Header.Set("User-Agent", cli.UserAgent+" (external media uploader)") + } - return http.DefaultClient.Do(req) + if cli.ExternalClient != nil { + return cli.ExternalClient.Do(req) + } else { + return http.DefaultClient.Do(req) + } } func (cli *Client) uploadMediaToURL(ctx context.Context, data ReqUploadMedia) (*RespMediaUpload, error) { @@ -1688,14 +2089,25 @@ func (cli *Client) uploadMediaToURL(ctx context.Context, data ReqUploadMedia) (* break } err = fmt.Errorf("HTTP %d", resp.StatusCode) + } else if errors.Is(err, context.Canceled) { + cli.Log.Warn().Str("url", data.UnstableUploadURL).Msg("External media upload canceled") + return nil, err } if retries <= 0 { cli.Log.Warn().Str("url", data.UnstableUploadURL).Err(err). Msg("Error uploading media to external URL, not retrying") return nil, err } - cli.Log.Warn().Str("url", data.UnstableUploadURL).Err(err). + backoff := time.Second * time.Duration(cli.DefaultHTTPRetries-retries) + cli.Log.Warn().Err(err). + Str("url", data.UnstableUploadURL). + Int("retry_in_seconds", int(backoff.Seconds())). Msg("Error uploading media to external URL, retrying") + select { + case <-time.After(backoff): + case <-ctx.Done(): + return nil, ctx.Err() + } retries-- _, err = readerSeeker.Seek(0, io.SeekStart) if err != nil { @@ -1733,6 +2145,9 @@ func (cli *Client) UploadMedia(ctx context.Context, data ReqUploadMedia) (*RespM if data.DoneCallback != nil { defer data.DoneCallback() } + if cli == nil { + return nil, ErrClientIsNil + } if data.UnstableUploadURL != "" { if data.MXC.IsEmpty() { return nil, errors.New("MXC must also be set when uploading to external URL") @@ -1866,6 +2281,12 @@ func (cli *Client) JoinedRooms(ctx context.Context) (resp *RespJoinedRooms, err return } +func (cli *Client) PublicRooms(ctx context.Context, req *ReqPublicRooms) (resp *RespPublicRooms, err error) { + urlPath := cli.BuildURLWithQuery(ClientURLPath{"v3", "publicRooms"}, req.Query()) + _, err = cli.MakeRequest(ctx, http.MethodGet, urlPath, nil, &resp) + return +} + // Hierarchy returns a list of rooms that are in the room's hierarchy. See https://spec.matrix.org/v1.4/client-server-api/#get_matrixclientv1roomsroomidhierarchy // // The hierarchy API is provided to walk the space tree and discover the rooms with their aesthetic details. works in a depth-first manner: @@ -1948,6 +2369,20 @@ func (cli *Client) GetEvent(ctx context.Context, roomID id.RoomID, eventID id.Ev return } +func (cli *Client) GetUnredactedEventContent(ctx context.Context, roomID id.RoomID, eventID id.EventID) (resp *event.Event, err error) { + urlPath := cli.BuildURLWithQuery(ClientURLPath{"v3", "rooms", roomID, "event", eventID}, map[string]string{ + "fi.mau.msc2815.include_unredacted_content": "true", + }) + _, err = cli.MakeRequest(ctx, http.MethodGet, urlPath, nil, &resp) + return +} + +func (cli *Client) GetRelations(ctx context.Context, roomID id.RoomID, eventID id.EventID, req *ReqGetRelations) (resp *RespGetRelations, err error) { + urlPath := cli.BuildURLWithQuery(append(ClientURLPath{"v1", "rooms", roomID, "relations", eventID}, req.PathSuffix()...), req.Query()) + _, err = cli.MakeRequest(ctx, http.MethodGet, urlPath, nil, &resp) + return +} + func (cli *Client) MarkRead(ctx context.Context, roomID id.RoomID, eventID id.EventID) (err error) { return cli.SendReceipt(ctx, roomID, eventID, event.ReceiptTypeRead, nil) } @@ -2252,15 +2687,15 @@ func (cli *Client) SetDeviceInfo(ctx context.Context, deviceID id.DeviceID, req return err } -func (cli *Client) DeleteDevice(ctx context.Context, deviceID id.DeviceID, req *ReqDeleteDevice) error { +func (cli *Client) DeleteDevice(ctx context.Context, deviceID id.DeviceID, req *ReqDeleteDevice[any]) error { urlPath := cli.BuildClientURL("v3", "devices", deviceID) _, err := cli.MakeRequest(ctx, http.MethodDelete, urlPath, req, nil) return err } -func (cli *Client) DeleteDevices(ctx context.Context, req *ReqDeleteDevices) error { +func (cli *Client) DeleteDevices(ctx context.Context, req *ReqDeleteDevices[any]) error { urlPath := cli.BuildClientURL("v3", "delete_devices") - _, err := cli.MakeRequest(ctx, http.MethodDelete, urlPath, req, nil) + _, err := cli.MakeRequest(ctx, http.MethodPost, urlPath, req, nil) return err } @@ -2269,7 +2704,7 @@ type UIACallback = func(*RespUserInteractive) interface{} // UploadCrossSigningKeys uploads the given cross-signing keys to the server. // Because the endpoint requires user-interactive authentication a callback must be provided that, // given the UI auth parameters, produces the required result (or nil to end the flow). -func (cli *Client) UploadCrossSigningKeys(ctx context.Context, keys *UploadCrossSigningKeysReq, uiaCallback UIACallback) error { +func (cli *Client) UploadCrossSigningKeys(ctx context.Context, keys *UploadCrossSigningKeysReq[any], uiaCallback UIACallback) error { content, err := cli.MakeFullRequest(ctx, FullRequest{ Method: http.MethodPost, URL: cli.BuildClientURL("v3", "keys", "device_signing", "upload"), @@ -2351,24 +2786,61 @@ func (cli *Client) ReportRoom(ctx context.Context, roomID id.RoomID, reason stri return err } -// BatchSend sends a batch of historical events into a room. This is only available for appservices. +// AdminWhoIs fetches session information belonging to a specific user. Typically requires being a server admin. // -// Deprecated: MSC2716 has been abandoned, so this is now Beeper-specific. BeeperBatchSend should be used instead. -func (cli *Client) BatchSend(ctx context.Context, roomID id.RoomID, req *ReqBatchSend) (resp *RespBatchSend, err error) { - path := ClientURLPath{"unstable", "org.matrix.msc2716", "rooms", roomID, "batch_send"} - query := map[string]string{ - "prev_event_id": req.PrevEventID.String(), +// https://spec.matrix.org/v1.15/client-server-api/#get_matrixclientv3adminwhoisuserid +func (cli *Client) AdminWhoIs(ctx context.Context, userID id.UserID) (resp RespWhoIs, err error) { + urlPath := cli.BuildClientURL("v3", "admin", "whois", userID) + _, err = cli.MakeRequest(ctx, http.MethodGet, urlPath, nil, &resp) + return +} + +func (cli *Client) makeMSC4323URL(action string, target id.UserID) string { + if cli.SpecVersions.Supports(FeatureUnstableAccountModeration) { + return cli.BuildClientURL("unstable", "uk.timedout.msc4323", "admin", action, target) + } else if cli.SpecVersions.Supports(FeatureStableAccountModeration) { + return cli.BuildClientURL("v1", "admin", action, target) } - if req.BeeperNewMessages { - query["com.beeper.new_messages"] = "true" + return "" +} + +// GetSuspendedStatus uses MSC4323 to check if a user is suspended. +func (cli *Client) GetSuspendedStatus(ctx context.Context, userID id.UserID) (res *RespSuspended, err error) { + urlPath := cli.makeMSC4323URL("suspend", userID) + if urlPath == "" { + return nil, MUnrecognized.WithMessage("Homeserver does not advertise MSC4323 support") } - if req.BeeperMarkReadBy != "" { - query["com.beeper.mark_read_by"] = req.BeeperMarkReadBy.String() + _, err = cli.MakeRequest(ctx, http.MethodGet, urlPath, nil, res) + return +} + +// GetLockStatus uses MSC4323 to check if a user is locked. +func (cli *Client) GetLockStatus(ctx context.Context, userID id.UserID) (res *RespLocked, err error) { + urlPath := cli.makeMSC4323URL("lock", userID) + if urlPath == "" { + return nil, MUnrecognized.WithMessage("Homeserver does not advertise MSC4323 support") } - if len(req.BatchID) > 0 { - query["batch_id"] = req.BatchID.String() + _, err = cli.MakeRequest(ctx, http.MethodGet, urlPath, nil, res) + return +} + +// SetSuspendedStatus uses MSC4323 to set whether a user account is suspended. +func (cli *Client) SetSuspendedStatus(ctx context.Context, userID id.UserID, suspended bool) (res *RespSuspended, err error) { + urlPath := cli.makeMSC4323URL("suspend", userID) + if urlPath == "" { + return nil, MUnrecognized.WithMessage("Homeserver does not advertise MSC4323 support") } - _, err = cli.MakeRequest(ctx, http.MethodPost, cli.BuildURLWithQuery(path, query), req, &resp) + _, err = cli.MakeRequest(ctx, http.MethodPut, urlPath, &ReqSuspend{Suspended: suspended}, res) + return +} + +// SetLockStatus uses MSC4323 to set whether a user account is locked. +func (cli *Client) SetLockStatus(ctx context.Context, userID id.UserID, locked bool) (res *RespLocked, err error) { + urlPath := cli.makeMSC4323URL("lock", userID) + if urlPath == "" { + return nil, MUnrecognized.WithMessage("Homeserver does not advertise MSC4323 support") + } + _, err = cli.MakeRequest(ctx, http.MethodPut, urlPath, &ReqLocked{Locked: locked}, res) return } @@ -2409,6 +2881,9 @@ func (cli *Client) BeeperDeleteRoom(ctx context.Context, roomID id.RoomID) (err // TxnID returns the next transaction ID. func (cli *Client) TxnID() string { + if cli == nil { + return "client is nil" + } txnID := atomic.AddInt32(&cli.txnID, 1) return fmt.Sprintf("mautrix-go_%d_%d", time.Now().UnixNano(), txnID) } diff --git a/client_ephemeral_test.go b/client_ephemeral_test.go new file mode 100644 index 00000000..c2846427 --- /dev/null +++ b/client_ephemeral_test.go @@ -0,0 +1,158 @@ +// Copyright (c) 2026 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package mautrix_test + +import ( + "context" + "encoding/json" + "errors" + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "maunium.net/go/mautrix" + "maunium.net/go/mautrix/event" + "maunium.net/go/mautrix/id" +) + +func TestClient_SendEphemeralEvent_UsesUnstablePathTxnAndTS(t *testing.T) { + roomID := id.RoomID("!room:example.com") + evtType := event.Type{Type: "com.example.ephemeral", Class: event.EphemeralEventType} + txnID := "txn-123" + + var gotPath string + var gotQueryTS string + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + gotPath = r.URL.Path + gotQueryTS = r.URL.Query().Get("ts") + assert.Equal(t, http.MethodPut, r.Method) + _, _ = w.Write([]byte(`{"event_id":"$evt"}`)) + })) + defer ts.Close() + + cli, err := mautrix.NewClient(ts.URL, "", "") + require.NoError(t, err) + + _, err = cli.BeeperSendEphemeralEvent( + context.Background(), + roomID, + evtType, + map[string]any{"foo": "bar"}, + mautrix.ReqSendEvent{TransactionID: txnID, Timestamp: 1234}, + ) + require.NoError(t, err) + + assert.True(t, strings.Contains(gotPath, "/_matrix/client/unstable/com.beeper.ephemeral/rooms/")) + assert.True(t, strings.HasSuffix(gotPath, "/ephemeral/com.example.ephemeral/"+txnID)) + assert.Equal(t, "1234", gotQueryTS) +} + +func TestClient_SendEphemeralEvent_UnsupportedReturnsMUnrecognized(t *testing.T) { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusNotFound) + _, _ = w.Write([]byte(`{"errcode":"M_UNRECOGNIZED","error":"Unrecognized endpoint"}`)) + })) + defer ts.Close() + + cli, err := mautrix.NewClient(ts.URL, "", "") + require.NoError(t, err) + + _, err = cli.BeeperSendEphemeralEvent( + context.Background(), + id.RoomID("!room:example.com"), + event.Type{Type: "com.example.ephemeral", Class: event.EphemeralEventType}, + map[string]any{"foo": "bar"}, + ) + require.Error(t, err) + assert.True(t, errors.Is(err, mautrix.MUnrecognized)) +} + +func TestClient_SendEphemeralEvent_EncryptsInEncryptedRooms(t *testing.T) { + roomID := id.RoomID("!room:example.com") + evtType := event.Type{Type: "com.example.ephemeral", Class: event.EphemeralEventType} + txnID := "txn-encrypted" + + stateStore := mautrix.NewMemoryStateStore() + err := stateStore.SetEncryptionEvent(context.Background(), roomID, &event.EncryptionEventContent{ + Algorithm: id.AlgorithmMegolmV1, + }) + require.NoError(t, err) + + fakeCrypto := &fakeCryptoHelper{ + encryptedContent: &event.EncryptedEventContent{ + Algorithm: id.AlgorithmMegolmV1, + MegolmCiphertext: []byte("ciphertext"), + }, + } + + var gotPath string + var gotBody map[string]any + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + gotPath = r.URL.Path + assert.Equal(t, http.MethodPut, r.Method) + err := json.NewDecoder(r.Body).Decode(&gotBody) + require.NoError(t, err) + _, _ = w.Write([]byte(`{"event_id":"$evt"}`)) + })) + defer ts.Close() + + cli, err := mautrix.NewClient(ts.URL, "", "") + require.NoError(t, err) + cli.StateStore = stateStore + cli.Crypto = fakeCrypto + + _, err = cli.BeeperSendEphemeralEvent( + context.Background(), + roomID, + evtType, + map[string]any{"foo": "bar"}, + mautrix.ReqSendEvent{TransactionID: txnID}, + ) + require.NoError(t, err) + + assert.True(t, strings.HasSuffix(gotPath, "/ephemeral/m.room.encrypted/"+txnID)) + assert.Equal(t, string(id.AlgorithmMegolmV1), gotBody["algorithm"]) + assert.Equal(t, 1, fakeCrypto.encryptCalls) + assert.Equal(t, roomID, fakeCrypto.lastRoomID) + assert.Equal(t, evtType, fakeCrypto.lastEventType) +} + +type fakeCryptoHelper struct { + encryptCalls int + lastRoomID id.RoomID + lastEventType event.Type + lastEncryptInput any + encryptedContent *event.EncryptedEventContent +} + +func (f *fakeCryptoHelper) Encrypt(_ context.Context, roomID id.RoomID, eventType event.Type, content any) (*event.EncryptedEventContent, error) { + f.encryptCalls++ + f.lastRoomID = roomID + f.lastEventType = eventType + f.lastEncryptInput = content + return f.encryptedContent, nil +} + +func (f *fakeCryptoHelper) Decrypt(context.Context, *event.Event) (*event.Event, error) { + return nil, nil +} + +func (f *fakeCryptoHelper) WaitForSession(context.Context, id.RoomID, id.SenderKey, id.SessionID, time.Duration) bool { + return false +} + +func (f *fakeCryptoHelper) RequestSession(context.Context, id.RoomID, id.SenderKey, id.SessionID, id.UserID, id.DeviceID) { +} + +func (f *fakeCryptoHelper) Init(context.Context) error { + return nil +} diff --git a/commands/container.go b/commands/container.go new file mode 100644 index 00000000..9b909b75 --- /dev/null +++ b/commands/container.go @@ -0,0 +1,133 @@ +// Copyright (c) 2026 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package commands + +import ( + "fmt" + "slices" + "strings" + "sync" + + "go.mau.fi/util/exmaps" + + "maunium.net/go/mautrix/event/cmdschema" +) + +type CommandContainer[MetaType any] struct { + commands map[string]*Handler[MetaType] + aliases map[string]string + lock sync.RWMutex + parent *Handler[MetaType] +} + +func NewCommandContainer[MetaType any]() *CommandContainer[MetaType] { + return &CommandContainer[MetaType]{ + commands: make(map[string]*Handler[MetaType]), + aliases: make(map[string]string), + } +} + +func (cont *CommandContainer[MetaType]) AllSpecs() []*cmdschema.EventContent { + data := make(exmaps.Set[*Handler[MetaType]]) + cont.collectHandlers(data) + specs := make([]*cmdschema.EventContent, 0, data.Size()) + for handler := range data.Iter() { + if handler.Parameters != nil { + specs = append(specs, handler.Spec()) + } + } + return specs +} + +func (cont *CommandContainer[MetaType]) collectHandlers(into exmaps.Set[*Handler[MetaType]]) { + cont.lock.RLock() + defer cont.lock.RUnlock() + for _, handler := range cont.commands { + into.Add(handler) + if handler.subcommandContainer != nil { + handler.subcommandContainer.collectHandlers(into) + } + } +} + +// Register registers the given command handlers. +func (cont *CommandContainer[MetaType]) Register(handlers ...*Handler[MetaType]) { + if cont == nil { + return + } + cont.lock.Lock() + defer cont.lock.Unlock() + for i, handler := range handlers { + if handler == nil { + panic(fmt.Errorf("handler #%d is nil", i+1)) + } + cont.registerOne(handler) + } +} + +func (cont *CommandContainer[MetaType]) registerOne(handler *Handler[MetaType]) { + if strings.ToLower(handler.Name) != handler.Name { + panic(fmt.Errorf("command %q is not lowercase", handler.Name)) + } else if val, alreadyExists := cont.commands[handler.Name]; alreadyExists && val != handler { + panic(fmt.Errorf("tried to register command %q, but it's already registered", handler.Name)) + } else if aliasTarget, alreadyExists := cont.aliases[handler.Name]; alreadyExists { + panic(fmt.Errorf("tried to register command %q, but it's already registered as an alias for %q", handler.Name, aliasTarget)) + } + if !slices.Contains(handler.parents, cont.parent) { + handler.parents = append(handler.parents, cont.parent) + handler.nestedNameCache = nil + } + cont.commands[handler.Name] = handler + for _, alias := range handler.Aliases { + if strings.ToLower(alias) != alias { + panic(fmt.Errorf("alias %q is not lowercase", alias)) + } else if val, alreadyExists := cont.aliases[alias]; alreadyExists && val != handler.Name { + panic(fmt.Errorf("tried to register alias %q for %q, but it's already registered for %q", alias, handler.Name, cont.aliases[alias])) + } else if _, alreadyExists = cont.commands[alias]; alreadyExists { + panic(fmt.Errorf("tried to register alias %q for %q, but it's already registered as a command", alias, handler.Name)) + } + cont.aliases[alias] = handler.Name + } + handler.initSubcommandContainer() +} + +func (cont *CommandContainer[MetaType]) Unregister(handlers ...*Handler[MetaType]) { + if cont == nil { + return + } + cont.lock.Lock() + defer cont.lock.Unlock() + for _, handler := range handlers { + cont.unregisterOne(handler) + } +} + +func (cont *CommandContainer[MetaType]) unregisterOne(handler *Handler[MetaType]) { + delete(cont.commands, handler.Name) + for _, alias := range handler.Aliases { + if cont.aliases[alias] == handler.Name { + delete(cont.aliases, alias) + } + } +} + +func (cont *CommandContainer[MetaType]) GetHandler(name string) *Handler[MetaType] { + if cont == nil { + return nil + } + cont.lock.RLock() + defer cont.lock.RUnlock() + alias, ok := cont.aliases[name] + if ok { + name = alias + } + handler, ok := cont.commands[name] + if !ok { + handler = cont.commands[UnknownCommandName] + } + return handler +} diff --git a/commands/event.go b/commands/event.go new file mode 100644 index 00000000..76d6c9f0 --- /dev/null +++ b/commands/event.go @@ -0,0 +1,237 @@ +// Copyright (c) 2026 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package commands + +import ( + "context" + "encoding/json" + "fmt" + "strings" + + "github.com/rs/zerolog" + + "maunium.net/go/mautrix/event" + "maunium.net/go/mautrix/format" + "maunium.net/go/mautrix/id" +) + +// Event contains the data of a single command event. +// It also provides some helper methods for responding to the command. +type Event[MetaType any] struct { + *event.Event + // RawInput is the entire message before splitting into command and arguments. + RawInput string + // ParentCommands is the chain of commands leading up to this command. + // This is only set if the command is a subcommand. + ParentCommands []string + ParentHandlers []*Handler[MetaType] + // Command is the lowercased first word of the message. + Command string + // Args are the rest of the message split by whitespace ([strings.Fields]). + Args []string + // RawArgs is the same as args, but without the splitting by whitespace. + RawArgs string + + StructuredArgs json.RawMessage + + Ctx context.Context + Log *zerolog.Logger + Proc *Processor[MetaType] + Handler *Handler[MetaType] + Meta MetaType + + redactedBy id.EventID +} + +var IDHTMLParser = &format.HTMLParser{ + PillConverter: func(displayname, mxid, eventID string, ctx format.Context) string { + if len(mxid) == 0 { + return displayname + } + if eventID != "" { + return fmt.Sprintf("https://matrix.to/#/%s/%s", mxid, eventID) + } + return mxid + }, + ItalicConverter: func(s string, c format.Context) string { + return fmt.Sprintf("*%s*", s) + }, + Newline: "\n", +} + +// ParseEvent parses a message into a command event struct. +func (proc *Processor[MetaType]) ParseEvent(ctx context.Context, evt *event.Event) *Event[MetaType] { + content, ok := evt.Content.Parsed.(*event.MessageEventContent) + if !ok || content.MsgType == event.MsgNotice || content.RelatesTo.GetReplaceID() != "" { + return nil + } + text := content.Body + if content.Format == event.FormatHTML { + text = IDHTMLParser.Parse(content.FormattedBody, format.NewContext(ctx)) + } + if content.MSC4391BotCommand != nil { + if !content.Mentions.Has(proc.Client.UserID) || len(content.Mentions.UserIDs) != 1 { + return nil + } + wrapped := StructuredCommandToEvent[MetaType](ctx, evt, content.MSC4391BotCommand) + wrapped.RawInput = text + return wrapped + } + if len(text) == 0 { + return nil + } + return RawTextToEvent[MetaType](ctx, evt, text) +} + +func StructuredCommandToEvent[MetaType any](ctx context.Context, evt *event.Event, content *event.MSC4391BotCommandInput) *Event[MetaType] { + commandParts := strings.Split(content.Command, " ") + return &Event[MetaType]{ + Event: evt, + // Fake a command and args to let the subcommand finder in Process work. + Command: commandParts[0], + Args: commandParts[1:], + Ctx: ctx, + Log: zerolog.Ctx(ctx), + + StructuredArgs: content.Arguments, + } +} + +func RawTextToEvent[MetaType any](ctx context.Context, evt *event.Event, text string) *Event[MetaType] { + parts := strings.Fields(text) + if len(parts) == 0 { + parts = []string{""} + } + return &Event[MetaType]{ + Event: evt, + RawInput: text, + Command: strings.ToLower(parts[0]), + Args: parts[1:], + RawArgs: strings.TrimLeft(strings.TrimPrefix(text, parts[0]), " "), + Log: zerolog.Ctx(ctx), + Ctx: ctx, + } +} + +type ReplyOpts struct { + AllowHTML bool + AllowMarkdown bool + Reply bool + Thread bool + SendAsText bool + Edit id.EventID + OverrideMentions *event.Mentions + Extra map[string]any +} + +func (evt *Event[MetaType]) Reply(msg string, args ...any) id.EventID { + if len(args) > 0 { + msg = fmt.Sprintf(msg, args...) + } + return evt.Respond(msg, ReplyOpts{AllowMarkdown: true, Reply: true}) +} + +func (evt *Event[MetaType]) Respond(msg string, opts ReplyOpts) id.EventID { + content := format.RenderMarkdown(msg, opts.AllowMarkdown, opts.AllowHTML) + if opts.Thread { + content.SetThread(evt.Event) + } + if opts.Reply { + content.SetReply(evt.Event) + } + if !opts.SendAsText { + content.MsgType = event.MsgNotice + } + if opts.Edit != "" { + content.SetEdit(opts.Edit) + } + if opts.OverrideMentions != nil { + content.Mentions = opts.OverrideMentions + } + var wrapped any = &content + if opts.Extra != nil { + wrapped = &event.Content{ + Parsed: &content, + Raw: opts.Extra, + } + } + resp, err := evt.Proc.Client.SendMessageEvent(evt.Ctx, evt.RoomID, event.EventMessage, wrapped) + if err != nil { + zerolog.Ctx(evt.Ctx).Err(err).Msg("Failed to send reply") + return "" + } + return resp.EventID +} + +func (evt *Event[MetaType]) React(emoji string) id.EventID { + resp, err := evt.Proc.Client.SendReaction(evt.Ctx, evt.RoomID, evt.ID, emoji) + if err != nil { + zerolog.Ctx(evt.Ctx).Err(err).Msg("Failed to send reaction") + return "" + } + return resp.EventID +} + +func (evt *Event[MetaType]) Redact() id.EventID { + if evt.redactedBy != "" { + return evt.redactedBy + } + resp, err := evt.Proc.Client.RedactEvent(evt.Ctx, evt.RoomID, evt.ID) + if err != nil { + zerolog.Ctx(evt.Ctx).Err(err).Msg("Failed to redact command") + return "" + } + evt.redactedBy = resp.EventID + return resp.EventID +} + +func (evt *Event[MetaType]) MarkRead() { + err := evt.Proc.Client.MarkRead(evt.Ctx, evt.RoomID, evt.ID) + if err != nil { + zerolog.Ctx(evt.Ctx).Err(err).Msg("Failed to send read receipt") + } +} + +// ShiftArg removes the first argument from the Args list and RawArgs data and returns it. +// RawInput will not be modified. +func (evt *Event[MetaType]) ShiftArg() string { + if len(evt.Args) == 0 { + return "" + } + firstArg := evt.Args[0] + evt.RawArgs = strings.TrimLeft(strings.TrimPrefix(evt.RawArgs, evt.Args[0]), " ") + evt.Args = evt.Args[1:] + return firstArg +} + +// UnshiftArg reverses ShiftArg by adding the given value to the beginning of the Args list and RawArgs data. +func (evt *Event[MetaType]) UnshiftArg(arg string) { + evt.RawArgs = arg + " " + evt.RawArgs + evt.Args = append([]string{arg}, evt.Args...) +} + +func (evt *Event[MetaType]) ParseArgs(into any) error { + return json.Unmarshal(evt.StructuredArgs, into) +} + +func ParseArgs[T, MetaType any](evt *Event[MetaType]) (into T, err error) { + err = evt.ParseArgs(&into) + return +} + +func WithParsedArgs[T, MetaType any](fn func(*Event[MetaType], T)) func(*Event[MetaType]) { + return func(evt *Event[MetaType]) { + parsed, err := ParseArgs[T, MetaType](evt) + if err != nil { + evt.Log.Debug().Err(err).Msg("Failed to parse structured args into struct") + // TODO better error, usage info? deduplicate with Process + evt.Reply("Failed to parse arguments: %v", err) + return + } + fn(evt, parsed) + } +} diff --git a/commands/handler.go b/commands/handler.go new file mode 100644 index 00000000..56f27f06 --- /dev/null +++ b/commands/handler.go @@ -0,0 +1,105 @@ +// Copyright (c) 2026 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package commands + +import ( + "strings" + + "maunium.net/go/mautrix/event" + "maunium.net/go/mautrix/event/cmdschema" +) + +type Handler[MetaType any] struct { + // Func is the function that is called when the command is executed. + Func func(ce *Event[MetaType]) + + // Name is the primary name of the command. It must be lowercase. + Name string + // Aliases are alternative names for the command. They must be lowercase. + Aliases []string + // Subcommands are subcommands of this command. + Subcommands []*Handler[MetaType] + // PreFunc is a function that is called before checking subcommands. + // It can be used to have parameters between subcommands (e.g. `!rooms `). + // Event.ShiftArg will likely be useful for implementing such parameters. + PreFunc func(ce *Event[MetaType]) + + // Description is a short description of the command. + Description *event.ExtensibleTextContainer + // Parameters is a description of structured command parameters. + // If set, the StructuredArgs field of Event will be populated. + Parameters []*cmdschema.Parameter + TailParam string + + parents []*Handler[MetaType] + nestedNameCache []string + subcommandContainer *CommandContainer[MetaType] +} + +func (h *Handler[MetaType]) NestedNames() []string { + if h.nestedNameCache != nil { + return h.nestedNameCache + } + nestedNames := make([]string, 0, (1+len(h.Aliases))*len(h.parents)) + for _, parent := range h.parents { + if parent == nil { + nestedNames = append(nestedNames, h.Name) + nestedNames = append(nestedNames, h.Aliases...) + } else { + for _, parentName := range parent.NestedNames() { + nestedNames = append(nestedNames, parentName+" "+h.Name) + for _, alias := range h.Aliases { + nestedNames = append(nestedNames, parentName+" "+alias) + } + } + } + } + h.nestedNameCache = nestedNames + return nestedNames +} + +func (h *Handler[MetaType]) Spec() *cmdschema.EventContent { + names := h.NestedNames() + return &cmdschema.EventContent{ + Command: names[0], + Aliases: names[1:], + Parameters: h.Parameters, + Description: h.Description, + TailParam: h.TailParam, + } +} + +func (h *Handler[MetaType]) CopyFrom(other *Handler[MetaType]) { + if h.Parameters == nil { + h.Parameters = other.Parameters + h.TailParam = other.TailParam + } + h.Func = other.Func +} + +func (h *Handler[MetaType]) initSubcommandContainer() { + if len(h.Subcommands) > 0 { + h.subcommandContainer = NewCommandContainer[MetaType]() + h.subcommandContainer.parent = h + h.subcommandContainer.Register(h.Subcommands...) + } else { + h.subcommandContainer = nil + } +} + +func MakeUnknownCommandHandler[MetaType any](prefix string) *Handler[MetaType] { + return &Handler[MetaType]{ + Name: UnknownCommandName, + Func: func(ce *Event[MetaType]) { + if len(ce.ParentCommands) == 0 { + ce.Reply("Unknown command `%s%s`", prefix, ce.Command) + } else { + ce.Reply("Unknown subcommand `%s%s %s`", prefix, strings.Join(ce.ParentCommands, " "), ce.Command) + } + }, + } +} diff --git a/commands/prevalidate.go b/commands/prevalidate.go new file mode 100644 index 00000000..facca4da --- /dev/null +++ b/commands/prevalidate.go @@ -0,0 +1,84 @@ +// Copyright (c) 2025 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package commands + +import ( + "strings" +) + +// A PreValidator contains a function that takes an Event and returns true if the event should be processed further. +// +// The [PreValidator] field in [Processor] is called before the handler of the command is checked. +// It can be used to modify the command or arguments, or to skip the command entirely. +// +// The primary use case is removing a static command prefix, such as requiring all commands start with `!`. +type PreValidator[MetaType any] interface { + Validate(*Event[MetaType]) bool +} + +// FuncPreValidator is a simple function that implements the PreValidator interface. +type FuncPreValidator[MetaType any] func(*Event[MetaType]) bool + +func (f FuncPreValidator[MetaType]) Validate(ce *Event[MetaType]) bool { + return f(ce) +} + +// AllPreValidator can be used to combine multiple PreValidators, such that +// all of them must return true for the command to be processed further. +type AllPreValidator[MetaType any] []PreValidator[MetaType] + +func (f AllPreValidator[MetaType]) Validate(ce *Event[MetaType]) bool { + for _, validator := range f { + if !validator.Validate(ce) { + return false + } + } + return true +} + +// AnyPreValidator can be used to combine multiple PreValidators, such that +// at least one of them must return true for the command to be processed further. +type AnyPreValidator[MetaType any] []PreValidator[MetaType] + +func (f AnyPreValidator[MetaType]) Validate(ce *Event[MetaType]) bool { + for _, validator := range f { + if validator.Validate(ce) { + return true + } + } + return false +} + +// ValidatePrefixCommand checks that the first word in the input is exactly the given string, +// and if so, removes it from the command and sets the command to the next word. +// +// For example, `ValidateCommandPrefix("!mybot")` would only allow commands in the form `!mybot foo`, +// where `foo` would be used to look up the command handler. +func ValidatePrefixCommand[MetaType any](prefix string) PreValidator[MetaType] { + return FuncPreValidator[MetaType](func(ce *Event[MetaType]) bool { + if ce.Command == prefix && len(ce.Args) > 0 { + ce.Command = strings.ToLower(ce.ShiftArg()) + return true + } + return false + }) +} + +// ValidatePrefixSubstring checks that the command starts with the given prefix, +// and if so, removes it from the command. +// +// For example, `ValidatePrefixSubstring("!")` would only allow commands in the form `!foo`, +// where `foo` would be used to look up the command handler. +func ValidatePrefixSubstring[MetaType any](prefix string) PreValidator[MetaType] { + return FuncPreValidator[MetaType](func(ce *Event[MetaType]) bool { + if strings.HasPrefix(ce.Command, prefix) { + ce.Command = ce.Command[len(prefix):] + return true + } + return false + }) +} diff --git a/commands/processor.go b/commands/processor.go new file mode 100644 index 00000000..80f6745d --- /dev/null +++ b/commands/processor.go @@ -0,0 +1,152 @@ +// Copyright (c) 2026 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package commands + +import ( + "context" + "runtime/debug" + "strings" + + "github.com/rs/zerolog" + + "maunium.net/go/mautrix" + "maunium.net/go/mautrix/event" +) + +// Processor implements boilerplate code for splitting messages into a command and arguments, +// and finding the appropriate handler for the command. +type Processor[MetaType any] struct { + *CommandContainer[MetaType] + + Client *mautrix.Client + LogArgs bool + PreValidator PreValidator[MetaType] + Meta MetaType + + ReactionCommandPrefix string +} + +// UnknownCommandName is the name of the fallback handler which is used if no other handler is found. +// If even the unknown command handler is not found, the command is ignored. +const UnknownCommandName = "__unknown-command__" + +func NewProcessor[MetaType any](cli *mautrix.Client) *Processor[MetaType] { + proc := &Processor[MetaType]{ + CommandContainer: NewCommandContainer[MetaType](), + Client: cli, + PreValidator: ValidatePrefixSubstring[MetaType]("!"), + } + proc.Register(MakeUnknownCommandHandler[MetaType]("!")) + return proc +} + +func (proc *Processor[MetaType]) Process(ctx context.Context, evt *event.Event) { + log := zerolog.Ctx(ctx).With(). + Stringer("sender", evt.Sender). + Stringer("room_id", evt.RoomID). + Stringer("event_id", evt.ID). + Logger() + defer func() { + panicErr := recover() + if panicErr != nil { + logEvt := log.Error(). + Bytes(zerolog.ErrorStackFieldName, debug.Stack()) + if realErr, ok := panicErr.(error); ok { + logEvt = logEvt.Err(realErr) + } else { + logEvt = logEvt.Any(zerolog.ErrorFieldName, panicErr) + } + logEvt.Msg("Panic in command handler") + _, err := proc.Client.SendReaction(ctx, evt.RoomID, evt.ID, "💥") + if err != nil { + log.Err(err).Msg("Failed to send reaction after panic") + } + } + }() + var parsed *Event[MetaType] + switch evt.Type { + case event.EventReaction: + parsed = proc.ParseReaction(ctx, evt) + case event.EventMessage: + parsed = proc.ParseEvent(ctx, evt) + } + if parsed == nil || (!proc.PreValidator.Validate(parsed) && parsed.StructuredArgs == nil) { + return + } + parsed.Proc = proc + parsed.Meta = proc.Meta + parsed.Ctx = ctx + + handler := proc.GetHandler(parsed.Command) + if handler == nil { + return + } + parsed.Handler = handler + if handler.PreFunc != nil { + handler.PreFunc(parsed) + } + handlerChain := zerolog.Arr() + handlerChain.Str(handler.Name) + for handler.subcommandContainer != nil && len(parsed.Args) > 0 { + subHandler := handler.subcommandContainer.GetHandler(strings.ToLower(parsed.Args[0])) + if subHandler != nil { + parsed.ParentCommands = append(parsed.ParentCommands, parsed.Command) + parsed.ParentHandlers = append(parsed.ParentHandlers, handler) + handler = subHandler + handlerChain.Str(subHandler.Name) + parsed.Command = strings.ToLower(parsed.ShiftArg()) + parsed.Handler = subHandler + if subHandler.PreFunc != nil { + subHandler.PreFunc(parsed) + } + } else { + break + } + } + if parsed.StructuredArgs != nil && len(parsed.Args) > 0 { + // TODO allow unknown command handlers to be called? + // The client sent MSC4391 data, but the target command wasn't found + log.Debug().Msg("Didn't find handler for MSC4391 command") + return + } + + logWith := log.With(). + Str("command", parsed.Command). + Array("handler", handlerChain) + if len(parsed.ParentCommands) > 0 { + logWith = logWith.Strs("parent_commands", parsed.ParentCommands) + } + if proc.LogArgs { + logWith = logWith.Strs("args", parsed.Args) + if parsed.StructuredArgs != nil { + logWith = logWith.RawJSON("structured_args", parsed.StructuredArgs) + } + } + log = logWith.Logger() + parsed.Ctx = log.WithContext(ctx) + parsed.Log = &log + + if handler.Parameters != nil && parsed.StructuredArgs == nil { + // The handler wants structured parameters, but the client didn't send MSC4391 data + var err error + parsed.StructuredArgs, err = handler.Spec().ParseArguments(parsed.RawArgs) + if err != nil { + log.Debug().Err(err).Msg("Failed to parse structured arguments") + // TODO better error, usage info? deduplicate with WithParsedArgs + parsed.Reply("Failed to parse arguments: %v", err) + return + } + if proc.LogArgs { + log.UpdateContext(func(c zerolog.Context) zerolog.Context { + return c.RawJSON("structured_args", parsed.StructuredArgs) + }) + } + } + + log.Debug().Msg("Processing command") + handler.Func(parsed) +} diff --git a/commands/reactions.go b/commands/reactions.go new file mode 100644 index 00000000..0d316219 --- /dev/null +++ b/commands/reactions.go @@ -0,0 +1,143 @@ +// Copyright (c) 2026 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package commands + +import ( + "context" + "encoding/json" + "strings" + + "github.com/rs/zerolog" + + "maunium.net/go/mautrix" + "maunium.net/go/mautrix/event" +) + +const ReactionCommandsKey = "fi.mau.reaction_commands" +const ReactionMultiUseKey = "fi.mau.reaction_multi_use" + +type ReactionCommandData struct { + Command string `json:"command"` + Args any `json:"args,omitempty"` +} + +func (proc *Processor[MetaType]) ParseReaction(ctx context.Context, evt *event.Event) *Event[MetaType] { + content, ok := evt.Content.Parsed.(*event.ReactionEventContent) + if !ok { + return nil + } + evtID := content.RelatesTo.EventID + if evtID == "" || !strings.HasPrefix(content.RelatesTo.Key, proc.ReactionCommandPrefix) { + return nil + } + targetEvt, err := proc.Client.GetEvent(ctx, evt.RoomID, evtID) + if err != nil { + zerolog.Ctx(ctx).Err(err).Stringer("target_event_id", evtID).Msg("Failed to get target event for reaction") + return nil + } else if targetEvt.Sender != proc.Client.UserID || targetEvt.Unsigned.RedactedBecause != nil { + return nil + } + if targetEvt.Type == event.EventEncrypted { + if proc.Client.Crypto == nil { + zerolog.Ctx(ctx).Warn(). + Stringer("target_event_id", evtID). + Msg("Received reaction to encrypted event, but don't have crypto helper in client") + return nil + } + _ = targetEvt.Content.ParseRaw(targetEvt.Type) + targetEvt, err = proc.Client.Crypto.Decrypt(ctx, targetEvt) + if err != nil { + zerolog.Ctx(ctx).Err(err). + Stringer("target_event_id", evtID). + Msg("Failed to decrypt target event for reaction") + return nil + } + } + reactionCommands, ok := targetEvt.Content.Raw[ReactionCommandsKey].(map[string]any) + if !ok { + zerolog.Ctx(ctx).Trace(). + Stringer("target_event_id", evtID). + Msg("Reaction target event doesn't have commands key") + return nil + } + isMultiUse, _ := targetEvt.Content.Raw[ReactionMultiUseKey].(bool) + rawCmd, ok := reactionCommands[content.RelatesTo.Key] + if !ok { + zerolog.Ctx(ctx).Debug(). + Stringer("target_event_id", evtID). + Str("reaction_key", content.RelatesTo.Key). + Msg("Reaction command not found in target event") + return nil + } + var wrappedEvt *Event[MetaType] + switch typedCmd := rawCmd.(type) { + case string: + wrappedEvt = RawTextToEvent[MetaType](ctx, evt, typedCmd) + case map[string]any: + var input event.MSC4391BotCommandInput + if marshaled, err := json.Marshal(typedCmd); err != nil { + + } else if err = json.Unmarshal(marshaled, &input); err != nil { + + } else { + wrappedEvt = StructuredCommandToEvent[MetaType](ctx, evt, &input) + } + } + if wrappedEvt == nil { + zerolog.Ctx(ctx).Debug(). + Stringer("target_event_id", evtID). + Str("reaction_key", content.RelatesTo.Key). + Msg("Reaction command data is invalid") + return nil + } + wrappedEvt.Proc = proc + wrappedEvt.Redact() + if !isMultiUse { + DeleteAllReactions(ctx, proc.Client, evt) + } + if wrappedEvt.Command == "" { + return nil + } + return wrappedEvt +} + +func DeleteAllReactionsCommandFunc[MetaType any](ce *Event[MetaType]) { + DeleteAllReactions(ce.Ctx, ce.Proc.Client, ce.Event) +} + +func DeleteAllReactions(ctx context.Context, client *mautrix.Client, evt *event.Event) { + rel, ok := evt.Content.Parsed.(event.Relatable) + if !ok { + return + } + relation := rel.OptionalGetRelatesTo() + if relation == nil { + return + } + targetEvt := relation.GetReplyTo() + if targetEvt == "" { + targetEvt = relation.GetAnnotationID() + } + if targetEvt == "" { + return + } + relations, err := client.GetRelations(ctx, evt.RoomID, targetEvt, &mautrix.ReqGetRelations{ + RelationType: event.RelAnnotation, + EventType: event.EventReaction, + Limit: 20, + }) + if err != nil { + zerolog.Ctx(ctx).Err(err).Msg("Failed to get reactions to delete") + return + } + for _, relEvt := range relations.Chunk { + _, err = client.RedactEvent(ctx, relEvt.RoomID, relEvt.ID) + if err != nil { + zerolog.Ctx(ctx).Err(err).Stringer("event_id", relEvt.ID).Msg("Failed to redact reaction event") + } + } +} diff --git a/crypto/aescbc/aes_cbc_test.go b/crypto/aescbc/aes_cbc_test.go index bb03f706..d6611dc9 100644 --- a/crypto/aescbc/aes_cbc_test.go +++ b/crypto/aescbc/aes_cbc_test.go @@ -7,11 +7,13 @@ package aescbc_test import ( - "bytes" "crypto/aes" "crypto/rand" "testing" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "maunium.net/go/mautrix/crypto/aescbc" ) @@ -22,32 +24,23 @@ func TestAESCBC(t *testing.T) { // The key length can be 32, 24, 16 bytes (OR in bits: 128, 192 or 256) key := make([]byte, 32) _, err = rand.Read(key) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) iv := make([]byte, aes.BlockSize) _, err = rand.Read(iv) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) plaintext = []byte("secret message for testing") //increase to next block size for len(plaintext)%8 != 0 { plaintext = append(plaintext, []byte("-")...) } - if ciphertext, err = aescbc.Encrypt(key, iv, plaintext); err != nil { - t.Fatal(err) - } + ciphertext, err = aescbc.Encrypt(key, iv, plaintext) + require.NoError(t, err) resultPlainText, err := aescbc.Decrypt(key, iv, ciphertext) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) - if string(resultPlainText) != string(plaintext) { - t.Fatalf("message '%s' (length %d) != '%s'", resultPlainText, len(resultPlainText), plaintext) - } + assert.Equal(t, string(resultPlainText), string(plaintext)) } func TestAESCBCCase1(t *testing.T) { @@ -61,18 +54,10 @@ func TestAESCBCCase1(t *testing.T) { key := make([]byte, 32) iv := make([]byte, aes.BlockSize) encrypted, err := aescbc.Encrypt(key, iv, input) - if err != nil { - t.Fatal(err) - } - if !bytes.Equal(expected, encrypted) { - t.Fatalf("encrypted did not match expected:\n%v\n%v\n", encrypted, expected) - } + require.NoError(t, err) + assert.Equal(t, expected, encrypted, "encrypted output does not match expected") decrypted, err := aescbc.Decrypt(key, iv, encrypted) - if err != nil { - t.Fatal(err) - } - if !bytes.Equal(input, decrypted) { - t.Fatalf("decrypted did not match expected:\n%v\n%v\n", decrypted, input) - } + require.NoError(t, err) + assert.Equal(t, input, decrypted, "decrypted output does not match input") } diff --git a/crypto/attachment/attachments.go b/crypto/attachment/attachments.go index cfa1c3e5..727aacbf 100644 --- a/crypto/attachment/attachments.go +++ b/crypto/attachment/attachments.go @@ -9,6 +9,7 @@ package attachment import ( "crypto/aes" "crypto/cipher" + "crypto/hmac" "crypto/sha256" "encoding/base64" "errors" @@ -20,13 +21,24 @@ import ( ) var ( - HashMismatch = errors.New("mismatching SHA-256 digest") - UnsupportedVersion = errors.New("unsupported Matrix file encryption version") - UnsupportedAlgorithm = errors.New("unsupported JWK encryption algorithm") - InvalidKey = errors.New("failed to decode key") - InvalidInitVector = errors.New("failed to decode initialization vector") - InvalidHash = errors.New("failed to decode SHA-256 hash") - ReaderClosed = errors.New("encrypting reader was already closed") + ErrHashMismatch = errors.New("mismatching SHA-256 digest") + ErrUnsupportedVersion = errors.New("unsupported Matrix file encryption version") + ErrUnsupportedAlgorithm = errors.New("unsupported JWK encryption algorithm") + ErrInvalidKey = errors.New("failed to decode key") + ErrInvalidInitVector = errors.New("failed to decode initialization vector") + ErrInvalidHash = errors.New("failed to decode SHA-256 hash") + ErrReaderClosed = errors.New("encrypting reader was already closed") +) + +// Deprecated: use variables prefixed with Err +var ( + HashMismatch = ErrHashMismatch + UnsupportedVersion = ErrUnsupportedVersion + UnsupportedAlgorithm = ErrUnsupportedAlgorithm + InvalidKey = ErrInvalidKey + InvalidInitVector = ErrInvalidInitVector + InvalidHash = ErrInvalidHash + ReaderClosed = ErrReaderClosed ) var ( @@ -84,25 +96,25 @@ func (ef *EncryptedFile) decodeKeys(includeHash bool) error { if ef.decoded != nil { return nil } else if len(ef.Key.Key) != keyBase64Length { - return InvalidKey + return ErrInvalidKey } else if len(ef.InitVector) != ivBase64Length { - return InvalidInitVector + return ErrInvalidInitVector } else if includeHash && len(ef.Hashes.SHA256) != hashBase64Length { - return InvalidHash + return ErrInvalidHash } ef.decoded = &decodedKeys{} _, err := base64.RawURLEncoding.Decode(ef.decoded.key[:], []byte(ef.Key.Key)) if err != nil { - return InvalidKey + return ErrInvalidKey } _, err = base64.RawStdEncoding.Decode(ef.decoded.iv[:], []byte(ef.InitVector)) if err != nil { - return InvalidInitVector + return ErrInvalidInitVector } if includeHash { _, err = base64.RawStdEncoding.Decode(ef.decoded.sha256[:], []byte(ef.Hashes.SHA256)) if err != nil { - return InvalidHash + return ErrInvalidHash } } return nil @@ -178,7 +190,7 @@ var _ io.ReadSeekCloser = (*encryptingReader)(nil) func (r *encryptingReader) Seek(offset int64, whence int) (int64, error) { if r.closed { - return 0, ReaderClosed + return 0, ErrReaderClosed } if offset != 0 || whence != io.SeekStart { return 0, fmt.Errorf("attachments.EncryptStream: only seeking to the beginning is supported") @@ -199,15 +211,20 @@ func (r *encryptingReader) Seek(offset int64, whence int) (int64, error) { func (r *encryptingReader) Read(dst []byte) (n int, err error) { if r.closed { - return 0, ReaderClosed + return 0, ErrReaderClosed } else if r.isDecrypting && r.file.decoded == nil { if err = r.file.PrepareForDecryption(); err != nil { return } } n, err = r.source.Read(dst) + if r.isDecrypting { + r.hash.Write(dst[:n]) + } r.stream.XORKeyStream(dst[:n], dst[:n]) - r.hash.Write(dst[:n]) + if !r.isDecrypting { + r.hash.Write(dst[:n]) + } return } @@ -217,10 +234,8 @@ func (r *encryptingReader) Close() (err error) { err = closer.Close() } if r.isDecrypting { - var downloadedChecksum [utils.SHAHashLength]byte - r.hash.Sum(downloadedChecksum[:]) - if downloadedChecksum != r.file.decoded.sha256 { - return HashMismatch + if !hmac.Equal(r.hash.Sum(nil), r.file.decoded.sha256[:]) { + return ErrHashMismatch } } else { r.file.Hashes.SHA256 = base64.RawStdEncoding.EncodeToString(r.hash.Sum(nil)) @@ -261,9 +276,9 @@ func (ef *EncryptedFile) Decrypt(ciphertext []byte) ([]byte, error) { // DecryptInPlace will always call this automatically, so calling this manually is not necessary when using that function. func (ef *EncryptedFile) PrepareForDecryption() error { if ef.Version != "v2" { - return UnsupportedVersion + return ErrUnsupportedVersion } else if ef.Key.Algorithm != "A256CTR" { - return UnsupportedAlgorithm + return ErrUnsupportedAlgorithm } else if err := ef.decodeKeys(true); err != nil { return err } @@ -274,12 +289,13 @@ func (ef *EncryptedFile) PrepareForDecryption() error { func (ef *EncryptedFile) DecryptInPlace(data []byte) error { if err := ef.PrepareForDecryption(); err != nil { return err - } else if ef.decoded.sha256 != sha256.Sum256(data) { - return HashMismatch - } else { - utils.XorA256CTR(data, ef.decoded.key, ef.decoded.iv) - return nil } + dataHash := sha256.Sum256(data) + if !hmac.Equal(ef.decoded.sha256[:], dataHash[:]) { + return ErrHashMismatch + } + utils.XorA256CTR(data, ef.decoded.key, ef.decoded.iv) + return nil } // DecryptStream wraps the given io.Reader in order to decrypt the data. @@ -292,9 +308,10 @@ func (ef *EncryptedFile) DecryptInPlace(data []byte) error { func (ef *EncryptedFile) DecryptStream(reader io.Reader) io.ReadSeekCloser { block, _ := aes.NewCipher(ef.decoded.key[:]) return &encryptingReader{ - stream: cipher.NewCTR(block, ef.decoded.iv[:]), - hash: sha256.New(), - source: reader, - file: ef, + isDecrypting: true, + stream: cipher.NewCTR(block, ef.decoded.iv[:]), + hash: sha256.New(), + source: reader, + file: ef, } } diff --git a/crypto/attachment/attachments_test.go b/crypto/attachment/attachments_test.go index d7f1394a..9fe929ab 100644 --- a/crypto/attachment/attachments_test.go +++ b/crypto/attachment/attachments_test.go @@ -53,33 +53,33 @@ func TestUnsupportedVersion(t *testing.T) { file := parseHelloWorld() file.Version = "foo" err := file.DecryptInPlace([]byte(helloWorldCiphertext)) - assert.ErrorIs(t, err, UnsupportedVersion) + assert.ErrorIs(t, err, ErrUnsupportedVersion) } func TestUnsupportedAlgorithm(t *testing.T) { file := parseHelloWorld() file.Key.Algorithm = "bar" err := file.DecryptInPlace([]byte(helloWorldCiphertext)) - assert.ErrorIs(t, err, UnsupportedAlgorithm) + assert.ErrorIs(t, err, ErrUnsupportedAlgorithm) } func TestHashMismatch(t *testing.T) { file := parseHelloWorld() file.Hashes.SHA256 = base64.RawStdEncoding.EncodeToString([]byte(random32Bytes)) err := file.DecryptInPlace([]byte(helloWorldCiphertext)) - assert.ErrorIs(t, err, HashMismatch) + assert.ErrorIs(t, err, ErrHashMismatch) } func TestTooLongHash(t *testing.T) { file := parseHelloWorld() file.Hashes.SHA256 = "TG9yZW0gaXBzdW0gZG9sb3Igc2l0IGFtZXQsIGNvbnNlY3RldHVlciBhZGlwaXNjaW5nIGVsaXQuIFNlZCBwb3N1ZXJlIGludGVyZHVtIHNlbS4gUXVpc3F1ZSBsaWd1bGEgZXJvcyB1bGxhbWNvcnBlciBxdWlzLCBsYWNpbmlhIHF1aXMgZmFjaWxpc2lzIHNlZCBzYXBpZW4uCg" err := file.DecryptInPlace([]byte(helloWorldCiphertext)) - assert.ErrorIs(t, err, InvalidHash) + assert.ErrorIs(t, err, ErrInvalidHash) } func TestTooShortHash(t *testing.T) { file := parseHelloWorld() file.Hashes.SHA256 = "5/Gy1JftyyQ" err := file.DecryptInPlace([]byte(helloWorldCiphertext)) - assert.ErrorIs(t, err, InvalidHash) + assert.ErrorIs(t, err, ErrInvalidHash) } diff --git a/crypto/backup/encryptedsessiondata.go b/crypto/backup/encryptedsessiondata.go index ec551dbe..25250178 100644 --- a/crypto/backup/encryptedsessiondata.go +++ b/crypto/backup/encryptedsessiondata.go @@ -68,6 +68,10 @@ func calculateCompatMAC(macKey []byte) []byte { // // [Section 11.12.3.2.2 of the Spec]: https://spec.matrix.org/v1.9/client-server-api/#backup-algorithm-mmegolm_backupv1curve25519-aes-sha2 func EncryptSessionData[T any](backupKey *MegolmBackupKey, sessionData T) (*EncryptedSessionData[T], error) { + return EncryptSessionDataWithPubkey(backupKey.PublicKey(), sessionData) +} + +func EncryptSessionDataWithPubkey[T any](pubkey *ecdh.PublicKey, sessionData T) (*EncryptedSessionData[T], error) { sessionJSON, err := json.Marshal(sessionData) if err != nil { return nil, err @@ -78,7 +82,7 @@ func EncryptSessionData[T any](backupKey *MegolmBackupKey, sessionData T) (*Encr return nil, err } - sharedSecret, err := ephemeralKey.ECDH(backupKey.PublicKey()) + sharedSecret, err := ephemeralKey.ECDH(pubkey) if err != nil { return nil, err } diff --git a/crypto/canonicaljson/json_test.go b/crypto/canonicaljson/json_test.go index d1a7f0a5..36476aa4 100644 --- a/crypto/canonicaljson/json_test.go +++ b/crypto/canonicaljson/json_test.go @@ -17,31 +17,43 @@ package canonicaljson import ( "testing" + + "github.com/stretchr/testify/assert" ) -func testSortJSON(t *testing.T, input, want string) { - got := SortJSON([]byte(input), nil) - - // Squash out the whitespace before comparing the JSON in case SortJSON had inserted whitespace. - if string(CompactJSON(got, nil)) != want { - t.Errorf("SortJSON(%q): want %q got %q", input, want, got) - } -} - func TestSortJSON(t *testing.T) { - testSortJSON(t, `[{"b":"two","a":1}]`, `[{"a":1,"b":"two"}]`) - testSortJSON(t, `{"B":{"4":4,"3":3},"A":{"1":1,"2":2}}`, - `{"A":{"1":1,"2":2},"B":{"3":3,"4":4}}`) - testSortJSON(t, `[true,false,null]`, `[true,false,null]`) - testSortJSON(t, `[9007199254740991]`, `[9007199254740991]`) - testSortJSON(t, "\t\n[9007199254740991]", `[9007199254740991]`) + var tests = []struct { + input string + want string + }{ + {"{}", "{}"}, + {`[{"b":"two","a":1}]`, `[{"a":1,"b":"two"}]`}, + {`{"B":{"4":4,"3":3},"A":{"1":1,"2":2}}`, `{"A":{"1":1,"2":2},"B":{"3":3,"4":4}}`}, + {`[true,false,null]`, `[true,false,null]`}, + {`[9007199254740991]`, `[9007199254740991]`}, + {"\t\n[9007199254740991]", `[9007199254740991]`}, + {`[true,false,null]`, `[true,false,null]`}, + {`[{"b":"two","a":1}]`, `[{"a":1,"b":"two"}]`}, + {`{"B":{"4":4,"3":3},"A":{"1":1,"2":2}}`, `{"A":{"1":1,"2":2},"B":{"3":3,"4":4}}`}, + {`[true,false,null]`, `[true,false,null]`}, + {`[9007199254740991]`, `[9007199254740991]`}, + {"\t\n[9007199254740991]", `[9007199254740991]`}, + {`[true,false,null]`, `[true,false,null]`}, + } + for _, test := range tests { + t.Run(test.input, func(t *testing.T) { + got := SortJSON([]byte(test.input), nil) + + // Squash out the whitespace before comparing the JSON in case SortJSON had inserted whitespace. + assert.EqualValues(t, test.want, string(CompactJSON(got, nil))) + }) + } } func testCompactJSON(t *testing.T, input, want string) { + t.Helper() got := string(CompactJSON([]byte(input), nil)) - if got != want { - t.Errorf("CompactJSON(%q): want %q got %q", input, want, got) - } + assert.EqualValues(t, want, got) } func TestCompactJSON(t *testing.T) { @@ -74,18 +86,23 @@ func TestCompactJSON(t *testing.T) { testCompactJSON(t, `["\"\\\/"]`, `["\"\\/"]`) } -func testReadHex(t *testing.T, input string, want uint32) { - got := readHexDigits([]byte(input)) - if want != got { - t.Errorf("readHexDigits(%q): want 0x%x got 0x%x", input, want, got) +func TestReadHex(t *testing.T) { + tests := []struct { + input string + want uint32 + }{ + + {"0123", 0x0123}, + {"4567", 0x4567}, + {"89AB", 0x89AB}, + {"CDEF", 0xCDEF}, + {"89ab", 0x89AB}, + {"cdef", 0xCDEF}, + } + for _, test := range tests { + t.Run(test.input, func(t *testing.T) { + got := readHexDigits([]byte(test.input)) + assert.Equal(t, test.want, got) + }) } } - -func TestReadHex(t *testing.T) { - testReadHex(t, "0123", 0x0123) - testReadHex(t, "4567", 0x4567) - testReadHex(t, "89AB", 0x89AB) - testReadHex(t, "CDEF", 0xCDEF) - testReadHex(t, "89ab", 0x89AB) - testReadHex(t, "cdef", 0xCDEF) -} diff --git a/crypto/cross_sign_key.go b/crypto/cross_sign_key.go index 97ecd865..5d9bf5b3 100644 --- a/crypto/cross_sign_key.go +++ b/crypto/cross_sign_key.go @@ -11,6 +11,8 @@ import ( "context" "fmt" + "go.mau.fi/util/jsonbytes" + "maunium.net/go/mautrix" "maunium.net/go/mautrix/crypto/olm" "maunium.net/go/mautrix/crypto/signatures" @@ -33,9 +35,9 @@ func (cskc *CrossSigningKeysCache) PublicKeys() *CrossSigningPublicKeysCache { } type CrossSigningSeeds struct { - MasterKey []byte - SelfSigningKey []byte - UserSigningKey []byte + MasterKey jsonbytes.UnpaddedURLBytes `json:"m.cross_signing.master"` + SelfSigningKey jsonbytes.UnpaddedURLBytes `json:"m.cross_signing.self_signing"` + UserSigningKey jsonbytes.UnpaddedURLBytes `json:"m.cross_signing.user_signing"` } func (mach *OlmMachine) ExportCrossSigningKeys() CrossSigningSeeds { @@ -133,7 +135,7 @@ func (mach *OlmMachine) PublishCrossSigningKeys(ctx context.Context, keys *Cross } userKey.Signatures = signatures.NewSingleSignature(userID, id.KeyAlgorithmEd25519, keys.MasterKey.PublicKey().String(), userSig) - err = mach.Client.UploadCrossSigningKeys(ctx, &mautrix.UploadCrossSigningKeysReq{ + err = mach.Client.UploadCrossSigningKeys(ctx, &mautrix.UploadCrossSigningKeysReq[any]{ Master: masterKey, SelfSigning: selfKey, UserSigning: userKey, diff --git a/crypto/cross_sign_pubkey.go b/crypto/cross_sign_pubkey.go index 77efab5b..223fc7b5 100644 --- a/crypto/cross_sign_pubkey.go +++ b/crypto/cross_sign_pubkey.go @@ -20,6 +20,20 @@ type CrossSigningPublicKeysCache struct { UserSigningKey id.Ed25519 } +func (mach *OlmMachine) GetOwnVerificationStatus(ctx context.Context) (hasKeys, isVerified bool, err error) { + pubkeys := mach.GetOwnCrossSigningPublicKeys(ctx) + if pubkeys != nil { + hasKeys = true + isVerified, err = mach.CryptoStore.IsKeySignedBy( + ctx, mach.Client.UserID, mach.GetAccount().SigningKey(), mach.Client.UserID, pubkeys.SelfSigningKey, + ) + if err != nil { + err = fmt.Errorf("failed to check if current device is signed by own self-signing key: %w", err) + } + } + return +} + func (mach *OlmMachine) GetOwnCrossSigningPublicKeys(ctx context.Context) *CrossSigningPublicKeysCache { if mach.crossSigningPubkeys != nil { return mach.crossSigningPubkeys @@ -49,8 +63,8 @@ func (mach *OlmMachine) GetCrossSigningPublicKeys(ctx context.Context, userID id if len(dbKeys) > 0 { masterKey, ok := dbKeys[id.XSUsageMaster] if ok { - selfSigning, _ := dbKeys[id.XSUsageSelfSigning] - userSigning, _ := dbKeys[id.XSUsageUserSigning] + selfSigning := dbKeys[id.XSUsageSelfSigning] + userSigning := dbKeys[id.XSUsageUserSigning] return &CrossSigningPublicKeysCache{ MasterKey: masterKey.Key, SelfSigningKey: selfSigning.Key, diff --git a/crypto/cross_sign_ssss.go b/crypto/cross_sign_ssss.go index 389a9fd2..fd42880d 100644 --- a/crypto/cross_sign_ssss.go +++ b/crypto/cross_sign_ssss.go @@ -8,6 +8,7 @@ package crypto import ( "context" + "errors" "fmt" "maunium.net/go/mautrix" @@ -71,6 +72,46 @@ func (mach *OlmMachine) GenerateAndUploadCrossSigningKeysWithPassword(ctx contex }, passphrase) } +func (mach *OlmMachine) VerifyWithRecoveryKey(ctx context.Context, recoveryKey string) error { + keyID, keyData, err := mach.SSSS.GetDefaultKeyData(ctx) + if err != nil { + return fmt.Errorf("failed to get default SSSS key data: %w", err) + } + key, err := keyData.VerifyRecoveryKey(keyID, recoveryKey) + if errors.Is(err, ssss.ErrUnverifiableKey) { + mach.machOrContextLog(ctx).Warn(). + Str("key_id", keyID). + Msg("SSSS key is unverifiable, trying to use without verifying") + } else if err != nil { + return err + } + err = mach.FetchCrossSigningKeysFromSSSS(ctx, key) + if err != nil { + return fmt.Errorf("failed to fetch cross-signing keys from SSSS: %w", err) + } + err = mach.SignOwnDevice(ctx, mach.OwnIdentity()) + if err != nil { + return fmt.Errorf("failed to sign own device: %w", err) + } + err = mach.SignOwnMasterKey(ctx) + if err != nil { + return fmt.Errorf("failed to sign own master key: %w", err) + } + return nil +} + +func (mach *OlmMachine) GenerateAndVerifyWithRecoveryKey(ctx context.Context) (recoveryKey string, err error) { + recoveryKey, _, err = mach.GenerateAndUploadCrossSigningKeys(ctx, nil, "") + if err != nil { + err = fmt.Errorf("failed to generate and upload cross-signing keys: %w", err) + } else if err = mach.SignOwnDevice(ctx, mach.OwnIdentity()); err != nil { + err = fmt.Errorf("failed to sign own device: %w", err) + } else if err = mach.SignOwnMasterKey(ctx); err != nil { + err = fmt.Errorf("failed to sign own master key: %w", err) + } + return +} + // GenerateAndUploadCrossSigningKeys generates a new key with all corresponding cross-signing keys. // // A passphrase can be provided to generate the SSSS key. If the passphrase is empty, a random key @@ -97,12 +138,12 @@ func (mach *OlmMachine) GenerateAndUploadCrossSigningKeys(ctx context.Context, u // Publish cross-signing keys err = mach.PublishCrossSigningKeys(ctx, keysCache, uiaCallback) if err != nil { - return "", nil, fmt.Errorf("failed to publish cross-signing keys: %w", err) + return key.RecoveryKey(), keysCache, fmt.Errorf("failed to publish cross-signing keys: %w", err) } err = mach.SSSS.SetDefaultKeyID(ctx, key.ID) if err != nil { - return "", nil, fmt.Errorf("failed to mark %s as the default key: %w", key.ID, err) + return key.RecoveryKey(), keysCache, fmt.Errorf("failed to mark %s as the default key: %w", key.ID, err) } return key.RecoveryKey(), keysCache, nil diff --git a/crypto/cross_sign_store.go b/crypto/cross_sign_store.go index b583bada..57406b11 100644 --- a/crypto/cross_sign_store.go +++ b/crypto/cross_sign_store.go @@ -20,36 +20,34 @@ import ( func (mach *OlmMachine) storeCrossSigningKeys(ctx context.Context, crossSigningKeys map[id.UserID]mautrix.CrossSigningKeys, deviceKeys map[id.UserID]map[id.DeviceID]mautrix.DeviceKeys) { log := mach.machOrContextLog(ctx) for userID, userKeys := range crossSigningKeys { - log := log.With().Str("user_id", userID.String()).Logger() + log := log.With().Stringer("user_id", userID).Logger() currentKeys, err := mach.CryptoStore.GetCrossSigningKeys(ctx, userID) if err != nil { log.Error().Err(err). Msg("Error fetching current cross-signing keys of user") } - if currentKeys != nil { - for curKeyUsage, curKey := range currentKeys { - log := log.With().Str("old_key", curKey.Key.String()).Str("old_key_usage", string(curKeyUsage)).Logger() - // got a new key with the same usage as an existing key - for _, newKeyUsage := range userKeys.Usage { - if newKeyUsage == curKeyUsage { - if _, ok := userKeys.Keys[id.NewKeyID(id.KeyAlgorithmEd25519, curKey.Key.String())]; !ok { - // old key is not in the new key map, so we drop signatures made by it - if count, err := mach.CryptoStore.DropSignaturesByKey(ctx, userID, curKey.Key); err != nil { - log.Error().Err(err).Msg("Error deleting old signatures made by user") - } else { - log.Debug(). - Int64("signature_count", count). - Msg("Dropped signatures made by old key as it has been replaced") - } + for curKeyUsage, curKey := range currentKeys { + log := log.With().Stringer("old_key", curKey.Key).Str("old_key_usage", string(curKeyUsage)).Logger() + // got a new key with the same usage as an existing key + for _, newKeyUsage := range userKeys.Usage { + if newKeyUsage == curKeyUsage { + if _, ok := userKeys.Keys[id.NewKeyID(id.KeyAlgorithmEd25519, curKey.Key.String())]; !ok { + // old key is not in the new key map, so we drop signatures made by it + if count, err := mach.CryptoStore.DropSignaturesByKey(ctx, userID, curKey.Key); err != nil { + log.Error().Err(err).Msg("Error deleting old signatures made by user") + } else { + log.Debug(). + Int64("signature_count", count). + Msg("Dropped signatures made by old key as it has been replaced") } - break } + break } } } for _, key := range userKeys.Keys { - log := log.With().Str("key", key.String()).Array("usages", exzerolog.ArrayOfStrs(userKeys.Usage)).Logger() + log := log.With().Stringer("key", key).Array("usages", exzerolog.ArrayOfStrs(userKeys.Usage)).Logger() for _, usage := range userKeys.Usage { log.Trace().Str("usage", string(usage)).Msg("Storing cross-signing key") if err = mach.CryptoStore.PutCrossSigningKey(ctx, userID, usage, key); err != nil { diff --git a/crypto/cross_sign_test.go b/crypto/cross_sign_test.go index e11fb018..b70370a2 100644 --- a/crypto/cross_sign_test.go +++ b/crypto/cross_sign_test.go @@ -13,6 +13,8 @@ import ( "testing" "github.com/rs/zerolog" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "go.mau.fi/util/dbutil" "maunium.net/go/mautrix" @@ -24,17 +26,12 @@ var noopLogger = zerolog.Nop() func getOlmMachine(t *testing.T) *OlmMachine { rawDB, err := sql.Open("sqlite3", ":memory:?_busy_timeout=5000") - if err != nil { - t.Fatalf("Error opening db: %v", err) - } + require.NoError(t, err, "Error opening raw database") db, err := dbutil.NewWithDB(rawDB, "sqlite3") - if err != nil { - t.Fatalf("Error opening db: %v", err) - } + require.NoError(t, err, "Error creating database wrapper") sqlStore := NewSQLCryptoStore(db, nil, "accid", id.DeviceID("dev"), []byte("test")) - if err = sqlStore.DB.Upgrade(context.TODO()); err != nil { - t.Fatalf("Error creating tables: %v", err) - } + err = sqlStore.DB.Upgrade(context.TODO()) + require.NoError(t, err, "Error upgrading database") userID := id.UserID("@mautrix") mk, _ := olm.NewPKSigning() @@ -66,29 +63,25 @@ func TestTrustOwnDevice(t *testing.T) { DeviceID: "device", SigningKey: id.Ed25519("deviceKey"), } - if m.IsDeviceTrusted(ownDevice) { - t.Error("Own device trusted while it shouldn't be") - } + assert.False(t, m.IsDeviceTrusted(context.TODO(), ownDevice), "Own device trusted while it shouldn't be") m.CryptoStore.PutSignature(context.TODO(), ownDevice.UserID, m.CrossSigningKeys.SelfSigningKey.PublicKey(), ownDevice.UserID, m.CrossSigningKeys.MasterKey.PublicKey(), "sig1") m.CryptoStore.PutSignature(context.TODO(), ownDevice.UserID, ownDevice.SigningKey, ownDevice.UserID, m.CrossSigningKeys.SelfSigningKey.PublicKey(), "sig2") - if trusted, _ := m.IsUserTrusted(context.TODO(), ownDevice.UserID); !trusted { - t.Error("Own user not trusted while they should be") - } - if !m.IsDeviceTrusted(ownDevice) { - t.Error("Own device not trusted while it should be") - } + trusted, err := m.IsUserTrusted(context.TODO(), ownDevice.UserID) + require.NoError(t, err, "Error checking if own user is trusted") + assert.True(t, trusted, "Own user not trusted while they should be") + assert.True(t, m.IsDeviceTrusted(context.TODO(), ownDevice), "Own device not trusted while it should be") } func TestTrustOtherUser(t *testing.T) { m := getOlmMachine(t) otherUser := id.UserID("@user") - if trusted, _ := m.IsUserTrusted(context.TODO(), otherUser); trusted { - t.Error("Other user trusted while they shouldn't be") - } + trusted, err := m.IsUserTrusted(context.TODO(), otherUser) + require.NoError(t, err, "Error checking if other user is trusted") + assert.False(t, trusted, "Other user trusted while they shouldn't be") theirMasterKey, _ := olm.NewPKSigning() m.CryptoStore.PutCrossSigningKey(context.TODO(), otherUser, id.XSUsageMaster, theirMasterKey.PublicKey()) @@ -100,16 +93,16 @@ func TestTrustOtherUser(t *testing.T) { m.CryptoStore.PutSignature(context.TODO(), otherUser, theirMasterKey.PublicKey(), m.Client.UserID, m.CrossSigningKeys.SelfSigningKey.PublicKey(), "invalid_sig") - if trusted, _ := m.IsUserTrusted(context.TODO(), otherUser); trusted { - t.Error("Other user trusted before their master key has been signed with our user-signing key") - } + trusted, err = m.IsUserTrusted(context.TODO(), otherUser) + require.NoError(t, err, "Error checking if other user is trusted") + assert.False(t, trusted, "Other user trusted before their master key has been signed with our user-signing key") m.CryptoStore.PutSignature(context.TODO(), otherUser, theirMasterKey.PublicKey(), m.Client.UserID, m.CrossSigningKeys.UserSigningKey.PublicKey(), "sig2") - if trusted, _ := m.IsUserTrusted(context.TODO(), otherUser); !trusted { - t.Error("Other user not trusted while they should be") - } + trusted, err = m.IsUserTrusted(context.TODO(), otherUser) + require.NoError(t, err, "Error checking if other user is trusted") + assert.True(t, trusted, "Other user not trusted while they should be") } func TestTrustOtherDevice(t *testing.T) { @@ -120,12 +113,11 @@ func TestTrustOtherDevice(t *testing.T) { DeviceID: "theirDevice", SigningKey: id.Ed25519("theirDeviceKey"), } - if trusted, _ := m.IsUserTrusted(context.TODO(), otherUser); trusted { - t.Error("Other user trusted while they shouldn't be") - } - if m.IsDeviceTrusted(theirDevice) { - t.Error("Other device trusted while it shouldn't be") - } + + trusted, err := m.IsUserTrusted(context.TODO(), otherUser) + require.NoError(t, err, "Error checking if other user is trusted") + assert.False(t, trusted, "Other user trusted while they shouldn't be") + assert.False(t, m.IsDeviceTrusted(context.TODO(), theirDevice), "Other device trusted while it shouldn't be") theirMasterKey, _ := olm.NewPKSigning() m.CryptoStore.PutCrossSigningKey(context.TODO(), otherUser, id.XSUsageMaster, theirMasterKey.PublicKey()) @@ -137,21 +129,17 @@ func TestTrustOtherDevice(t *testing.T) { m.CryptoStore.PutSignature(context.TODO(), otherUser, theirMasterKey.PublicKey(), m.Client.UserID, m.CrossSigningKeys.UserSigningKey.PublicKey(), "sig2") - if trusted, _ := m.IsUserTrusted(context.TODO(), otherUser); !trusted { - t.Error("Other user not trusted while they should be") - } + trusted, err = m.IsUserTrusted(context.TODO(), otherUser) + require.NoError(t, err, "Error checking if other user is trusted") + assert.True(t, trusted, "Other user not trusted while they should be") m.CryptoStore.PutSignature(context.TODO(), otherUser, theirSSK.PublicKey(), otherUser, theirMasterKey.PublicKey(), "sig3") - if m.IsDeviceTrusted(theirDevice) { - t.Error("Other device trusted before it has been signed with user's SSK") - } + assert.False(t, m.IsDeviceTrusted(context.TODO(), theirDevice), "Other device trusted before it has been signed with user's SSK") m.CryptoStore.PutSignature(context.TODO(), otherUser, theirDevice.SigningKey, otherUser, theirSSK.PublicKey(), "sig4") - if !m.IsDeviceTrusted(theirDevice) { - t.Error("Other device not trusted while it should be") - } + assert.True(t, m.IsDeviceTrusted(context.TODO(), theirDevice), "Other device not trusted after it has been signed with user's SSK") } diff --git a/crypto/cross_sign_validation.go b/crypto/cross_sign_validation.go index 04a179df..4cdf0dd5 100644 --- a/crypto/cross_sign_validation.go +++ b/crypto/cross_sign_validation.go @@ -13,6 +13,9 @@ import ( "maunium.net/go/mautrix/id" ) +// ResolveTrust resolves the trust state of the device from cross-signing. +// +// Deprecated: This method doesn't take a context. Use [OlmMachine.ResolveTrustContext] instead. func (mach *OlmMachine) ResolveTrust(device *id.Device) id.TrustState { state, _ := mach.ResolveTrustContext(context.Background(), device) return state @@ -77,8 +80,12 @@ func (mach *OlmMachine) ResolveTrustContext(ctx context.Context, device *id.Devi } // IsDeviceTrusted returns whether a device has been determined to be trusted either through verification or cross-signing. -func (mach *OlmMachine) IsDeviceTrusted(device *id.Device) bool { - switch mach.ResolveTrust(device) { +// +// Note: this will return false if resolving the trust state fails due to database errors. +// Use [OlmMachine.ResolveTrustContext] if special error handling is required. +func (mach *OlmMachine) IsDeviceTrusted(ctx context.Context, device *id.Device) bool { + trust, _ := mach.ResolveTrustContext(ctx, device) + switch trust { case id.TrustStateVerified, id.TrustStateCrossSignedTOFU, id.TrustStateCrossSignedVerified: return true default: diff --git a/crypto/cryptohelper/cryptohelper.go b/crypto/cryptohelper/cryptohelper.go index 0b3fbeaa..b62dc128 100644 --- a/crypto/cryptohelper/cryptohelper.go +++ b/crypto/cryptohelper/cryptohelper.go @@ -15,6 +15,7 @@ import ( "github.com/rs/zerolog" "go.mau.fi/util/dbutil" + _ "go.mau.fi/util/dbutil/litestream" "maunium.net/go/mautrix" "maunium.net/go/mautrix/crypto" @@ -36,6 +37,7 @@ type CryptoHelper struct { DecryptErrorCallback func(*event.Event, error) + MSC4190 bool LoginAs *mautrix.ReqLogin ASEventProcessor crypto.ASEventProcessor @@ -77,7 +79,7 @@ func NewCryptoHelper(cli *mautrix.Client, pickleKey []byte, store any) (*CryptoH } unmanagedCryptoStore = typedStore case string: - db, err := dbutil.NewWithDialect(typedStore, "sqlite3") + db, err := dbutil.NewWithDialect(fmt.Sprintf("file:%s?_txlock=immediate", typedStore), "sqlite3-fk-wal") if err != nil { return nil, err } @@ -151,7 +153,14 @@ func (helper *CryptoHelper) Init(ctx context.Context) error { if err != nil { return fmt.Errorf("failed to find existing device ID: %w", err) } - if helper.LoginAs != nil && helper.LoginAs.Type == mautrix.AuthTypeAppservice && helper.client.SetAppServiceDeviceID { + if helper.MSC4190 { + helper.log.Debug().Msg("Creating bot device with MSC4190") + err = helper.client.CreateDeviceMSC4190(ctx, storedDeviceID, helper.LoginAs.InitialDeviceDisplayName) + if err != nil { + return fmt.Errorf("failed to create device for bot: %w", err) + } + rawCryptoStore.DeviceID = helper.client.DeviceID + } else if helper.LoginAs != nil && helper.LoginAs.Type == mautrix.AuthTypeAppservice && helper.client.SetAppServiceDeviceID { if storedDeviceID == "" { helper.log.Debug(). Str("username", helper.LoginAs.Identifier.User). @@ -161,14 +170,12 @@ func (helper *CryptoHelper) Init(ctx context.Context) error { if err != nil { return err } - rawCryptoStore.DeviceID = resp.DeviceID helper.client.DeviceID = resp.DeviceID } else { helper.log.Debug(). Str("username", helper.LoginAs.Identifier.User). Stringer("device_id", storedDeviceID). Msg("Using existing device") - rawCryptoStore.DeviceID = storedDeviceID helper.client.DeviceID = storedDeviceID } } else if helper.LoginAs != nil { @@ -184,12 +191,10 @@ func (helper *CryptoHelper) Init(ctx context.Context) error { if err != nil { return err } - if storedDeviceID == "" { - rawCryptoStore.DeviceID = helper.client.DeviceID - } } else if storedDeviceID != "" && storedDeviceID != helper.client.DeviceID { return fmt.Errorf("mismatching device ID in client and crypto store (%q != %q)", storedDeviceID, helper.client.DeviceID) } + rawCryptoStore.DeviceID = helper.client.DeviceID } else if helper.LoginAs != nil { return fmt.Errorf("LoginAs can only be used with a managed crypto store") } @@ -220,13 +225,6 @@ func (helper *CryptoHelper) Init(ctx context.Context) error { helper.ASEventProcessor.On(event.EventEncrypted, helper.HandleEncrypted) } - if helper.client.SetAppServiceDeviceID { - err = helper.mach.ShareKeys(ctx, -1) - if err != nil { - return fmt.Errorf("failed to share keys: %w", err) - } - } - return nil } @@ -263,24 +261,24 @@ func (helper *CryptoHelper) verifyDeviceKeysOnServer(ctx context.Context) error if !ok || len(device.Keys) == 0 { if isShared { return fmt.Errorf("olm account is marked as shared, keys seem to have disappeared from the server") - } else { - helper.log.Debug().Msg("Olm account not shared and keys not on server, so device is probably fine") - return nil } + helper.log.Debug().Msg("Olm account not shared and keys not on server, sharing initial keys") + err = helper.mach.ShareKeys(ctx, -1) + if err != nil { + return fmt.Errorf("failed to share keys: %w", err) + } + return nil } else if !isShared { return fmt.Errorf("olm account is not marked as shared, but there are keys on the server") } else if ed := device.Keys.GetEd25519(helper.client.DeviceID); ownID.SigningKey != ed { return fmt.Errorf("mismatching identity key on server (%q != %q)", ownID.SigningKey, ed) - } - if !isShared { - helper.log.Debug().Msg("Olm account not marked as shared, but keys on server match?") } else { helper.log.Debug().Msg("Olm account marked as shared and keys on server match, device is fine") + return nil } - return nil } -var NoSessionFound = crypto.NoSessionFound +var NoSessionFound = crypto.ErrNoSessionFound const initialSessionWaitTimeout = 3 * time.Second const extendedSessionWaitTimeout = 22 * time.Second @@ -299,24 +297,14 @@ func (helper *CryptoHelper) HandleEncrypted(ctx context.Context, evt *event.Even ctx = log.WithContext(ctx) decrypted, err := helper.Decrypt(ctx, evt) - if errors.Is(err, NoSessionFound) { - log.Debug(). - Int("wait_seconds", int(initialSessionWaitTimeout.Seconds())). - Msg("Couldn't find session, waiting for keys to arrive...") - if helper.mach.WaitForSession(ctx, evt.RoomID, content.SenderKey, content.SessionID, initialSessionWaitTimeout) { - log.Debug().Msg("Got keys after waiting, trying to decrypt event again") - decrypted, err = helper.Decrypt(ctx, evt) - } else { - go helper.waitLongerForSession(ctx, log, evt) - return - } - } - if err != nil { + if errors.Is(err, NoSessionFound) && ctx.Value(mautrix.SyncTokenContextKey) != "" { + go helper.waitForSession(ctx, evt) + } else if err != nil { log.Warn().Err(err).Msg("Failed to decrypt event") helper.DecryptErrorCallback(evt, err) - return + } else { + helper.postDecrypt(ctx, decrypted) } - helper.postDecrypt(ctx, decrypted) } func (helper *CryptoHelper) postDecrypt(ctx context.Context, decrypted *event.Event) { @@ -357,10 +345,33 @@ func (helper *CryptoHelper) RequestSession(ctx context.Context, roomID id.RoomID } } -func (helper *CryptoHelper) waitLongerForSession(ctx context.Context, log zerolog.Logger, evt *event.Event) { +func (helper *CryptoHelper) waitForSession(ctx context.Context, evt *event.Event) { + log := zerolog.Ctx(ctx) + content := evt.Content.AsEncrypted() + + log.Debug(). + Int("wait_seconds", int(initialSessionWaitTimeout.Seconds())). + Msg("Couldn't find session, waiting for keys to arrive...") + if helper.mach.WaitForSession(ctx, evt.RoomID, content.SenderKey, content.SessionID, initialSessionWaitTimeout) { + log.Debug().Msg("Got keys after waiting, trying to decrypt event again") + decrypted, err := helper.Decrypt(ctx, evt) + if err != nil { + log.Warn().Err(err).Msg("Failed to decrypt event") + helper.DecryptErrorCallback(evt, err) + } else { + helper.postDecrypt(ctx, decrypted) + } + } else { + go helper.waitLongerForSession(ctx, evt) + } +} + +func (helper *CryptoHelper) waitLongerForSession(ctx context.Context, evt *event.Event) { + log := zerolog.Ctx(ctx) content := evt.Content.AsEncrypted() log.Debug().Int("wait_seconds", int(extendedSessionWaitTimeout.Seconds())).Msg("Couldn't find session, requesting keys and waiting longer...") + //lint:ignore SA1019 RequestSession will gracefully request from all devices if DeviceID is blank go helper.RequestSession(context.TODO(), evt.RoomID, content.SenderKey, content.SessionID, evt.Sender, content.DeviceID) if !helper.mach.WaitForSession(ctx, evt.RoomID, content.SenderKey, content.SessionID, extendedSessionWaitTimeout) { @@ -408,7 +419,7 @@ func (helper *CryptoHelper) EncryptWithStateKey(ctx context.Context, roomID id.R defer helper.lock.RUnlock() encrypted, err = helper.mach.EncryptMegolmEventWithStateKey(ctx, roomID, evtType, stateKey, content) if err != nil { - if !errors.Is(err, crypto.SessionExpired) && err != crypto.NoGroupSession && !errors.Is(err, crypto.SessionNotShared) { + if !errors.Is(err, crypto.ErrSessionExpired) && err != crypto.ErrNoGroupSession && !errors.Is(err, crypto.ErrSessionNotShared) { return } helper.log.Debug(). diff --git a/crypto/decryptmegolm.go b/crypto/decryptmegolm.go index 00f99ce4..457d5a0c 100644 --- a/crypto/decryptmegolm.go +++ b/crypto/decryptmegolm.go @@ -24,13 +24,23 @@ import ( ) var ( - IncorrectEncryptedContentType = errors.New("event content is not instance of *event.EncryptedEventContent") - NoSessionFound = errors.New("failed to decrypt megolm event: no session with given ID found") - DuplicateMessageIndex = errors.New("duplicate megolm message index") - WrongRoom = errors.New("encrypted megolm event is not intended for this room") - DeviceKeyMismatch = errors.New("device keys in event and verified device info do not match") - SenderKeyMismatch = errors.New("sender keys in content and megolm session do not match") - RatchetError = errors.New("failed to ratchet session after use") + ErrIncorrectEncryptedContentType = errors.New("event content is not instance of *event.EncryptedEventContent") + ErrNoSessionFound = errors.New("failed to decrypt megolm event: no session with given ID found") + ErrDuplicateMessageIndex = errors.New("duplicate megolm message index") + ErrWrongRoom = errors.New("encrypted megolm event is not intended for this room") + ErrDeviceKeyMismatch = errors.New("device keys in event and verified device info do not match") + ErrRatchetError = errors.New("failed to ratchet session after use") + ErrCorruptedMegolmPayload = errors.New("corrupted megolm payload") +) + +// Deprecated: use variables prefixed with Err +var ( + IncorrectEncryptedContentType = ErrIncorrectEncryptedContentType + NoSessionFound = ErrNoSessionFound + DuplicateMessageIndex = ErrDuplicateMessageIndex + WrongRoom = ErrWrongRoom + DeviceKeyMismatch = ErrDeviceKeyMismatch + RatchetError = ErrRatchetError ) type megolmEvent struct { @@ -45,13 +55,30 @@ var ( relatesToTopLevelPath = exgjson.Path("content", "m.relates_to") ) +const sessionIDLength = 43 + +func validateCiphertextCharacters(ciphertext []byte) bool { + for _, b := range ciphertext { + if (b < 'a' || b > 'z') && (b < 'A' || b > 'Z') && (b < '0' || b > '9') && b != '+' && b != '/' { + return false + } + } + return true +} + // DecryptMegolmEvent decrypts an m.room.encrypted event where the algorithm is m.megolm.v1.aes-sha2 func (mach *OlmMachine) DecryptMegolmEvent(ctx context.Context, evt *event.Event) (*event.Event, error) { content, ok := evt.Content.Parsed.(*event.EncryptedEventContent) if !ok { - return nil, IncorrectEncryptedContentType + return nil, ErrIncorrectEncryptedContentType } else if content.Algorithm != id.AlgorithmMegolmV1 { - return nil, UnsupportedAlgorithm + return nil, ErrUnsupportedAlgorithm + } else if len(content.MegolmCiphertext) < 74 { + return nil, fmt.Errorf("%w: ciphertext too short (%d bytes)", ErrCorruptedMegolmPayload, len(content.MegolmCiphertext)) + } else if len(content.SessionID) != sessionIDLength { + return nil, fmt.Errorf("%w: invalid session ID length %d", ErrCorruptedMegolmPayload, len(content.SessionID)) + } else if !validateCiphertextCharacters(content.MegolmCiphertext) { + return nil, fmt.Errorf("%w: invalid characters in ciphertext", ErrCorruptedMegolmPayload) } log := mach.machOrContextLog(ctx).With(). Str("action", "decrypt megolm event"). @@ -97,7 +124,13 @@ func (mach *OlmMachine) DecryptMegolmEvent(ctx context.Context, evt *event.Event Msg("Couldn't resolve trust level of session: sent by unknown device") trustLevel = id.TrustStateUnknownDevice } else if device.SigningKey != sess.SigningKey || device.IdentityKey != sess.SenderKey { - return nil, DeviceKeyMismatch + log.Debug(). + Stringer("session_sender_key", sess.SenderKey). + Stringer("device_sender_key", device.IdentityKey). + Stringer("session_signing_key", sess.SigningKey). + Stringer("device_signing_key", device.SigningKey). + Msg("Device keys don't match keys in session, marking as untrusted") + trustLevel = id.TrustStateDeviceKeyMismatch } else { trustLevel, err = mach.ResolveTrustContext(ctx, device) if err != nil { @@ -147,9 +180,9 @@ func (mach *OlmMachine) DecryptMegolmEvent(ctx context.Context, evt *event.Event if err != nil { return nil, fmt.Errorf("failed to parse megolm payload: %w", err) } else if megolmEvt.RoomID != encryptionRoomID { - return nil, WrongRoom + return nil, ErrWrongRoom } - if evt.StateKey != nil && megolmEvt.StateKey != nil { + if evt.StateKey != nil && megolmEvt.StateKey != nil && mach.AllowEncryptedState { megolmEvt.Type.Class = event.StateEventType } else { megolmEvt.Type.Class = evt.Type.Class @@ -160,7 +193,7 @@ func (mach *OlmMachine) DecryptMegolmEvent(ctx context.Context, evt *event.Event if err != nil { if errors.Is(err, event.ErrUnsupportedContentType) { log.Warn().Msg("Unsupported event type in encrypted event") - } else { + } else if !mach.IgnorePostDecryptionParseErrors { return nil, fmt.Errorf("failed to parse content of megolm payload event: %w", err) } } @@ -180,6 +213,7 @@ func (mach *OlmMachine) DecryptMegolmEvent(ctx context.Context, evt *event.Event TrustSource: device, ForwardedKeys: forwardedKeys, WasEncrypted: true, + EventSource: evt.Mautrix.EventSource | event.SourceDecrypted, ReceivedAt: evt.Mautrix.ReceivedAt, }, }, nil @@ -201,19 +235,19 @@ func (mach *OlmMachine) checkUndecryptableMessageIndexDuplication(ctx context.Co messageIndex, decodeErr := ParseMegolmMessageIndex(content.MegolmCiphertext) if decodeErr != nil { log.Warn().Err(decodeErr).Msg("Failed to parse message index to check if it's a duplicate for message that failed to decrypt") - return 0, fmt.Errorf("%w (also failed to parse message index)", olm.UnknownMessageIndex) + return 0, fmt.Errorf("%w (also failed to parse message index)", olm.ErrUnknownMessageIndex) } firstKnown := sess.Internal.FirstKnownIndex() log = log.With().Uint("message_index", messageIndex).Uint32("first_known_index", firstKnown).Logger() if ok, err := mach.CryptoStore.ValidateMessageIndex(ctx, sess.SenderKey, content.SessionID, evt.ID, messageIndex, evt.Timestamp); err != nil { log.Debug().Err(err).Msg("Failed to check if message index is duplicate") - return messageIndex, fmt.Errorf("%w (failed to check if index is duplicate; received: %d, earliest known: %d)", olm.UnknownMessageIndex, messageIndex, firstKnown) + return messageIndex, fmt.Errorf("%w (failed to check if index is duplicate; received: %d, earliest known: %d)", olm.ErrUnknownMessageIndex, messageIndex, firstKnown) } else if !ok { log.Debug().Msg("Failed to decrypt message due to unknown index and found duplicate") - return messageIndex, fmt.Errorf("%w %d (also failed to decrypt because earliest known index is %d)", DuplicateMessageIndex, messageIndex, firstKnown) + return messageIndex, fmt.Errorf("%w %d (also failed to decrypt because earliest known index is %d)", ErrDuplicateMessageIndex, messageIndex, firstKnown) } log.Debug().Msg("Failed to decrypt message due to unknown index, but index is not duplicate") - return messageIndex, fmt.Errorf("%w (not duplicate index; received: %d, earliest known: %d)", olm.UnknownMessageIndex, messageIndex, firstKnown) + return messageIndex, fmt.Errorf("%w (not duplicate index; received: %d, earliest known: %d)", olm.ErrUnknownMessageIndex, messageIndex, firstKnown) } func (mach *OlmMachine) actuallyDecryptMegolmEvent(ctx context.Context, evt *event.Event, encryptionRoomID id.RoomID, content *event.EncryptedEventContent) (*InboundGroupSession, []byte, uint, error) { @@ -224,13 +258,11 @@ func (mach *OlmMachine) actuallyDecryptMegolmEvent(ctx context.Context, evt *eve if err != nil { return nil, nil, 0, fmt.Errorf("failed to get group session: %w", err) } else if sess == nil { - return nil, nil, 0, fmt.Errorf("%w (ID %s)", NoSessionFound, content.SessionID) - } else if content.SenderKey != "" && content.SenderKey != sess.SenderKey { - return sess, nil, 0, SenderKeyMismatch + return nil, nil, 0, fmt.Errorf("%w (ID %s)", ErrNoSessionFound, content.SessionID) } plaintext, messageIndex, err := sess.Internal.Decrypt(content.MegolmCiphertext) if err != nil { - if errors.Is(err, olm.UnknownMessageIndex) && mach.RatchetKeysOnDecrypt { + if errors.Is(err, olm.ErrUnknownMessageIndex) && mach.RatchetKeysOnDecrypt { messageIndex, err = mach.checkUndecryptableMessageIndexDuplication(ctx, sess, evt, content) return sess, nil, messageIndex, fmt.Errorf("failed to decrypt megolm event: %w", err) } @@ -238,7 +270,7 @@ func (mach *OlmMachine) actuallyDecryptMegolmEvent(ctx context.Context, evt *eve } else if ok, err := mach.CryptoStore.ValidateMessageIndex(ctx, sess.SenderKey, content.SessionID, evt.ID, messageIndex, evt.Timestamp); err != nil { return sess, nil, messageIndex, fmt.Errorf("failed to check if message index is duplicate: %w", err) } else if !ok { - return sess, nil, messageIndex, fmt.Errorf("%w %d", DuplicateMessageIndex, messageIndex) + return sess, nil, messageIndex, fmt.Errorf("%w %d", ErrDuplicateMessageIndex, messageIndex) } // Normal clients don't care about tracking the ratchet state, so let them bypass the rest of the function @@ -290,24 +322,24 @@ func (mach *OlmMachine) actuallyDecryptMegolmEvent(ctx context.Context, evt *eve err = mach.CryptoStore.RedactGroupSession(ctx, sess.RoomID, sess.ID(), "maximum messages reached") if err != nil { log.Err(err).Msg("Failed to delete fully used session") - return sess, plaintext, messageIndex, RatchetError + return sess, plaintext, messageIndex, ErrRatchetError } else { log.Info().Msg("Deleted fully used session") } } else if ratchetCurrentIndex < ratchetTargetIndex && mach.RatchetKeysOnDecrypt { if err = sess.RatchetTo(ratchetTargetIndex); err != nil { log.Err(err).Msg("Failed to ratchet session") - return sess, plaintext, messageIndex, RatchetError + return sess, plaintext, messageIndex, ErrRatchetError } else if err = mach.CryptoStore.PutGroupSession(ctx, sess); err != nil { log.Err(err).Msg("Failed to store ratcheted session") - return sess, plaintext, messageIndex, RatchetError + return sess, plaintext, messageIndex, ErrRatchetError } else { log.Info().Msg("Ratcheted session forward") } } else if didModify { if err = mach.CryptoStore.PutGroupSession(ctx, sess); err != nil { log.Err(err).Msg("Failed to store updated ratchet safety data") - return sess, plaintext, messageIndex, RatchetError + return sess, plaintext, messageIndex, ErrRatchetError } else { log.Debug().Msg("Ratchet safety data changed (ratchet state didn't change)") } diff --git a/crypto/decryptolm.go b/crypto/decryptolm.go index 8f1eb1f7..aea5e6dc 100644 --- a/crypto/decryptolm.go +++ b/crypto/decryptolm.go @@ -8,6 +8,8 @@ package crypto import ( "context" + "crypto/sha256" + "encoding/base64" "encoding/json" "errors" "fmt" @@ -15,20 +17,36 @@ import ( "time" "github.com/rs/zerolog" + "go.mau.fi/util/exerrors" + "go.mau.fi/util/ptr" + "maunium.net/go/mautrix/crypto/goolm/account" "maunium.net/go/mautrix/event" "maunium.net/go/mautrix/id" ) var ( - UnsupportedAlgorithm = errors.New("unsupported event encryption algorithm") - NotEncryptedForMe = errors.New("olm event doesn't contain ciphertext for this device") - UnsupportedOlmMessageType = errors.New("unsupported olm message type") - DecryptionFailedWithMatchingSession = errors.New("decryption failed with matching session") - DecryptionFailedForNormalMessage = errors.New("decryption failed for normal message") - SenderMismatch = errors.New("mismatched sender in olm payload") - RecipientMismatch = errors.New("mismatched recipient in olm payload") - RecipientKeyMismatch = errors.New("mismatched recipient key in olm payload") + ErrUnsupportedAlgorithm = errors.New("unsupported event encryption algorithm") + ErrNotEncryptedForMe = errors.New("olm event doesn't contain ciphertext for this device") + ErrUnsupportedOlmMessageType = errors.New("unsupported olm message type") + ErrDecryptionFailedWithMatchingSession = errors.New("decryption failed with matching session") + ErrDecryptionFailedForNormalMessage = errors.New("decryption failed for normal message") + ErrSenderMismatch = errors.New("mismatched sender in olm payload") + ErrRecipientMismatch = errors.New("mismatched recipient in olm payload") + ErrRecipientKeyMismatch = errors.New("mismatched recipient key in olm payload") + ErrDuplicateMessage = errors.New("duplicate olm message") +) + +// Deprecated: use variables prefixed with Err +var ( + UnsupportedAlgorithm = ErrUnsupportedAlgorithm + NotEncryptedForMe = ErrNotEncryptedForMe + UnsupportedOlmMessageType = ErrUnsupportedOlmMessageType + DecryptionFailedWithMatchingSession = ErrDecryptionFailedWithMatchingSession + DecryptionFailedForNormalMessage = ErrDecryptionFailedForNormalMessage + SenderMismatch = ErrSenderMismatch + RecipientMismatch = ErrRecipientMismatch + RecipientKeyMismatch = ErrRecipientKeyMismatch ) // DecryptedOlmEvent represents an event that was decrypted from an event encrypted with the m.olm.v1.curve25519-aes-sha2 algorithm. @@ -50,13 +68,13 @@ type DecryptedOlmEvent struct { func (mach *OlmMachine) decryptOlmEvent(ctx context.Context, evt *event.Event) (*DecryptedOlmEvent, error) { content, ok := evt.Content.Parsed.(*event.EncryptedEventContent) if !ok { - return nil, IncorrectEncryptedContentType + return nil, ErrIncorrectEncryptedContentType } else if content.Algorithm != id.AlgorithmOlmV1 { - return nil, UnsupportedAlgorithm + return nil, ErrUnsupportedAlgorithm } ownContent, ok := content.OlmCiphertext[mach.account.IdentityKey()] if !ok { - return nil, NotEncryptedForMe + return nil, ErrNotEncryptedForMe } decrypted, err := mach.decryptAndParseOlmCiphertext(ctx, evt, content.SenderKey, ownContent.Type, ownContent.Body) if err != nil { @@ -72,7 +90,7 @@ type OlmEventKeys struct { func (mach *OlmMachine) decryptAndParseOlmCiphertext(ctx context.Context, evt *event.Event, senderKey id.SenderKey, olmType id.OlmMsgType, ciphertext string) (*DecryptedOlmEvent, error) { if olmType != id.OlmMsgTypePreKey && olmType != id.OlmMsgTypeMsg { - return nil, UnsupportedOlmMessageType + return nil, ErrUnsupportedOlmMessageType } log := mach.machOrContextLog(ctx).With(). @@ -96,16 +114,18 @@ func (mach *OlmMachine) decryptAndParseOlmCiphertext(ctx context.Context, evt *e } olmEvt.Type.Class = evt.Type.Class if evt.Sender != olmEvt.Sender { - return nil, SenderMismatch + return nil, ErrSenderMismatch } else if mach.Client.UserID != olmEvt.Recipient { - return nil, RecipientMismatch + return nil, ErrRecipientMismatch } else if mach.account.SigningKey() != olmEvt.RecipientKeys.Ed25519 { - return nil, RecipientKeyMismatch + return nil, ErrRecipientKeyMismatch } - err = olmEvt.Content.ParseRaw(olmEvt.Type) - if err != nil && !errors.Is(err, event.ErrUnsupportedContentType) { - return nil, fmt.Errorf("failed to parse content of olm payload event: %w", err) + if len(olmEvt.Content.VeryRaw) > 0 { + err = olmEvt.Content.ParseRaw(olmEvt.Type) + if err != nil && !errors.Is(err, event.ErrUnsupportedContentType) { + return nil, fmt.Errorf("failed to parse content of olm payload event: %w", err) + } } olmEvt.SenderKey = senderKey @@ -113,16 +133,40 @@ func (mach *OlmMachine) decryptAndParseOlmCiphertext(ctx context.Context, evt *e return &olmEvt, nil } +func olmMessageHash(ciphertext string) ([32]byte, error) { + if ciphertext == "" { + return [32]byte{}, fmt.Errorf("empty ciphertext") + } + ciphertextBytes, err := base64.RawStdEncoding.DecodeString(ciphertext) + return sha256.Sum256(ciphertextBytes), err +} + func (mach *OlmMachine) tryDecryptOlmCiphertext(ctx context.Context, sender id.UserID, senderKey id.SenderKey, olmType id.OlmMsgType, ciphertext string) ([]byte, error) { + ciphertextHash, err := olmMessageHash(ciphertext) + if err != nil { + return nil, fmt.Errorf("failed to hash olm ciphertext: %w", err) + } + log := *zerolog.Ctx(ctx) endTimeTrace := mach.timeTrace(ctx, "waiting for olm lock", 5*time.Second) mach.olmLock.Lock() endTimeTrace() defer mach.olmLock.Unlock() - plaintext, err := mach.tryDecryptOlmCiphertextWithExistingSession(ctx, senderKey, olmType, ciphertext) + duplicateTS, err := mach.CryptoStore.GetOlmHash(ctx, ciphertextHash) if err != nil { - if err == DecryptionFailedWithMatchingSession { + log.Warn().Err(err).Msg("Failed to check for duplicate olm message") + } else if !duplicateTS.IsZero() { + log.Warn(). + Hex("ciphertext_hash", ciphertextHash[:]). + Time("duplicate_ts", duplicateTS). + Msg("Ignoring duplicate olm message") + return nil, ErrDuplicateMessage + } + + plaintext, err := mach.tryDecryptOlmCiphertextWithExistingSession(ctx, senderKey, olmType, ciphertext, ciphertextHash) + if err != nil { + if err == ErrDecryptionFailedWithMatchingSession { log.Warn().Msg("Found matching session, but decryption failed") go mach.unwedgeDevice(log, sender, senderKey) } @@ -140,9 +184,10 @@ func (mach *OlmMachine) tryDecryptOlmCiphertext(ctx context.Context, sender id.U // if it isn't one at this point in time anymore, so return early. if olmType != id.OlmMsgTypePreKey { go mach.unwedgeDevice(log, sender, senderKey) - return nil, DecryptionFailedForNormalMessage + return nil, ErrDecryptionFailedForNormalMessage } + accountBackup, _ := mach.account.Internal.Pickle([]byte("tmp")) log.Trace().Msg("Trying to create inbound session") endTimeTrace = mach.timeTrace(ctx, "creating inbound olm session", time.Second) session, err := mach.createInboundSession(ctx, senderKey, ciphertext) @@ -153,6 +198,8 @@ func (mach *OlmMachine) tryDecryptOlmCiphertext(ctx context.Context, sender id.U } log = log.With().Str("new_olm_session_id", session.ID().String()).Logger() log.Debug(). + Hex("ciphertext_hash", ciphertextHash[:]). + Hex("ciphertext_hash_repeat", ptr.Ptr(exerrors.Must(olmMessageHash(ciphertext)))[:]). Str("olm_session_description", session.Describe()). Msg("Created inbound olm session") ctx = log.WithContext(ctx) @@ -161,11 +208,28 @@ func (mach *OlmMachine) tryDecryptOlmCiphertext(ctx context.Context, sender id.U plaintext, err = session.Decrypt(ciphertext, olmType) endTimeTrace() if err != nil { + log.Debug(). + Hex("ciphertext_hash", ciphertextHash[:]). + Hex("ciphertext_hash_repeat", ptr.Ptr(exerrors.Must(olmMessageHash(ciphertext)))[:]). + Str("ciphertext", ciphertext). + Str("olm_session_description", session.Describe()). + Msg("DEBUG: Failed to decrypt prekey olm message with newly created session") + err2 := mach.goolmRetryHack(ctx, senderKey, ciphertext, accountBackup) + if err2 != nil { + log.Debug().Err(err2).Msg("Goolm confirmed decryption failure") + } else { + log.Warn().Msg("Goolm decryption was successful after libolm failure?") + } + go mach.unwedgeDevice(log, sender, senderKey) return nil, fmt.Errorf("failed to decrypt olm event with session created from prekey message: %w", err) } endTimeTrace = mach.timeTrace(ctx, "updating new session in database", time.Second) + err = mach.CryptoStore.PutOlmHash(ctx, ciphertextHash, time.Now()) + if err != nil { + log.Warn().Err(err).Msg("Failed to store olm message hash after decrypting") + } err = mach.CryptoStore.UpdateSession(ctx, senderKey, session) endTimeTrace() if err != nil { @@ -174,9 +238,28 @@ func (mach *OlmMachine) tryDecryptOlmCiphertext(ctx context.Context, sender id.U return plaintext, nil } +func (mach *OlmMachine) goolmRetryHack(ctx context.Context, senderKey id.SenderKey, ciphertext string, accountBackup []byte) error { + acc, err := account.AccountFromPickled(accountBackup, []byte("tmp")) + if err != nil { + return fmt.Errorf("failed to unpickle olm account: %w", err) + } + sess, err := acc.NewInboundSessionFrom(&senderKey, ciphertext) + if err != nil { + return fmt.Errorf("failed to create inbound session: %w", err) + } + _, err = sess.Decrypt(ciphertext, id.OlmMsgTypePreKey) + if err != nil { + // This is the expected result if libolm failed + return fmt.Errorf("failed to decrypt with new session: %w", err) + } + return nil +} + const MaxOlmSessionsPerDevice = 5 -func (mach *OlmMachine) tryDecryptOlmCiphertextWithExistingSession(ctx context.Context, senderKey id.SenderKey, olmType id.OlmMsgType, ciphertext string) ([]byte, error) { +func (mach *OlmMachine) tryDecryptOlmCiphertextWithExistingSession( + ctx context.Context, senderKey id.SenderKey, olmType id.OlmMsgType, ciphertext string, ciphertextHash [32]byte, +) ([]byte, error) { log := *zerolog.Ctx(ctx) endTimeTrace := mach.timeTrace(ctx, "getting sessions with sender key", time.Second) sessions, err := mach.CryptoStore.GetSessions(ctx, senderKey) @@ -229,19 +312,28 @@ func (mach *OlmMachine) tryDecryptOlmCiphertextWithExistingSession(ctx context.C endTimeTrace() if err != nil { log.Warn().Err(err). + Hex("ciphertext_hash", ciphertextHash[:]). + Hex("ciphertext_hash_repeat", ptr.Ptr(exerrors.Must(olmMessageHash(ciphertext)))[:]). Str("session_description", session.Describe()). Msg("Failed to decrypt olm message") if olmType == id.OlmMsgTypePreKey { - return nil, DecryptionFailedWithMatchingSession + return nil, ErrDecryptionFailedWithMatchingSession } } else { endTimeTrace = mach.timeTrace(ctx, "updating session in database", time.Second) + err = mach.CryptoStore.PutOlmHash(ctx, ciphertextHash, time.Now()) + if err != nil { + log.Warn().Err(err).Msg("Failed to store olm message hash after decrypting") + } err = mach.CryptoStore.UpdateSession(ctx, senderKey, session) endTimeTrace() if err != nil { log.Warn().Err(err).Msg("Failed to update olm session in crypto store after decrypting") } - log.Debug().Str("session_description", session.Describe()).Msg("Decrypted olm message") + log.Debug(). + Hex("ciphertext_hash", ciphertextHash[:]). + Str("session_description", session.Describe()). + Msg("Decrypted olm message") return plaintext, nil } } @@ -265,10 +357,10 @@ const MinUnwedgeInterval = 1 * time.Hour func (mach *OlmMachine) unwedgeDevice(log zerolog.Logger, sender id.UserID, senderKey id.SenderKey) { log = log.With().Str("action", "unwedge olm session").Logger() - ctx := log.WithContext(context.TODO()) + ctx := log.WithContext(mach.backgroundCtx) mach.recentlyUnwedgedLock.Lock() prevUnwedge, ok := mach.recentlyUnwedged[senderKey] - delta := time.Now().Sub(prevUnwedge) + delta := time.Since(prevUnwedge) if ok && delta < MinUnwedgeInterval { log.Debug(). Str("previous_recreation", delta.String()). @@ -279,6 +371,17 @@ func (mach *OlmMachine) unwedgeDevice(log zerolog.Logger, sender id.UserID, send mach.recentlyUnwedged[senderKey] = time.Now() mach.recentlyUnwedgedLock.Unlock() + lastCreatedAt, err := mach.CryptoStore.GetNewestSessionCreationTS(ctx, senderKey) + if err != nil { + log.Warn().Err(err).Msg("Failed to get newest session creation timestamp") + return + } else if time.Since(lastCreatedAt) < MinUnwedgeInterval { + log.Debug(). + Time("last_created_at", lastCreatedAt). + Msg("Not creating new Olm session as it was already recreated recently") + return + } + deviceIdentity, err := mach.GetOrFetchDeviceByKey(ctx, sender, senderKey) if err != nil { log.Error().Err(err).Msg("Failed to find device info by identity key") @@ -288,7 +391,10 @@ func (mach *OlmMachine) unwedgeDevice(log zerolog.Logger, sender id.UserID, send return } - log.Debug().Str("device_id", deviceIdentity.DeviceID.String()).Msg("Creating new Olm session") + log.Debug(). + Time("last_created", lastCreatedAt). + Stringer("device_id", deviceIdentity.DeviceID). + Msg("Creating new Olm session") mach.devicesToUnwedgeLock.Lock() mach.devicesToUnwedge[senderKey] = true mach.devicesToUnwedgeLock.Unlock() diff --git a/crypto/devicelist.go b/crypto/devicelist.go index a2116ed5..f0d2b129 100644 --- a/crypto/devicelist.go +++ b/crypto/devicelist.go @@ -22,14 +22,23 @@ import ( ) var ( - MismatchingDeviceID = errors.New("mismatching device ID in parameter and keys object") - MismatchingUserID = errors.New("mismatching user ID in parameter and keys object") - MismatchingSigningKey = errors.New("received update for device with different signing key") - NoSigningKeyFound = errors.New("didn't find ed25519 signing key") - NoIdentityKeyFound = errors.New("didn't find curve25519 identity key") - InvalidKeySignature = errors.New("invalid signature on device keys") + ErrMismatchingDeviceID = errors.New("mismatching device ID in parameter and keys object") + ErrMismatchingUserID = errors.New("mismatching user ID in parameter and keys object") + ErrMismatchingSigningKey = errors.New("received update for device with different signing key") + ErrNoSigningKeyFound = errors.New("didn't find ed25519 signing key") + ErrNoIdentityKeyFound = errors.New("didn't find curve25519 identity key") + ErrInvalidKeySignature = errors.New("invalid signature on device keys") + ErrUserNotTracked = errors.New("user is not tracked") +) - ErrUserNotTracked = errors.New("user is not tracked") +// Deprecated: use variables prefixed with Err +var ( + MismatchingDeviceID = ErrMismatchingDeviceID + MismatchingUserID = ErrMismatchingUserID + MismatchingSigningKey = ErrMismatchingSigningKey + NoSigningKeyFound = ErrNoSigningKeyFound + NoIdentityKeyFound = ErrNoIdentityKeyFound + InvalidKeySignature = ErrInvalidKeySignature ) func (mach *OlmMachine) LoadDevices(ctx context.Context, user id.UserID) (keys map[id.DeviceID]*id.Device) { @@ -206,7 +215,7 @@ func (mach *OlmMachine) FetchKeys(ctx context.Context, users []id.UserID, includ log.Trace().Int("user_count", len(resp.DeviceKeys)).Msg("Query key result received") data = make(map[id.UserID]map[id.DeviceID]*id.Device) for userID, devices := range resp.DeviceKeys { - log := log.With().Str("user_id", userID.String()).Logger() + log := log.With().Stringer("user_id", userID).Logger() delete(req.DeviceKeys, userID) newDevices := make(map[id.DeviceID]*id.Device) @@ -222,7 +231,7 @@ func (mach *OlmMachine) FetchKeys(ctx context.Context, users []id.UserID, includ Msg("Updating devices in store") changed := false for deviceID, deviceKeys := range devices { - log := log.With().Str("device_id", deviceID.String()).Logger() + log := log.With().Stringer("device_id", deviceID).Logger() existing, ok := existingDevices[deviceID] if !ok { // New device @@ -270,7 +279,7 @@ func (mach *OlmMachine) FetchKeys(ctx context.Context, users []id.UserID, includ } } for userID := range req.DeviceKeys { - log.Warn().Str("user_id", userID.String()).Msg("Didn't get any keys for user") + log.Warn().Stringer("user_id", userID).Msg("Didn't get any keys for user") } mach.storeCrossSigningKeys(ctx, resp.MasterKeys, resp.DeviceKeys) @@ -312,28 +321,28 @@ func (mach *OlmMachine) OnDevicesChanged(ctx context.Context, userID id.UserID) func (mach *OlmMachine) validateDevice(userID id.UserID, deviceID id.DeviceID, deviceKeys mautrix.DeviceKeys, existing *id.Device) (*id.Device, error) { if deviceID != deviceKeys.DeviceID { - return nil, fmt.Errorf("%w (expected %s, got %s)", MismatchingDeviceID, deviceID, deviceKeys.DeviceID) + return nil, fmt.Errorf("%w (expected %s, got %s)", ErrMismatchingDeviceID, deviceID, deviceKeys.DeviceID) } else if userID != deviceKeys.UserID { - return nil, fmt.Errorf("%w (expected %s, got %s)", MismatchingUserID, userID, deviceKeys.UserID) + return nil, fmt.Errorf("%w (expected %s, got %s)", ErrMismatchingUserID, userID, deviceKeys.UserID) } signingKey := deviceKeys.Keys.GetEd25519(deviceID) identityKey := deviceKeys.Keys.GetCurve25519(deviceID) if signingKey == "" { - return nil, NoSigningKeyFound + return nil, ErrNoSigningKeyFound } else if identityKey == "" { - return nil, NoIdentityKeyFound + return nil, ErrNoIdentityKeyFound } if existing != nil && existing.SigningKey != signingKey { - return existing, fmt.Errorf("%w (expected %s, got %s)", MismatchingSigningKey, existing.SigningKey, signingKey) + return existing, fmt.Errorf("%w (expected %s, got %s)", ErrMismatchingSigningKey, existing.SigningKey, signingKey) } ok, err := signatures.VerifySignatureJSON(deviceKeys, userID, deviceID.String(), signingKey) if err != nil { return existing, fmt.Errorf("failed to verify signature: %w", err) } else if !ok { - return existing, InvalidKeySignature + return existing, ErrInvalidKeySignature } name, ok := deviceKeys.Unsigned["device_display_name"].(string) diff --git a/crypto/encryptmegolm.go b/crypto/encryptmegolm.go index ef5f404f..88f9c8d4 100644 --- a/crypto/encryptmegolm.go +++ b/crypto/encryptmegolm.go @@ -25,8 +25,12 @@ import ( ) var ( - AlreadyShared = errors.New("group session already shared") - NoGroupSession = errors.New("no group session created") + ErrNoGroupSession = errors.New("no group session created") +) + +// Deprecated: use variables prefixed with Err +var ( + NoGroupSession = ErrNoGroupSession ) func getRawJSON[T any](content json.RawMessage, path ...string) *T { @@ -42,7 +46,7 @@ func getRawJSON[T any](content json.RawMessage, path ...string) *T { return &result } -func getRelatesTo(content any) *event.RelatesTo { +func getRelatesTo(content any, plaintext json.RawMessage) *event.RelatesTo { contentJSON, ok := content.(json.RawMessage) if ok { return getRawJSON[event.RelatesTo](contentJSON, "m.relates_to") @@ -55,7 +59,7 @@ func getRelatesTo(content any) *event.RelatesTo { if ok { return relatable.OptionalGetRelatesTo() } - return nil + return getRawJSON[event.RelatesTo](plaintext, "content", "m.relates_to") } func getMentions(content any) *event.Mentions { @@ -83,15 +87,20 @@ type rawMegolmEvent struct { // IsShareError returns true if the error is caused by the lack of an outgoing megolm session and can be solved with OlmMachine.ShareGroupSession func IsShareError(err error) bool { - return err == SessionExpired || err == SessionNotShared || err == NoGroupSession + return err == ErrSessionExpired || err == ErrSessionNotShared || err == ErrNoGroupSession } func ParseMegolmMessageIndex(ciphertext []byte) (uint, error) { + if len(ciphertext) == 0 { + return 0, fmt.Errorf("empty ciphertext") + } decoded := make([]byte, base64.RawStdEncoding.DecodedLen(len(ciphertext))) var err error _, err = base64.RawStdEncoding.Decode(decoded, ciphertext) if err != nil { return 0, err + } else if len(decoded) < 2+binary.MaxVarintLen64 { + return 0, fmt.Errorf("decoded ciphertext too short: %d bytes", len(decoded)) } else if decoded[0] != 3 || decoded[1] != 8 { return 0, fmt.Errorf("unexpected initial bytes %d and %d", decoded[0], decoded[1]) } @@ -121,7 +130,7 @@ func (mach *OlmMachine) EncryptMegolmEventWithStateKey(ctx context.Context, room if err != nil { return nil, fmt.Errorf("failed to get outbound group session: %w", err) } else if session == nil { - return nil, NoGroupSession + return nil, ErrNoGroupSession } plaintext, err := json.Marshal(&rawMegolmEvent{ RoomID: roomID, @@ -159,12 +168,21 @@ func (mach *OlmMachine) EncryptMegolmEventWithStateKey(ctx context.Context, room Algorithm: id.AlgorithmMegolmV1, SessionID: session.ID(), MegolmCiphertext: ciphertext, - RelatesTo: getRelatesTo(content), + RelatesTo: getRelatesTo(content, plaintext), // These are deprecated SenderKey: mach.account.IdentityKey(), DeviceID: mach.Client.DeviceID, } + if mach.MSC4392Relations && encrypted.RelatesTo != nil { + // When MSC4392 mode is enabled, reply and reaction metadata is stripped from the unencrypted content. + // Other relations like threads are still left unencrypted. + encrypted.RelatesTo.InReplyTo = nil + encrypted.RelatesTo.IsFallingBack = false + if evtType == event.EventReaction || encrypted.RelatesTo.Type == "" { + encrypted.RelatesTo = nil + } + } if mach.PlaintextMentions { encrypted.Mentions = getMentions(content) } @@ -209,7 +227,8 @@ func (mach *OlmMachine) ShareGroupSession(ctx context.Context, roomID id.RoomID, if err != nil { return fmt.Errorf("failed to get previous outbound group session: %w", err) } else if session != nil && session.Shared && !session.Expired() { - return AlreadyShared + mach.machOrContextLog(ctx).Debug().Stringer("room_id", roomID).Msg("Not re-sharing group session, already shared") + return nil } log := mach.machOrContextLog(ctx).With(). Str("room_id", roomID.String()). @@ -233,7 +252,7 @@ func (mach *OlmMachine) ShareGroupSession(ctx context.Context, roomID id.RoomID, var fetchKeysForUsers []id.UserID for _, userID := range users { - log := log.With().Str("target_user_id", userID.String()).Logger() + log := log.With().Stringer("target_user_id", userID).Logger() devices, err := mach.CryptoStore.GetDevices(ctx, userID) if err != nil { log.Err(err).Msg("Failed to get devices of user") @@ -305,7 +324,7 @@ func (mach *OlmMachine) ShareGroupSession(ctx context.Context, roomID id.RoomID, toDeviceWithheld.Messages[userID] = withheld } - log := log.With().Str("target_user_id", userID.String()).Logger() + log := log.With().Stringer("target_user_id", userID).Logger() log.Trace().Msg("Trying to find olm session to encrypt megolm session for user (post-fetch retry)") mach.findOlmSessionsForUser(ctx, session, userID, devices, output, withheld, nil) log.Debug(). @@ -351,26 +370,19 @@ func (mach *OlmMachine) encryptAndSendGroupSession(ctx context.Context, session log.Trace().Msg("Encrypting group session for all found devices") deviceCount := 0 toDevice := &mautrix.ReqSendToDevice{Messages: make(map[id.UserID]map[id.DeviceID]*event.Content)} + logUsers := zerolog.Dict() for userID, sessions := range olmSessions { if len(sessions) == 0 { continue } + logDevices := zerolog.Dict() output := make(map[id.DeviceID]*event.Content) toDevice.Messages[userID] = output for deviceID, device := range sessions { - log.Trace(). - Stringer("target_user_id", userID). - Stringer("target_device_id", deviceID). - Stringer("target_identity_key", device.identity.IdentityKey). - Msg("Encrypting group session for device") content := mach.encryptOlmEvent(ctx, device.session, device.identity, event.ToDeviceRoomKey, session.ShareContent()) output[deviceID] = &event.Content{Parsed: content} + logDevices.Str(string(deviceID), string(device.identity.IdentityKey)) deviceCount++ - log.Debug(). - Stringer("target_user_id", userID). - Stringer("target_device_id", deviceID). - Stringer("target_identity_key", device.identity.IdentityKey). - Msg("Encrypted group session for device") if !mach.DisableSharedGroupSessionTracking { err := mach.CryptoStore.MarkOutboundGroupSessionShared(ctx, userID, device.identity.IdentityKey, session.id) if err != nil { @@ -384,11 +396,13 @@ func (mach *OlmMachine) encryptAndSendGroupSession(ctx context.Context, session } } } + logUsers.Dict(string(userID), logDevices) } log.Debug(). Int("device_count", deviceCount). Int("user_count", len(toDevice.Messages)). + Dict("destination_map", logUsers). Msg("Sending to-device messages to share group session") _, err := mach.Client.SendToDevice(ctx, event.ToDeviceEncrypted, toDevice) return err @@ -417,7 +431,7 @@ func (mach *OlmMachine) findOlmSessionsForUser(ctx context.Context, session *Out Reason: "Device is blacklisted", }} session.Users[userKey] = OGSIgnored - } else if trustState := mach.ResolveTrust(device); trustState < mach.SendKeysMinTrust { + } else if trustState, _ := mach.ResolveTrustContext(ctx, device); trustState < mach.SendKeysMinTrust { log.Debug(). Str("min_trust", mach.SendKeysMinTrust.String()). Str("device_trust", trustState.String()). diff --git a/crypto/encryptolm.go b/crypto/encryptolm.go index 52e30166..765307af 100644 --- a/crypto/encryptolm.go +++ b/crypto/encryptolm.go @@ -17,6 +17,70 @@ import ( "maunium.net/go/mautrix/id" ) +func (mach *OlmMachine) EncryptToDevices(ctx context.Context, eventType event.Type, req *mautrix.ReqSendToDevice) (*mautrix.ReqSendToDevice, error) { + devicesToCreateSessions := make(map[id.UserID]map[id.DeviceID]*id.Device) + for userID, devices := range req.Messages { + for deviceID := range devices { + device, err := mach.GetOrFetchDevice(ctx, userID, deviceID) + if err != nil { + return nil, fmt.Errorf("failed to get device %s of user %s: %w", deviceID, userID, err) + } + + if _, ok := devicesToCreateSessions[userID]; !ok { + devicesToCreateSessions[userID] = make(map[id.DeviceID]*id.Device) + } + devicesToCreateSessions[userID][deviceID] = device + } + } + if err := mach.createOutboundSessions(ctx, devicesToCreateSessions); err != nil { + return nil, fmt.Errorf("failed to create outbound sessions: %w", err) + } + + mach.olmLock.Lock() + defer mach.olmLock.Unlock() + + encryptedReq := &mautrix.ReqSendToDevice{ + Messages: make(map[id.UserID]map[id.DeviceID]*event.Content), + } + + log := mach.machOrContextLog(ctx) + + for userID, devices := range req.Messages { + encryptedReq.Messages[userID] = make(map[id.DeviceID]*event.Content) + + for deviceID, content := range devices { + device := devicesToCreateSessions[userID][deviceID] + + olmSess, err := mach.CryptoStore.GetLatestSession(ctx, device.IdentityKey) + if err != nil { + return nil, fmt.Errorf("failed to get latest session for device %s of %s: %w", deviceID, userID, err) + } else if olmSess == nil { + log.Warn(). + Str("target_user_id", userID.String()). + Str("target_device_id", deviceID.String()). + Str("identity_key", device.IdentityKey.String()). + Msg("No outbound session found for device") + continue + } + + encrypted := mach.encryptOlmEvent(ctx, olmSess, device, eventType, *content) + encryptedContent := &event.Content{Parsed: &encrypted} + + log.Debug(). + Str("decrypted_type", eventType.Type). + Str("target_user_id", userID.String()). + Str("target_device_id", deviceID.String()). + Str("target_identity_key", device.IdentityKey.String()). + Str("olm_session_id", olmSess.ID().String()). + Msg("Encrypted to-device event") + + encryptedReq.Messages[userID][deviceID] = encryptedContent + } + } + + return encryptedReq, nil +} + func (mach *OlmMachine) encryptOlmEvent(ctx context.Context, session *OlmSession, recipient *id.Device, evtType event.Type, content event.Content) *event.EncryptedEventContent { evt := &DecryptedOlmEvent{ Sender: mach.Client.UserID, @@ -32,15 +96,19 @@ func (mach *OlmMachine) encryptOlmEvent(ctx context.Context, session *OlmSession panic(err) } log := mach.machOrContextLog(ctx) - log.Debug(). - Str("recipient_identity_key", recipient.IdentityKey.String()). - Str("olm_session_id", session.ID().String()). - Str("olm_session_description", session.Describe()). - Msg("Encrypting olm message") msgType, ciphertext, err := session.Encrypt(plaintext) if err != nil { panic(err) } + ciphertextStr := string(ciphertext) + ciphertextHash, _ := olmMessageHash(ciphertextStr) + log.Debug(). + Stringer("event_type", evtType). + Str("recipient_identity_key", recipient.IdentityKey.String()). + Str("olm_session_id", session.ID().String()). + Str("olm_session_description", session.Describe()). + Hex("ciphertext_hash", ciphertextHash[:]). + Msg("Encrypted olm message") err = mach.CryptoStore.UpdateSession(ctx, recipient.IdentityKey, session) if err != nil { log.Error().Err(err).Msg("Failed to update olm session in crypto store after encrypting") @@ -51,7 +119,7 @@ func (mach *OlmMachine) encryptOlmEvent(ctx context.Context, session *OlmSession OlmCiphertext: event.OlmCiphertexts{ recipient.IdentityKey: { Type: msgType, - Body: string(ciphertext), + Body: ciphertextStr, }, }, } diff --git a/crypto/goolm/account/account.go b/crypto/goolm/account/account.go index 099cc493..b48843a4 100644 --- a/crypto/goolm/account/account.go +++ b/crypto/goolm/account/account.go @@ -8,11 +8,9 @@ import ( "maunium.net/go/mautrix/id" - "maunium.net/go/mautrix/crypto/goolm/cipher" "maunium.net/go/mautrix/crypto/goolm/crypto" "maunium.net/go/mautrix/crypto/goolm/libolmpickle" "maunium.net/go/mautrix/crypto/goolm/session" - "maunium.net/go/mautrix/crypto/goolm/utilities" "maunium.net/go/mautrix/crypto/olm" ) @@ -76,12 +74,12 @@ func NewAccount() (*Account, error) { // PickleAsJSON returns an Account as a base64 string encrypted using the supplied key. The unencrypted representation of the Account is in JSON format. func (a *Account) PickleAsJSON(key []byte) ([]byte, error) { - return utilities.PickleAsJSON(a, accountPickleVersionJSON, key) + return libolmpickle.PickleAsJSON(a, accountPickleVersionJSON, key) } // UnpickleAsJSON updates an Account by a base64 encrypted string using the supplied key. The unencrypted representation has to be in JSON format. func (a *Account) UnpickleAsJSON(pickled, key []byte) error { - return utilities.UnpickleAsJSON(a, pickled, key, accountPickleVersionJSON) + return libolmpickle.UnpickleAsJSON(a, pickled, key, accountPickleVersionJSON) } // IdentityKeysJSON returns the public parts of the identity keys for the Account in a JSON string. @@ -322,7 +320,7 @@ func (a *Account) ForgetOldFallbackKey() { // Unpickle decodes the base64 encoded string and decrypts the result with the key. // The decrypted value is then passed to UnpickleLibOlm. func (a *Account) Unpickle(pickled, key []byte) error { - decrypted, err := cipher.Unpickle(key, pickled) + decrypted, err := libolmpickle.Unpickle(key, pickled) if err != nil { return err } @@ -336,7 +334,7 @@ func (a *Account) UnpickleLibOlm(buf []byte) error { if err != nil { return err } else if pickledVersion != accountPickleVersionLibOLM && pickledVersion != 3 && pickledVersion != 2 { - return fmt.Errorf("unpickle account: %w (found version %d)", olm.ErrBadVersion, pickledVersion) + return fmt.Errorf("unpickle account: %w (found version %d)", olm.ErrUnknownOlmPickleVersion, pickledVersion) } else if err = a.IdKeys.Ed25519.UnpickleLibOlm(decoder); err != nil { // read the ed25519 key pair return err } else if err = a.IdKeys.Curve25519.UnpickleLibOlm(decoder); err != nil { // read curve25519 key pair @@ -410,7 +408,7 @@ func (a *Account) Pickle(key []byte) ([]byte, error) { if len(key) == 0 { return nil, olm.ErrNoKeyProvided } - return cipher.Pickle(key, a.PickleLibOlm()) + return libolmpickle.Pickle(key, a.PickleLibOlm()) } // PickleLibOlm pickles the [Account] and returns the raw bytes. diff --git a/crypto/goolm/account/account_test.go b/crypto/goolm/account/account_test.go index e1c9b452..d0dec5f0 100644 --- a/crypto/goolm/account/account_test.go +++ b/crypto/goolm/account/account_test.go @@ -124,7 +124,7 @@ func TestOldAccountPickle(t *testing.T) { account, err := account.NewAccount() assert.NoError(t, err) err = account.Unpickle(pickled, pickleKey) - assert.ErrorIs(t, err, olm.ErrBadVersion) + assert.ErrorIs(t, err, olm.ErrUnknownOlmPickleVersion) } func TestLoopback(t *testing.T) { diff --git a/crypto/goolm/account/register.go b/crypto/goolm/account/register.go index c6b9e523..ec392d7e 100644 --- a/crypto/goolm/account/register.go +++ b/crypto/goolm/account/register.go @@ -10,7 +10,7 @@ import ( "maunium.net/go/mautrix/crypto/olm" ) -func init() { +func Register() { olm.InitNewAccount = func() (olm.Account, error) { return NewAccount() } diff --git a/crypto/goolm/aessha2/aessha2.go b/crypto/goolm/aessha2/aessha2.go new file mode 100644 index 00000000..42d9811b --- /dev/null +++ b/crypto/goolm/aessha2/aessha2.go @@ -0,0 +1,59 @@ +// Copyright (c) 2024 Sumner Evans +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +// Package aessha2 implements the m.megolm.v1.aes-sha2 encryption algorithm +// described in [Section 10.12.4.3] in the Spec +// +// [Section 10.12.4.3]: https://spec.matrix.org/v1.12/client-server-api/#mmegolmv1aes-sha2 +package aessha2 + +import ( + "crypto/hmac" + "crypto/sha256" + "crypto/subtle" + "io" + + "golang.org/x/crypto/hkdf" + + "maunium.net/go/mautrix/crypto/aescbc" +) + +type AESSHA2 struct { + aesKey, hmacKey, iv []byte +} + +func NewAESSHA2(secret, info []byte) (AESSHA2, error) { + kdf := hkdf.New(sha256.New, secret, nil, info) + keymatter := make([]byte, 80) + _, err := io.ReadFull(kdf, keymatter) + return AESSHA2{ + keymatter[:32], // AES Key + keymatter[32:64], // HMAC Key + keymatter[64:], // IV + }, err +} + +func (a *AESSHA2) Encrypt(plaintext []byte) ([]byte, error) { + return aescbc.Encrypt(a.aesKey, a.iv, plaintext) +} + +func (a *AESSHA2) Decrypt(ciphertext []byte) ([]byte, error) { + return aescbc.Decrypt(a.aesKey, a.iv, ciphertext) +} + +func (a *AESSHA2) MAC(ciphertext []byte) ([]byte, error) { + hash := hmac.New(sha256.New, a.hmacKey) + _, err := hash.Write(ciphertext) + return hash.Sum(nil), err +} + +func (a *AESSHA2) VerifyMAC(ciphertext, theirMAC []byte) (bool, error) { + if mac, err := a.MAC(ciphertext); err != nil { + return false, err + } else { + return subtle.ConstantTimeCompare(mac[:len(theirMAC)], theirMAC) == 1, nil + } +} diff --git a/crypto/goolm/aessha2/aessha2_test.go b/crypto/goolm/aessha2/aessha2_test.go new file mode 100644 index 00000000..b2cfe8aa --- /dev/null +++ b/crypto/goolm/aessha2/aessha2_test.go @@ -0,0 +1,33 @@ +package aessha2_test + +import ( + "crypto/aes" + "testing" + + "github.com/stretchr/testify/assert" + + "maunium.net/go/mautrix/crypto/goolm/aessha2" +) + +func TestCipherAESSha256(t *testing.T) { + key := []byte("test key") + cipher, err := aessha2.NewAESSHA2(key, []byte("testKDFinfo")) + assert.NoError(t, err) + message := []byte("this is a random message for testing the implementation") + //increase to next block size + for len(message)%aes.BlockSize != 0 { + message = append(message, []byte("-")...) + } + encrypted, err := cipher.Encrypt([]byte(message)) + assert.NoError(t, err) + mac, err := cipher.MAC(encrypted) + assert.NoError(t, err) + + verified, err := cipher.VerifyMAC(encrypted, mac[:8]) + assert.NoError(t, err) + assert.True(t, verified, "signature verification failed") + + resultPlainText, err := cipher.Decrypt(encrypted) + assert.NoError(t, err) + assert.Equal(t, message, resultPlainText) +} diff --git a/crypto/goolm/cipher/aes_sha256.go b/crypto/goolm/cipher/aes_sha256.go deleted file mode 100644 index 42f5d069..00000000 --- a/crypto/goolm/cipher/aes_sha256.go +++ /dev/null @@ -1,81 +0,0 @@ -package cipher - -import ( - "bytes" - "crypto/hmac" - "crypto/sha256" - "io" - - "golang.org/x/crypto/hkdf" - - "maunium.net/go/mautrix/crypto/aescbc" -) - -// derivedAESKeys stores the derived keys for the AESSHA256 cipher -type derivedAESKeys struct { - key []byte - hmacKey []byte - iv []byte -} - -// deriveAESKeys derives three keys for the AESSHA256 cipher -func deriveAESKeys(kdfInfo []byte, key []byte) (derivedAESKeys, error) { - kdf := hkdf.New(sha256.New, key, nil, kdfInfo) - keymatter := make([]byte, 80) - _, err := io.ReadFull(kdf, keymatter) - return derivedAESKeys{ - key: keymatter[:32], - hmacKey: keymatter[32:64], - iv: keymatter[64:], - }, err -} - -// AESSHA256 is a valid cipher using AES with CBC and HKDFSha256. -type AESSHA256 struct { - kdfInfo []byte -} - -// NewAESSHA256 returns a new AESSHA256 cipher with the key derive function info (kdfInfo). -func NewAESSHA256(kdfInfo []byte) *AESSHA256 { - return &AESSHA256{ - kdfInfo: kdfInfo, - } -} - -// Encrypt encrypts the plaintext with the key. The key is used to derive the actual encryption key (32 bytes) as well as the iv (16 bytes). -func (c AESSHA256) Encrypt(key, plaintext []byte) (ciphertext []byte, err error) { - keys, err := deriveAESKeys(c.kdfInfo, key) - if err != nil { - return nil, err - } - return aescbc.Encrypt(keys.key, keys.iv, plaintext) -} - -// Decrypt decrypts the ciphertext with the key. The key is used to derive the actual encryption key (32 bytes) as well as the iv (16 bytes). -func (c AESSHA256) Decrypt(key, ciphertext []byte) (plaintext []byte, err error) { - keys, err := deriveAESKeys(c.kdfInfo, key) - if err != nil { - return nil, err - } - return aescbc.Decrypt(keys.key, keys.iv, ciphertext) -} - -// MAC returns the MAC for the message using the key. The key is used to derive the actual mac key (32 bytes). -func (c AESSHA256) MAC(key, message []byte) ([]byte, error) { - keys, err := deriveAESKeys(c.kdfInfo, key) - if err != nil { - return nil, err - } - hash := hmac.New(sha256.New, keys.hmacKey) - _, err = hash.Write(message) - return hash.Sum(nil), err -} - -// Verify checks the MAC of the message using the key against the givenMAC. The key is used to derive the actual mac key (32 bytes). -func (c AESSHA256) Verify(key, message, givenMAC []byte) (bool, error) { - mac, err := c.MAC(key, message) - if err != nil { - return false, err - } - return bytes.Equal(givenMAC, mac[:len(givenMAC)]), nil -} diff --git a/crypto/goolm/cipher/aes_sha256_test.go b/crypto/goolm/cipher/aes_sha256_test.go deleted file mode 100644 index 2f58605f..00000000 --- a/crypto/goolm/cipher/aes_sha256_test.go +++ /dev/null @@ -1,60 +0,0 @@ -package cipher - -import ( - "crypto/aes" - "testing" - - "github.com/stretchr/testify/assert" -) - -func TestDeriveAESKeys(t *testing.T) { - derivedKeys, err := deriveAESKeys([]byte("test"), []byte("test key")) - assert.NoError(t, err) - derivedKeys2, err := deriveAESKeys([]byte("test"), []byte("test key")) - assert.NoError(t, err) - - //derivedKeys and derivedKeys2 should be identical - assert.Equal(t, derivedKeys.key, derivedKeys2.key) - assert.Equal(t, derivedKeys.iv, derivedKeys2.iv) - assert.Equal(t, derivedKeys.hmacKey, derivedKeys2.hmacKey) - - //changing kdfInfo - derivedKeys2, err = deriveAESKeys([]byte("other kdf"), []byte("test key")) - assert.NoError(t, err) - - //derivedKeys and derivedKeys2 should now be different - assert.NotEqual(t, derivedKeys.key, derivedKeys2.key) - assert.NotEqual(t, derivedKeys.iv, derivedKeys2.iv) - assert.NotEqual(t, derivedKeys.hmacKey, derivedKeys2.hmacKey) - - //changing key - derivedKeys, err = deriveAESKeys([]byte("test"), []byte("other test key")) - assert.NoError(t, err) - - //derivedKeys and derivedKeys2 should now be different - assert.NotEqual(t, derivedKeys.key, derivedKeys2.key) - assert.NotEqual(t, derivedKeys.iv, derivedKeys2.iv) - assert.NotEqual(t, derivedKeys.hmacKey, derivedKeys2.hmacKey) -} - -func TestCipherAESSha256(t *testing.T) { - key := []byte("test key") - cipher := NewAESSHA256([]byte("testKDFinfo")) - message := []byte("this is a random message for testing the implementation") - //increase to next block size - for len(message)%aes.BlockSize != 0 { - message = append(message, []byte("-")...) - } - encrypted, err := cipher.Encrypt(key, []byte(message)) - assert.NoError(t, err) - mac, err := cipher.MAC(key, encrypted) - assert.NoError(t, err) - - verified, err := cipher.Verify(key, encrypted, mac[:8]) - assert.NoError(t, err) - assert.True(t, verified, "signature verification failed") - - resultPlainText, err := cipher.Decrypt(key, encrypted) - assert.NoError(t, err) - assert.Equal(t, message, resultPlainText) -} diff --git a/crypto/goolm/cipher/cipher.go b/crypto/goolm/cipher/cipher.go deleted file mode 100644 index 43580b0b..00000000 --- a/crypto/goolm/cipher/cipher.go +++ /dev/null @@ -1,18 +0,0 @@ -// Package cipher provides the methods and structs to do encryptions for -// olm/megolm. -package cipher - -// Cipher defines a valid cipher. -type Cipher interface { - // Encrypt encrypts the plaintext. - Encrypt(key, plaintext []byte) (ciphertext []byte, err error) - - // Decrypt decrypts the ciphertext. - Decrypt(key, ciphertext []byte) (plaintext []byte, err error) - - //MAC returns the MAC of the message calculated with the key. - MAC(key, message []byte) ([]byte, error) - - //Verify checks the MAC of the message calculated with the key against the givenMAC. - Verify(key, message, givenMAC []byte) (bool, error) -} diff --git a/crypto/goolm/cipher/pickle.go b/crypto/goolm/cipher/pickle.go deleted file mode 100644 index 754c7963..00000000 --- a/crypto/goolm/cipher/pickle.go +++ /dev/null @@ -1,55 +0,0 @@ -package cipher - -import ( - "crypto/aes" - "fmt" - - "maunium.net/go/mautrix/crypto/goolm/goolmbase64" - "maunium.net/go/mautrix/crypto/olm" -) - -const ( - kdfPickle = "Pickle" //used to derive the keys for encryption - pickleMACLength = 8 -) - -// PickleBlockSize returns the blocksize of the used cipher. -func PickleBlockSize() int { - return aes.BlockSize -} - -// Pickle encrypts the input with the key and the cipher AESSHA256. The result is then encoded in base64. -func Pickle(key, input []byte) ([]byte, error) { - pickleCipher := NewAESSHA256([]byte(kdfPickle)) - ciphertext, err := pickleCipher.Encrypt(key, input) - if err != nil { - return nil, err - } - mac, err := pickleCipher.MAC(key, ciphertext) - if err != nil { - return nil, err - } - ciphertext = append(ciphertext, mac[:pickleMACLength]...) - return goolmbase64.Encode(ciphertext), nil -} - -// Unpickle decodes the input from base64 and decrypts the decoded input with the key and the cipher AESSHA256. -func Unpickle(key, input []byte) ([]byte, error) { - pickleCipher := NewAESSHA256([]byte(kdfPickle)) - ciphertext, err := goolmbase64.Decode(input) - if err != nil { - return nil, err - } - //remove mac and check - verified, err := pickleCipher.Verify(key, ciphertext[:len(ciphertext)-pickleMACLength], ciphertext[len(ciphertext)-pickleMACLength:]) - if err != nil { - return nil, err - } - if !verified { - return nil, fmt.Errorf("decrypt pickle: %w", olm.ErrBadMAC) - } - //Set to next block size - targetCipherText := make([]byte, int(len(ciphertext)/PickleBlockSize())*PickleBlockSize()) - copy(targetCipherText, ciphertext) - return pickleCipher.Decrypt(key, targetCipherText) -} diff --git a/crypto/goolm/cipher/pickle_test.go b/crypto/goolm/cipher/pickle_test.go deleted file mode 100644 index b6cfe809..00000000 --- a/crypto/goolm/cipher/pickle_test.go +++ /dev/null @@ -1,28 +0,0 @@ -package cipher_test - -import ( - "crypto/aes" - "testing" - - "github.com/stretchr/testify/assert" - - "maunium.net/go/mautrix/crypto/goolm/cipher" -) - -func TestEncoding(t *testing.T) { - key := []byte("test key") - input := []byte("test") - //pad marshaled to get block size - toEncrypt := input - if len(input)%aes.BlockSize != 0 { - padding := aes.BlockSize - len(input)%aes.BlockSize - toEncrypt = make([]byte, len(input)+padding) - copy(toEncrypt, input) - } - encoded, err := cipher.Pickle(key, toEncrypt) - assert.NoError(t, err) - - decoded, err := cipher.Unpickle(key, encoded) - assert.NoError(t, err) - assert.Equal(t, toEncrypt, decoded) -} diff --git a/crypto/goolm/crypto/curve25519.go b/crypto/goolm/crypto/curve25519.go index e9759501..6e42d886 100644 --- a/crypto/goolm/crypto/curve25519.go +++ b/crypto/goolm/crypto/curve25519.go @@ -53,6 +53,7 @@ func (c Curve25519KeyPair) B64Encoded() id.Curve25519 { // SharedSecret returns the shared secret between the key pair and the given public key. func (c Curve25519KeyPair) SharedSecret(pubKey Curve25519PublicKey) ([]byte, error) { + // Note: the standard library checks that the output is non-zero return c.PrivateKey.SharedSecret(pubKey) } diff --git a/crypto/goolm/crypto/curve25519_test.go b/crypto/goolm/crypto/curve25519_test.go index 9039c126..2550f15e 100644 --- a/crypto/goolm/crypto/curve25519_test.go +++ b/crypto/goolm/crypto/curve25519_test.go @@ -25,6 +25,8 @@ func TestCurve25519(t *testing.T) { fromPrivate, err := crypto.Curve25519GenerateFromPrivate(firstKeypair.PrivateKey) assert.NoError(t, err) assert.Equal(t, fromPrivate, firstKeypair) + _, err = secondKeypair.SharedSecret(make([]byte, crypto.Curve25519PublicKeyLength)) + assert.Error(t, err) } func TestCurve25519Case1(t *testing.T) { diff --git a/crypto/goolm/goolmbase64/base64.go b/crypto/goolm/goolmbase64/base64.go index 061a052a..58ee26f7 100644 --- a/crypto/goolm/goolmbase64/base64.go +++ b/crypto/goolm/goolmbase64/base64.go @@ -4,7 +4,8 @@ import ( "encoding/base64" ) -// Deprecated: base64.RawStdEncoding should be used directly +// These methods should only be used for raw byte operations, never with string conversion + func Decode(input []byte) ([]byte, error) { decoded := make([]byte, base64.RawStdEncoding.DecodedLen(len(input))) writtenBytes, err := base64.RawStdEncoding.Decode(decoded, input) @@ -14,7 +15,6 @@ func Decode(input []byte) ([]byte, error) { return decoded[:writtenBytes], nil } -// Deprecated: base64.RawStdEncoding should be used directly func Encode(input []byte) []byte { encoded := make([]byte, base64.RawStdEncoding.EncodedLen(len(input))) base64.RawStdEncoding.Encode(encoded, input) diff --git a/crypto/goolm/libolmpickle/encoder.go b/crypto/goolm/libolmpickle/encoder.go new file mode 100644 index 00000000..63e7b09b --- /dev/null +++ b/crypto/goolm/libolmpickle/encoder.go @@ -0,0 +1,40 @@ +package libolmpickle + +import ( + "bytes" + "encoding/binary" + + "go.mau.fi/util/exerrors" +) + +const ( + PickleBoolLength = 1 + PickleUInt8Length = 1 + PickleUInt32Length = 4 +) + +type Encoder struct { + bytes.Buffer +} + +func NewEncoder() *Encoder { return &Encoder{} } + +func (p *Encoder) WriteUInt8(value uint8) { + exerrors.PanicIfNotNil(p.WriteByte(value)) +} + +func (p *Encoder) WriteBool(value bool) { + if value { + exerrors.PanicIfNotNil(p.WriteByte(0x01)) + } else { + exerrors.PanicIfNotNil(p.WriteByte(0x00)) + } +} + +func (p *Encoder) WriteEmptyBytes(count int) { + exerrors.Must(p.Write(make([]byte, count))) +} + +func (p *Encoder) WriteUInt32(value uint32) { + exerrors.PanicIfNotNil(binary.Write(&p.Buffer, binary.BigEndian, value)) +} diff --git a/crypto/goolm/libolmpickle/encoder_test.go b/crypto/goolm/libolmpickle/encoder_test.go new file mode 100644 index 00000000..c7811225 --- /dev/null +++ b/crypto/goolm/libolmpickle/encoder_test.go @@ -0,0 +1,99 @@ +package libolmpickle_test + +import ( + "testing" + + "github.com/stretchr/testify/assert" + + "maunium.net/go/mautrix/crypto/goolm/libolmpickle" +) + +func TestEncoder(t *testing.T) { + var encoder libolmpickle.Encoder + encoder.WriteUInt32(4) + encoder.WriteUInt8(8) + encoder.WriteBool(false) + encoder.WriteEmptyBytes(10) + encoder.WriteBool(true) + encoder.Write([]byte("test")) + encoder.WriteUInt32(420_000) + assert.Equal(t, []byte{ + 0x00, 0x00, 0x00, 0x04, // 4 + 0x08, // 8 + 0x00, // false + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // ten empty bytes + 0x01, //true + 0x74, 0x65, 0x73, 0x74, // "test" (ASCII) + 0x00, 0x06, 0x68, 0xa0, // 420,000 + }, encoder.Bytes()) +} + +func TestPickleUInt32(t *testing.T) { + values := []uint32{ + 0xffffffff, + 0x00ff00ff, + 0xf0000000, + 0xf00f0000, + } + expected := [][]byte{ + {0xff, 0xff, 0xff, 0xff}, + {0x00, 0xff, 0x00, 0xff}, + {0xf0, 0x00, 0x00, 0x00}, + {0xf0, 0x0f, 0x00, 0x00}, + } + for i, value := range values { + var encoder libolmpickle.Encoder + encoder.WriteUInt32(value) + assert.Equal(t, expected[i], encoder.Bytes()) + } +} + +func TestPickleBool(t *testing.T) { + values := []bool{ + true, + false, + } + expected := [][]byte{ + {0x01}, + {0x00}, + } + for i, value := range values { + var encoder libolmpickle.Encoder + encoder.WriteBool(value) + assert.Equal(t, expected[i], encoder.Bytes()) + } +} + +func TestPickleUInt8(t *testing.T) { + values := []uint8{ + 0xff, + 0x1a, + } + expected := [][]byte{ + {0xff}, + {0x1a}, + } + for i, value := range values { + var encoder libolmpickle.Encoder + encoder.WriteUInt8(value) + assert.Equal(t, expected[i], encoder.Bytes()) + } +} + +func TestPickleBytes(t *testing.T) { + values := [][]byte{ + {0xff, 0xff, 0xff, 0xff}, + {0x00, 0xff, 0x00, 0xff}, + {0xf0, 0x00, 0x00, 0x00}, + } + expected := [][]byte{ + {0xff, 0xff, 0xff, 0xff}, + {0x00, 0xff, 0x00, 0xff}, + {0xf0, 0x00, 0x00, 0x00}, + } + for i, value := range values { + var encoder libolmpickle.Encoder + encoder.Write(value) + assert.Equal(t, expected[i], encoder.Bytes()) + } +} diff --git a/crypto/goolm/libolmpickle/pickle.go b/crypto/goolm/libolmpickle/pickle.go index 590033fc..d15358fd 100644 --- a/crypto/goolm/libolmpickle/pickle.go +++ b/crypto/goolm/libolmpickle/pickle.go @@ -1,40 +1,48 @@ package libolmpickle import ( - "bytes" - "encoding/binary" + "crypto/aes" + "fmt" - "go.mau.fi/util/exerrors" + "maunium.net/go/mautrix/crypto/goolm/aessha2" + "maunium.net/go/mautrix/crypto/goolm/goolmbase64" + "maunium.net/go/mautrix/crypto/olm" ) -const ( - PickleBoolLength = 1 - PickleUInt8Length = 1 - PickleUInt32Length = 4 -) +const pickleMACLength = 8 -type Encoder struct { - bytes.Buffer -} +var kdfPickle = []byte("Pickle") //used to derive the keys for encryption -func NewEncoder() *Encoder { return &Encoder{} } - -func (p *Encoder) WriteUInt8(value uint8) { - exerrors.PanicIfNotNil(p.WriteByte(value)) -} - -func (p *Encoder) WriteBool(value bool) { - if value { - exerrors.PanicIfNotNil(p.WriteByte(0x01)) +// Pickle encrypts the input with the key and the cipher AESSHA256. The result is then encoded in base64. +func Pickle(key, plaintext []byte) ([]byte, error) { + if c, err := aessha2.NewAESSHA2(key, kdfPickle); err != nil { + return nil, err + } else if ciphertext, err := c.Encrypt(plaintext); err != nil { + return nil, err + } else if mac, err := c.MAC(ciphertext); err != nil { + return nil, err } else { - exerrors.PanicIfNotNil(p.WriteByte(0x00)) + return goolmbase64.Encode(append(ciphertext, mac[:pickleMACLength]...)), nil } } -func (p *Encoder) WriteEmptyBytes(count int) { - exerrors.Must(p.Write(make([]byte, count))) -} - -func (p *Encoder) WriteUInt32(value uint32) { - exerrors.Must(p.Write(binary.BigEndian.AppendUint32(nil, value))) +// Unpickle decodes the input from base64 and decrypts the decoded input with the key and the cipher AESSHA256. +func Unpickle(key, input []byte) ([]byte, error) { + ciphertext, err := goolmbase64.Decode(input) + if err != nil { + return nil, err + } + ciphertext, mac := ciphertext[:len(ciphertext)-pickleMACLength], ciphertext[len(ciphertext)-pickleMACLength:] + if c, err := aessha2.NewAESSHA2(key, kdfPickle); err != nil { + return nil, err + } else if verified, err := c.VerifyMAC(ciphertext, mac); err != nil { + return nil, err + } else if !verified { + return nil, fmt.Errorf("decrypt pickle: %w", olm.ErrBadMAC) + } else { + // Set to next block size + targetCipherText := make([]byte, int(len(ciphertext)/aes.BlockSize)*aes.BlockSize) + copy(targetCipherText, ciphertext) + return c.Decrypt(targetCipherText) + } } diff --git a/crypto/goolm/libolmpickle/pickle_test.go b/crypto/goolm/libolmpickle/pickle_test.go index c7811225..0720e008 100644 --- a/crypto/goolm/libolmpickle/pickle_test.go +++ b/crypto/goolm/libolmpickle/pickle_test.go @@ -1,99 +1,26 @@ -package libolmpickle_test +package libolmpickle import ( + "crypto/aes" "testing" "github.com/stretchr/testify/assert" - - "maunium.net/go/mautrix/crypto/goolm/libolmpickle" ) -func TestEncoder(t *testing.T) { - var encoder libolmpickle.Encoder - encoder.WriteUInt32(4) - encoder.WriteUInt8(8) - encoder.WriteBool(false) - encoder.WriteEmptyBytes(10) - encoder.WriteBool(true) - encoder.Write([]byte("test")) - encoder.WriteUInt32(420_000) - assert.Equal(t, []byte{ - 0x00, 0x00, 0x00, 0x04, // 4 - 0x08, // 8 - 0x00, // false - 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // ten empty bytes - 0x01, //true - 0x74, 0x65, 0x73, 0x74, // "test" (ASCII) - 0x00, 0x06, 0x68, 0xa0, // 420,000 - }, encoder.Bytes()) -} +func TestEncoding(t *testing.T) { + key := []byte("test key") + input := []byte("test") + //pad marshaled to get block size + toEncrypt := input + if len(input)%aes.BlockSize != 0 { + padding := aes.BlockSize - len(input)%aes.BlockSize + toEncrypt = make([]byte, len(input)+padding) + copy(toEncrypt, input) + } + encoded, err := Pickle(key, toEncrypt) + assert.NoError(t, err) -func TestPickleUInt32(t *testing.T) { - values := []uint32{ - 0xffffffff, - 0x00ff00ff, - 0xf0000000, - 0xf00f0000, - } - expected := [][]byte{ - {0xff, 0xff, 0xff, 0xff}, - {0x00, 0xff, 0x00, 0xff}, - {0xf0, 0x00, 0x00, 0x00}, - {0xf0, 0x0f, 0x00, 0x00}, - } - for i, value := range values { - var encoder libolmpickle.Encoder - encoder.WriteUInt32(value) - assert.Equal(t, expected[i], encoder.Bytes()) - } -} - -func TestPickleBool(t *testing.T) { - values := []bool{ - true, - false, - } - expected := [][]byte{ - {0x01}, - {0x00}, - } - for i, value := range values { - var encoder libolmpickle.Encoder - encoder.WriteBool(value) - assert.Equal(t, expected[i], encoder.Bytes()) - } -} - -func TestPickleUInt8(t *testing.T) { - values := []uint8{ - 0xff, - 0x1a, - } - expected := [][]byte{ - {0xff}, - {0x1a}, - } - for i, value := range values { - var encoder libolmpickle.Encoder - encoder.WriteUInt8(value) - assert.Equal(t, expected[i], encoder.Bytes()) - } -} - -func TestPickleBytes(t *testing.T) { - values := [][]byte{ - {0xff, 0xff, 0xff, 0xff}, - {0x00, 0xff, 0x00, 0xff}, - {0xf0, 0x00, 0x00, 0x00}, - } - expected := [][]byte{ - {0xff, 0xff, 0xff, 0xff}, - {0x00, 0xff, 0x00, 0xff}, - {0xf0, 0x00, 0x00, 0x00}, - } - for i, value := range values { - var encoder libolmpickle.Encoder - encoder.Write(value) - assert.Equal(t, expected[i], encoder.Bytes()) - } + decoded, err := Unpickle(key, encoded) + assert.NoError(t, err) + assert.Equal(t, toEncrypt, decoded) } diff --git a/crypto/goolm/utilities/pickle.go b/crypto/goolm/libolmpickle/picklejson.go similarity index 81% rename from crypto/goolm/utilities/pickle.go rename to crypto/goolm/libolmpickle/picklejson.go index 6ce35efe..f765391f 100644 --- a/crypto/goolm/utilities/pickle.go +++ b/crypto/goolm/libolmpickle/picklejson.go @@ -1,10 +1,10 @@ -package utilities +package libolmpickle import ( + "crypto/aes" "encoding/json" "fmt" - "maunium.net/go/mautrix/crypto/goolm/cipher" "maunium.net/go/mautrix/crypto/olm" ) @@ -21,12 +21,12 @@ func PickleAsJSON(object any, pickleVersion byte, key []byte) ([]byte, error) { toEncrypt := make([]byte, len(marshaled)) copy(toEncrypt, marshaled) //pad marshaled to get block size - if len(marshaled)%cipher.PickleBlockSize() != 0 { - padding := cipher.PickleBlockSize() - len(marshaled)%cipher.PickleBlockSize() + if len(marshaled)%aes.BlockSize != 0 { + padding := aes.BlockSize - len(marshaled)%aes.BlockSize toEncrypt = make([]byte, len(marshaled)+padding) copy(toEncrypt, marshaled) } - encrypted, err := cipher.Pickle(key, toEncrypt) + encrypted, err := Pickle(key, toEncrypt) if err != nil { return nil, fmt.Errorf("pickle encrypt: %w", err) } @@ -38,7 +38,7 @@ func UnpickleAsJSON(object any, pickled, key []byte, pickleVersion byte) error { if len(key) == 0 { return fmt.Errorf("unpickle: %w", olm.ErrNoKeyProvided) } - decrypted, err := cipher.Unpickle(key, pickled) + decrypted, err := Unpickle(key, pickled) if err != nil { return fmt.Errorf("unpickle decrypt: %w", err) } @@ -50,7 +50,7 @@ func UnpickleAsJSON(object any, pickled, key []byte, pickleVersion byte) error { } } if decrypted[0] != pickleVersion { - return fmt.Errorf("unpickle: %w", olm.ErrWrongPickleVersion) + return fmt.Errorf("unpickle: %w", olm.ErrUnknownJSONPickleVersion) } err = json.Unmarshal(decrypted[1:], object) if err != nil { diff --git a/crypto/goolm/megolm/megolm.go b/crypto/goolm/megolm/megolm.go index 416db111..3b5f1e4a 100644 --- a/crypto/goolm/megolm/megolm.go +++ b/crypto/goolm/megolm/megolm.go @@ -7,12 +7,11 @@ import ( "crypto/sha256" "fmt" - "maunium.net/go/mautrix/crypto/goolm/cipher" + "maunium.net/go/mautrix/crypto/goolm/aessha2" "maunium.net/go/mautrix/crypto/goolm/crypto" "maunium.net/go/mautrix/crypto/goolm/goolmbase64" "maunium.net/go/mautrix/crypto/goolm/libolmpickle" "maunium.net/go/mautrix/crypto/goolm/message" - "maunium.net/go/mautrix/crypto/goolm/utilities" "maunium.net/go/mautrix/crypto/olm" ) @@ -26,7 +25,7 @@ const ( RatchetPartLength = 256 / 8 // length of each ratchet part in bytes ) -var RatchetCipher = cipher.NewAESSHA256([]byte("MEGOLM_KEYS")) +var megolmKeysKDFInfo = []byte("MEGOLM_KEYS") // hasKeySeed are the seed for the different ratchet parts var hashKeySeeds [RatchetParts][]byte = [RatchetParts][]byte{ @@ -136,9 +135,8 @@ func (m *Ratchet) AdvanceTo(target uint32) { // Encrypt encrypts the message in a message.GroupMessage with MAC and signature. // The output is base64 encoded. -func (r *Ratchet) Encrypt(plaintext []byte, key *crypto.Ed25519KeyPair) ([]byte, error) { - var err error - encryptedText, err := RatchetCipher.Encrypt(r.Data[:], plaintext) +func (r *Ratchet) Encrypt(plaintext []byte, key crypto.Ed25519KeyPair) ([]byte, error) { + cipher, err := aessha2.NewAESSHA2(r.Data[:], megolmKeysKDFInfo) if err != nil { return nil, fmt.Errorf("cipher encrypt: %w", err) } @@ -146,9 +144,12 @@ func (r *Ratchet) Encrypt(plaintext []byte, key *crypto.Ed25519KeyPair) ([]byte, message := &message.GroupMessage{} message.Version = protocolVersion message.MessageIndex = r.Counter - message.Ciphertext = encryptedText - //creating the mac and signing is done in encode - output, err := message.EncodeAndMacAndSign(r.Data[:], RatchetCipher, key) + message.Ciphertext, err = cipher.Encrypt(plaintext) + if err != nil { + return nil, err + } + //creating the MAC and signing is done in encode + output, err := message.EncodeAndMACAndSign(cipher, key) if err != nil { return nil, err } @@ -178,7 +179,11 @@ func (r Ratchet) SessionExportMessage(key crypto.Ed25519PublicKey) ([]byte, erro // Decrypt decrypts the ciphertext and verifies the MAC but not the signature. func (r Ratchet) Decrypt(ciphertext []byte, signingkey *crypto.Ed25519PublicKey, msg *message.GroupMessage) ([]byte, error) { //verify mac - verifiedMAC, err := msg.VerifyMACInline(r.Data[:], RatchetCipher, ciphertext) + cipher, err := aessha2.NewAESSHA2(r.Data[:], megolmKeysKDFInfo) + if err != nil { + return nil, err + } + verifiedMAC, err := msg.VerifyMACInline(cipher, ciphertext) if err != nil { return nil, err } @@ -186,17 +191,17 @@ func (r Ratchet) Decrypt(ciphertext []byte, signingkey *crypto.Ed25519PublicKey, return nil, fmt.Errorf("decrypt: %w", olm.ErrBadMAC) } - return RatchetCipher.Decrypt(r.Data[:], msg.Ciphertext) + return cipher.Decrypt(msg.Ciphertext) } // PickleAsJSON returns a ratchet as a base64 string encrypted using the supplied key. The unencrypted representation of the Account is in JSON format. func (r Ratchet) PickleAsJSON(key []byte) ([]byte, error) { - return utilities.PickleAsJSON(r, megolmPickleVersion, key) + return libolmpickle.PickleAsJSON(r, megolmPickleVersion, key) } // UnpickleAsJSON updates a ratchet by a base64 encrypted string using the supplied key. The unencrypted representation has to be in JSON format. func (r *Ratchet) UnpickleAsJSON(pickled, key []byte) error { - return utilities.UnpickleAsJSON(r, pickled, key, megolmPickleVersion) + return libolmpickle.UnpickleAsJSON(r, pickled, key, megolmPickleVersion) } // UnpickleLibOlm decodes the unencryted value and populates the Ratchet accordingly. It returns the number of bytes read. diff --git a/crypto/goolm/message/decoder.go b/crypto/goolm/message/decoder.go index 9ce426b5..b06756a9 100644 --- a/crypto/goolm/message/decoder.go +++ b/crypto/goolm/message/decoder.go @@ -1,70 +1,33 @@ package message import ( + "bytes" "encoding/binary" + "fmt" "maunium.net/go/mautrix/crypto/olm" ) -// checkDecodeErr checks if there was an error during decode. -func checkDecodeErr(readBytes int) error { - if readBytes == 0 { - //end reached - return olm.ErrInputToSmall +type Decoder struct { + *bytes.Buffer +} + +func NewDecoder(buf []byte) *Decoder { + return &Decoder{bytes.NewBuffer(buf)} +} + +func (d *Decoder) ReadVarInt() (uint64, error) { + return binary.ReadUvarint(d) +} + +func (d *Decoder) ReadVarBytes() ([]byte, error) { + if n, err := d.ReadVarInt(); err != nil { + return nil, err + } else if n > uint64(d.Len()) { + return nil, fmt.Errorf("%w: var bytes length says %d, but only %d bytes left", olm.ErrInputToSmall, n, d.Available()) + } else { + out := make([]byte, n) + _, err = d.Read(out) + return out, err } - if readBytes < 0 { - return olm.ErrOverflow - } - return nil -} - -// decodeVarInt decodes a single big-endian encoded varint. -func decodeVarInt(input []byte) (uint32, int) { - value, readBytes := binary.Uvarint(input) - return uint32(value), readBytes -} - -// decodeVarString decodes the length of the string (varint) and returns the actual string -func decodeVarString(input []byte) ([]byte, int) { - stringLen, readBytes := decodeVarInt(input) - if readBytes <= 0 { - return nil, readBytes - } - input = input[readBytes:] - value := input[:stringLen] - readBytes += int(stringLen) - return value, readBytes -} - -// encodeVarIntByteLength returns the number of bytes needed to encode the uint32. -func encodeVarIntByteLength(input uint32) int { - result := 1 - for input >= 128 { - result++ - input >>= 7 - } - return result -} - -// encodeVarStringByteLength returns the number of bytes needed to encode the input. -func encodeVarStringByteLength(input []byte) int { - result := encodeVarIntByteLength(uint32(len(input))) - result += len(input) - return result -} - -// encodeVarInt encodes a single uint32 -func encodeVarInt(input uint32) []byte { - out := make([]byte, encodeVarIntByteLength(input)) - binary.PutUvarint(out, uint64(input)) - return out -} - -// encodeVarString encodes the length of the input (varint) and appends the actual input -func encodeVarString(input []byte) []byte { - out := make([]byte, encodeVarStringByteLength(input)) - length := encodeVarInt(uint32(len(input))) - copy(out, length) - copy(out[len(length):], input) - return out } diff --git a/crypto/goolm/message/encoder.go b/crypto/goolm/message/encoder.go new file mode 100644 index 00000000..95ab6d41 --- /dev/null +++ b/crypto/goolm/message/encoder.go @@ -0,0 +1,24 @@ +package message + +import "encoding/binary" + +type Encoder struct { + buf []byte +} + +func (e *Encoder) Bytes() []byte { + return e.buf +} + +func (e *Encoder) PutByte(val byte) { + e.buf = append(e.buf, val) +} + +func (e *Encoder) PutVarInt(val uint64) { + e.buf = binary.AppendUvarint(e.buf, val) +} + +func (e *Encoder) PutVarBytes(data []byte) { + e.PutVarInt(uint64(len(data))) + e.buf = append(e.buf, data...) +} diff --git a/crypto/goolm/message/decoder_test.go b/crypto/goolm/message/encoder_test.go similarity index 58% rename from crypto/goolm/message/decoder_test.go rename to crypto/goolm/message/encoder_test.go index 8b7561ad..1fe2ebdb 100644 --- a/crypto/goolm/message/decoder_test.go +++ b/crypto/goolm/message/encoder_test.go @@ -1,33 +1,13 @@ -package message +package message_test import ( "testing" "github.com/stretchr/testify/assert" + + "maunium.net/go/mautrix/crypto/goolm/message" ) -func TestEncodeLengthInt(t *testing.T) { - numbers := []uint32{127, 128, 16383, 16384, 32767} - expected := []int{1, 2, 2, 3, 3} - for curIndex := range numbers { - assert.Equal(t, expected[curIndex], encodeVarIntByteLength(numbers[curIndex])) - } -} - -func TestEncodeLengthString(t *testing.T) { - var strings [][]byte - var expected []int - strings = append(strings, []byte("test")) - expected = append(expected, 1+4) - strings = append(strings, []byte("this is a long message with a length of 127 so that the varint of the length is just one byte. just needs some padding---------")) - expected = append(expected, 1+127) - strings = append(strings, []byte("this is an even longer message with a length between 128 and 16383 so that the varint of the length needs two byte. just needs some padding again ---------")) - expected = append(expected, 2+155) - for curIndex := range strings { - assert.Equal(t, expected[curIndex], encodeVarStringByteLength(strings[curIndex])) - } -} - func TestEncodeInt(t *testing.T) { var ints []uint32 var expected [][]byte @@ -40,7 +20,9 @@ func TestEncodeInt(t *testing.T) { ints = append(ints, 16383) expected = append(expected, []byte{0b11111111, 0b01111111}) for curIndex := range ints { - assert.Equal(t, expected[curIndex], encodeVarInt(ints[curIndex])) + var encoder message.Encoder + encoder.PutVarInt(uint64(ints[curIndex])) + assert.Equal(t, expected[curIndex], encoder.Bytes()) } } @@ -70,6 +52,8 @@ func TestEncodeString(t *testing.T) { res = append(res, curTest...) //Add string itself expected = append(expected, res) for curIndex := range strings { - assert.Equal(t, expected[curIndex], encodeVarString(strings[curIndex])) + var encoder message.Encoder + encoder.PutVarBytes(strings[curIndex]) + assert.Equal(t, expected[curIndex], encoder.Bytes()) } } diff --git a/crypto/goolm/message/group_message.go b/crypto/goolm/message/group_message.go index b34bfa5e..c83540c1 100644 --- a/crypto/goolm/message/group_message.go +++ b/crypto/goolm/message/group_message.go @@ -2,9 +2,12 @@ package message import ( "bytes" + "fmt" + "io" - "maunium.net/go/mautrix/crypto/goolm/cipher" + "maunium.net/go/mautrix/crypto/goolm/aessha2" "maunium.net/go/mautrix/crypto/goolm/crypto" + "maunium.net/go/mautrix/crypto/olm" ) const ( @@ -22,105 +25,77 @@ type GroupMessage struct { } // Decodes decodes the input and populates the corresponding fileds. MAC and signature are ignored but have to be present. -func (r *GroupMessage) Decode(input []byte) error { +func (r *GroupMessage) Decode(input []byte) (err error) { r.Version = 0 r.MessageIndex = 0 r.Ciphertext = nil if len(input) == 0 { return nil } - //first Byte is always version - r.Version = input[0] - curPos := 1 - for curPos < len(input)-countMACBytesGroupMessage-crypto.Ed25519SignatureSize { - //Read Key - curKey, readBytes := decodeVarInt(input[curPos:]) - if err := checkDecodeErr(readBytes); err != nil { - return err - } - curPos += readBytes - if (curKey & 0b111) == 0 { - //The value is of type varint - value, readBytes := decodeVarInt(input[curPos:]) - if err := checkDecodeErr(readBytes); err != nil { - return err + + decoder := NewDecoder(input[:len(input)-countMACBytesGroupMessage-crypto.Ed25519SignatureSize]) + r.Version, err = decoder.ReadByte() // First byte is the version + if err != nil { + return + } + if r.Version != protocolVersion { + return fmt.Errorf("GroupMessage.Decode: %w (got %d, expected %d)", olm.ErrWrongProtocolVersion, r.Version, protocolVersion) + } + + for { + // Read Key + if curKey, err := decoder.ReadVarInt(); err != nil { + if err == io.EOF { + // No more keys to read + return nil } - curPos += readBytes - switch curKey { - case messageIndexTag: - r.MessageIndex = value + return err + } else if (curKey & 0b111) == 0 { + // The value is of type varint + if value, err := decoder.ReadVarInt(); err != nil { + return err + } else if curKey == messageIndexTag { + r.MessageIndex = uint32(value) r.HasMessageIndex = true } } else if (curKey & 0b111) == 2 { - //The value is of type string - value, readBytes := decodeVarString(input[curPos:]) - if err := checkDecodeErr(readBytes); err != nil { + // The value is of type string + if value, err := decoder.ReadVarBytes(); err != nil { return err - } - curPos += readBytes - switch curKey { - case cipherTextTag: + } else if curKey == cipherTextTag { r.Ciphertext = value } } } - - return nil } -// EncodeAndMacAndSign encodes the message, creates the mac with the key and the cipher and signs the message. +// EncodeAndMACAndSign encodes the message, creates the mac with the key and the cipher and signs the message. // If macKey or cipher is nil, no mac is appended. If signKey is nil, no signature is appended. -func (r *GroupMessage) EncodeAndMacAndSign(macKey []byte, cipher cipher.Cipher, signKey *crypto.Ed25519KeyPair) ([]byte, error) { - var lengthOfMessage int - lengthOfMessage += 1 //Version - lengthOfMessage += encodeVarIntByteLength(messageIndexTag) + encodeVarIntByteLength(r.MessageIndex) - lengthOfMessage += encodeVarIntByteLength(cipherTextTag) + encodeVarStringByteLength(r.Ciphertext) - out := make([]byte, lengthOfMessage) - out[0] = r.Version - curPos := 1 - encodedTag := encodeVarInt(messageIndexTag) - copy(out[curPos:], encodedTag) - curPos += len(encodedTag) - encodedValue := encodeVarInt(r.MessageIndex) - copy(out[curPos:], encodedValue) - curPos += len(encodedValue) - encodedTag = encodeVarInt(cipherTextTag) - copy(out[curPos:], encodedTag) - curPos += len(encodedTag) - encodedValue = encodeVarString(r.Ciphertext) - copy(out[curPos:], encodedValue) - curPos += len(encodedValue) - if len(macKey) != 0 && cipher != nil { - mac, err := r.MAC(macKey, cipher, out) - if err != nil { - return nil, err - } - out = append(out, mac[:countMACBytesGroupMessage]...) +func (r *GroupMessage) EncodeAndMACAndSign(cipher aessha2.AESSHA2, signKey crypto.Ed25519KeyPair) ([]byte, error) { + var encoder Encoder + encoder.PutByte(r.Version) + encoder.PutVarInt(messageIndexTag) + encoder.PutVarInt(uint64(r.MessageIndex)) + encoder.PutVarInt(cipherTextTag) + encoder.PutVarBytes(r.Ciphertext) + mac, err := r.MAC(cipher, encoder.Bytes()) + if err != nil { + return nil, err } - if signKey != nil { - signature, err := signKey.Sign(out) - if err != nil { - return nil, err - } - out = append(out, signature...) - } - return out, nil + ciphertextWithMAC := append(encoder.Bytes(), mac[:countMACBytesGroupMessage]...) + signature, err := signKey.Sign(ciphertextWithMAC) + return append(ciphertextWithMAC, signature...), err } // MAC returns the MAC of the message calculated with cipher and key. The length of the MAC is truncated to the correct length. -func (r *GroupMessage) MAC(key []byte, cipher cipher.Cipher, message []byte) ([]byte, error) { - mac, err := cipher.MAC(key, message) +func (r *GroupMessage) MAC(cipher aessha2.AESSHA2, ciphertext []byte) ([]byte, error) { + mac, err := cipher.MAC(ciphertext) if err != nil { return nil, err } return mac[:countMACBytesGroupMessage], nil } -// VerifySignature verifies the givenSignature to the calculated signature of the message. -func (r *GroupMessage) VerifySignature(key crypto.Ed25519PublicKey, message, givenSignature []byte) bool { - return key.Verify(message, givenSignature) -} - // VerifySignature verifies the signature taken from the message to the calculated signature of the message. func (r *GroupMessage) VerifySignatureInline(key crypto.Ed25519PublicKey, message []byte) bool { signature := message[len(message)-crypto.Ed25519SignatureSize:] @@ -129,8 +104,8 @@ func (r *GroupMessage) VerifySignatureInline(key crypto.Ed25519PublicKey, messag } // VerifyMAC verifies the givenMAC to the calculated MAC of the message. -func (r *GroupMessage) VerifyMAC(key []byte, cipher cipher.Cipher, message, givenMAC []byte) (bool, error) { - checkMac, err := r.MAC(key, cipher, message) +func (r *GroupMessage) VerifyMAC(cipher aessha2.AESSHA2, ciphertext, givenMAC []byte) (bool, error) { + checkMac, err := r.MAC(cipher, ciphertext) if err != nil { return false, err } @@ -138,10 +113,10 @@ func (r *GroupMessage) VerifyMAC(key []byte, cipher cipher.Cipher, message, give } // VerifyMACInline verifies the MAC taken from the message to the calculated MAC of the message. -func (r *GroupMessage) VerifyMACInline(key []byte, cipher cipher.Cipher, message []byte) (bool, error) { +func (r *GroupMessage) VerifyMACInline(cipher aessha2.AESSHA2, message []byte) (bool, error) { startMAC := len(message) - countMACBytesGroupMessage - crypto.Ed25519SignatureSize endMAC := startMAC + countMACBytesGroupMessage suplMac := message[startMAC:endMAC] message = message[:startMAC] - return r.VerifyMAC(key, cipher, message, suplMac) + return r.VerifyMAC(cipher, message, suplMac) } diff --git a/crypto/goolm/message/group_message_test.go b/crypto/goolm/message/group_message_test.go index d52cf6a3..272138c4 100644 --- a/crypto/goolm/message/group_message_test.go +++ b/crypto/goolm/message/group_message_test.go @@ -4,7 +4,10 @@ import ( "testing" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "maunium.net/go/mautrix/crypto/goolm/aessha2" + "maunium.net/go/mautrix/crypto/goolm/crypto" "maunium.net/go/mautrix/crypto/goolm/message" ) @@ -24,7 +27,6 @@ func TestGroupMessageDecode(t *testing.T) { } func TestGroupMessageEncode(t *testing.T) { - expectedRaw := []byte("\x03\x08\xC8\x01\x12\x0aciphertexthmacsha2signature") hmacsha256 := []byte("hmacsha2") sign := []byte("signature") msg := message.GroupMessage{ @@ -32,9 +34,29 @@ func TestGroupMessageEncode(t *testing.T) { MessageIndex: 200, Ciphertext: []byte("ciphertext"), } - encoded, err := msg.EncodeAndMacAndSign(nil, nil, nil) + + cipher, err := aessha2.NewAESSHA2(nil, nil) + require.NoError(t, err) + encoded, err := msg.EncodeAndMACAndSign(cipher, crypto.Ed25519GenerateFromSeed(make([]byte, 32))) assert.NoError(t, err) encoded = append(encoded, hmacsha256...) encoded = append(encoded, sign...) - assert.Equal(t, expectedRaw, encoded) + expected := []byte{ + 0x03, // Version + 0x08, + 0xC8, // 200 + 0x01, + 0x12, + 0x0a, + } + expected = append(expected, []byte("ciphertext")...) + expected = append(expected, []byte{ + 0x6f, 0x95, 0x35, 0x51, 0xdc, 0xdb, 0xcb, 0x03, 0x0b, 0x22, 0xa2, 0xa7, 0xa1, 0xb7, 0x4f, 0x1a, + 0xa3, 0xe9, 0x5c, 0x05, 0x5d, 0x56, 0xdc, 0x5b, 0x87, 0x73, 0x05, 0x42, 0x2a, 0x59, 0x9a, 0x9a, + 0x26, 0x7a, 0x8d, 0xba, 0x65, 0xb2, 0x17, 0x65, 0x51, 0x6f, 0x37, 0xf3, 0x8f, 0xa1, 0x70, 0xd0, + 0xc4, 0x06, 0x05, 0xdc, 0x17, 0x71, 0x5e, 0x63, 0x84, 0xbe, 0xec, 0x7b, 0xa0, 0xc4, 0x08, 0xb8, + 0x9b, 0xc5, 0x08, 0x16, 0xad, 0xe5, 0x43, 0x0c, + }...) + expected = append(expected, []byte("hmacsha2signature")...) + assert.Equal(t, expected, encoded) } diff --git a/crypto/goolm/message/message.go b/crypto/goolm/message/message.go index 8b721aeb..b161a2d1 100644 --- a/crypto/goolm/message/message.go +++ b/crypto/goolm/message/message.go @@ -2,9 +2,12 @@ package message import ( "bytes" + "fmt" + "io" - "maunium.net/go/mautrix/crypto/goolm/cipher" + "maunium.net/go/mautrix/crypto/goolm/aessha2" "maunium.net/go/mautrix/crypto/goolm/crypto" + "maunium.net/go/mautrix/crypto/olm" ) const ( @@ -24,7 +27,7 @@ type Message struct { } // Decodes decodes the input and populates the corresponding fileds. MAC is ignored but has to be present. -func (r *Message) Decode(input []byte) error { +func (r *Message) Decode(input []byte) (err error) { r.Version = 0 r.HasCounter = false r.Counter = 0 @@ -33,89 +36,63 @@ func (r *Message) Decode(input []byte) error { if len(input) == 0 { return nil } - //first Byte is always version - r.Version = input[0] - curPos := 1 - for curPos < len(input)-countMACBytesMessage { - //Read Key - curKey, readBytes := decodeVarInt(input[curPos:]) - if err := checkDecodeErr(readBytes); err != nil { - return err - } - curPos += readBytes - if (curKey & 0b111) == 0 { - //The value is of type varint - value, readBytes := decodeVarInt(input[curPos:]) - if err := checkDecodeErr(readBytes); err != nil { - return err + + decoder := NewDecoder(input[:len(input)-countMACBytesMessage]) + r.Version, err = decoder.ReadByte() // first byte is always version + if err != nil { + return + } + if r.Version != protocolVersion { + return fmt.Errorf("Message.Decode: %w (got %d, expected %d)", olm.ErrWrongProtocolVersion, r.Version, protocolVersion) + } + + for { + // Read Key + if curKey, err := decoder.ReadVarInt(); err != nil { + if err == io.EOF { + // No more keys to read + return nil } - curPos += readBytes - switch curKey { - case counterTag: + return err + } else if (curKey & 0b111) == 0 { + // The value is of type varint + if value, err := decoder.ReadVarInt(); err != nil { + return err + } else if curKey == counterTag { + r.Counter = uint32(value) r.HasCounter = true - r.Counter = value } } else if (curKey & 0b111) == 2 { - //The value is of type string - value, readBytes := decodeVarString(input[curPos:]) - if err := checkDecodeErr(readBytes); err != nil { + // The value is of type string + if value, err := decoder.ReadVarBytes(); err != nil { return err - } - curPos += readBytes - switch curKey { - case ratchetKeyTag: + } else if curKey == ratchetKeyTag { r.RatchetKey = value - case cipherTextKeyTag: + } else if curKey == cipherTextKeyTag { r.Ciphertext = value } } } - - return nil } // EncodeAndMAC encodes the message and creates the MAC with the key and the cipher. // If key or cipher is nil, no MAC is appended. -func (r *Message) EncodeAndMAC(key []byte, cipher cipher.Cipher) ([]byte, error) { - var lengthOfMessage int - lengthOfMessage += 1 //Version - lengthOfMessage += encodeVarIntByteLength(ratchetKeyTag) + encodeVarStringByteLength(r.RatchetKey) - lengthOfMessage += encodeVarIntByteLength(counterTag) + encodeVarIntByteLength(r.Counter) - lengthOfMessage += encodeVarIntByteLength(cipherTextKeyTag) + encodeVarStringByteLength(r.Ciphertext) - out := make([]byte, lengthOfMessage) - out[0] = r.Version - curPos := 1 - encodedTag := encodeVarInt(ratchetKeyTag) - copy(out[curPos:], encodedTag) - curPos += len(encodedTag) - encodedValue := encodeVarString(r.RatchetKey) - copy(out[curPos:], encodedValue) - curPos += len(encodedValue) - encodedTag = encodeVarInt(counterTag) - copy(out[curPos:], encodedTag) - curPos += len(encodedTag) - encodedValue = encodeVarInt(r.Counter) - copy(out[curPos:], encodedValue) - curPos += len(encodedValue) - encodedTag = encodeVarInt(cipherTextKeyTag) - copy(out[curPos:], encodedTag) - curPos += len(encodedTag) - encodedValue = encodeVarString(r.Ciphertext) - copy(out[curPos:], encodedValue) - curPos += len(encodedValue) - if len(key) != 0 && cipher != nil { - mac, err := cipher.MAC(key, out) - if err != nil { - return nil, err - } - out = append(out, mac[:countMACBytesMessage]...) - } - return out, nil +func (r *Message) EncodeAndMAC(cipher aessha2.AESSHA2) ([]byte, error) { + var encoder Encoder + encoder.PutByte(r.Version) + encoder.PutVarInt(ratchetKeyTag) + encoder.PutVarBytes(r.RatchetKey) + encoder.PutVarInt(counterTag) + encoder.PutVarInt(uint64(r.Counter)) + encoder.PutVarInt(cipherTextKeyTag) + encoder.PutVarBytes(r.Ciphertext) + mac, err := cipher.MAC(encoder.Bytes()) + return append(encoder.Bytes(), mac[:countMACBytesMessage]...), err } // VerifyMAC verifies the givenMAC to the calculated MAC of the message. -func (r *Message) VerifyMAC(key []byte, cipher cipher.Cipher, message, givenMAC []byte) (bool, error) { - checkMAC, err := cipher.MAC(key, message) +func (r *Message) VerifyMAC(key []byte, cipher aessha2.AESSHA2, ciphertext, givenMAC []byte) (bool, error) { + checkMAC, err := cipher.MAC(ciphertext) if err != nil { return false, err } @@ -123,7 +100,7 @@ func (r *Message) VerifyMAC(key []byte, cipher cipher.Cipher, message, givenMAC } // VerifyMACInline verifies the MAC taken from the message to the calculated MAC of the message. -func (r *Message) VerifyMACInline(key []byte, cipher cipher.Cipher, message []byte) (bool, error) { +func (r *Message) VerifyMACInline(key []byte, cipher aessha2.AESSHA2, message []byte) (bool, error) { givenMAC := message[len(message)-countMACBytesMessage:] return r.VerifyMAC(key, cipher, message[:len(message)-countMACBytesMessage], givenMAC) } diff --git a/crypto/goolm/message/message_test.go b/crypto/goolm/message/message_test.go index b5c3551b..f3aa7108 100644 --- a/crypto/goolm/message/message_test.go +++ b/crypto/goolm/message/message_test.go @@ -5,6 +5,7 @@ import ( "github.com/stretchr/testify/assert" + "maunium.net/go/mautrix/crypto/goolm/aessha2" "maunium.net/go/mautrix/crypto/goolm/message" ) @@ -24,7 +25,7 @@ func TestMessageDecode(t *testing.T) { } func TestMessageEncode(t *testing.T) { - expectedRaw := []byte("\x03\n\nratchetkey\x10\x01\"\nciphertexthmacsha2") + expectedRaw := []byte("\x03\n\nratchetkey\x10\x01\"\nciphertext\x95\x95\x92\x72\x04\x70\x56\xcdhmacsha2") hmacsha256 := []byte("hmacsha2") msg := message.Message{ Version: 3, @@ -32,7 +33,9 @@ func TestMessageEncode(t *testing.T) { RatchetKey: []byte("ratchetkey"), Ciphertext: []byte("ciphertext"), } - encoded, err := msg.EncodeAndMAC(nil, nil) + cipher, err := aessha2.NewAESSHA2(nil, nil) + assert.NoError(t, err) + encoded, err := msg.EncodeAndMAC(cipher) assert.NoError(t, err) encoded = append(encoded, hmacsha256...) assert.Equal(t, expectedRaw, encoded) diff --git a/crypto/goolm/message/prekey_message.go b/crypto/goolm/message/prekey_message.go index 1238a9a5..4e3d495d 100644 --- a/crypto/goolm/message/prekey_message.go +++ b/crypto/goolm/message/prekey_message.go @@ -1,11 +1,15 @@ package message import ( + "fmt" + "io" + "maunium.net/go/mautrix/crypto/goolm/crypto" + "maunium.net/go/mautrix/crypto/olm" ) const ( - oneTimeKeyIdTag = 0x0A + oneTimeKeyIDTag = 0x0A baseKeyTag = 0x12 identityKeyTag = 0x1A messageTag = 0x22 @@ -19,8 +23,13 @@ type PreKeyMessage struct { Message []byte `json:"message"` } +// TODO deduplicate constant with one in session/olm_session.go +const ( + protocolVersion = 0x3 +) + // Decodes decodes the input and populates the corresponding fileds. -func (r *PreKeyMessage) Decode(input []byte) error { +func (r *PreKeyMessage) Decode(input []byte) (err error) { r.Version = 0 r.IdentityKey = nil r.BaseKey = nil @@ -29,44 +38,55 @@ func (r *PreKeyMessage) Decode(input []byte) error { if len(input) == 0 { return nil } - //first Byte is always version - r.Version = input[0] - curPos := 1 - for curPos < len(input) { - //Read Key - curKey, readBytes := decodeVarInt(input[curPos:]) - if err := checkDecodeErr(readBytes); err != nil { - return err + + decoder := NewDecoder(input) + r.Version, err = decoder.ReadByte() // first byte is always version + if err != nil { + if err == io.EOF { + return olm.ErrInputToSmall } - curPos += readBytes - if (curKey & 0b111) == 0 { - //The value is of type varint - _, readBytes := decodeVarInt(input[curPos:]) - if err := checkDecodeErr(readBytes); err != nil { + return + } + if r.Version != protocolVersion { + return fmt.Errorf("PreKeyMessage.Decode: %w (got %d, expected %d)", olm.ErrWrongProtocolVersion, r.Version, protocolVersion) + } + + for { + // Read Key + if curKey, err := decoder.ReadVarInt(); err != nil { + if err == io.EOF { + return nil + } + return err + } else if (curKey & 0b111) == 0 { + // The value is of type varint + if _, err = decoder.ReadVarInt(); err != nil { + if err == io.EOF { + return olm.ErrInputToSmall + } return err } - curPos += readBytes } else if (curKey & 0b111) == 2 { - //The value is of type string - value, readBytes := decodeVarString(input[curPos:]) - if err := checkDecodeErr(readBytes); err != nil { + // The value is of type string + if value, err := decoder.ReadVarBytes(); err != nil { + if err == io.EOF { + return olm.ErrInputToSmall + } return err - } - curPos += readBytes - switch curKey { - case oneTimeKeyIdTag: - r.OneTimeKey = value - case baseKeyTag: - r.BaseKey = value - case identityKeyTag: - r.IdentityKey = value - case messageTag: - r.Message = value + } else { + switch curKey { + case oneTimeKeyIDTag: + r.OneTimeKey = value + case baseKeyTag: + r.BaseKey = value + case identityKeyTag: + r.IdentityKey = value + case messageTag: + r.Message = value + } } } } - - return nil } // CheckField verifies the fields. If theirIdentityKey is nil, it is not compared to the key in the message. @@ -84,37 +104,15 @@ func (r *PreKeyMessage) CheckFields(theirIdentityKey *crypto.Curve25519PublicKey // Encode encodes the message. func (r *PreKeyMessage) Encode() ([]byte, error) { - var lengthOfMessage int - lengthOfMessage += 1 //Version - lengthOfMessage += encodeVarIntByteLength(oneTimeKeyIdTag) + encodeVarStringByteLength(r.OneTimeKey) - lengthOfMessage += encodeVarIntByteLength(identityKeyTag) + encodeVarStringByteLength(r.IdentityKey) - lengthOfMessage += encodeVarIntByteLength(baseKeyTag) + encodeVarStringByteLength(r.BaseKey) - lengthOfMessage += encodeVarIntByteLength(messageTag) + encodeVarStringByteLength(r.Message) - out := make([]byte, lengthOfMessage) - out[0] = r.Version - curPos := 1 - encodedTag := encodeVarInt(oneTimeKeyIdTag) - copy(out[curPos:], encodedTag) - curPos += len(encodedTag) - encodedValue := encodeVarString(r.OneTimeKey) - copy(out[curPos:], encodedValue) - curPos += len(encodedValue) - encodedTag = encodeVarInt(identityKeyTag) - copy(out[curPos:], encodedTag) - curPos += len(encodedTag) - encodedValue = encodeVarString(r.IdentityKey) - copy(out[curPos:], encodedValue) - curPos += len(encodedValue) - encodedTag = encodeVarInt(baseKeyTag) - copy(out[curPos:], encodedTag) - curPos += len(encodedTag) - encodedValue = encodeVarString(r.BaseKey) - copy(out[curPos:], encodedValue) - curPos += len(encodedValue) - encodedTag = encodeVarInt(messageTag) - copy(out[curPos:], encodedTag) - curPos += len(encodedTag) - encodedValue = encodeVarString(r.Message) - copy(out[curPos:], encodedValue) - return out, nil + var encoder Encoder + encoder.PutByte(r.Version) + encoder.PutVarInt(oneTimeKeyIDTag) + encoder.PutVarBytes(r.OneTimeKey) + encoder.PutVarInt(identityKeyTag) + encoder.PutVarBytes(r.IdentityKey) + encoder.PutVarInt(baseKeyTag) + encoder.PutVarBytes(r.BaseKey) + encoder.PutVarInt(messageTag) + encoder.PutVarBytes(r.Message) + return encoder.Bytes(), nil } diff --git a/crypto/goolm/message/session_export.go b/crypto/goolm/message/session_export.go index 956868b2..d58dbb21 100644 --- a/crypto/goolm/message/session_export.go +++ b/crypto/goolm/message/session_export.go @@ -35,7 +35,7 @@ func (s *MegolmSessionExport) Decode(input []byte) error { return fmt.Errorf("decrypt: %w", olm.ErrBadInput) } if input[0] != sessionExportVersion { - return fmt.Errorf("decrypt: %w", olm.ErrBadVersion) + return fmt.Errorf("decrypt: %w", olm.ErrUnknownOlmPickleVersion) } s.Counter = binary.BigEndian.Uint32(input[1:5]) copy(s.RatchetData[:], input[5:133]) diff --git a/crypto/goolm/message/session_sharing.go b/crypto/goolm/message/session_sharing.go index 16240945..d04ef15a 100644 --- a/crypto/goolm/message/session_sharing.go +++ b/crypto/goolm/message/session_sharing.go @@ -42,7 +42,7 @@ func (s *MegolmSessionSharing) VerifyAndDecode(input []byte) error { } s.PublicKey = publicKey if input[0] != sessionSharingVersion { - return fmt.Errorf("verify: %w", olm.ErrBadVersion) + return fmt.Errorf("verify: %w", olm.ErrUnknownOlmPickleVersion) } s.Counter = binary.BigEndian.Uint32(input[1:5]) copy(s.RatchetData[:], input[5:133]) diff --git a/crypto/goolm/pk/decryption.go b/crypto/goolm/pk/decryption.go index ba94dc37..cdb20eb1 100644 --- a/crypto/goolm/pk/decryption.go +++ b/crypto/goolm/pk/decryption.go @@ -4,11 +4,10 @@ import ( "encoding/base64" "fmt" - "maunium.net/go/mautrix/crypto/goolm/cipher" + "maunium.net/go/mautrix/crypto/goolm/aessha2" "maunium.net/go/mautrix/crypto/goolm/crypto" "maunium.net/go/mautrix/crypto/goolm/goolmbase64" "maunium.net/go/mautrix/crypto/goolm/libolmpickle" - "maunium.net/go/mautrix/crypto/goolm/utilities" "maunium.net/go/mautrix/crypto/olm" "maunium.net/go/mautrix/id" ) @@ -57,43 +56,37 @@ func (s Decryption) PrivateKey() crypto.Curve25519PrivateKey { // Decrypt decrypts the ciphertext and verifies the MAC. The base64 encoded key is used to construct the shared secret. func (s Decryption) Decrypt(ephemeralKey, mac, ciphertext []byte) ([]byte, error) { - keyDecoded, err := base64.RawStdEncoding.DecodeString(string(ephemeralKey)) - if err != nil { + if keyDecoded, err := base64.RawStdEncoding.DecodeString(string(ephemeralKey)); err != nil { return nil, err - } - sharedSecret, err := s.KeyPair.SharedSecret(keyDecoded) - if err != nil { + } else if sharedSecret, err := s.KeyPair.SharedSecret(keyDecoded); err != nil { return nil, err - } - decodedMAC, err := goolmbase64.Decode(mac) - if err != nil { + } else if decodedMAC, err := goolmbase64.Decode(mac); err != nil { return nil, err - } - cipher := cipher.NewAESSHA256(nil) - verified, err := cipher.Verify(sharedSecret, ciphertext, decodedMAC) - if err != nil { + } else if cipher, err := aessha2.NewAESSHA2(sharedSecret, nil); err != nil { return nil, err - } - if !verified { + } else if verified, err := cipher.VerifyMAC(ciphertext, decodedMAC); err != nil { + return nil, err + } else if !verified { return nil, fmt.Errorf("decrypt: %w", olm.ErrBadMAC) + } else { + return cipher.Decrypt(ciphertext) } - return cipher.Decrypt(sharedSecret, ciphertext) } // PickleAsJSON returns an Decryption as a base64 string encrypted using the supplied key. The unencrypted representation of the Account is in JSON format. func (a Decryption) PickleAsJSON(key []byte) ([]byte, error) { - return utilities.PickleAsJSON(a, decryptionPickleVersionJSON, key) + return libolmpickle.PickleAsJSON(a, decryptionPickleVersionJSON, key) } // UnpickleAsJSON updates an Decryption by a base64 encrypted string using the supplied key. The unencrypted representation has to be in JSON format. func (a *Decryption) UnpickleAsJSON(pickled, key []byte) error { - return utilities.UnpickleAsJSON(a, pickled, key, decryptionPickleVersionJSON) + return libolmpickle.UnpickleAsJSON(a, pickled, key, decryptionPickleVersionJSON) } // Unpickle decodes the base64 encoded string and decrypts the result with the key. // The decrypted value is then passed to UnpickleLibOlm. func (a *Decryption) Unpickle(pickled, key []byte) error { - decrypted, err := cipher.Unpickle(key, pickled) + decrypted, err := libolmpickle.Unpickle(key, pickled) if err != nil { return err } @@ -110,13 +103,13 @@ func (a *Decryption) UnpickleLibOlm(unpickled []byte) error { if pickledVersion == decryptionPickleVersionLibOlm { return a.KeyPair.UnpickleLibOlm(decoder) } else { - return fmt.Errorf("unpickle olmSession: %w (found %d, expected %d)", olm.ErrBadVersion, pickledVersion, decryptionPickleVersionLibOlm) + return fmt.Errorf("unpickle olmSession: %w (found %d, expected %d)", olm.ErrUnknownOlmPickleVersion, pickledVersion, decryptionPickleVersionLibOlm) } } // Pickle returns a base64 encoded and with key encrypted pickled Decryption using PickleLibOlm(). func (a Decryption) Pickle(key []byte) ([]byte, error) { - return cipher.Pickle(key, a.PickleLibOlm()) + return libolmpickle.Pickle(key, a.PickleLibOlm()) } // PickleLibOlm pickles the [Decryption] into the encoder. diff --git a/crypto/goolm/pk/encryption.go b/crypto/goolm/pk/encryption.go index c99a9517..2897d9b0 100644 --- a/crypto/goolm/pk/encryption.go +++ b/crypto/goolm/pk/encryption.go @@ -5,7 +5,7 @@ import ( "maunium.net/go/mautrix/id" - "maunium.net/go/mautrix/crypto/goolm/cipher" + "maunium.net/go/mautrix/crypto/goolm/aessha2" "maunium.net/go/mautrix/crypto/goolm/crypto" "maunium.net/go/mautrix/crypto/goolm/goolmbase64" ) @@ -36,11 +36,14 @@ func (e Encryption) Encrypt(plaintext []byte, privateKey crypto.Curve25519Privat if err != nil { return nil, nil, err } - cipher := cipher.NewAESSHA256(nil) - ciphertext, err = cipher.Encrypt(sharedSecret, plaintext) + cipher, err := aessha2.NewAESSHA2(sharedSecret, nil) if err != nil { return nil, nil, err } - mac, err = cipher.MAC(sharedSecret, ciphertext) + ciphertext, err = cipher.Encrypt(plaintext) + if err != nil { + return nil, nil, err + } + mac, err = cipher.MAC(ciphertext) return ciphertext, goolmbase64.Encode(mac), err } diff --git a/crypto/goolm/pk/register.go b/crypto/goolm/pk/register.go index b7af6a5b..0e27b568 100644 --- a/crypto/goolm/pk/register.go +++ b/crypto/goolm/pk/register.go @@ -8,7 +8,7 @@ package pk import "maunium.net/go/mautrix/crypto/olm" -func init() { +func Register() { olm.InitNewPKSigningFromSeed = func(seed []byte) (olm.PKSigning, error) { return NewSigningFromSeed(seed) } diff --git a/crypto/goolm/ratchet/olm.go b/crypto/goolm/ratchet/olm.go index e53d126a..9901ada8 100644 --- a/crypto/goolm/ratchet/olm.go +++ b/crypto/goolm/ratchet/olm.go @@ -9,11 +9,10 @@ import ( "golang.org/x/crypto/hkdf" - "maunium.net/go/mautrix/crypto/goolm/cipher" + "maunium.net/go/mautrix/crypto/goolm/aessha2" "maunium.net/go/mautrix/crypto/goolm/crypto" "maunium.net/go/mautrix/crypto/goolm/libolmpickle" "maunium.net/go/mautrix/crypto/goolm/message" - "maunium.net/go/mautrix/crypto/goolm/utilities" "maunium.net/go/mautrix/crypto/olm" ) @@ -31,6 +30,8 @@ const ( sharedKeyLength = 32 ) +var olmKeysKDFInfo = []byte("OLM_KEYS") + // KdfInfo has the infos used for the kdf var KdfInfo = struct { Root []byte @@ -40,8 +41,6 @@ var KdfInfo = struct { Ratchet: []byte("OLM_RATCHET"), } -var RatchetCipher = cipher.NewAESSHA256([]byte("OLM_KEYS")) - // Ratchet represents the olm ratchet as described in // // https://gitlab.matrix.org/matrix-org/olm/-/blob/master/docs/olm.md @@ -68,8 +67,7 @@ type Ratchet struct { // New creates a new ratchet, setting the kdfInfos and cipher. func New() *Ratchet { - r := &Ratchet{} - return r + return &Ratchet{} } // InitializeAsBob initializes this ratchet from a receiving point of view (only first message). @@ -117,7 +115,11 @@ func (r *Ratchet) Encrypt(plaintext []byte) ([]byte, error) { messageKey := r.createMessageKeys(r.SenderChains.chainKey()) r.SenderChains.advance() - encryptedText, err := RatchetCipher.Encrypt(messageKey.Key, plaintext) + cipher, err := aessha2.NewAESSHA2(messageKey.Key, olmKeysKDFInfo) + if err != nil { + return nil, err + } + encryptedText, err := cipher.Encrypt(plaintext) if err != nil { return nil, fmt.Errorf("cipher encrypt: %w", err) } @@ -128,7 +130,7 @@ func (r *Ratchet) Encrypt(plaintext []byte) ([]byte, error) { message.RatchetKey = r.SenderChains.ratchetKey().PublicKey message.Ciphertext = encryptedText //creating the mac is done in encode - return message.EncodeAndMAC(messageKey.Key, RatchetCipher) + return message.EncodeAndMAC(cipher) } // Decrypt decrypts the ciphertext and verifies the MAC. @@ -140,7 +142,7 @@ func (r *Ratchet) Decrypt(input []byte) ([]byte, error) { return nil, err } if message.Version != protocolVersion { - return nil, fmt.Errorf("decrypt: %w", olm.ErrWrongProtocolVersion) + return nil, fmt.Errorf("decrypt: %w (got %d, expected %d)", olm.ErrWrongProtocolVersion, message.Version, protocolVersion) } if !message.HasCounter || len(message.RatchetKey) == 0 || len(message.Ciphertext) == 0 { return nil, fmt.Errorf("decrypt: %w", olm.ErrBadMessageFormat) @@ -165,15 +167,13 @@ func (r *Ratchet) Decrypt(input []byte) ([]byte, error) { } // Found the key for this message. Check the MAC. - verified, err := message.VerifyMACInline(r.SkippedMessageKeys[curSkippedIndex].MKey.Key, RatchetCipher, input) - if err != nil { + if cipher, err := aessha2.NewAESSHA2(r.SkippedMessageKeys[curSkippedIndex].MKey.Key, olmKeysKDFInfo); err != nil { return nil, err - } - if !verified { + } else if verified, err := message.VerifyMACInline(r.SkippedMessageKeys[curSkippedIndex].MKey.Key, cipher, input); err != nil { + return nil, err + } else if !verified { return nil, fmt.Errorf("decrypt from skipped message keys: %w", olm.ErrBadMAC) - } - result, err := RatchetCipher.Decrypt(r.SkippedMessageKeys[curSkippedIndex].MKey.Key, message.Ciphertext) - if err != nil { + } else if result, err := cipher.Decrypt(message.Ciphertext); err != nil { return nil, fmt.Errorf("cipher decrypt: %w", err) } else if len(result) != 0 { // Remove the key from the skipped keys now that we've @@ -235,14 +235,18 @@ func (r *Ratchet) decryptForExistingChain(chain *receiverChain, message *message } messageKey := r.createMessageKeys(chain.chainKey()) chain.advance() - verified, err := message.VerifyMACInline(messageKey.Key, RatchetCipher, rawMessage) + cipher, err := aessha2.NewAESSHA2(messageKey.Key, olmKeysKDFInfo) + if err != nil { + return nil, err + } + verified, err := message.VerifyMACInline(messageKey.Key, cipher, rawMessage) if err != nil { return nil, err } if !verified { return nil, fmt.Errorf("decrypt from existing chain: %w", olm.ErrBadMAC) } - return RatchetCipher.Decrypt(messageKey.Key, message.Ciphertext) + return cipher.Decrypt(message.Ciphertext) } // decryptForNewChain returns the decrypted message by creating a new chain and advancing the root key. @@ -276,12 +280,12 @@ func (r *Ratchet) decryptForNewChain(message *message.Message, rawMessage []byte // PickleAsJSON returns a ratchet as a base64 string encrypted using the supplied key. The unencrypted representation of the Account is in JSON format. func (r Ratchet) PickleAsJSON(key []byte) ([]byte, error) { - return utilities.PickleAsJSON(r, olmPickleVersion, key) + return libolmpickle.PickleAsJSON(r, olmPickleVersion, key) } // UnpickleAsJSON updates a ratchet by a base64 encrypted string using the supplied key. The unencrypted representation has to be in JSON format. func (r *Ratchet) UnpickleAsJSON(pickled, key []byte) error { - return utilities.UnpickleAsJSON(r, pickled, key, olmPickleVersion) + return libolmpickle.UnpickleAsJSON(r, pickled, key, olmPickleVersion) } // UnpickleLibOlm unpickles the unencryted value and populates the [Ratchet] diff --git a/crypto/goolm/ratchet/olm_test.go b/crypto/goolm/ratchet/olm_test.go index 6a8fefc3..2bf7ea0a 100644 --- a/crypto/goolm/ratchet/olm_test.go +++ b/crypto/goolm/ratchet/olm_test.go @@ -6,7 +6,6 @@ import ( "github.com/stretchr/testify/assert" - "maunium.net/go/mautrix/crypto/goolm/cipher" "maunium.net/go/mautrix/crypto/goolm/crypto" "maunium.net/go/mautrix/crypto/goolm/ratchet" ) @@ -23,7 +22,6 @@ func initializeRatchets() (*ratchet.Ratchet, *ratchet.Ratchet, error) { Root: []byte("Olm"), Ratchet: []byte("OlmRatchet"), } - ratchet.RatchetCipher = cipher.NewAESSHA256([]byte("OlmMessageKeys")) aliceRatchet := ratchet.New() bobRatchet := ratchet.New() diff --git a/crypto/goolm/register.go b/crypto/goolm/register.go index 80ed206b..800f567f 100644 --- a/crypto/goolm/register.go +++ b/crypto/goolm/register.go @@ -7,19 +7,23 @@ package goolm import ( - // Need to import these subpackages to ensure they are registered - _ "maunium.net/go/mautrix/crypto/goolm/account" - _ "maunium.net/go/mautrix/crypto/goolm/pk" - _ "maunium.net/go/mautrix/crypto/goolm/session" - + "maunium.net/go/mautrix/crypto/goolm/account" + "maunium.net/go/mautrix/crypto/goolm/pk" + "maunium.net/go/mautrix/crypto/goolm/session" "maunium.net/go/mautrix/crypto/olm" ) -func init() { +func Register() { + olm.Driver = "goolm" + olm.GetVersion = func() (major, minor, patch uint8) { return 3, 2, 15 } olm.SetPickleKeyImpl = func(key []byte) { panic("gob and json encoding is deprecated and not supported with goolm") } + + account.Register() + pk.Register() + session.Register() } diff --git a/crypto/goolm/session/megolm_inbound_session.go b/crypto/goolm/session/megolm_inbound_session.go index 4c107e92..7ccbd26d 100644 --- a/crypto/goolm/session/megolm_inbound_session.go +++ b/crypto/goolm/session/megolm_inbound_session.go @@ -4,13 +4,11 @@ import ( "encoding/base64" "fmt" - "maunium.net/go/mautrix/crypto/goolm/cipher" "maunium.net/go/mautrix/crypto/goolm/crypto" "maunium.net/go/mautrix/crypto/goolm/goolmbase64" "maunium.net/go/mautrix/crypto/goolm/libolmpickle" "maunium.net/go/mautrix/crypto/goolm/megolm" "maunium.net/go/mautrix/crypto/goolm/message" - "maunium.net/go/mautrix/crypto/goolm/utilities" "maunium.net/go/mautrix/crypto/olm" "maunium.net/go/mautrix/id" ) @@ -101,7 +99,7 @@ func (o *MegolmInboundSession) getRatchet(messageIndex uint32) (*megolm.Ratchet, } if (messageIndex - o.InitialRatchet.Counter) >= uint32(1<<31) { // the counter is before our initial ratchet - we can't decode this - return nil, fmt.Errorf("decrypt: %w", olm.ErrRatchetNotAvailable) + return nil, fmt.Errorf("decrypt: %w", olm.ErrUnknownMessageIndex) } // otherwise, start from the initial ratchet. Take a copy so that we don't overwrite the initial ratchet copiedRatchet := o.InitialRatchet @@ -128,7 +126,7 @@ func (o *MegolmInboundSession) Decrypt(ciphertext []byte) ([]byte, uint, error) return nil, 0, err } if msg.Version != protocolVersion { - return nil, 0, fmt.Errorf("decrypt: %w", olm.ErrWrongProtocolVersion) + return nil, 0, fmt.Errorf("decrypt: %w (got %d, expected %d)", olm.ErrWrongProtocolVersion, msg.Version, protocolVersion) } if msg.Ciphertext == nil || !msg.HasMessageIndex { return nil, 0, fmt.Errorf("decrypt: %w", olm.ErrBadMessageFormat) @@ -161,12 +159,12 @@ func (o *MegolmInboundSession) ID() id.SessionID { // PickleAsJSON returns an MegolmInboundSession as a base64 string encrypted using the supplied key. The unencrypted representation of the Account is in JSON format. func (o *MegolmInboundSession) PickleAsJSON(key []byte) ([]byte, error) { - return utilities.PickleAsJSON(o, megolmInboundSessionPickleVersionJSON, key) + return libolmpickle.PickleAsJSON(o, megolmInboundSessionPickleVersionJSON, key) } // UnpickleAsJSON updates an MegolmInboundSession by a base64 encrypted string using the supplied key. The unencrypted representation has to be in JSON format. func (o *MegolmInboundSession) UnpickleAsJSON(pickled, key []byte) error { - return utilities.UnpickleAsJSON(o, pickled, key, megolmInboundSessionPickleVersionJSON) + return libolmpickle.UnpickleAsJSON(o, pickled, key, megolmInboundSessionPickleVersionJSON) } // Export returns the base64-encoded ratchet key for this session, at the given @@ -192,7 +190,7 @@ func (o *MegolmInboundSession) Unpickle(pickled, key []byte) error { } else if len(pickled) == 0 { return olm.ErrEmptyInput } - decrypted, err := cipher.Unpickle(key, pickled) + decrypted, err := libolmpickle.Unpickle(key, pickled) if err != nil { return err } @@ -208,7 +206,7 @@ func (o *MegolmInboundSession) UnpickleLibOlm(value []byte) error { return err } if pickledVersion != megolmInboundSessionPickleVersionLibOlm && pickledVersion != 1 { - return fmt.Errorf("unpickle MegolmInboundSession: %w (found version %d)", olm.ErrBadVersion, pickledVersion) + return fmt.Errorf("unpickle MegolmInboundSession: %w (found version %d)", olm.ErrUnknownOlmPickleVersion, pickledVersion) } if err = o.InitialRatchet.UnpickleLibOlm(decoder); err != nil { @@ -234,7 +232,7 @@ func (o *MegolmInboundSession) Pickle(key []byte) ([]byte, error) { if len(key) == 0 { return nil, olm.ErrNoKeyProvided } - return cipher.Pickle(key, o.PickleLibOlm()) + return libolmpickle.Pickle(key, o.PickleLibOlm()) } // PickleLibOlm pickles the session returning the raw bytes. diff --git a/crypto/goolm/session/megolm_outbound_session.go b/crypto/goolm/session/megolm_outbound_session.go index b42dab53..7f923534 100644 --- a/crypto/goolm/session/megolm_outbound_session.go +++ b/crypto/goolm/session/megolm_outbound_session.go @@ -7,12 +7,10 @@ import ( "go.mau.fi/util/exerrors" - "maunium.net/go/mautrix/crypto/goolm/cipher" "maunium.net/go/mautrix/crypto/goolm/crypto" "maunium.net/go/mautrix/crypto/goolm/goolmbase64" "maunium.net/go/mautrix/crypto/goolm/libolmpickle" "maunium.net/go/mautrix/crypto/goolm/megolm" - "maunium.net/go/mautrix/crypto/goolm/utilities" "maunium.net/go/mautrix/crypto/olm" "maunium.net/go/mautrix/id" ) @@ -66,7 +64,7 @@ func (o *MegolmOutboundSession) Encrypt(plaintext []byte) ([]byte, error) { if len(plaintext) == 0 { return nil, olm.ErrEmptyInput } - encrypted, err := o.Ratchet.Encrypt(plaintext, &o.SigningKey) + encrypted, err := o.Ratchet.Encrypt(plaintext, o.SigningKey) return goolmbase64.Encode(encrypted), err } @@ -77,12 +75,12 @@ func (o *MegolmOutboundSession) ID() id.SessionID { // PickleAsJSON returns an Session as a base64 string encrypted using the supplied key. The unencrypted representation of the Account is in JSON format. func (o *MegolmOutboundSession) PickleAsJSON(key []byte) ([]byte, error) { - return utilities.PickleAsJSON(o, megolmOutboundSessionPickleVersion, key) + return libolmpickle.PickleAsJSON(o, megolmOutboundSessionPickleVersion, key) } // UnpickleAsJSON updates an Session by a base64 encrypted string with the key. The unencrypted representation has to be in JSON format. func (o *MegolmOutboundSession) UnpickleAsJSON(pickled, key []byte) error { - return utilities.UnpickleAsJSON(o, pickled, key, megolmOutboundSessionPickleVersion) + return libolmpickle.UnpickleAsJSON(o, pickled, key, megolmOutboundSessionPickleVersion) } // Unpickle decodes the base64 encoded string and decrypts the result with the key. @@ -91,7 +89,7 @@ func (o *MegolmOutboundSession) Unpickle(pickled, key []byte) error { if len(key) == 0 { return olm.ErrNoKeyProvided } - decrypted, err := cipher.Unpickle(key, pickled) + decrypted, err := libolmpickle.Unpickle(key, pickled) if err != nil { return err } @@ -103,8 +101,10 @@ func (o *MegolmOutboundSession) Unpickle(pickled, key []byte) error { func (o *MegolmOutboundSession) UnpickleLibOlm(buf []byte) error { decoder := libolmpickle.NewDecoder(buf) pickledVersion, err := decoder.ReadUInt32() - if pickledVersion != megolmOutboundSessionPickleVersionLibOlm { - return fmt.Errorf("unpickle MegolmInboundSession: %w (found version %d)", olm.ErrBadVersion, pickledVersion) + if err != nil { + return fmt.Errorf("unpickle MegolmOutboundSession: failed to read version: %w", err) + } else if pickledVersion != megolmOutboundSessionPickleVersionLibOlm { + return fmt.Errorf("unpickle MegolmInboundSession: %w (found version %d)", olm.ErrUnknownOlmPickleVersion, pickledVersion) } if err = o.Ratchet.UnpickleLibOlm(decoder); err != nil { return err @@ -117,7 +117,7 @@ func (o *MegolmOutboundSession) Pickle(key []byte) ([]byte, error) { if len(key) == 0 { return nil, olm.ErrNoKeyProvided } - return cipher.Pickle(key, o.PickleLibOlm()) + return libolmpickle.Pickle(key, o.PickleLibOlm()) } // PickleLibOlm pickles the session returning the raw bytes. diff --git a/crypto/goolm/session/olm_session.go b/crypto/goolm/session/olm_session.go index fcd9d0dc..a1cb8d66 100644 --- a/crypto/goolm/session/olm_session.go +++ b/crypto/goolm/session/olm_session.go @@ -7,13 +7,11 @@ import ( "fmt" "strings" - "maunium.net/go/mautrix/crypto/goolm/cipher" "maunium.net/go/mautrix/crypto/goolm/crypto" "maunium.net/go/mautrix/crypto/goolm/goolmbase64" "maunium.net/go/mautrix/crypto/goolm/libolmpickle" "maunium.net/go/mautrix/crypto/goolm/message" "maunium.net/go/mautrix/crypto/goolm/ratchet" - "maunium.net/go/mautrix/crypto/goolm/utilities" "maunium.net/go/mautrix/crypto/olm" "maunium.net/go/mautrix/id" ) @@ -170,11 +168,11 @@ func NewInboundOlmSession(identityKeyAlice *crypto.Curve25519PublicKey, received msg := message.Message{} err = msg.Decode(oneTimeMsg.Message) if err != nil { - return nil, fmt.Errorf("Message decode: %w", err) + return nil, fmt.Errorf("message decode: %w", err) } if len(msg.RatchetKey) == 0 { - return nil, fmt.Errorf("Message missing ratchet key: %w", olm.ErrBadMessageFormat) + return nil, fmt.Errorf("message missing ratchet key: %w", olm.ErrBadMessageFormat) } //Init Ratchet s.Ratchet.InitializeAsBob(secret, msg.RatchetKey) @@ -189,12 +187,12 @@ func NewInboundOlmSession(identityKeyAlice *crypto.Curve25519PublicKey, received // PickleAsJSON returns an Session as a base64 string encrypted using the supplied key. The unencrypted representation of the Account is in JSON format. func (a OlmSession) PickleAsJSON(key []byte) ([]byte, error) { - return utilities.PickleAsJSON(a, olmSessionPickleVersionJSON, key) + return libolmpickle.PickleAsJSON(a, olmSessionPickleVersionJSON, key) } // UnpickleAsJSON updates an Session by a base64 encrypted string with the key. The unencrypted representation has to be in JSON format. func (a *OlmSession) UnpickleAsJSON(pickled, key []byte) error { - return utilities.UnpickleAsJSON(a, pickled, key, olmSessionPickleVersionJSON) + return libolmpickle.UnpickleAsJSON(a, pickled, key, olmSessionPickleVersionJSON) } // ID returns an identifier for this Session. Will be the same for both ends of the conversation. @@ -205,7 +203,7 @@ func (s *OlmSession) ID() id.SessionID { copy(message[crypto.Curve25519PrivateKeyLength:], s.AliceBaseKey) copy(message[2*crypto.Curve25519PrivateKeyLength:], s.BobOneTimeKey) hash := sha256.Sum256(message) - res := id.SessionID(goolmbase64.Encode(hash[:])) + res := id.SessionID(base64.RawStdEncoding.EncodeToString(hash[:])) return res } @@ -327,7 +325,7 @@ func (s *OlmSession) Decrypt(crypttext string, msgType id.OlmMsgType) ([]byte, e if len(crypttext) == 0 { return nil, fmt.Errorf("decrypt: %w", olm.ErrEmptyInput) } - decodedCrypttext, err := goolmbase64.Decode([]byte(crypttext)) + decodedCrypttext, err := base64.RawStdEncoding.DecodeString(crypttext) if err != nil { return nil, err } @@ -355,7 +353,7 @@ func (o *OlmSession) Unpickle(pickled, key []byte) error { if len(pickled) == 0 { return olm.ErrEmptyInput } - decrypted, err := cipher.Unpickle(key, pickled) + decrypted, err := libolmpickle.Unpickle(key, pickled) if err != nil { return err } @@ -367,6 +365,9 @@ func (o *OlmSession) Unpickle(pickled, key []byte) error { func (o *OlmSession) UnpickleLibOlm(buf []byte) error { decoder := libolmpickle.NewDecoder(buf) pickledVersion, err := decoder.ReadUInt32() + if err != nil { + return fmt.Errorf("unpickle olmSession: failed to read version: %w", err) + } var includesChainIndex bool switch pickledVersion { @@ -375,7 +376,7 @@ func (o *OlmSession) UnpickleLibOlm(buf []byte) error { case uint32(0x80000001): includesChainIndex = true default: - return fmt.Errorf("unpickle olmSession: %w (found version %d)", olm.ErrBadVersion, pickledVersion) + return fmt.Errorf("unpickle olmSession: %w (found version %d)", olm.ErrUnknownOlmPickleVersion, pickledVersion) } if o.ReceivedMessage, err = decoder.ReadBool(); err != nil { @@ -396,7 +397,7 @@ func (s *OlmSession) Pickle(key []byte) ([]byte, error) { if len(key) == 0 { return nil, olm.ErrNoKeyProvided } - return cipher.Pickle(key, s.PickleLibOlm()) + return libolmpickle.Pickle(key, s.PickleLibOlm()) } // PickleLibOlm pickles the session and returns the raw bytes. diff --git a/crypto/goolm/session/register.go b/crypto/goolm/session/register.go index 09ed42d4..b95a44ac 100644 --- a/crypto/goolm/session/register.go +++ b/crypto/goolm/session/register.go @@ -10,11 +10,11 @@ import ( "maunium.net/go/mautrix/crypto/olm" ) -func init() { +func Register() { // Inbound Session olm.InitInboundGroupSessionFromPickled = func(pickled, key []byte) (olm.InboundGroupSession, error) { if len(pickled) == 0 { - return nil, olm.EmptyInput + return nil, olm.ErrEmptyInput } if len(key) == 0 { key = []byte(" ") @@ -23,13 +23,13 @@ func init() { } olm.InitNewInboundGroupSession = func(sessionKey []byte) (olm.InboundGroupSession, error) { if len(sessionKey) == 0 { - return nil, olm.EmptyInput + return nil, olm.ErrEmptyInput } return NewMegolmInboundSession(sessionKey) } olm.InitInboundGroupSessionImport = func(sessionKey []byte) (olm.InboundGroupSession, error) { if len(sessionKey) == 0 { - return nil, olm.EmptyInput + return nil, olm.ErrEmptyInput } return NewMegolmInboundSessionFromExport(sessionKey) } @@ -40,7 +40,7 @@ func init() { // Outbound Session olm.InitNewOutboundGroupSessionFromPickled = func(pickled, key []byte) (olm.OutboundGroupSession, error) { if len(pickled) == 0 { - return nil, olm.EmptyInput + return nil, olm.ErrEmptyInput } lenKey := len(key) if lenKey == 0 { diff --git a/crypto/keybackup.go b/crypto/keybackup.go index 4e9431bb..7b3c30db 100644 --- a/crypto/keybackup.go +++ b/crypto/keybackup.go @@ -2,6 +2,8 @@ package crypto import ( "context" + "encoding/base64" + "errors" "fmt" "time" @@ -11,6 +13,7 @@ import ( "maunium.net/go/mautrix/crypto/backup" "maunium.net/go/mautrix/crypto/olm" "maunium.net/go/mautrix/crypto/signatures" + "maunium.net/go/mautrix/event" "maunium.net/go/mautrix/id" ) @@ -21,7 +24,7 @@ func (mach *OlmMachine) DownloadAndStoreLatestKeyBackup(ctx context.Context, meg ctx = log.WithContext(ctx) - versionInfo, err := mach.GetAndVerifyLatestKeyBackupVersion(ctx) + versionInfo, err := mach.GetAndVerifyLatestKeyBackupVersion(ctx, megolmBackupKey) if err != nil { return "", err } else if versionInfo == nil { @@ -32,7 +35,7 @@ func (mach *OlmMachine) DownloadAndStoreLatestKeyBackup(ctx context.Context, meg return versionInfo.Version, err } -func (mach *OlmMachine) GetAndVerifyLatestKeyBackupVersion(ctx context.Context) (*mautrix.RespRoomKeysVersion[backup.MegolmAuthData], error) { +func (mach *OlmMachine) GetAndVerifyLatestKeyBackupVersion(ctx context.Context, megolmBackupKey *backup.MegolmBackupKey) (*mautrix.RespRoomKeysVersion[backup.MegolmAuthData], error) { versionInfo, err := mach.Client.GetKeyBackupLatestVersion(ctx) if err != nil { return nil, err @@ -48,6 +51,24 @@ func (mach *OlmMachine) GetAndVerifyLatestKeyBackupVersion(ctx context.Context) Stringer("key_backup_version", versionInfo.Version). Logger() + // https://spec.matrix.org/v1.10/client-server-api/#server-side-key-backups + // "Clients must only store keys in backups after they have ensured that the auth_data is trusted. This can be done either... + // ...by deriving the public key from a private key that it obtained from a trusted source. Trusted sources for the private + // key include the user entering the key, retrieving the key stored in secret storage, or obtaining the key via secret sharing + // from a verified device belonging to the same user." + if megolmBackupKey != nil { + megolmBackupDerivedPublicKey := id.Ed25519(base64.RawStdEncoding.EncodeToString(megolmBackupKey.PublicKey().Bytes())) + if versionInfo.AuthData.PublicKey == megolmBackupDerivedPublicKey { + log.Debug().Msg("Key backup is trusted based on derived public key") + return versionInfo, nil + } + log.Debug(). + Stringer("expected_key", megolmBackupDerivedPublicKey). + Stringer("actual_key", versionInfo.AuthData.PublicKey). + Msg("key backup public keys do not match, proceeding to check device signatures") + } + + // "...or checking that it is signed by the user’s master cross-signing key or by a verified device belonging to the same user" userSignatures, ok := versionInfo.AuthData.Signatures[mach.Client.UserID] if !ok { return nil, fmt.Errorf("no signature from user %s found in key backup", mach.Client.UserID) @@ -74,7 +95,7 @@ func (mach *OlmMachine) GetAndVerifyLatestKeyBackupVersion(ctx context.Context) } else if device == nil { log.Warn().Err(err).Msg("Device does not exist, ignoring signature") continue - } else if !mach.IsDeviceTrusted(device) { + } else if !mach.IsDeviceTrusted(ctx, device) { log.Warn().Err(err).Msg("Device is not trusted") continue } else { @@ -87,6 +108,7 @@ func (mach *OlmMachine) GetAndVerifyLatestKeyBackupVersion(ctx context.Context) continue } else { // One of the signatures is valid, break from the loop. + log.Debug().Stringer("key_id", keyID).Msg("key backup is trusted based on matching signature") signatureVerified = true break } @@ -135,13 +157,23 @@ func (mach *OlmMachine) GetAndStoreKeyBackup(ctx context.Context, version id.Key return nil } -func (mach *OlmMachine) ImportRoomKeyFromBackup(ctx context.Context, version id.KeyBackupVersion, roomID id.RoomID, sessionID id.SessionID, keyBackupData *backup.MegolmSessionData) (*InboundGroupSession, error) { - log := zerolog.Ctx(ctx).With(). - Str("room_id", roomID.String()). - Str("session_id", sessionID.String()). - Logger() +var ( + ErrUnknownAlgorithmInKeyBackup = errors.New("ignoring room key in backup with weird algorithm") + ErrMismatchingSessionIDInKeyBackup = errors.New("mismatched session ID while creating inbound group session from key backup") + ErrFailedToStoreNewInboundGroupSessionFromBackup = errors.New("failed to store new inbound group session from key backup") +) + +func (mach *OlmMachine) ImportRoomKeyFromBackupWithoutSaving( + ctx context.Context, + version id.KeyBackupVersion, + roomID id.RoomID, + config *event.EncryptionEventContent, + sessionID id.SessionID, + keyBackupData *backup.MegolmSessionData, +) (*InboundGroupSession, error) { + log := zerolog.Ctx(ctx) if keyBackupData.Algorithm != id.AlgorithmMegolmV1 { - return nil, fmt.Errorf("ignoring room key in backup with weird algorithm %s", keyBackupData.Algorithm) + return nil, fmt.Errorf("%w %s", ErrUnknownAlgorithmInKeyBackup, keyBackupData.Algorithm) } igsInternal, err := olm.InboundGroupSessionImport([]byte(keyBackupData.SessionKey)) @@ -149,42 +181,60 @@ func (mach *OlmMachine) ImportRoomKeyFromBackup(ctx context.Context, version id. return nil, fmt.Errorf("failed to import inbound group session: %w", err) } else if igsInternal.ID() != sessionID { log.Warn(). + Stringer("room_id", roomID). + Stringer("session_id", sessionID). Stringer("actual_session_id", igsInternal.ID()). Msg("Mismatched session ID while creating inbound group session from key backup") - return nil, fmt.Errorf("mismatched session ID while creating inbound group session from key backup") + return nil, ErrMismatchingSessionIDInKeyBackup } var maxAge time.Duration var maxMessages int - if config, err := mach.StateStore.GetEncryptionEvent(ctx, roomID); err != nil { - log.Error().Err(err).Msg("Failed to get encryption event for room") - } else if config != nil { + if config != nil { maxAge = time.Duration(config.RotationPeriodMillis) * time.Millisecond maxMessages = config.RotationPeriodMessages } - firstKnownIndex := igsInternal.FirstKnownIndex() - if firstKnownIndex > 0 { - log.Warn().Uint32("first_known_index", firstKnownIndex).Msg("Importing partial session") - } - - igs := &InboundGroupSession{ + return &InboundGroupSession{ Internal: igsInternal, SigningKey: keyBackupData.SenderClaimedKeys.Ed25519, SenderKey: keyBackupData.SenderKey, RoomID: roomID, - ForwardingChains: append(keyBackupData.ForwardingKeyChain, keyBackupData.SenderKey.String()), + ForwardingChains: keyBackupData.ForwardingKeyChain, id: sessionID, ReceivedAt: time.Now().UTC(), MaxAge: maxAge.Milliseconds(), MaxMessages: maxMessages, KeyBackupVersion: version, - } - err = mach.CryptoStore.PutGroupSession(ctx, igs) - if err != nil { - return nil, fmt.Errorf("failed to store new inbound group session: %w", err) - } - mach.markSessionReceived(ctx, roomID, sessionID, firstKnownIndex) - return igs, nil + KeySource: id.KeySourceBackup, + }, nil +} + +func (mach *OlmMachine) ImportRoomKeyFromBackup(ctx context.Context, version id.KeyBackupVersion, roomID id.RoomID, sessionID id.SessionID, keyBackupData *backup.MegolmSessionData) (*InboundGroupSession, error) { + config, err := mach.StateStore.GetEncryptionEvent(ctx, roomID) + if err != nil { + zerolog.Ctx(ctx).Err(err). + Stringer("room_id", roomID). + Stringer("session_id", sessionID). + Msg("Failed to get encryption event for room") + } + imported, err := mach.ImportRoomKeyFromBackupWithoutSaving(ctx, version, roomID, config, sessionID, keyBackupData) + if err != nil { + return nil, err + } + firstKnownIndex := imported.Internal.FirstKnownIndex() + if firstKnownIndex > 0 { + zerolog.Ctx(ctx).Warn(). + Stringer("room_id", roomID). + Stringer("session_id", sessionID). + Uint32("first_known_index", firstKnownIndex). + Msg("Importing partial session") + } + err = mach.CryptoStore.PutGroupSession(ctx, imported) + if err != nil { + return nil, fmt.Errorf("%w: %w", ErrFailedToStoreNewInboundGroupSessionFromBackup, err) + } + mach.MarkSessionReceived(ctx, roomID, sessionID, firstKnownIndex) + return imported, nil } diff --git a/crypto/keyexport.go b/crypto/keyexport.go index 3d126db4..1904c8a5 100644 --- a/crypto/keyexport.go +++ b/crypto/keyexport.go @@ -1,4 +1,4 @@ -// Copyright (c) 2020 Tulir Asokan +// Copyright (c) 2025 Tulir Asokan // // This Source Code Form is subject to the terms of the Mozilla Public // License, v. 2.0. If a copy of the MPL was not distributed with this @@ -16,15 +16,21 @@ import ( "encoding/base64" "encoding/binary" "encoding/json" + "errors" "fmt" "math" + "go.mau.fi/util/dbutil" + "go.mau.fi/util/exbytes" + "go.mau.fi/util/exerrors" "go.mau.fi/util/random" "golang.org/x/crypto/pbkdf2" "maunium.net/go/mautrix/id" ) +var ErrNoSessionsForExport = errors.New("no sessions provided for export") + type SenderClaimedKeys struct { Ed25519 id.Ed25519 `json:"ed25519"` } @@ -78,22 +84,14 @@ func makeExportKeys(passphrase string) (encryptionKey, hashKey, salt, iv []byte) return } -func exportSessions(sessions []*InboundGroupSession) ([]ExportedSession, error) { - export := make([]ExportedSession, len(sessions)) +func exportSessions(sessions []*InboundGroupSession) ([]*ExportedSession, error) { + export := make([]*ExportedSession, len(sessions)) + var err error for i, session := range sessions { - key, err := session.Internal.Export(session.Internal.FirstKnownIndex()) + export[i], err = session.export() if err != nil { return nil, fmt.Errorf("failed to export session: %w", err) } - export[i] = ExportedSession{ - Algorithm: id.AlgorithmMegolmV1, - ForwardingChains: session.ForwardingChains, - RoomID: session.RoomID, - SenderKey: session.SenderKey, - SenderClaimedKeys: SenderClaimedKeys{}, - SessionID: session.ID(), - SessionKey: string(key), - } } return export, nil } @@ -107,38 +105,73 @@ func exportSessionsJSON(sessions []*InboundGroupSession) ([]byte, error) { } func formatKeyExportData(data []byte) []byte { - base64Data := make([]byte, base64.StdEncoding.EncodedLen(len(data))) - base64.StdEncoding.Encode(base64Data, data) - - // Prefix + data and newline for each 76 characters of data + suffix + encodedLen := base64.StdEncoding.EncodedLen(len(data)) outputLength := len(exportPrefix) + - len(base64Data) + int(math.Ceil(float64(len(base64Data))/exportLineLengthLimit)) + + encodedLen + int(math.Ceil(float64(encodedLen)/exportLineLengthLimit)) + len(exportSuffix) + output := make([]byte, 0, outputLength) + outputWriter := (*exbytes.Writer)(&output) + base64Writer := base64.NewEncoder(base64.StdEncoding, outputWriter) + lineByteCount := base64.StdEncoding.DecodedLen(exportLineLengthLimit) + exerrors.Must(outputWriter.WriteString(exportPrefix)) + for i := 0; i < len(data); i += lineByteCount { + exerrors.Must(base64Writer.Write(data[i:min(i+lineByteCount, len(data))])) + if i+lineByteCount >= len(data) { + exerrors.PanicIfNotNil(base64Writer.Close()) + } + exerrors.PanicIfNotNil(outputWriter.WriteByte('\n')) + } + exerrors.Must(outputWriter.WriteString(exportSuffix)) + if len(output) != outputLength { + panic(fmt.Errorf("unexpected length %d / %d", len(output), outputLength)) + } + return output +} - var buf bytes.Buffer - buf.Grow(outputLength) - buf.WriteString(exportPrefix) - for ptr := 0; ptr < len(base64Data); ptr += exportLineLengthLimit { - buf.Write(base64Data[ptr:min(ptr+exportLineLengthLimit, len(base64Data))]) - buf.WriteRune('\n') +func ExportKeysIter(passphrase string, sessions dbutil.RowIter[*InboundGroupSession]) ([]byte, error) { + buf := bytes.NewBuffer(make([]byte, 0, 50*1024)) + enc := json.NewEncoder(buf) + buf.WriteByte('[') + err := sessions.Iter(func(session *InboundGroupSession) (bool, error) { + exported, err := session.export() + if err != nil { + return false, err + } + err = enc.Encode(exported) + if err != nil { + return false, err + } + buf.WriteByte(',') + return true, nil + }) + if err != nil { + return nil, err } - buf.WriteString(exportSuffix) - if buf.Len() != buf.Cap() || buf.Len() != outputLength { - panic(fmt.Errorf("unexpected length %d / %d / %d", buf.Len(), buf.Cap(), outputLength)) + output := buf.Bytes() + if len(output) == 1 { + return nil, ErrNoSessionsForExport } - return buf.Bytes() + output[len(output)-1] = ']' // Replace the last comma with a closing bracket + return EncryptKeyExport(passphrase, output) } // ExportKeys exports the given Megolm sessions with the format specified in the Matrix spec. // See https://spec.matrix.org/v1.2/client-server-api/#key-exports func ExportKeys(passphrase string, sessions []*InboundGroupSession) ([]byte, error) { - // Make all the keys necessary for exporting - encryptionKey, hashKey, salt, iv := makeExportKeys(passphrase) + if len(sessions) == 0 { + return nil, ErrNoSessionsForExport + } // Export all the given sessions and put them in JSON unencryptedData, err := exportSessionsJSON(sessions) if err != nil { return nil, err } + return EncryptKeyExport(passphrase, unencryptedData) +} + +func EncryptKeyExport(passphrase string, unencryptedData json.RawMessage) ([]byte, error) { + // Make all the keys necessary for exporting + encryptionKey, hashKey, salt, iv := makeExportKeys(passphrase) // The export data consists of: // 1 byte of export format version diff --git a/crypto/keyexport_test.go b/crypto/keyexport_test.go new file mode 100644 index 00000000..fd6f105d --- /dev/null +++ b/crypto/keyexport_test.go @@ -0,0 +1,35 @@ +// Copyright (c) 2024 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package crypto_test + +import ( + "testing" + + "github.com/stretchr/testify/assert" + + "go.mau.fi/util/exerrors" + "go.mau.fi/util/exfmt" + + "maunium.net/go/mautrix/crypto" + "maunium.net/go/mautrix/crypto/olm" +) + +func TestExportKeys(t *testing.T) { + acc := crypto.NewOlmAccount() + sess := exerrors.Must(crypto.NewInboundGroupSession( + acc.IdentityKey(), + acc.SigningKey(), + "!room:example.com", + exerrors.Must(olm.NewOutboundGroupSession()).Key(), + 7*exfmt.Day, + 100, + false, + )) + data, err := crypto.ExportKeys("meow", []*crypto.InboundGroupSession{sess}) + assert.NoError(t, err) + assert.Len(t, data, 893) +} diff --git a/crypto/keyimport.go b/crypto/keyimport.go index 108c67ac..3ffc74a5 100644 --- a/crypto/keyimport.go +++ b/crypto/keyimport.go @@ -36,6 +36,10 @@ var ( var exportPrefixBytes, exportSuffixBytes = []byte(exportPrefix), []byte(exportSuffix) func decodeKeyExport(data []byte) ([]byte, error) { + // Fix some types of corruption in the key export file before checking anything + if bytes.IndexByte(data, '\r') != -1 { + data = bytes.ReplaceAll(data, []byte{'\r', '\n'}, []byte{'\n'}) + } // If the valid prefix and suffix aren't there, it's probably not a Matrix key export if !bytes.HasPrefix(data, exportPrefixBytes) { return nil, ErrMissingExportPrefix @@ -104,26 +108,27 @@ func (mach *OlmMachine) importExportedRoomKey(ctx context.Context, session Expor return false, ErrMismatchingExportedSessionID } igs := &InboundGroupSession{ - Internal: igsInternal, - SigningKey: session.SenderClaimedKeys.Ed25519, - SenderKey: session.SenderKey, - RoomID: session.RoomID, - // TODO should we add something here to mark the signing key as unverified like key requests do? + Internal: igsInternal, + SigningKey: session.SenderClaimedKeys.Ed25519, + SenderKey: session.SenderKey, + RoomID: session.RoomID, ForwardingChains: session.ForwardingChains, - - ReceivedAt: time.Now().UTC(), + KeySource: id.KeySourceImport, + ReceivedAt: time.Now().UTC(), } existingIGS, _ := mach.CryptoStore.GetGroupSession(ctx, igs.RoomID, igs.ID()) firstKnownIndex := igs.Internal.FirstKnownIndex() if existingIGS != nil && existingIGS.Internal.FirstKnownIndex() <= firstKnownIndex { - // We already have an equivalent or better session in the store, so don't override it. + // We already have an equivalent or better session in the store, so don't override it, + // but do notify the session received callback just in case. + mach.MarkSessionReceived(ctx, session.RoomID, igs.ID(), existingIGS.Internal.FirstKnownIndex()) return false, nil } err = mach.CryptoStore.PutGroupSession(ctx, igs) if err != nil { return false, fmt.Errorf("failed to store imported session: %w", err) } - mach.markSessionReceived(ctx, session.RoomID, igs.ID(), firstKnownIndex) + mach.MarkSessionReceived(ctx, session.RoomID, igs.ID(), firstKnownIndex) return true, nil } diff --git a/crypto/keysharing.go b/crypto/keysharing.go index 0ccf006a..19a68c87 100644 --- a/crypto/keysharing.go +++ b/crypto/keysharing.go @@ -189,6 +189,7 @@ func (mach *OlmMachine) importForwardedRoomKey(ctx context.Context, evt *Decrypt MaxAge: maxAge.Milliseconds(), MaxMessages: maxMessages, IsScheduled: content.IsScheduled, + KeySource: id.KeySourceForward, } existingIGS, _ := mach.CryptoStore.GetGroupSession(ctx, igs.RoomID, igs.ID()) if existingIGS != nil && existingIGS.Internal.FirstKnownIndex() <= igs.Internal.FirstKnownIndex() { @@ -200,7 +201,7 @@ func (mach *OlmMachine) importForwardedRoomKey(ctx context.Context, evt *Decrypt log.Error().Err(err).Msg("Failed to store new inbound group session") return false } - mach.markSessionReceived(ctx, content.RoomID, content.SessionID, firstKnownIndex) + mach.MarkSessionReceived(ctx, content.RoomID, content.SessionID, firstKnownIndex) log.Debug().Msg("Received forwarded inbound group session") return true } @@ -214,6 +215,7 @@ func (mach *OlmMachine) rejectKeyRequest(ctx context.Context, rejection KeyShare RoomID: request.RoomID, Algorithm: request.Algorithm, SessionID: request.SessionID, + //lint:ignore SA1019 This is just echoing back the deprecated field SenderKey: request.SenderKey, Code: rejection.Code, Reason: rejection.Reason, @@ -263,9 +265,14 @@ func (mach *OlmMachine) defaultAllowKeyShare(ctx context.Context, device *id.Dev log.Err(err).Msg("Rejecting key request due to internal error when checking session sharing") return &KeyShareRejectNoResponse } else if !isShared { - // TODO differentiate session not shared with requester vs session not created by this device? - log.Debug().Msg("Rejecting key request for unshared session") - return &KeyShareRejectNotRecipient + igs, _ := mach.CryptoStore.GetGroupSession(ctx, evt.RoomID, evt.SessionID) + if igs != nil && igs.SenderKey == mach.OwnIdentity().IdentityKey { + log.Debug().Msg("Rejecting key request for unshared session") + return &KeyShareRejectNotRecipient + } + // Note: this case will also happen for redacted sessions and database errors + log.Debug().Msg("Rejecting key request for session created by another device") + return &KeyShareRejectNoResponse } log.Debug().Msg("Accepting key request for shared session") return nil @@ -275,7 +282,7 @@ func (mach *OlmMachine) defaultAllowKeyShare(ctx context.Context, device *id.Dev } else if device.Trust == id.TrustStateBlacklisted { log.Debug().Msg("Rejecting key request from blacklisted device") return &KeyShareRejectBlacklisted - } else if trustState := mach.ResolveTrust(device); trustState >= mach.ShareKeysMinTrust { + } else if trustState, _ := mach.ResolveTrustContext(ctx, device); trustState >= mach.ShareKeysMinTrust { log.Debug(). Str("min_trust", mach.SendKeysMinTrust.String()). Str("device_trust", trustState.String()). @@ -323,7 +330,9 @@ func (mach *OlmMachine) HandleRoomKeyRequest(ctx context.Context, sender id.User if err != nil { if errors.Is(err, ErrGroupSessionWithheld) { log.Debug().Err(err).Msg("Requested group session not available") - mach.rejectKeyRequest(ctx, KeyShareRejectUnavailable, device, content.Body) + if sender != mach.Client.UserID { + mach.rejectKeyRequest(ctx, KeyShareRejectUnavailable, device, content.Body) + } } else { log.Error().Err(err).Msg("Failed to get group session to forward") mach.rejectKeyRequest(ctx, KeyShareRejectInternalError, device, content.Body) @@ -331,7 +340,9 @@ func (mach *OlmMachine) HandleRoomKeyRequest(ctx context.Context, sender id.User return } else if igs == nil { log.Error().Msg("Didn't find group session to forward") - mach.rejectKeyRequest(ctx, KeyShareRejectUnavailable, device, content.Body) + if sender != mach.Client.UserID { + mach.rejectKeyRequest(ctx, KeyShareRejectUnavailable, device, content.Body) + } return } if internalID := igs.ID(); internalID != content.Body.SessionID { @@ -347,9 +358,6 @@ func (mach *OlmMachine) HandleRoomKeyRequest(ctx context.Context, sender id.User mach.rejectKeyRequest(ctx, KeyShareRejectInternalError, device, content.Body) return } - if igs.ForwardingChains == nil { - igs.ForwardingChains = []string{} - } forwardedRoomKey := event.Content{ Parsed: &event.ForwardedRoomKeyEventContent{ @@ -359,7 +367,7 @@ func (mach *OlmMachine) HandleRoomKeyRequest(ctx context.Context, sender id.User SessionID: igs.ID(), SessionKey: string(exportedKey), }, - SenderKey: content.Body.SenderKey, + SenderKey: igs.SenderKey, ForwardingKeyChain: igs.ForwardingChains, SenderClaimedKey: igs.SigningKey, }, diff --git a/crypto/libolm/account.go b/crypto/libolm/account.go index cddce7ce..0350f083 100644 --- a/crypto/libolm/account.go +++ b/crypto/libolm/account.go @@ -8,6 +8,7 @@ import ( "crypto/rand" "encoding/base64" "encoding/json" + "runtime" "unsafe" "github.com/tidwall/gjson" @@ -22,18 +23,6 @@ type Account struct { mem []byte } -func init() { - olm.InitNewAccount = func() (olm.Account, error) { - return NewAccount() - } - olm.InitBlankAccount = func() olm.Account { - return NewBlankAccount() - } - olm.InitNewAccountFromPickled = func(pickled, key []byte) (olm.Account, error) { - return AccountFromPickled(pickled, key) - } -} - // Ensure that [Account] implements [olm.Account]. var _ olm.Account = (*Account)(nil) @@ -44,7 +33,7 @@ var _ olm.Account = (*Account)(nil) // "INVALID_BASE64". func AccountFromPickled(pickled, key []byte) (*Account, error) { if len(pickled) == 0 { - return nil, olm.EmptyInput + return nil, olm.ErrEmptyInput } a := NewBlankAccount() return a, a.Unpickle(pickled, key) @@ -53,7 +42,7 @@ func AccountFromPickled(pickled, key []byte) (*Account, error) { func NewBlankAccount() *Account { memory := make([]byte, accountSize()) return &Account{ - int: C.olm_account(unsafe.Pointer(&memory[0])), + int: C.olm_account(unsafe.Pointer(unsafe.SliceData(memory))), mem: memory, } } @@ -64,12 +53,13 @@ func NewAccount() (*Account, error) { random := make([]byte, a.createRandomLen()+1) _, err := rand.Read(random) if err != nil { - panic(olm.NotEnoughGoRandom) + panic(olm.ErrNotEnoughGoRandom) } ret := C.olm_create_account( (*C.OlmAccount)(a.int), - unsafe.Pointer(&random[0]), + unsafe.Pointer(unsafe.SliceData(random)), C.size_t(len(random))) + runtime.KeepAlive(random) if ret == errorVal() { return nil, a.lastError() } else { @@ -138,14 +128,14 @@ func (a *Account) genOneTimeKeysRandomLen(num uint) uint { // supplied key. func (a *Account) Pickle(key []byte) ([]byte, error) { if len(key) == 0 { - return nil, olm.NoKeyProvided + return nil, olm.ErrNoKeyProvided } pickled := make([]byte, a.pickleLen()) r := C.olm_pickle_account( (*C.OlmAccount)(a.int), - unsafe.Pointer(&key[0]), + unsafe.Pointer(unsafe.SliceData(key)), C.size_t(len(key)), - unsafe.Pointer(&pickled[0]), + unsafe.Pointer(unsafe.SliceData(pickled)), C.size_t(len(pickled))) if r == errorVal() { return nil, a.lastError() @@ -155,13 +145,13 @@ func (a *Account) Pickle(key []byte) ([]byte, error) { func (a *Account) Unpickle(pickled, key []byte) error { if len(key) == 0 { - return olm.NoKeyProvided + return olm.ErrNoKeyProvided } r := C.olm_unpickle_account( (*C.OlmAccount)(a.int), - unsafe.Pointer(&key[0]), + unsafe.Pointer(unsafe.SliceData(key)), C.size_t(len(key)), - unsafe.Pointer(&pickled[0]), + unsafe.Pointer(unsafe.SliceData(pickled)), C.size_t(len(pickled))) if r == errorVal() { return a.lastError() @@ -208,7 +198,7 @@ func (a *Account) MarshalJSON() ([]byte, error) { // Deprecated func (a *Account) UnmarshalJSON(data []byte) error { if len(data) == 0 || data[0] != '"' || data[len(data)-1] != '"' { - return olm.InputNotJSONString + return olm.ErrInputNotJSONString } if a.int == nil { *a = *NewBlankAccount() @@ -221,7 +211,7 @@ func (a *Account) IdentityKeysJSON() ([]byte, error) { identityKeys := make([]byte, a.identityKeysLen()) r := C.olm_account_identity_keys( (*C.OlmAccount)(a.int), - unsafe.Pointer(&identityKeys[0]), + unsafe.Pointer(unsafe.SliceData(identityKeys)), C.size_t(len(identityKeys))) if r == errorVal() { return nil, a.lastError() @@ -245,15 +235,16 @@ func (a *Account) IdentityKeys() (id.Ed25519, id.Curve25519, error) { // Account. func (a *Account) Sign(message []byte) ([]byte, error) { if len(message) == 0 { - panic(olm.EmptyInput) + panic(olm.ErrEmptyInput) } signature := make([]byte, a.signatureLen()) r := C.olm_account_sign( (*C.OlmAccount)(a.int), - unsafe.Pointer(&message[0]), + unsafe.Pointer(unsafe.SliceData(message)), C.size_t(len(message)), - unsafe.Pointer(&signature[0]), + unsafe.Pointer(unsafe.SliceData(signature)), C.size_t(len(signature))) + runtime.KeepAlive(message) if r == errorVal() { panic(a.lastError()) } @@ -277,8 +268,9 @@ func (a *Account) OneTimeKeys() (map[string]id.Curve25519, error) { oneTimeKeysJSON := make([]byte, a.oneTimeKeysLen()) r := C.olm_account_one_time_keys( (*C.OlmAccount)(a.int), - unsafe.Pointer(&oneTimeKeysJSON[0]), - C.size_t(len(oneTimeKeysJSON))) + unsafe.Pointer(unsafe.SliceData(oneTimeKeysJSON)), + C.size_t(len(oneTimeKeysJSON)), + ) if r == errorVal() { return nil, a.lastError() } @@ -307,13 +299,15 @@ func (a *Account) GenOneTimeKeys(num uint) error { random := make([]byte, a.genOneTimeKeysRandomLen(num)+1) _, err := rand.Read(random) if err != nil { - return olm.NotEnoughGoRandom + return olm.ErrNotEnoughGoRandom } r := C.olm_account_generate_one_time_keys( (*C.OlmAccount)(a.int), C.size_t(num), - unsafe.Pointer(&random[0]), - C.size_t(len(random))) + unsafe.Pointer(unsafe.SliceData(random)), + C.size_t(len(random)), + ) + runtime.KeepAlive(random) if r == errorVal() { return a.lastError() } @@ -325,23 +319,29 @@ func (a *Account) GenOneTimeKeys(num uint) error { // keys couldn't be decoded as base64 then the error will be "INVALID_BASE64" func (a *Account) NewOutboundSession(theirIdentityKey, theirOneTimeKey id.Curve25519) (olm.Session, error) { if len(theirIdentityKey) == 0 || len(theirOneTimeKey) == 0 { - return nil, olm.EmptyInput + return nil, olm.ErrEmptyInput } s := NewBlankSession() random := make([]byte, s.createOutboundRandomLen()+1) _, err := rand.Read(random) if err != nil { - panic(olm.NotEnoughGoRandom) + panic(olm.ErrNotEnoughGoRandom) } + theirIdentityKeyCopy := []byte(theirIdentityKey) + theirOneTimeKeyCopy := []byte(theirOneTimeKey) r := C.olm_create_outbound_session( (*C.OlmSession)(s.int), (*C.OlmAccount)(a.int), - unsafe.Pointer(&([]byte(theirIdentityKey)[0])), - C.size_t(len(theirIdentityKey)), - unsafe.Pointer(&([]byte(theirOneTimeKey)[0])), - C.size_t(len(theirOneTimeKey)), - unsafe.Pointer(&random[0]), - C.size_t(len(random))) + unsafe.Pointer(unsafe.SliceData(theirIdentityKeyCopy)), + C.size_t(len(theirIdentityKeyCopy)), + unsafe.Pointer(unsafe.SliceData(theirOneTimeKeyCopy)), + C.size_t(len(theirOneTimeKeyCopy)), + unsafe.Pointer(unsafe.SliceData(random)), + C.size_t(len(random)), + ) + runtime.KeepAlive(random) + runtime.KeepAlive(theirIdentityKeyCopy) + runtime.KeepAlive(theirOneTimeKeyCopy) if r == errorVal() { return nil, s.lastError() } @@ -357,14 +357,17 @@ func (a *Account) NewOutboundSession(theirIdentityKey, theirOneTimeKey id.Curve2 // time key then the error will be "BAD_MESSAGE_KEY_ID". func (a *Account) NewInboundSession(oneTimeKeyMsg string) (olm.Session, error) { if len(oneTimeKeyMsg) == 0 { - return nil, olm.EmptyInput + return nil, olm.ErrEmptyInput } s := NewBlankSession() + oneTimeKeyMsgCopy := []byte(oneTimeKeyMsg) r := C.olm_create_inbound_session( (*C.OlmSession)(s.int), (*C.OlmAccount)(a.int), - unsafe.Pointer(&([]byte(oneTimeKeyMsg)[0])), - C.size_t(len(oneTimeKeyMsg))) + unsafe.Pointer(unsafe.SliceData(oneTimeKeyMsgCopy)), + C.size_t(len(oneTimeKeyMsgCopy)), + ) + runtime.KeepAlive(oneTimeKeyMsgCopy) if r == errorVal() { return nil, s.lastError() } @@ -380,16 +383,21 @@ func (a *Account) NewInboundSession(oneTimeKeyMsg string) (olm.Session, error) { // time key then the error will be "BAD_MESSAGE_KEY_ID". func (a *Account) NewInboundSessionFrom(theirIdentityKey *id.Curve25519, oneTimeKeyMsg string) (olm.Session, error) { if theirIdentityKey == nil || len(oneTimeKeyMsg) == 0 { - return nil, olm.EmptyInput + return nil, olm.ErrEmptyInput } + theirIdentityKeyCopy := []byte(*theirIdentityKey) + oneTimeKeyMsgCopy := []byte(oneTimeKeyMsg) s := NewBlankSession() r := C.olm_create_inbound_session_from( (*C.OlmSession)(s.int), (*C.OlmAccount)(a.int), - unsafe.Pointer(&([]byte(*theirIdentityKey)[0])), - C.size_t(len(*theirIdentityKey)), - unsafe.Pointer(&([]byte(oneTimeKeyMsg)[0])), - C.size_t(len(oneTimeKeyMsg))) + unsafe.Pointer(unsafe.SliceData(theirIdentityKeyCopy)), + C.size_t(len(theirIdentityKeyCopy)), + unsafe.Pointer(unsafe.SliceData(oneTimeKeyMsgCopy)), + C.size_t(len(oneTimeKeyMsgCopy)), + ) + runtime.KeepAlive(theirIdentityKeyCopy) + runtime.KeepAlive(oneTimeKeyMsgCopy) if r == errorVal() { return nil, s.lastError() } @@ -402,7 +410,8 @@ func (a *Account) NewInboundSessionFrom(theirIdentityKey *id.Curve25519, oneTime func (a *Account) RemoveOneTimeKeys(s olm.Session) error { r := C.olm_remove_one_time_keys( (*C.OlmAccount)(a.int), - (*C.OlmSession)(s.(*Session).int)) + (*C.OlmSession)(s.(*Session).int), + ) if r == errorVal() { return a.lastError() } diff --git a/crypto/libolm/error.go b/crypto/libolm/error.go index 9ca415ee..6fb5512b 100644 --- a/crypto/libolm/error.go +++ b/crypto/libolm/error.go @@ -11,21 +11,21 @@ import ( ) var errorMap = map[string]error{ - "NOT_ENOUGH_RANDOM": olm.NotEnoughRandom, - "OUTPUT_BUFFER_TOO_SMALL": olm.OutputBufferTooSmall, - "BAD_MESSAGE_VERSION": olm.BadMessageVersion, - "BAD_MESSAGE_FORMAT": olm.BadMessageFormat, - "BAD_MESSAGE_MAC": olm.BadMessageMAC, - "BAD_MESSAGE_KEY_ID": olm.BadMessageKeyID, - "INVALID_BASE64": olm.InvalidBase64, - "BAD_ACCOUNT_KEY": olm.BadAccountKey, - "UNKNOWN_PICKLE_VERSION": olm.UnknownPickleVersion, - "CORRUPTED_PICKLE": olm.CorruptedPickle, - "BAD_SESSION_KEY": olm.BadSessionKey, - "UNKNOWN_MESSAGE_INDEX": olm.UnknownMessageIndex, - "BAD_LEGACY_ACCOUNT_PICKLE": olm.BadLegacyAccountPickle, - "BAD_SIGNATURE": olm.BadSignature, - "INPUT_BUFFER_TOO_SMALL": olm.InputBufferTooSmall, + "NOT_ENOUGH_RANDOM": olm.ErrLibolmNotEnoughRandom, + "OUTPUT_BUFFER_TOO_SMALL": olm.ErrLibolmOutputBufferTooSmall, + "BAD_MESSAGE_VERSION": olm.ErrWrongProtocolVersion, + "BAD_MESSAGE_FORMAT": olm.ErrBadMessageFormat, + "BAD_MESSAGE_MAC": olm.ErrBadMAC, + "BAD_MESSAGE_KEY_ID": olm.ErrBadMessageKeyID, + "INVALID_BASE64": olm.ErrLibolmInvalidBase64, + "BAD_ACCOUNT_KEY": olm.ErrLibolmBadAccountKey, + "UNKNOWN_PICKLE_VERSION": olm.ErrUnknownOlmPickleVersion, + "CORRUPTED_PICKLE": olm.ErrLibolmCorruptedPickle, + "BAD_SESSION_KEY": olm.ErrLibolmBadSessionKey, + "UNKNOWN_MESSAGE_INDEX": olm.ErrUnknownMessageIndex, + "BAD_LEGACY_ACCOUNT_PICKLE": olm.ErrLibolmBadLegacyAccountPickle, + "BAD_SIGNATURE": olm.ErrBadSignature, + "INPUT_BUFFER_TOO_SMALL": olm.ErrInputToSmall, } func convertError(errCode string) error { diff --git a/crypto/libolm/inboundgroupsession.go b/crypto/libolm/inboundgroupsession.go index 1e25748d..8815ac32 100644 --- a/crypto/libolm/inboundgroupsession.go +++ b/crypto/libolm/inboundgroupsession.go @@ -7,6 +7,7 @@ import "C" import ( "bytes" "encoding/base64" + "runtime" "unsafe" "maunium.net/go/mautrix/crypto/olm" @@ -20,21 +21,6 @@ type InboundGroupSession struct { mem []byte } -func init() { - olm.InitInboundGroupSessionFromPickled = func(pickled, key []byte) (olm.InboundGroupSession, error) { - return InboundGroupSessionFromPickled(pickled, key) - } - olm.InitNewInboundGroupSession = func(sessionKey []byte) (olm.InboundGroupSession, error) { - return NewInboundGroupSession(sessionKey) - } - olm.InitInboundGroupSessionImport = func(sessionKey []byte) (olm.InboundGroupSession, error) { - return InboundGroupSessionImport(sessionKey) - } - olm.InitBlankInboundGroupSession = func() olm.InboundGroupSession { - return NewBlankInboundGroupSession() - } -} - // Ensure that [InboundGroupSession] implements [olm.InboundGroupSession]. var _ olm.InboundGroupSession = (*InboundGroupSession)(nil) @@ -45,7 +31,7 @@ var _ olm.InboundGroupSession = (*InboundGroupSession)(nil) // base64 couldn't be decoded then the error will be "INVALID_BASE64". func InboundGroupSessionFromPickled(pickled, key []byte) (*InboundGroupSession, error) { if len(pickled) == 0 { - return nil, olm.EmptyInput + return nil, olm.ErrEmptyInput } lenKey := len(key) if lenKey == 0 { @@ -62,13 +48,15 @@ func InboundGroupSessionFromPickled(pickled, key []byte) (*InboundGroupSession, // "OLM_BAD_SESSION_KEY". func NewInboundGroupSession(sessionKey []byte) (*InboundGroupSession, error) { if len(sessionKey) == 0 { - return nil, olm.EmptyInput + return nil, olm.ErrEmptyInput } s := NewBlankInboundGroupSession() r := C.olm_init_inbound_group_session( (*C.OlmInboundGroupSession)(s.int), - (*C.uint8_t)(&sessionKey[0]), - C.size_t(len(sessionKey))) + (*C.uint8_t)(unsafe.Pointer(unsafe.SliceData(sessionKey))), + C.size_t(len(sessionKey)), + ) + runtime.KeepAlive(sessionKey) if r == errorVal() { return nil, s.lastError() } @@ -81,13 +69,15 @@ func NewInboundGroupSession(sessionKey []byte) (*InboundGroupSession, error) { // error will be "OLM_BAD_SESSION_KEY". func InboundGroupSessionImport(sessionKey []byte) (*InboundGroupSession, error) { if len(sessionKey) == 0 { - return nil, olm.EmptyInput + return nil, olm.ErrEmptyInput } s := NewBlankInboundGroupSession() r := C.olm_import_inbound_group_session( (*C.OlmInboundGroupSession)(s.int), - (*C.uint8_t)(&sessionKey[0]), - C.size_t(len(sessionKey))) + (*C.uint8_t)(unsafe.Pointer(unsafe.SliceData(sessionKey))), + C.size_t(len(sessionKey)), + ) + runtime.KeepAlive(sessionKey) if r == errorVal() { return nil, s.lastError() } @@ -104,7 +94,7 @@ func inboundGroupSessionSize() uint { func NewBlankInboundGroupSession() *InboundGroupSession { memory := make([]byte, inboundGroupSessionSize()) return &InboundGroupSession{ - int: C.olm_inbound_group_session(unsafe.Pointer(&memory[0])), + int: C.olm_inbound_group_session(unsafe.Pointer(unsafe.SliceData(memory))), mem: memory, } } @@ -134,15 +124,17 @@ func (s *InboundGroupSession) pickleLen() uint { // InboundGroupSession using the supplied key. func (s *InboundGroupSession) Pickle(key []byte) ([]byte, error) { if len(key) == 0 { - return nil, olm.NoKeyProvided + return nil, olm.ErrNoKeyProvided } pickled := make([]byte, s.pickleLen()) r := C.olm_pickle_inbound_group_session( (*C.OlmInboundGroupSession)(s.int), - unsafe.Pointer(&key[0]), + unsafe.Pointer(unsafe.SliceData(key)), C.size_t(len(key)), - unsafe.Pointer(&pickled[0]), - C.size_t(len(pickled))) + unsafe.Pointer(unsafe.SliceData(pickled)), + C.size_t(len(pickled)), + ) + runtime.KeepAlive(key) if r == errorVal() { return nil, s.lastError() } @@ -151,16 +143,18 @@ func (s *InboundGroupSession) Pickle(key []byte) ([]byte, error) { func (s *InboundGroupSession) Unpickle(pickled, key []byte) error { if len(key) == 0 { - return olm.NoKeyProvided + return olm.ErrNoKeyProvided } else if len(pickled) == 0 { - return olm.EmptyInput + return olm.ErrEmptyInput } r := C.olm_unpickle_inbound_group_session( (*C.OlmInboundGroupSession)(s.int), - unsafe.Pointer(&key[0]), + unsafe.Pointer(unsafe.SliceData(key)), C.size_t(len(key)), - unsafe.Pointer(&pickled[0]), - C.size_t(len(pickled))) + unsafe.Pointer(unsafe.SliceData(pickled)), + C.size_t(len(pickled)), + ) + runtime.KeepAlive(key) if r == errorVal() { return s.lastError() } @@ -206,7 +200,7 @@ func (s *InboundGroupSession) MarshalJSON() ([]byte, error) { // Deprecated func (s *InboundGroupSession) UnmarshalJSON(data []byte) error { if len(data) == 0 || data[0] != '"' || data[len(data)-1] != '"' { - return olm.InputNotJSONString + return olm.ErrInputNotJSONString } if s == nil || s.int == nil { *s = *NewBlankInboundGroupSession() @@ -223,14 +217,16 @@ func (s *InboundGroupSession) UnmarshalJSON(data []byte) error { // will be "BAD_MESSAGE_FORMAT". func (s *InboundGroupSession) decryptMaxPlaintextLen(message []byte) (uint, error) { if len(message) == 0 { - return 0, olm.EmptyInput + return 0, olm.ErrEmptyInput } // olm_group_decrypt_max_plaintext_length destroys the input, so we have to clone it - message = bytes.Clone(message) + messageCopy := bytes.Clone(message) r := C.olm_group_decrypt_max_plaintext_length( (*C.OlmInboundGroupSession)(s.int), - (*C.uint8_t)(&message[0]), - C.size_t(len(message))) + (*C.uint8_t)(unsafe.Pointer(unsafe.SliceData(messageCopy))), + C.size_t(len(messageCopy)), + ) + runtime.KeepAlive(messageCopy) if r == errorVal() { return 0, s.lastError() } @@ -248,23 +244,24 @@ func (s *InboundGroupSession) decryptMaxPlaintextLen(message []byte) (uint, erro // was shared with us) the error will be "OLM_UNKNOWN_MESSAGE_INDEX". func (s *InboundGroupSession) Decrypt(message []byte) ([]byte, uint, error) { if len(message) == 0 { - return nil, 0, olm.EmptyInput + return nil, 0, olm.ErrEmptyInput } decryptMaxPlaintextLen, err := s.decryptMaxPlaintextLen(message) if err != nil { return nil, 0, err } - messageCopy := make([]byte, len(message)) - copy(messageCopy, message) + messageCopy := bytes.Clone(message) plaintext := make([]byte, decryptMaxPlaintextLen) var messageIndex uint32 r := C.olm_group_decrypt( (*C.OlmInboundGroupSession)(s.int), - (*C.uint8_t)(&messageCopy[0]), + (*C.uint8_t)(unsafe.Pointer(unsafe.SliceData(messageCopy))), C.size_t(len(messageCopy)), - (*C.uint8_t)(&plaintext[0]), + (*C.uint8_t)(unsafe.Pointer(unsafe.SliceData(plaintext))), C.size_t(len(plaintext)), - (*C.uint32_t)(&messageIndex)) + (*C.uint32_t)(unsafe.Pointer(&messageIndex)), + ) + runtime.KeepAlive(messageCopy) if r == errorVal() { return nil, 0, s.lastError() } @@ -281,8 +278,9 @@ func (s *InboundGroupSession) ID() id.SessionID { sessionID := make([]byte, s.sessionIdLen()) r := C.olm_inbound_group_session_id( (*C.OlmInboundGroupSession)(s.int), - (*C.uint8_t)(&sessionID[0]), - C.size_t(len(sessionID))) + (*C.uint8_t)(unsafe.Pointer(unsafe.SliceData(sessionID))), + C.size_t(len(sessionID)), + ) if r == errorVal() { panic(s.lastError()) } @@ -318,9 +316,10 @@ func (s *InboundGroupSession) Export(messageIndex uint32) ([]byte, error) { key := make([]byte, s.exportLen()) r := C.olm_export_inbound_group_session( (*C.OlmInboundGroupSession)(s.int), - (*C.uint8_t)(&key[0]), + (*C.uint8_t)(unsafe.Pointer(unsafe.SliceData(key))), C.size_t(len(key)), - C.uint32_t(messageIndex)) + C.uint32_t(messageIndex), + ) if r == errorVal() { return nil, s.lastError() } diff --git a/crypto/libolm/outboundgroupsession.go b/crypto/libolm/outboundgroupsession.go index a21f8d4a..ca5b68f7 100644 --- a/crypto/libolm/outboundgroupsession.go +++ b/crypto/libolm/outboundgroupsession.go @@ -7,6 +7,7 @@ import "C" import ( "crypto/rand" "encoding/base64" + "runtime" "unsafe" "maunium.net/go/mautrix/crypto/olm" @@ -20,18 +21,6 @@ type OutboundGroupSession struct { mem []byte } -func init() { - olm.InitNewOutboundGroupSessionFromPickled = func(pickled, key []byte) (olm.OutboundGroupSession, error) { - if len(pickled) == 0 { - return nil, olm.EmptyInput - } - s := NewBlankOutboundGroupSession() - return s, s.Unpickle(pickled, key) - } - olm.InitNewOutboundGroupSession = func() (olm.OutboundGroupSession, error) { return NewOutboundGroupSession() } - olm.InitNewBlankOutboundGroupSession = func() olm.OutboundGroupSession { return NewBlankOutboundGroupSession() } -} - // Ensure that [OutboundGroupSession] implements [olm.OutboundGroupSession]. var _ olm.OutboundGroupSession = (*OutboundGroupSession)(nil) @@ -44,8 +33,10 @@ func NewOutboundGroupSession() (*OutboundGroupSession, error) { } r := C.olm_init_outbound_group_session( (*C.OlmOutboundGroupSession)(s.int), - (*C.uint8_t)(&random[0]), - C.size_t(len(random))) + (*C.uint8_t)(unsafe.Pointer(unsafe.SliceData(random))), + C.size_t(len(random)), + ) + runtime.KeepAlive(random) if r == errorVal() { return nil, s.lastError() } @@ -62,7 +53,7 @@ func outboundGroupSessionSize() uint { func NewBlankOutboundGroupSession() *OutboundGroupSession { memory := make([]byte, outboundGroupSessionSize()) return &OutboundGroupSession{ - int: C.olm_outbound_group_session(unsafe.Pointer(&memory[0])), + int: C.olm_outbound_group_session(unsafe.Pointer(unsafe.SliceData(memory))), mem: memory, } } @@ -93,15 +84,17 @@ func (s *OutboundGroupSession) pickleLen() uint { // OutboundGroupSession using the supplied key. func (s *OutboundGroupSession) Pickle(key []byte) ([]byte, error) { if len(key) == 0 { - return nil, olm.NoKeyProvided + return nil, olm.ErrNoKeyProvided } pickled := make([]byte, s.pickleLen()) r := C.olm_pickle_outbound_group_session( (*C.OlmOutboundGroupSession)(s.int), - unsafe.Pointer(&key[0]), + unsafe.Pointer(unsafe.SliceData(key)), C.size_t(len(key)), - unsafe.Pointer(&pickled[0]), - C.size_t(len(pickled))) + unsafe.Pointer(unsafe.SliceData(pickled)), + C.size_t(len(pickled)), + ) + runtime.KeepAlive(key) if r == errorVal() { return nil, s.lastError() } @@ -110,14 +103,17 @@ func (s *OutboundGroupSession) Pickle(key []byte) ([]byte, error) { func (s *OutboundGroupSession) Unpickle(pickled, key []byte) error { if len(key) == 0 { - return olm.NoKeyProvided + return olm.ErrNoKeyProvided } r := C.olm_unpickle_outbound_group_session( (*C.OlmOutboundGroupSession)(s.int), - unsafe.Pointer(&key[0]), + unsafe.Pointer(unsafe.SliceData(key)), C.size_t(len(key)), - unsafe.Pointer(&pickled[0]), - C.size_t(len(pickled))) + unsafe.Pointer(unsafe.SliceData(pickled)), + C.size_t(len(pickled)), + ) + runtime.KeepAlive(pickled) + runtime.KeepAlive(key) if r == errorVal() { return s.lastError() } @@ -163,7 +159,7 @@ func (s *OutboundGroupSession) MarshalJSON() ([]byte, error) { // Deprecated func (s *OutboundGroupSession) UnmarshalJSON(data []byte) error { if len(data) == 0 || data[0] != '"' || data[len(data)-1] != '"' { - return olm.InputNotJSONString + return olm.ErrInputNotJSONString } if s == nil || s.int == nil { *s = *NewBlankOutboundGroupSession() @@ -187,15 +183,17 @@ func (s *OutboundGroupSession) encryptMsgLen(plainTextLen int) uint { // as base64. func (s *OutboundGroupSession) Encrypt(plaintext []byte) ([]byte, error) { if len(plaintext) == 0 { - return nil, olm.EmptyInput + return nil, olm.ErrEmptyInput } message := make([]byte, s.encryptMsgLen(len(plaintext))) r := C.olm_group_encrypt( (*C.OlmOutboundGroupSession)(s.int), - (*C.uint8_t)(&plaintext[0]), + (*C.uint8_t)(unsafe.Pointer(unsafe.SliceData(plaintext))), C.size_t(len(plaintext)), - (*C.uint8_t)(&message[0]), - C.size_t(len(message))) + (*C.uint8_t)(unsafe.Pointer(unsafe.SliceData(message))), + C.size_t(len(message)), + ) + runtime.KeepAlive(plaintext) if r == errorVal() { return nil, s.lastError() } @@ -212,8 +210,9 @@ func (s *OutboundGroupSession) ID() id.SessionID { sessionID := make([]byte, s.sessionIdLen()) r := C.olm_outbound_group_session_id( (*C.OlmOutboundGroupSession)(s.int), - (*C.uint8_t)(&sessionID[0]), - C.size_t(len(sessionID))) + (*C.uint8_t)(unsafe.Pointer(unsafe.SliceData(sessionID))), + C.size_t(len(sessionID)), + ) if r == errorVal() { panic(s.lastError()) } @@ -236,8 +235,9 @@ func (s *OutboundGroupSession) Key() string { sessionKey := make([]byte, s.sessionKeyLen()) r := C.olm_outbound_group_session_key( (*C.OlmOutboundGroupSession)(s.int), - (*C.uint8_t)(&sessionKey[0]), - C.size_t(len(sessionKey))) + (*C.uint8_t)(unsafe.Pointer(unsafe.SliceData(sessionKey))), + C.size_t(len(sessionKey)), + ) if r == errorVal() { panic(s.lastError()) } diff --git a/crypto/libolm/pk.go b/crypto/libolm/pk.go index db8d35c5..2683cf15 100644 --- a/crypto/libolm/pk.go +++ b/crypto/libolm/pk.go @@ -14,6 +14,7 @@ import "C" import ( "crypto/rand" "encoding/json" + "runtime" "unsafe" "github.com/tidwall/sjson" @@ -34,16 +35,6 @@ type PKSigning struct { // Ensure that [PKSigning] implements [olm.PKSigning]. var _ olm.PKSigning = (*PKSigning)(nil) -func init() { - olm.InitNewPKSigning = func() (olm.PKSigning, error) { return NewPKSigning() } - olm.InitNewPKSigningFromSeed = func(seed []byte) (olm.PKSigning, error) { - return NewPKSigningFromSeed(seed) - } - olm.InitNewPKDecryptionFromPrivateKey = func(privateKey []byte) (olm.PKDecryption, error) { - return NewPkDecryption(privateKey) - } -} - func pkSigningSize() uint { return uint(C.olm_pk_signing_size()) } @@ -63,7 +54,7 @@ func pkSigningSignatureLength() uint { func newBlankPKSigning() *PKSigning { memory := make([]byte, pkSigningSize()) return &PKSigning{ - int: C.olm_pk_signing(unsafe.Pointer(&memory[0])), + int: C.olm_pk_signing(unsafe.Pointer(unsafe.SliceData(memory))), mem: memory, } } @@ -73,9 +64,14 @@ func NewPKSigningFromSeed(seed []byte) (*PKSigning, error) { p := newBlankPKSigning() p.clear() pubKey := make([]byte, pkSigningPublicKeyLength()) - if C.olm_pk_signing_key_from_seed((*C.OlmPkSigning)(p.int), - unsafe.Pointer(&pubKey[0]), C.size_t(len(pubKey)), - unsafe.Pointer(&seed[0]), C.size_t(len(seed))) == errorVal() { + r := C.olm_pk_signing_key_from_seed( + (*C.OlmPkSigning)(p.int), + unsafe.Pointer(unsafe.SliceData(pubKey)), + C.size_t(len(pubKey)), + unsafe.Pointer(unsafe.SliceData(seed)), + C.size_t(len(seed)), + ) + if r == errorVal() { return nil, p.lastError() } p.publicKey = id.Ed25519(pubKey) @@ -90,7 +86,7 @@ func NewPKSigning() (*PKSigning, error) { seed := make([]byte, pkSigningSeedLength()) _, err := rand.Read(seed) if err != nil { - panic(olm.NotEnoughGoRandom) + panic(olm.ErrNotEnoughGoRandom) } pk, err := NewPKSigningFromSeed(seed) return pk, err @@ -112,8 +108,15 @@ func (p *PKSigning) clear() { // Sign creates a signature for the given message using this key. func (p *PKSigning) Sign(message []byte) ([]byte, error) { signature := make([]byte, pkSigningSignatureLength()) - if C.olm_pk_sign((*C.OlmPkSigning)(p.int), (*C.uint8_t)(unsafe.Pointer(&message[0])), C.size_t(len(message)), - (*C.uint8_t)(unsafe.Pointer(&signature[0])), C.size_t(len(signature))) == errorVal() { + r := C.olm_pk_sign( + (*C.OlmPkSigning)(p.int), + (*C.uint8_t)(unsafe.Pointer(unsafe.SliceData(message))), + C.size_t(len(message)), + (*C.uint8_t)(unsafe.Pointer(unsafe.SliceData(signature))), + C.size_t(len(signature)), + ) + runtime.KeepAlive(message) + if r == errorVal() { return nil, p.lastError() } return signature, nil @@ -157,15 +160,21 @@ func pkDecryptionPublicKeySize() uint { func NewPkDecryption(privateKey []byte) (*PKDecryption, error) { memory := make([]byte, pkDecryptionSize()) p := &PKDecryption{ - int: C.olm_pk_decryption(unsafe.Pointer(&memory[0])), + int: C.olm_pk_decryption(unsafe.Pointer(unsafe.SliceData(memory))), mem: memory, } p.clear() pubKey := make([]byte, pkDecryptionPublicKeySize()) - if C.olm_pk_key_from_private((*C.OlmPkDecryption)(p.int), - unsafe.Pointer(&pubKey[0]), C.size_t(len(pubKey)), - unsafe.Pointer(&privateKey[0]), C.size_t(len(privateKey))) == errorVal() { + r := C.olm_pk_key_from_private( + (*C.OlmPkDecryption)(p.int), + unsafe.Pointer(unsafe.SliceData(pubKey)), + C.size_t(len(pubKey)), + unsafe.Pointer(unsafe.SliceData(privateKey)), + C.size_t(len(privateKey)), + ) + runtime.KeepAlive(privateKey) + if r == errorVal() { return nil, p.lastError() } p.publicKey = pubKey @@ -178,14 +187,26 @@ func (p *PKDecryption) PublicKey() id.Curve25519 { } func (p *PKDecryption) Decrypt(ephemeralKey []byte, mac []byte, ciphertext []byte) ([]byte, error) { - maxPlaintextLength := uint(C.olm_pk_max_plaintext_length((*C.OlmPkDecryption)(p.int), C.size_t(len(ciphertext)))) + maxPlaintextLength := uint(C.olm_pk_max_plaintext_length( + (*C.OlmPkDecryption)(p.int), + C.size_t(len(ciphertext)), + )) plaintext := make([]byte, maxPlaintextLength) - size := C.olm_pk_decrypt((*C.OlmPkDecryption)(p.int), - unsafe.Pointer(&ephemeralKey[0]), C.size_t(len(ephemeralKey)), - unsafe.Pointer(&mac[0]), C.size_t(len(mac)), - unsafe.Pointer(&ciphertext[0]), C.size_t(len(ciphertext)), - unsafe.Pointer(&plaintext[0]), C.size_t(len(plaintext))) + size := C.olm_pk_decrypt( + (*C.OlmPkDecryption)(p.int), + unsafe.Pointer(unsafe.SliceData(ephemeralKey)), + C.size_t(len(ephemeralKey)), + unsafe.Pointer(unsafe.SliceData(mac)), + C.size_t(len(mac)), + unsafe.Pointer(unsafe.SliceData(ciphertext)), + C.size_t(len(ciphertext)), + unsafe.Pointer(unsafe.SliceData(plaintext)), + C.size_t(len(plaintext)), + ) + runtime.KeepAlive(ephemeralKey) + runtime.KeepAlive(mac) + runtime.KeepAlive(ciphertext) if size == errorVal() { return nil, p.lastError() } diff --git a/crypto/libolm/register.go b/crypto/libolm/register.go index a423a7d0..ddf84613 100644 --- a/crypto/libolm/register.go +++ b/crypto/libolm/register.go @@ -3,19 +3,73 @@ package libolm // #cgo LDFLAGS: -lolm -lstdc++ // #include import "C" -import "maunium.net/go/mautrix/crypto/olm" +import ( + "unsafe" + + "maunium.net/go/mautrix/crypto/olm" +) var pickleKey = []byte("maunium.net/go/mautrix/crypto/olm") -func init() { +func Register() { + olm.Driver = "libolm" + olm.GetVersion = func() (major, minor, patch uint8) { C.olm_get_library_version( - (*C.uint8_t)(&major), - (*C.uint8_t)(&minor), - (*C.uint8_t)(&patch)) + (*C.uint8_t)(unsafe.Pointer(&major)), + (*C.uint8_t)(unsafe.Pointer(&minor)), + (*C.uint8_t)(unsafe.Pointer(&patch))) return 3, 2, 15 } olm.SetPickleKeyImpl = func(key []byte) { pickleKey = key } + + olm.InitNewAccount = func() (olm.Account, error) { + return NewAccount() + } + olm.InitBlankAccount = func() olm.Account { + return NewBlankAccount() + } + olm.InitNewAccountFromPickled = func(pickled, key []byte) (olm.Account, error) { + return AccountFromPickled(pickled, key) + } + + olm.InitSessionFromPickled = func(pickled, key []byte) (olm.Session, error) { + return SessionFromPickled(pickled, key) + } + olm.InitNewBlankSession = func() olm.Session { + return NewBlankSession() + } + + olm.InitNewPKSigning = func() (olm.PKSigning, error) { return NewPKSigning() } + olm.InitNewPKSigningFromSeed = func(seed []byte) (olm.PKSigning, error) { + return NewPKSigningFromSeed(seed) + } + olm.InitNewPKDecryptionFromPrivateKey = func(privateKey []byte) (olm.PKDecryption, error) { + return NewPkDecryption(privateKey) + } + + olm.InitInboundGroupSessionFromPickled = func(pickled, key []byte) (olm.InboundGroupSession, error) { + return InboundGroupSessionFromPickled(pickled, key) + } + olm.InitNewInboundGroupSession = func(sessionKey []byte) (olm.InboundGroupSession, error) { + return NewInboundGroupSession(sessionKey) + } + olm.InitInboundGroupSessionImport = func(sessionKey []byte) (olm.InboundGroupSession, error) { + return InboundGroupSessionImport(sessionKey) + } + olm.InitBlankInboundGroupSession = func() olm.InboundGroupSession { + return NewBlankInboundGroupSession() + } + + olm.InitNewOutboundGroupSessionFromPickled = func(pickled, key []byte) (olm.OutboundGroupSession, error) { + if len(pickled) == 0 { + return nil, olm.ErrEmptyInput + } + s := NewBlankOutboundGroupSession() + return s, s.Unpickle(pickled, key) + } + olm.InitNewOutboundGroupSession = func() (olm.OutboundGroupSession, error) { return NewOutboundGroupSession() } + olm.InitNewBlankOutboundGroupSession = func() olm.OutboundGroupSession { return NewBlankOutboundGroupSession() } } diff --git a/crypto/libolm/session.go b/crypto/libolm/session.go index 4cc22809..1441df26 100644 --- a/crypto/libolm/session.go +++ b/crypto/libolm/session.go @@ -23,6 +23,7 @@ import "C" import ( "crypto/rand" "encoding/base64" + "runtime" "unsafe" "maunium.net/go/mautrix/crypto/olm" @@ -38,15 +39,6 @@ type Session struct { // Ensure that [Session] implements [olm.Session]. var _ olm.Session = (*Session)(nil) -func init() { - olm.InitSessionFromPickled = func(pickled, key []byte) (olm.Session, error) { - return SessionFromPickled(pickled, key) - } - olm.InitNewBlankSession = func() olm.Session { - return NewBlankSession() - } -} - // sessionSize is the size of a session object in bytes. func sessionSize() uint { return uint(C.olm_session_size()) @@ -59,7 +51,7 @@ func sessionSize() uint { // "INVALID_BASE64". func SessionFromPickled(pickled, key []byte) (*Session, error) { if len(pickled) == 0 { - return nil, olm.EmptyInput + return nil, olm.ErrEmptyInput } s := NewBlankSession() return s, s.Unpickle(pickled, key) @@ -68,7 +60,7 @@ func SessionFromPickled(pickled, key []byte) (*Session, error) { func NewBlankSession() *Session { memory := make([]byte, sessionSize()) return &Session{ - int: C.olm_session(unsafe.Pointer(&memory[0])), + int: C.olm_session(unsafe.Pointer(unsafe.SliceData(memory))), mem: memory, } } @@ -126,13 +118,16 @@ func (s *Session) encryptMsgLen(plainTextLen int) uint { // will be "BAD_MESSAGE_FORMAT". func (s *Session) decryptMaxPlaintextLen(message string, msgType id.OlmMsgType) (uint, error) { if len(message) == 0 { - return 0, olm.EmptyInput + return 0, olm.ErrEmptyInput } + messageCopy := []byte(message) r := C.olm_decrypt_max_plaintext_length( (*C.OlmSession)(s.int), C.size_t(msgType), - unsafe.Pointer(C.CString(message)), - C.size_t(len(message))) + unsafe.Pointer(unsafe.SliceData((messageCopy))), + C.size_t(len(messageCopy)), + ) + runtime.KeepAlive(messageCopy) if r == errorVal() { return 0, s.lastError() } @@ -143,15 +138,16 @@ func (s *Session) decryptMaxPlaintextLen(message string, msgType id.OlmMsgType) // supplied key. func (s *Session) Pickle(key []byte) ([]byte, error) { if len(key) == 0 { - return nil, olm.NoKeyProvided + return nil, olm.ErrNoKeyProvided } pickled := make([]byte, s.pickleLen()) r := C.olm_pickle_session( (*C.OlmSession)(s.int), - unsafe.Pointer(&key[0]), + unsafe.Pointer(unsafe.SliceData(key)), C.size_t(len(key)), - unsafe.Pointer(&pickled[0]), + unsafe.Pointer(unsafe.SliceData(pickled)), C.size_t(len(pickled))) + runtime.KeepAlive(key) if r == errorVal() { panic(s.lastError()) } @@ -162,14 +158,16 @@ func (s *Session) Pickle(key []byte) ([]byte, error) { // provided key. This function mutates the input pickled data slice. func (s *Session) Unpickle(pickled, key []byte) error { if len(key) == 0 { - return olm.NoKeyProvided + return olm.ErrNoKeyProvided } r := C.olm_unpickle_session( (*C.OlmSession)(s.int), - unsafe.Pointer(&key[0]), + unsafe.Pointer(unsafe.SliceData(key)), C.size_t(len(key)), - unsafe.Pointer(&pickled[0]), + unsafe.Pointer(unsafe.SliceData(pickled)), C.size_t(len(pickled))) + runtime.KeepAlive(pickled) + runtime.KeepAlive(key) if r == errorVal() { return s.lastError() } @@ -215,7 +213,7 @@ func (s *Session) MarshalJSON() ([]byte, error) { // Deprecated func (s *Session) UnmarshalJSON(data []byte) error { if len(data) == 0 || data[0] != '"' || data[len(data)-1] != '"' { - return olm.InputNotJSONString + return olm.ErrInputNotJSONString } if s == nil || s.int == nil { *s = *NewBlankSession() @@ -229,8 +227,9 @@ func (s *Session) ID() id.SessionID { sessionID := make([]byte, s.idLen()) r := C.olm_session_id( (*C.OlmSession)(s.int), - unsafe.Pointer(&sessionID[0]), - C.size_t(len(sessionID))) + unsafe.Pointer(unsafe.SliceData(sessionID)), + C.size_t(len(sessionID)), + ) if r == errorVal() { panic(s.lastError()) } @@ -257,12 +256,15 @@ func (s *Session) HasReceivedMessage() bool { // decoded then then the error will be "BAD_MESSAGE_FORMAT". func (s *Session) MatchesInboundSession(oneTimeKeyMsg string) (bool, error) { if len(oneTimeKeyMsg) == 0 { - return false, olm.EmptyInput + return false, olm.ErrEmptyInput } + oneTimeKeyMsgCopy := []byte(oneTimeKeyMsg) r := C.olm_matches_inbound_session( (*C.OlmSession)(s.int), - unsafe.Pointer(&([]byte(oneTimeKeyMsg))[0]), - C.size_t(len(oneTimeKeyMsg))) + unsafe.Pointer(unsafe.SliceData(oneTimeKeyMsgCopy)), + C.size_t(len(oneTimeKeyMsgCopy)), + ) + runtime.KeepAlive(oneTimeKeyMsgCopy) if r == 1 { return true, nil } else if r == 0 { @@ -282,14 +284,19 @@ func (s *Session) MatchesInboundSession(oneTimeKeyMsg string) (bool, error) { // decoded then then the error will be "BAD_MESSAGE_FORMAT". func (s *Session) MatchesInboundSessionFrom(theirIdentityKey, oneTimeKeyMsg string) (bool, error) { if len(theirIdentityKey) == 0 || len(oneTimeKeyMsg) == 0 { - return false, olm.EmptyInput + return false, olm.ErrEmptyInput } + theirIdentityKeyCopy := []byte(theirIdentityKey) + oneTimeKeyMsgCopy := []byte(oneTimeKeyMsg) r := C.olm_matches_inbound_session_from( (*C.OlmSession)(s.int), - unsafe.Pointer(&([]byte(theirIdentityKey))[0]), - C.size_t(len(theirIdentityKey)), - unsafe.Pointer(&([]byte(oneTimeKeyMsg))[0]), - C.size_t(len(oneTimeKeyMsg))) + unsafe.Pointer(unsafe.SliceData(theirIdentityKeyCopy)), + C.size_t(len(theirIdentityKeyCopy)), + unsafe.Pointer(unsafe.SliceData(oneTimeKeyMsgCopy)), + C.size_t(len(oneTimeKeyMsgCopy)), + ) + runtime.KeepAlive(theirIdentityKeyCopy) + runtime.KeepAlive(oneTimeKeyMsgCopy) if r == 1 { return true, nil } else if r == 0 { @@ -318,25 +325,28 @@ func (s *Session) EncryptMsgType() id.OlmMsgType { // as base64. func (s *Session) Encrypt(plaintext []byte) (id.OlmMsgType, []byte, error) { if len(plaintext) == 0 { - return 0, nil, olm.EmptyInput + return 0, nil, olm.ErrEmptyInput } // Make the slice be at least length 1 random := make([]byte, s.encryptRandomLen()+1) _, err := rand.Read(random) if err != nil { // TODO can we just return err here? - return 0, nil, olm.NotEnoughGoRandom + return 0, nil, olm.ErrNotEnoughGoRandom } messageType := s.EncryptMsgType() message := make([]byte, s.encryptMsgLen(len(plaintext))) r := C.olm_encrypt( (*C.OlmSession)(s.int), - unsafe.Pointer(&plaintext[0]), + unsafe.Pointer(unsafe.SliceData(plaintext)), C.size_t(len(plaintext)), - unsafe.Pointer(&random[0]), + unsafe.Pointer(unsafe.SliceData(random)), C.size_t(len(random)), - unsafe.Pointer(&message[0]), - C.size_t(len(message))) + unsafe.Pointer(unsafe.SliceData(message)), + C.size_t(len(message)), + ) + runtime.KeepAlive(plaintext) + runtime.KeepAlive(random) if r == errorVal() { return 0, nil, s.lastError() } @@ -352,7 +362,7 @@ func (s *Session) Encrypt(plaintext []byte) (id.OlmMsgType, []byte, error) { // "BAD_MESSAGE_MAC". func (s *Session) Decrypt(message string, msgType id.OlmMsgType) ([]byte, error) { if len(message) == 0 { - return nil, olm.EmptyInput + return nil, olm.ErrEmptyInput } decryptMaxPlaintextLen, err := s.decryptMaxPlaintextLen(message, msgType) if err != nil { @@ -363,10 +373,12 @@ func (s *Session) Decrypt(message string, msgType id.OlmMsgType) ([]byte, error) r := C.olm_decrypt( (*C.OlmSession)(s.int), C.size_t(msgType), - unsafe.Pointer(&(messageCopy)[0]), + unsafe.Pointer(unsafe.SliceData(messageCopy)), C.size_t(len(messageCopy)), - unsafe.Pointer(&plaintext[0]), - C.size_t(len(plaintext))) + unsafe.Pointer(unsafe.SliceData(plaintext)), + C.size_t(len(plaintext)), + ) + runtime.KeepAlive(messageCopy) if r == errorVal() { return nil, s.lastError() } @@ -383,6 +395,7 @@ func (s *Session) Describe() string { C.meowlm_session_describe( (*C.OlmSession)(s.int), desc, - C.size_t(maxDescribeSize)) + C.size_t(maxDescribeSize), + ) return C.GoString(desc) } diff --git a/crypto/machine.go b/crypto/machine.go index 7c1093f3..fa051f94 100644 --- a/crypto/machine.go +++ b/crypto/machine.go @@ -15,10 +15,12 @@ import ( "time" "github.com/rs/zerolog" + "go.mau.fi/util/ptr" "go.mau.fi/util/exzerolog" "maunium.net/go/mautrix" + "maunium.net/go/mautrix/crypto/olm" "maunium.net/go/mautrix/crypto/ssss" "maunium.net/go/mautrix/event" "maunium.net/go/mautrix/id" @@ -33,7 +35,12 @@ type OlmMachine struct { CryptoStore Store StateStore StateStore - PlaintextMentions bool + backgroundCtx context.Context + cancelBackgroundCtx context.CancelFunc + + PlaintextMentions bool + MSC4392Relations bool + AllowEncryptedState bool // Never ask the server for keys automatically as a side effect during Megolm decryption. DisableDecryptKeyFetching bool @@ -41,6 +48,8 @@ type OlmMachine struct { // Don't mark outbound Olm sessions as shared for devices they were initially sent to. DisableSharedGroupSessionTracking bool + IgnorePostDecryptionParseErrors bool + SendKeysMinTrust id.TrustState ShareKeysMinTrust id.TrustState @@ -61,6 +70,9 @@ type OlmMachine struct { devicesToUnwedgeLock sync.Mutex recentlyUnwedged map[id.IdentityKey]time.Time recentlyUnwedgedLock sync.Mutex + olmHashSavePoints []time.Time + lastHashDelete time.Time + olmHashSavePointLock sync.Mutex olmLock sync.Mutex megolmEncryptLock sync.Mutex @@ -124,6 +136,7 @@ func NewOlmMachine(client *mautrix.Client, log *zerolog.Logger, cryptoStore Stor recentlyUnwedged: make(map[id.IdentityKey]time.Time), secretListeners: make(map[string]chan<- string), } + mach.backgroundCtx, mach.cancelBackgroundCtx = context.WithCancel(context.Background()) mach.AllowKeyShare = mach.defaultAllowKeyShare return mach } @@ -136,6 +149,11 @@ func (mach *OlmMachine) machOrContextLog(ctx context.Context) *zerolog.Logger { return log } +func (mach *OlmMachine) SetBackgroundCtx(ctx context.Context) { + mach.cancelBackgroundCtx() + mach.backgroundCtx, mach.cancelBackgroundCtx = context.WithCancel(ctx) +} + // Load loads the Olm account information from the crypto store. If there's no olm account, a new one is created. // This must be called before using the machine. func (mach *OlmMachine) Load(ctx context.Context) (err error) { @@ -146,9 +164,23 @@ func (mach *OlmMachine) Load(ctx context.Context) (err error) { if mach.account == nil { mach.account = NewOlmAccount() } + zerolog.Ctx(ctx).Debug(). + Str("machine_ptr", fmt.Sprintf("%p", mach)). + Str("account_ptr", fmt.Sprintf("%p", mach.account.Internal)). + Str("olm_driver", olm.Driver). + Msg("Loaded olm account") return nil } +func (mach *OlmMachine) Destroy() { + mach.Log.Debug(). + Str("machine_ptr", fmt.Sprintf("%p", mach)). + Str("account_ptr", fmt.Sprintf("%p", ptr.Val(mach.account).Internal)). + Msg("Destroying olm machine") + mach.cancelBackgroundCtx() + // TODO actually destroy something? +} + func (mach *OlmMachine) saveAccount(ctx context.Context) error { err := mach.CryptoStore.PutAccount(ctx, mach.account) if err != nil { @@ -174,7 +206,7 @@ func (mach *OlmMachine) FlushStore(ctx context.Context) error { func (mach *OlmMachine) timeTrace(ctx context.Context, thing string, expectedDuration time.Duration) func() { start := time.Now() return func() { - duration := time.Now().Sub(start) + duration := time.Since(start) if duration > expectedDuration { zerolog.Ctx(ctx).Warn(). Str("action", thing). @@ -308,6 +340,7 @@ func (mach *OlmMachine) ProcessSyncResponse(ctx context.Context, resp *mautrix.R } mach.HandleOTKCounts(ctx, &resp.DeviceOTKCount) + mach.MarkOlmHashSavePoint(ctx) return true } @@ -350,20 +383,20 @@ func (mach *OlmMachine) HandleMemberEvent(ctx context.Context, evt *event.Event) Msg("Got membership state change, invalidating group session in room") err := mach.CryptoStore.RemoveOutboundGroupSession(ctx, evt.RoomID) if err != nil { - mach.Log.Warn().Str("room_id", evt.RoomID.String()).Msg("Failed to invalidate outbound group session") + mach.Log.Warn().Stringer("room_id", evt.RoomID).Msg("Failed to invalidate outbound group session") } } -func (mach *OlmMachine) HandleEncryptedEvent(ctx context.Context, evt *event.Event) { +func (mach *OlmMachine) HandleEncryptedEvent(ctx context.Context, evt *event.Event) *DecryptedOlmEvent { if _, ok := evt.Content.Parsed.(*event.EncryptedEventContent); !ok { mach.machOrContextLog(ctx).Warn().Msg("Passed invalid event to encrypted handler") - return + return nil } decryptedEvt, err := mach.decryptOlmEvent(ctx, evt) if err != nil { mach.machOrContextLog(ctx).Error().Err(err).Msg("Failed to decrypt to-device event") - return + return nil } log := mach.machOrContextLog(ctx).With(). @@ -392,6 +425,37 @@ func (mach *OlmMachine) HandleEncryptedEvent(ctx context.Context, evt *event.Eve log.Trace().Msg("Handled secret send event") default: log.Debug().Msg("Unhandled encrypted to-device event") + return decryptedEvt + } + return nil +} + +const olmHashSavePointCount = 5 +const olmHashDeleteMinInterval = 10 * time.Minute +const minSavePointInterval = 1 * time.Minute + +// MarkOlmHashSavePoint marks the current time as a save point for olm hashes and deletes old hashes if needed. +// +// This should be called after all to-device events in a sync have been processed. +// The function will then delete old olm hashes after enough syncs have happened +// (such that it's unlikely for the olm messages to repeat). +func (mach *OlmMachine) MarkOlmHashSavePoint(ctx context.Context) { + mach.olmHashSavePointLock.Lock() + defer mach.olmHashSavePointLock.Unlock() + if len(mach.olmHashSavePoints) > 0 && time.Since(mach.olmHashSavePoints[len(mach.olmHashSavePoints)-1]) < minSavePointInterval { + return + } + mach.olmHashSavePoints = append(mach.olmHashSavePoints, time.Now()) + if len(mach.olmHashSavePoints) > olmHashSavePointCount { + sp := mach.olmHashSavePoints[0] + mach.olmHashSavePoints = mach.olmHashSavePoints[1:] + if time.Since(mach.lastHashDelete) > olmHashDeleteMinInterval { + err := mach.CryptoStore.DeleteOldOlmHashes(ctx, sp) + mach.lastHashDelete = time.Now() + if err != nil { + zerolog.Ctx(ctx).Err(err).Msg("Failed to delete old olm hashes") + } + } } } @@ -539,10 +603,10 @@ func (mach *OlmMachine) createGroupSession(ctx context.Context, senderKey id.Sen } err = mach.CryptoStore.PutGroupSession(ctx, igs) if err != nil { - log.Err(err).Str("session_id", sessionID.String()).Msg("Failed to store new inbound group session") + log.Err(err).Stringer("session_id", sessionID).Msg("Failed to store new inbound group session") return fmt.Errorf("failed to store new inbound group session: %w", err) } - mach.markSessionReceived(ctx, roomID, sessionID, igs.Internal.FirstKnownIndex()) + mach.MarkSessionReceived(ctx, roomID, sessionID, igs.Internal.FirstKnownIndex()) log.Debug(). Str("session_id", sessionID.String()). Str("sender_key", senderKey.String()). @@ -553,7 +617,7 @@ func (mach *OlmMachine) createGroupSession(ctx context.Context, senderKey id.Sen return nil } -func (mach *OlmMachine) markSessionReceived(ctx context.Context, roomID id.RoomID, id id.SessionID, firstKnownIndex uint32) { +func (mach *OlmMachine) MarkSessionReceived(ctx context.Context, roomID id.RoomID, id id.SessionID, firstKnownIndex uint32) { if mach.SessionReceived != nil { mach.SessionReceived(ctx, roomID, id, firstKnownIndex) } @@ -666,7 +730,7 @@ func (mach *OlmMachine) ShareKeys(ctx context.Context, currentOTKCount int) erro start := time.Now() mach.otkUploadLock.Lock() defer mach.otkUploadLock.Unlock() - if mach.lastOTKUpload.Add(1*time.Minute).After(start) || currentOTKCount < 0 { + if mach.lastOTKUpload.Add(1*time.Minute).After(start) || (currentOTKCount < 0 && mach.account.Shared) { log.Debug().Msg("Checking OTK count from server due to suspiciously close share keys requests or negative OTK count") resp, err := mach.Client.UploadKeys(ctx, &mautrix.ReqUploadKeys{}) if err != nil { diff --git a/crypto/machine_bench_test.go b/crypto/machine_bench_test.go new file mode 100644 index 00000000..fd40d795 --- /dev/null +++ b/crypto/machine_bench_test.go @@ -0,0 +1,67 @@ +// Copyright (c) 2025 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package crypto_test + +import ( + "context" + "fmt" + "math/rand/v2" + "testing" + + "github.com/rs/zerolog" + globallog "github.com/rs/zerolog/log" // zerolog-allow-global-log + "github.com/stretchr/testify/require" + + "maunium.net/go/mautrix/crypto/cryptohelper" + "maunium.net/go/mautrix/id" + "maunium.net/go/mautrix/mockserver" +) + +func randomDeviceCount(r *rand.Rand) int { + k := 1 + for k < 10 && r.IntN(3) > 0 { + k++ + } + return k +} + +func BenchmarkOlmMachine_ShareGroupSession(b *testing.B) { + globallog.Logger = zerolog.Nop() + server := mockserver.Create(b) + server.PopOTKs = false + server.MemoryStore = false + var i int + var shareTargets []id.UserID + r := rand.New(rand.NewPCG(293, 0)) + var totalDeviceCount int + for i = 1; i < 1000; i++ { + userID := id.UserID(fmt.Sprintf("@user%d:localhost", i)) + deviceCount := randomDeviceCount(r) + for j := 0; j < deviceCount; j++ { + client, _ := server.Login(b, nil, userID, id.DeviceID(fmt.Sprintf("u%d_d%d", i, j))) + mach := client.Crypto.(*cryptohelper.CryptoHelper).Machine() + keysCache, err := mach.GenerateCrossSigningKeys() + require.NoError(b, err) + err = mach.PublishCrossSigningKeys(context.TODO(), keysCache, nil) + require.NoError(b, err) + } + totalDeviceCount += deviceCount + shareTargets = append(shareTargets, userID) + } + for b.Loop() { + client, _ := server.Login(b, nil, id.UserID(fmt.Sprintf("@benchuser%d:localhost", i)), id.DeviceID(fmt.Sprintf("u%d_d1", i))) + mach := client.Crypto.(*cryptohelper.CryptoHelper).Machine() + keysCache, err := mach.GenerateCrossSigningKeys() + require.NoError(b, err) + err = mach.PublishCrossSigningKeys(context.TODO(), keysCache, nil) + require.NoError(b, err) + err = mach.ShareGroupSession(context.TODO(), "!room:localhost", shareTargets) + require.NoError(b, err) + i++ + } + fmt.Println(totalDeviceCount, "devices total") +} diff --git a/crypto/machine_test.go b/crypto/machine_test.go index 59c86236..872c3ac4 100644 --- a/crypto/machine_test.go +++ b/crypto/machine_test.go @@ -36,20 +36,15 @@ func (mockStateStore) FindSharedRooms(context.Context, id.UserID) ([]id.RoomID, func newMachine(t *testing.T, userID id.UserID) *OlmMachine { client, err := mautrix.NewClient("http://localhost", userID, "token") - if err != nil { - t.Fatalf("Error creating client: %v", err) - } + require.NoError(t, err, "Error creating client") client.DeviceID = "device1" gobStore := NewMemoryStore(nil) - if err != nil { - t.Fatalf("Error creating Gob store: %v", err) - } + require.NoError(t, err, "Error creating Gob store") machine := NewOlmMachine(client, nil, gobStore, mockStateStore{}) - if err := machine.Load(context.TODO()); err != nil { - t.Fatalf("Error creating account: %v", err) - } + err = machine.Load(context.TODO()) + require.NoError(t, err, "Error creating account") return machine } @@ -82,9 +77,7 @@ func TestOlmMachineOlmMegolmSessions(t *testing.T) { // create outbound olm session for sending machine using OTK olmSession, err := machineOut.account.Internal.NewOutboundSession(machineIn.account.IdentityKey(), otk.Key) - if err != nil { - t.Errorf("Failed to create outbound olm session: %v", err) - } + require.NoError(t, err, "Error creating outbound olm session") // store sender device identity in receiving machine store machineIn.CryptoStore.PutDevices(context.TODO(), "user1", map[id.DeviceID]*id.Device{ @@ -121,29 +114,21 @@ func TestOlmMachineOlmMegolmSessions(t *testing.T) { Type: event.ToDeviceEncrypted, Sender: "user1", }, senderKey, content.Type, content.Body) - if err != nil { - t.Errorf("Error decrypting olm content: %v", err) - } + require.NoError(t, err, "Error decrypting olm ciphertext") + // store room key in new inbound group session roomKeyEvt := decrypted.Content.AsRoomKey() igs, err := NewInboundGroupSession(senderKey, signingKey, "room1", roomKeyEvt.SessionKey, 0, 0, false) - if err != nil { - t.Errorf("Error creating inbound megolm session: %v", err) - } - if err = machineIn.CryptoStore.PutGroupSession(context.TODO(), igs); err != nil { - t.Errorf("Error storing inbound megolm session: %v", err) - } + require.NoError(t, err, "Error creating inbound group session") + err = machineIn.CryptoStore.PutGroupSession(context.TODO(), igs) + require.NoError(t, err, "Error storing inbound group session") } // encrypt event with megolm session in sending machine eventContent := map[string]string{"hello": "world"} encryptedEvtContent, err := machineOut.EncryptMegolmEvent(context.TODO(), "room1", event.EventMessage, eventContent) - if err != nil { - t.Errorf("Error encrypting megolm event: %v", err) - } - if megolmOutSession.MessageCount != 1 { - t.Errorf("Megolm outbound session message count is not 1 but %d", megolmOutSession.MessageCount) - } + require.NoError(t, err, "Error encrypting megolm event") + assert.Equal(t, 1, megolmOutSession.MessageCount) encryptedEvt := &event.Event{ Content: event.Content{Parsed: encryptedEvtContent}, @@ -155,22 +140,12 @@ func TestOlmMachineOlmMegolmSessions(t *testing.T) { // decrypt event on receiving machine and confirm decryptedEvt, err := machineIn.DecryptMegolmEvent(context.TODO(), encryptedEvt) - if err != nil { - t.Errorf("Error decrypting megolm event: %v", err) - } - if decryptedEvt.Type != event.EventMessage { - t.Errorf("Expected event type %v, got %v", event.EventMessage, decryptedEvt.Type) - } - if decryptedEvt.Content.Raw["hello"] != "world" { - t.Errorf("Expected event content %v, got %v", eventContent, decryptedEvt.Content.Raw) - } + require.NoError(t, err, "Error decrypting megolm event") + assert.Equal(t, event.EventMessage, decryptedEvt.Type) + assert.Equal(t, "world", decryptedEvt.Content.Raw["hello"]) machineOut.EncryptMegolmEvent(context.TODO(), "room1", event.EventMessage, eventContent) - if megolmOutSession.Expired() { - t.Error("Megolm outbound session expired before 3rd message") - } + assert.False(t, megolmOutSession.Expired(), "Megolm outbound session expired before 3rd message") machineOut.EncryptMegolmEvent(context.TODO(), "room1", event.EventMessage, eventContent) - if !megolmOutSession.Expired() { - t.Error("Megolm outbound session not expired after 3rd message") - } + assert.True(t, megolmOutSession.Expired(), "Megolm outbound session not expired after 3rd message") } diff --git a/crypto/olm/account.go b/crypto/olm/account.go index 68393e8a..2ec5dd70 100644 --- a/crypto/olm/account.go +++ b/crypto/olm/account.go @@ -87,6 +87,8 @@ type Account interface { RemoveOneTimeKeys(s Session) error } +var Driver = "none" + var InitBlankAccount func() Account var InitNewAccount func() (Account, error) var InitNewAccountFromPickled func(pickled, key []byte) (Account, error) diff --git a/crypto/olm/errors.go b/crypto/olm/errors.go index 957d7928..9e522b2a 100644 --- a/crypto/olm/errors.go +++ b/crypto/olm/errors.go @@ -10,50 +10,67 @@ import "errors" // Those are the most common used errors var ( - ErrBadSignature = errors.New("bad signature") - ErrBadMAC = errors.New("bad mac") - ErrBadMessageFormat = errors.New("bad message format") - ErrBadVerification = errors.New("bad verification") - ErrWrongProtocolVersion = errors.New("wrong protocol version") - ErrEmptyInput = errors.New("empty input") - ErrNoKeyProvided = errors.New("no key") - ErrBadMessageKeyID = errors.New("bad message key id") - ErrRatchetNotAvailable = errors.New("ratchet not available: attempt to decode a message whose index is earlier than our earliest known session key") - ErrMsgIndexTooHigh = errors.New("message index too high") - ErrProtocolViolation = errors.New("not protocol message order") - ErrMessageKeyNotFound = errors.New("message key not found") - ErrChainTooHigh = errors.New("chain index too high") - ErrBadInput = errors.New("bad input") - ErrBadVersion = errors.New("wrong version") - ErrWrongPickleVersion = errors.New("wrong pickle version") - ErrInputToSmall = errors.New("input too small (truncated?)") - ErrOverflow = errors.New("overflow") + ErrBadSignature = errors.New("bad signature") + ErrBadMAC = errors.New("the message couldn't be decrypted (bad mac)") + ErrBadMessageFormat = errors.New("the message couldn't be decoded") + ErrBadVerification = errors.New("bad verification") + ErrWrongProtocolVersion = errors.New("wrong protocol version") + ErrEmptyInput = errors.New("empty input") + ErrNoKeyProvided = errors.New("no key provided") + ErrBadMessageKeyID = errors.New("the message references an unknown key ID") + ErrUnknownMessageIndex = errors.New("attempt to decode a message whose index is earlier than our earliest known session key") + ErrMsgIndexTooHigh = errors.New("message index too high") + ErrProtocolViolation = errors.New("not protocol message order") + ErrMessageKeyNotFound = errors.New("message key not found") + ErrChainTooHigh = errors.New("chain index too high") + ErrBadInput = errors.New("bad input") + ErrUnknownOlmPickleVersion = errors.New("unknown olm pickle version") + ErrUnknownJSONPickleVersion = errors.New("unknown JSON pickle version") + ErrInputToSmall = errors.New("input too small (truncated?)") ) // Error codes from go-olm var ( - EmptyInput = errors.New("empty input") - NoKeyProvided = errors.New("no pickle key provided") - NotEnoughGoRandom = errors.New("couldn't get enough randomness from crypto/rand") - SignatureNotFound = errors.New("input JSON doesn't contain signature from specified device") - InputNotJSONString = errors.New("input doesn't look like a JSON string") + ErrNotEnoughGoRandom = errors.New("couldn't get enough randomness from crypto/rand") + ErrInputNotJSONString = errors.New("input doesn't look like a JSON string") ) // Error codes from olm code var ( - NotEnoughRandom = errors.New("not enough entropy was supplied") - OutputBufferTooSmall = errors.New("supplied output buffer is too small") - BadMessageVersion = errors.New("the message version is unsupported") - BadMessageFormat = errors.New("the message couldn't be decoded") - BadMessageMAC = errors.New("the message couldn't be decrypted") - BadMessageKeyID = errors.New("the message references an unknown key ID") - InvalidBase64 = errors.New("the input base64 was invalid") - BadAccountKey = errors.New("the supplied account key is invalid") - UnknownPickleVersion = errors.New("the pickled object is too new") - CorruptedPickle = errors.New("the pickled object couldn't be decoded") - BadSessionKey = errors.New("attempt to initialise an inbound group session from an invalid session key") - UnknownMessageIndex = errors.New("attempt to decode a message whose index is earlier than our earliest known session key") - BadLegacyAccountPickle = errors.New("attempt to unpickle an account which uses pickle version 1") - BadSignature = errors.New("received message had a bad signature") - InputBufferTooSmall = errors.New("the input data was too small to be valid") + ErrLibolmInvalidBase64 = errors.New("the input base64 was invalid") + + ErrLibolmNotEnoughRandom = errors.New("not enough entropy was supplied") + ErrLibolmOutputBufferTooSmall = errors.New("supplied output buffer is too small") + ErrLibolmBadAccountKey = errors.New("the supplied account key is invalid") + ErrLibolmCorruptedPickle = errors.New("the pickled object couldn't be decoded") + ErrLibolmBadSessionKey = errors.New("attempt to initialise an inbound group session from an invalid session key") + ErrLibolmBadLegacyAccountPickle = errors.New("attempt to unpickle an account which uses pickle version 1") +) + +// Deprecated: use variables prefixed with Err +var ( + EmptyInput = ErrEmptyInput + BadSignature = ErrBadSignature + InvalidBase64 = ErrLibolmInvalidBase64 + BadMessageKeyID = ErrBadMessageKeyID + BadMessageFormat = ErrBadMessageFormat + BadMessageVersion = ErrWrongProtocolVersion + BadMessageMAC = ErrBadMAC + UnknownPickleVersion = ErrUnknownOlmPickleVersion + NotEnoughRandom = ErrLibolmNotEnoughRandom + OutputBufferTooSmall = ErrLibolmOutputBufferTooSmall + BadAccountKey = ErrLibolmBadAccountKey + CorruptedPickle = ErrLibolmCorruptedPickle + BadSessionKey = ErrLibolmBadSessionKey + UnknownMessageIndex = ErrUnknownMessageIndex + BadLegacyAccountPickle = ErrLibolmBadLegacyAccountPickle + InputBufferTooSmall = ErrInputToSmall + NoKeyProvided = ErrNoKeyProvided + + NotEnoughGoRandom = ErrNotEnoughGoRandom + InputNotJSONString = ErrInputNotJSONString + + ErrBadVersion = ErrUnknownJSONPickleVersion + ErrWrongPickleVersion = ErrUnknownJSONPickleVersion + ErrRatchetNotAvailable = ErrUnknownMessageIndex ) diff --git a/crypto/registergoolm.go b/crypto/registergoolm.go index f5cecafc..6b5b65fd 100644 --- a/crypto/registergoolm.go +++ b/crypto/registergoolm.go @@ -2,4 +2,10 @@ package crypto -import _ "maunium.net/go/mautrix/crypto/goolm" +import ( + "maunium.net/go/mautrix/crypto/goolm" +) + +func init() { + goolm.Register() +} diff --git a/crypto/registerlibolm.go b/crypto/registerlibolm.go index ab388a5c..ef78b6b5 100644 --- a/crypto/registerlibolm.go +++ b/crypto/registerlibolm.go @@ -2,4 +2,8 @@ package crypto -import _ "maunium.net/go/mautrix/crypto/libolm" +import "maunium.net/go/mautrix/crypto/libolm" + +func init() { + libolm.Register() +} diff --git a/crypto/sessions.go b/crypto/sessions.go index c22b5b58..ccc7b784 100644 --- a/crypto/sessions.go +++ b/crypto/sessions.go @@ -8,6 +8,7 @@ package crypto import ( "errors" + "fmt" "time" "maunium.net/go/mautrix/crypto/olm" @@ -17,8 +18,14 @@ import ( ) var ( - SessionNotShared = errors.New("session has not been shared") - SessionExpired = errors.New("session has expired") + ErrSessionNotShared = errors.New("session has not been shared") + ErrSessionExpired = errors.New("session has expired") +) + +// Deprecated: use variables prefixed with Err +var ( + SessionNotShared = ErrSessionNotShared + SessionExpired = ErrSessionExpired ) // OlmSessionList is a list of OlmSessions. @@ -110,6 +117,7 @@ type InboundGroupSession struct { MaxMessages int IsScheduled bool KeyBackupVersion id.KeyBackupVersion + KeySource id.KeySource id id.SessionID } @@ -124,11 +132,12 @@ func NewInboundGroupSession(senderKey id.SenderKey, signingKey id.Ed25519, roomI SigningKey: signingKey, SenderKey: senderKey, RoomID: roomID, - ForwardingChains: nil, + ForwardingChains: []string{}, ReceivedAt: time.Now().UTC(), MaxAge: maxAge.Milliseconds(), MaxMessages: maxMessages, IsScheduled: isScheduled, + KeySource: id.KeySourceDirect, }, nil } @@ -152,6 +161,22 @@ func (igs *InboundGroupSession) RatchetTo(index uint32) error { return nil } +func (igs *InboundGroupSession) export() (*ExportedSession, error) { + key, err := igs.Internal.Export(igs.Internal.FirstKnownIndex()) + if err != nil { + return nil, fmt.Errorf("failed to export session: %w", err) + } + return &ExportedSession{ + Algorithm: id.AlgorithmMegolmV1, + ForwardingChains: igs.ForwardingChains, + RoomID: igs.RoomID, + SenderKey: igs.SenderKey, + SenderClaimedKeys: SenderClaimedKeys{Ed25519: igs.SigningKey}, + SessionID: igs.ID(), + SessionKey: string(key), + }, nil +} + type OGSState int const ( @@ -238,9 +263,9 @@ func (ogs *OutboundGroupSession) Expired() bool { func (ogs *OutboundGroupSession) Encrypt(plaintext []byte) ([]byte, error) { if !ogs.Shared { - return nil, SessionNotShared + return nil, ErrSessionNotShared } else if ogs.Expired() { - return nil, SessionExpired + return nil, ErrSessionExpired } ogs.MessageCount++ ogs.LastEncryptedTime = time.Now() diff --git a/crypto/sharing.go b/crypto/sharing.go index c0f3e209..10e37ccc 100644 --- a/crypto/sharing.go +++ b/crypto/sharing.go @@ -173,6 +173,19 @@ func (mach *OlmMachine) receiveSecret(ctx context.Context, evt *DecryptedOlmEven return } + // https://spec.matrix.org/v1.10/client-server-api/#msecretsend + // "The recipient must ensure... that the device is a verified device owned by the recipient" + if senderDevice, err := mach.GetOrFetchDevice(ctx, evt.Sender, evt.SenderDevice); err != nil { + log.Err(err).Msg("Failed to get or fetch sender device, rejecting secret") + return + } else if senderDevice == nil { + log.Warn().Msg("Unknown sender device, rejecting secret") + return + } else if !mach.IsDeviceTrusted(ctx, senderDevice) { + log.Warn().Msg("Sender device is not verified, rejecting secret") + return + } + mach.secretLock.Lock() secretChan := mach.secretListeners[content.RequestID] mach.secretLock.Unlock() diff --git a/crypto/sql_store.go b/crypto/sql_store.go index e68f0df5..138cc557 100644 --- a/crypto/sql_store.go +++ b/crypto/sql_store.go @@ -13,6 +13,7 @@ import ( "encoding/json" "errors" "fmt" + "slices" "strings" "sync" "time" @@ -21,7 +22,7 @@ import ( "go.mau.fi/util/dbutil" "maunium.net/go/mautrix" - "maunium.net/go/mautrix/crypto/goolm/cipher" + "maunium.net/go/mautrix/crypto/goolm/libolmpickle" "maunium.net/go/mautrix/crypto/olm" "maunium.net/go/mautrix/crypto/sql_store_upgrade" "maunium.net/go/mautrix/event" @@ -249,6 +250,17 @@ func (store *SQLCryptoStore) GetLatestSession(ctx context.Context, key id.Sender } } +// GetNewestSessionCreationTS gets the creation timestamp of the most recently created session with the given sender key. +// This will exclude sessions that have never been used to encrypt or decrypt a message. +func (store *SQLCryptoStore) GetNewestSessionCreationTS(ctx context.Context, key id.SenderKey) (createdAt time.Time, err error) { + err = store.DB.QueryRow(ctx, "SELECT created_at FROM crypto_olm_session WHERE sender_key=$1 AND account_id=$2 AND (last_encrypted <> created_at OR last_decrypted <> created_at) ORDER BY created_at DESC LIMIT 1", + key, store.AccountID).Scan(&createdAt) + if errors.Is(err, sql.ErrNoRows) { + err = nil + } + return +} + // AddSession persists an Olm session for a sender in the database. func (store *SQLCryptoStore) AddSession(ctx context.Context, key id.SenderKey, session *OlmSession) error { store.olmSessionCacheLock.Lock() @@ -279,6 +291,29 @@ func (store *SQLCryptoStore) DeleteSession(ctx context.Context, _ id.SenderKey, return err } +func (store *SQLCryptoStore) PutOlmHash(ctx context.Context, messageHash [32]byte, receivedAt time.Time) error { + _, err := store.DB.Exec(ctx, "INSERT INTO crypto_olm_message_hash (account_id, received_at, message_hash) VALUES ($1, $2, $3) ON CONFLICT (message_hash) DO NOTHING", store.AccountID, receivedAt.UnixMilli(), messageHash[:]) + return err +} + +func (store *SQLCryptoStore) GetOlmHash(ctx context.Context, messageHash [32]byte) (receivedAt time.Time, err error) { + var receivedAtInt int64 + err = store.DB.QueryRow(ctx, "SELECT received_at FROM crypto_olm_message_hash WHERE message_hash=$1", messageHash[:]).Scan(&receivedAtInt) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + err = nil + } + return + } + receivedAt = time.UnixMilli(receivedAtInt) + return +} + +func (store *SQLCryptoStore) DeleteOldOlmHashes(ctx context.Context, beforeTS time.Time) error { + _, err := store.DB.Exec(ctx, "DELETE FROM crypto_olm_message_hash WHERE account_id = $1 AND received_at < $2", store.AccountID, beforeTS.UnixMilli()) + return err +} + func datePtr(t time.Time) *time.Time { if t.IsZero() { return nil @@ -292,6 +327,9 @@ func (store *SQLCryptoStore) PutGroupSession(ctx context.Context, session *Inbou if err != nil { return err } + if session.ForwardingChains == nil { + session.ForwardingChains = []string{} + } forwardingChains := strings.Join(session.ForwardingChains, ",") ratchetSafety, err := json.Marshal(&session.RatchetSafety) if err != nil { @@ -308,22 +346,23 @@ func (store *SQLCryptoStore) PutGroupSession(ctx context.Context, session *Inbou Int("max_messages", session.MaxMessages). Bool("is_scheduled", session.IsScheduled). Stringer("key_backup_version", session.KeyBackupVersion). + Stringer("key_source", session.KeySource). Msg("Upserting megolm inbound group session") _, err = store.DB.Exec(ctx, ` INSERT INTO crypto_megolm_inbound_session ( session_id, sender_key, signing_key, room_id, session, forwarding_chains, - ratchet_safety, received_at, max_age, max_messages, is_scheduled, key_backup_version, account_id - ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13) + ratchet_safety, received_at, max_age, max_messages, is_scheduled, key_backup_version, key_source, account_id + ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14) ON CONFLICT (session_id, account_id) DO UPDATE SET withheld_code=NULL, withheld_reason=NULL, sender_key=excluded.sender_key, signing_key=excluded.signing_key, room_id=excluded.room_id, session=excluded.session, forwarding_chains=excluded.forwarding_chains, ratchet_safety=excluded.ratchet_safety, received_at=excluded.received_at, max_age=excluded.max_age, max_messages=excluded.max_messages, is_scheduled=excluded.is_scheduled, - key_backup_version=excluded.key_backup_version + key_backup_version=excluded.key_backup_version, key_source=excluded.key_source `, session.ID(), session.SenderKey, session.SigningKey, session.RoomID, sessionBytes, forwardingChains, ratchetSafety, datePtr(session.ReceivedAt), dbutil.NumPtr(session.MaxAge), dbutil.NumPtr(session.MaxMessages), - session.IsScheduled, session.KeyBackupVersion, store.AccountID, + session.IsScheduled, session.KeyBackupVersion, session.KeySource, store.AccountID, ) return err } @@ -336,12 +375,13 @@ func (store *SQLCryptoStore) GetGroupSession(ctx context.Context, roomID id.Room var maxAge, maxMessages sql.NullInt64 var isScheduled bool var version id.KeyBackupVersion + var keySource id.KeySource err := store.DB.QueryRow(ctx, ` - SELECT sender_key, signing_key, session, forwarding_chains, withheld_code, withheld_reason, ratchet_safety, received_at, max_age, max_messages, is_scheduled, key_backup_version + SELECT sender_key, signing_key, session, forwarding_chains, withheld_code, withheld_reason, ratchet_safety, received_at, max_age, max_messages, is_scheduled, key_backup_version, key_source FROM crypto_megolm_inbound_session WHERE room_id=$1 AND session_id=$2 AND account_id=$3`, roomID, sessionID, store.AccountID, - ).Scan(&senderKey, &signingKey, &sessionBytes, &forwardingChains, &withheldCode, &withheldReason, &ratchetSafetyBytes, &receivedAt, &maxAge, &maxMessages, &isScheduled, &version) + ).Scan(&senderKey, &signingKey, &sessionBytes, &forwardingChains, &withheldCode, &withheldReason, &ratchetSafetyBytes, &receivedAt, &maxAge, &maxMessages, &isScheduled, &version, &keySource) if errors.Is(err, sql.ErrNoRows) { return nil, nil } else if err != nil { @@ -372,6 +412,7 @@ func (store *SQLCryptoStore) GetGroupSession(ctx context.Context, roomID id.Room MaxMessages: int(maxMessages.Int64), IsScheduled: isScheduled, KeyBackupVersion: version, + KeySource: keySource, }, nil } @@ -395,10 +436,7 @@ func (store *SQLCryptoStore) RedactGroupSessions(ctx context.Context, roomID id. AND session IS NOT NULL AND is_scheduled=false AND received_at IS NOT NULL RETURNING session_id `, event.RoomKeyWithheldBeeperRedacted, "Session redacted: "+reason, roomID, senderKey, store.AccountID) - if err != nil { - return nil, err - } - return dbutil.NewRowIter(res, dbutil.ScanSingleColumn[id.SessionID]).AsList() + return dbutil.NewRowIterWithError(res, dbutil.ScanSingleColumn[id.SessionID], err).AsList() } func (store *SQLCryptoStore) RedactExpiredGroupSessions(ctx context.Context) ([]id.SessionID, error) { @@ -426,10 +464,7 @@ func (store *SQLCryptoStore) RedactExpiredGroupSessions(ctx context.Context) ([] return nil, fmt.Errorf("unsupported dialect") } res, err := store.DB.Query(ctx, query, event.RoomKeyWithheldBeeperRedacted, "Session redacted: expired", store.AccountID) - if err != nil { - return nil, err - } - return dbutil.NewRowIter(res, dbutil.ScanSingleColumn[id.SessionID]).AsList() + return dbutil.NewRowIterWithError(res, dbutil.ScanSingleColumn[id.SessionID], err).AsList() } func (store *SQLCryptoStore) RedactOutdatedGroupSessions(ctx context.Context) ([]id.SessionID, error) { @@ -439,10 +474,7 @@ func (store *SQLCryptoStore) RedactOutdatedGroupSessions(ctx context.Context) ([ WHERE account_id=$3 AND session IS NOT NULL AND received_at IS NULL RETURNING session_id `, event.RoomKeyWithheldBeeperRedacted, "Session redacted: outdated", store.AccountID) - if err != nil { - return nil, err - } - return dbutil.NewRowIter(res, dbutil.ScanSingleColumn[id.SessionID]).AsList() + return dbutil.NewRowIterWithError(res, dbutil.ScanSingleColumn[id.SessionID], err).AsList() } func (store *SQLCryptoStore) PutWithheldGroupSession(ctx context.Context, content event.RoomKeyWithheldEventContent) error { @@ -484,6 +516,8 @@ func (store *SQLCryptoStore) postScanInboundGroupSession(sessionBytes, ratchetSa } if forwardingChains != "" { chains = strings.Split(forwardingChains, ",") + } else { + chains = []string{} } var rs RatchetSafety if len(ratchetSafetyBytes) > 0 { @@ -503,7 +537,8 @@ func (store *SQLCryptoStore) scanInboundGroupSession(rows dbutil.Scannable) (*In var maxAge, maxMessages sql.NullInt64 var isScheduled bool var version id.KeyBackupVersion - err := rows.Scan(&roomID, &senderKey, &signingKey, &sessionBytes, &forwardingChains, &ratchetSafetyBytes, &receivedAt, &maxAge, &maxMessages, &isScheduled, &version) + var keySource id.KeySource + err := rows.Scan(&roomID, &senderKey, &signingKey, &sessionBytes, &forwardingChains, &ratchetSafetyBytes, &receivedAt, &maxAge, &maxMessages, &isScheduled, &version, &keySource) if err != nil { return nil, err } @@ -523,12 +558,13 @@ func (store *SQLCryptoStore) scanInboundGroupSession(rows dbutil.Scannable) (*In MaxMessages: int(maxMessages.Int64), IsScheduled: isScheduled, KeyBackupVersion: version, + KeySource: keySource, }, nil } func (store *SQLCryptoStore) GetGroupSessionsForRoom(ctx context.Context, roomID id.RoomID) dbutil.RowIter[*InboundGroupSession] { rows, err := store.DB.Query(ctx, ` - SELECT room_id, sender_key, signing_key, session, forwarding_chains, ratchet_safety, received_at, max_age, max_messages, is_scheduled, key_backup_version + SELECT room_id, sender_key, signing_key, session, forwarding_chains, ratchet_safety, received_at, max_age, max_messages, is_scheduled, key_backup_version, key_source FROM crypto_megolm_inbound_session WHERE room_id=$1 AND account_id=$2 AND session IS NOT NULL`, roomID, store.AccountID, ) @@ -537,7 +573,7 @@ func (store *SQLCryptoStore) GetGroupSessionsForRoom(ctx context.Context, roomID func (store *SQLCryptoStore) GetAllGroupSessions(ctx context.Context) dbutil.RowIter[*InboundGroupSession] { rows, err := store.DB.Query(ctx, ` - SELECT room_id, sender_key, signing_key, session, forwarding_chains, ratchet_safety, received_at, max_age, max_messages, is_scheduled, key_backup_version + SELECT room_id, sender_key, signing_key, session, forwarding_chains, ratchet_safety, received_at, max_age, max_messages, is_scheduled, key_backup_version, key_source FROM crypto_megolm_inbound_session WHERE account_id=$1 AND session IS NOT NULL`, store.AccountID, ) @@ -546,7 +582,7 @@ func (store *SQLCryptoStore) GetAllGroupSessions(ctx context.Context) dbutil.Row func (store *SQLCryptoStore) GetGroupSessionsWithoutKeyBackupVersion(ctx context.Context, version id.KeyBackupVersion) dbutil.RowIter[*InboundGroupSession] { rows, err := store.DB.Query(ctx, ` - SELECT room_id, sender_key, signing_key, session, forwarding_chains, ratchet_safety, received_at, max_age, max_messages, is_scheduled, key_backup_version + SELECT room_id, sender_key, signing_key, session, forwarding_chains, ratchet_safety, received_at, max_age, max_messages, is_scheduled, key_backup_version, key_source FROM crypto_megolm_inbound_session WHERE account_id=$1 AND session IS NOT NULL AND key_backup_version != $2`, store.AccountID, version, ) @@ -633,6 +669,20 @@ func (store *SQLCryptoStore) IsOutboundGroupSessionShared(ctx context.Context, u // ValidateMessageIndex returns whether the given event information match the ones stored in the database // for the given sender key, session ID and index. If the index hasn't been stored, this will store it. func (store *SQLCryptoStore) ValidateMessageIndex(ctx context.Context, senderKey id.SenderKey, sessionID id.SessionID, eventID id.EventID, index uint, timestamp int64) (bool, error) { + if eventID == "" && timestamp == 0 { + var notOK bool + const validateEmptyQuery = ` + SELECT EXISTS(SELECT 1 FROM crypto_message_index WHERE sender_key=$1 AND session_id=$2 AND "index"=$3) + ` + err := store.DB.QueryRow(ctx, validateEmptyQuery, senderKey, sessionID, index).Scan(¬OK) + if notOK { + zerolog.Ctx(ctx).Debug(). + Uint("message_index", index). + Msg("Rejecting event without event ID and timestamp due to already knowing them") + } + return !notOK, err + } + const validateQuery = ` INSERT INTO crypto_message_index (sender_key, session_id, "index", event_id, timestamp) VALUES ($1, $2, $3, $4, $5) @@ -679,11 +729,8 @@ func (store *SQLCryptoStore) GetDevices(ctx context.Context, userID id.UserID) ( } rows, err := store.DB.Query(ctx, "SELECT user_id, device_id, identity_key, signing_key, trust, deleted, name FROM crypto_device WHERE user_id=$1 AND deleted=false", userID) - if err != nil { - return nil, err - } data := make(map[id.DeviceID]*id.Device) - err = dbutil.NewRowIter(rows, scanDevice).Iter(func(device *id.Device) (bool, error) { + err = dbutil.NewRowIterWithError(rows, scanDevice, err).Iter(func(device *id.Device) (bool, error) { data[device.DeviceID] = device return true, nil }) @@ -803,19 +850,18 @@ func (store *SQLCryptoStore) FilterTrackedUsers(ctx context.Context, users []id. placeholders, params := userIDsToParams(users) rows, err = store.DB.Query(ctx, "SELECT user_id FROM crypto_tracked_user WHERE user_id IN ("+placeholders+")", params...) } - if err != nil { - return users, err - } - return dbutil.NewRowIter(rows, dbutil.ScanSingleColumn[id.UserID]).AsList() + return dbutil.NewRowIterWithError(rows, dbutil.ScanSingleColumn[id.UserID], err).AsList() } // MarkTrackedUsersOutdated flags that the device list for given users are outdated. func (store *SQLCryptoStore) MarkTrackedUsersOutdated(ctx context.Context, users []id.UserID) (err error) { - if store.DB.Dialect == dbutil.Postgres && PostgresArrayWrapper != nil { - _, err = store.DB.Exec(ctx, "UPDATE crypto_tracked_user SET devices_outdated = true WHERE user_id = ANY($1)", PostgresArrayWrapper(users)) - } else { - placeholders, params := userIDsToParams(users) - _, err = store.DB.Exec(ctx, "UPDATE crypto_tracked_user SET devices_outdated = true WHERE user_id IN ("+placeholders+")", params...) + for chunk := range slices.Chunk(users, 1000) { + if store.DB.Dialect == dbutil.Postgres && PostgresArrayWrapper != nil { + _, err = store.DB.Exec(ctx, "UPDATE crypto_tracked_user SET devices_outdated = true WHERE user_id = ANY($1)", PostgresArrayWrapper(chunk)) + } else { + placeholders, params := userIDsToParams(chunk) + _, err = store.DB.Exec(ctx, "UPDATE crypto_tracked_user SET devices_outdated = true WHERE user_id IN ("+placeholders+")", params...) + } } return } @@ -823,10 +869,7 @@ func (store *SQLCryptoStore) MarkTrackedUsersOutdated(ctx context.Context, users // GetOutdatedTrackerUsers gets all tracked users whose devices need to be updated. func (store *SQLCryptoStore) GetOutdatedTrackedUsers(ctx context.Context) ([]id.UserID, error) { rows, err := store.DB.Query(ctx, "SELECT user_id FROM crypto_tracked_user WHERE devices_outdated = TRUE") - if err != nil { - return nil, err - } - return dbutil.NewRowIter(rows, dbutil.ScanSingleColumn[id.UserID]).AsList() + return dbutil.NewRowIterWithError(rows, dbutil.ScanSingleColumn[id.UserID], err).AsList() } // PutCrossSigningKey stores a cross-signing key of some user along with its usage. @@ -911,7 +954,7 @@ func (store *SQLCryptoStore) DropSignaturesByKey(ctx context.Context, userID id. } func (store *SQLCryptoStore) PutSecret(ctx context.Context, name id.Secret, value string) error { - bytes, err := cipher.Pickle(store.PickleKey, []byte(value)) + bytes, err := libolmpickle.Pickle(store.PickleKey, []byte(value)) if err != nil { return err } @@ -930,7 +973,7 @@ func (store *SQLCryptoStore) GetSecret(ctx context.Context, name id.Secret) (val } else if err != nil { return "", err } - bytes, err = cipher.Unpickle(store.PickleKey, bytes) + bytes, err = libolmpickle.Unpickle(store.PickleKey, bytes) return string(bytes), err } diff --git a/crypto/sql_store_upgrade/00-latest-revision.sql b/crypto/sql_store_upgrade/00-latest-revision.sql index 7cd3331c..3709f1e5 100644 --- a/crypto/sql_store_upgrade/00-latest-revision.sql +++ b/crypto/sql_store_upgrade/00-latest-revision.sql @@ -1,4 +1,4 @@ --- v0 -> v16 (compatible with v15+): Latest revision +-- v0 -> v19 (compatible with v15+): Latest revision CREATE TABLE IF NOT EXISTS crypto_account ( account_id TEXT PRIMARY KEY, device_id TEXT NOT NULL, @@ -45,6 +45,16 @@ CREATE TABLE IF NOT EXISTS crypto_olm_session ( ); CREATE INDEX crypto_olm_session_sender_key_idx ON crypto_olm_session (account_id, sender_key); +CREATE TABLE crypto_olm_message_hash ( + account_id TEXT NOT NULL, + received_at BIGINT NOT NULL, + message_hash bytea NOT NULL PRIMARY KEY, + + CONSTRAINT crypto_olm_message_hash_account_fkey FOREIGN KEY (account_id) + REFERENCES crypto_account (account_id) ON DELETE CASCADE ON UPDATE CASCADE +); +CREATE INDEX crypto_olm_message_hash_account_idx ON crypto_olm_message_hash (account_id); + CREATE TABLE IF NOT EXISTS crypto_megolm_inbound_session ( account_id TEXT, session_id CHAR(43), @@ -61,8 +71,11 @@ CREATE TABLE IF NOT EXISTS crypto_megolm_inbound_session ( max_messages INTEGER, is_scheduled BOOLEAN NOT NULL DEFAULT false, key_backup_version TEXT NOT NULL DEFAULT '', + key_source TEXT NOT NULL DEFAULT '', PRIMARY KEY (account_id, session_id) ); +-- Useful index to find keys that need backing up +CREATE INDEX crypto_megolm_inbound_session_backup_idx ON crypto_megolm_inbound_session(account_id, key_backup_version) WHERE session IS NOT NULL; CREATE TABLE IF NOT EXISTS crypto_megolm_outbound_session ( account_id TEXT, diff --git a/crypto/sql_store_upgrade/17-decrypted-olm-messages.sql b/crypto/sql_store_upgrade/17-decrypted-olm-messages.sql new file mode 100644 index 00000000..525bbb52 --- /dev/null +++ b/crypto/sql_store_upgrade/17-decrypted-olm-messages.sql @@ -0,0 +1,11 @@ +-- v17 (compatible with v15+): Add table for decrypted Olm message hashes +CREATE TABLE crypto_olm_message_hash ( + account_id TEXT NOT NULL, + received_at BIGINT NOT NULL, + message_hash bytea NOT NULL PRIMARY KEY, + + CONSTRAINT crypto_olm_message_hash_account_fkey FOREIGN KEY (account_id) + REFERENCES crypto_account (account_id) ON DELETE CASCADE ON UPDATE CASCADE +); + +CREATE INDEX crypto_olm_message_hash_account_idx ON crypto_olm_message_hash (account_id); diff --git a/crypto/sql_store_upgrade/18-megolm-inbound-session-backup-index.sql b/crypto/sql_store_upgrade/18-megolm-inbound-session-backup-index.sql new file mode 100644 index 00000000..da26da0f --- /dev/null +++ b/crypto/sql_store_upgrade/18-megolm-inbound-session-backup-index.sql @@ -0,0 +1,2 @@ +-- v18 (compatible with v15+): Add an index to the megolm_inbound_session table to make finding sessions to backup faster +CREATE INDEX crypto_megolm_inbound_session_backup_idx ON crypto_megolm_inbound_session(account_id, key_backup_version) WHERE session IS NOT NULL; diff --git a/crypto/sql_store_upgrade/19-megolm-session-source.sql b/crypto/sql_store_upgrade/19-megolm-session-source.sql new file mode 100644 index 00000000..f624222f --- /dev/null +++ b/crypto/sql_store_upgrade/19-megolm-session-source.sql @@ -0,0 +1,2 @@ +-- v19 (compatible with v15+): Store megolm session source +ALTER TABLE crypto_megolm_inbound_session ADD COLUMN key_source TEXT NOT NULL DEFAULT ''; diff --git a/crypto/ssss/client.go b/crypto/ssss/client.go index e30925d9..8691d032 100644 --- a/crypto/ssss/client.go +++ b/crypto/ssss/client.go @@ -95,6 +95,22 @@ func (mach *Machine) SetEncryptedAccountData(ctx context.Context, eventType even return mach.Client.SetAccountData(ctx, eventType.Type, &EncryptedAccountDataEventContent{Encrypted: encrypted}) } +// SetEncryptedAccountDataWithMetadata encrypts the given data with the given keys and stores it, +// alongside the unencrypted metadata, on the server. +func (mach *Machine) SetEncryptedAccountDataWithMetadata(ctx context.Context, eventType event.Type, data []byte, metadata map[string]any, keys ...*Key) error { + if len(keys) == 0 { + return ErrNoKeyGiven + } + encrypted := make(map[string]EncryptedKeyData, len(keys)) + for _, key := range keys { + encrypted[key.ID] = key.Encrypt(eventType.Type, data) + } + return mach.Client.SetAccountData(ctx, eventType.Type, &EncryptedAccountDataEventContent{ + Encrypted: encrypted, + Metadata: metadata, + }) +} + // GenerateAndUploadKey generates a new SSSS key and stores the metadata on the server. func (mach *Machine) GenerateAndUploadKey(ctx context.Context, passphrase string) (key *Key, err error) { key, err = NewKey(passphrase) diff --git a/crypto/ssss/key.go b/crypto/ssss/key.go index c973c1fe..78ebd8f3 100644 --- a/crypto/ssss/key.go +++ b/crypto/ssss/key.go @@ -7,6 +7,8 @@ package ssss import ( + "crypto/hmac" + "crypto/sha256" "encoding/base64" "fmt" "strings" @@ -57,7 +59,12 @@ func NewKey(passphrase string) (*Key, error) { // We store a certain hash in the key metadata so that clients can check if the user entered the correct key. ivBytes := random.Bytes(utils.AESCTRIVLength) keyData.IV = base64.RawStdEncoding.EncodeToString(ivBytes) - keyData.MAC = keyData.calculateHash(ssssKey) + macBytes, err := keyData.calculateHash(ssssKey) + if err != nil { + // This should never happen because we just generated the IV and key. + return nil, fmt.Errorf("failed to calculate hash: %w", err) + } + keyData.MAC = base64.RawStdEncoding.EncodeToString(macBytes) return &Key{ Key: ssssKey, @@ -103,12 +110,18 @@ func (key *Key) Decrypt(eventType string, data EncryptedKeyData) ([]byte, error) return nil, err } + mac, err := base64.RawStdEncoding.DecodeString(strings.TrimRight(data.MAC, "=")) + if err != nil { + return nil, err + } + // derive the AES and HMAC keys for the requested event type using the SSSS key aesKey, hmacKey := utils.DeriveKeysSHA256(key.Key, eventType) // compare the stored MAC with the one we calculated from the ciphertext - calcMac := utils.HMACSHA256B64(payload, hmacKey) - if strings.TrimRight(data.MAC, "=") != calcMac { + h := hmac.New(sha256.New, hmacKey[:]) + h.Write(payload) + if !hmac.Equal(h.Sum(nil), mac) { return nil, ErrKeyDataMACMismatch } diff --git a/crypto/ssss/meta.go b/crypto/ssss/meta.go index 210bcdcf..34775fa7 100644 --- a/crypto/ssss/meta.go +++ b/crypto/ssss/meta.go @@ -7,7 +7,10 @@ package ssss import ( + "crypto/hmac" + "crypto/sha256" "encoding/base64" + "errors" "fmt" "strings" @@ -33,8 +36,10 @@ func (kd *KeyMetadata) VerifyPassphrase(keyID, passphrase string) (*Key, error) ssssKey, err := kd.Passphrase.GetKey(passphrase) if err != nil { return nil, err - } else if !kd.VerifyKey(ssssKey) { - return nil, ErrIncorrectSSSSKey + } + err = kd.verifyKey(ssssKey) + if err != nil && !errors.Is(err, ErrUnverifiableKey) { + return nil, err } return &Key{ @@ -49,33 +54,70 @@ func (kd *KeyMetadata) VerifyRecoveryKey(keyID, recoveryKey string) (*Key, error ssssKey := utils.DecodeBase58RecoveryKey(recoveryKey) if ssssKey == nil { return nil, ErrInvalidRecoveryKey - } else if !kd.VerifyKey(ssssKey) { - return nil, ErrIncorrectSSSSKey + } + err := kd.verifyKey(ssssKey) + if err != nil && !errors.Is(err, ErrUnverifiableKey) { + return nil, err } return &Key{ ID: keyID, Key: ssssKey, Metadata: kd, - }, nil + }, err +} + +func (kd *KeyMetadata) verifyKey(key []byte) error { + if kd.MAC == "" || kd.IV == "" { + return ErrUnverifiableKey + } + unpaddedMAC := strings.TrimRight(kd.MAC, "=") + expectedMACLength := base64.RawStdEncoding.EncodedLen(utils.SHAHashLength) + if len(unpaddedMAC) != expectedMACLength { + return fmt.Errorf("%w: invalid mac length %d (expected %d)", ErrCorruptedKeyMetadata, len(unpaddedMAC), expectedMACLength) + } + expectedMAC, err := base64.RawStdEncoding.DecodeString(unpaddedMAC) + if err != nil { + return fmt.Errorf("%w: failed to decode mac: %w", ErrCorruptedKeyMetadata, err) + } + calculatedMAC, err := kd.calculateHash(key) + if err != nil { + return err + } + // This doesn't really need to be constant time since it's fully local, but might as well be. + if !hmac.Equal(expectedMAC, calculatedMAC) { + return ErrIncorrectSSSSKey + } + return nil } // VerifyKey verifies the SSSS key is valid by calculating and comparing its MAC. func (kd *KeyMetadata) VerifyKey(key []byte) bool { - return strings.TrimRight(kd.MAC, "=") == kd.calculateHash(key) + return kd.verifyKey(key) == nil } // calculateHash calculates the hash used for checking if the key is entered correctly as described // in the spec: https://matrix.org/docs/spec/client_server/unstable#m-secret-storage-v1-aes-hmac-sha2 -func (kd *KeyMetadata) calculateHash(key []byte) string { +func (kd *KeyMetadata) calculateHash(key []byte) ([]byte, error) { aesKey, hmacKey := utils.DeriveKeysSHA256(key, "") + unpaddedIV := strings.TrimRight(kd.IV, "=") + expectedIVLength := base64.RawStdEncoding.EncodedLen(utils.AESCTRIVLength) + if len(unpaddedIV) < expectedIVLength || len(unpaddedIV) > expectedIVLength*3 { + return nil, fmt.Errorf("%w: invalid iv length %d (expected %d)", ErrCorruptedKeyMetadata, len(unpaddedIV), expectedIVLength) + } + rawIVBytes, err := base64.RawStdEncoding.DecodeString(unpaddedIV) + if err != nil { + return nil, fmt.Errorf("%w: failed to decode iv: %w", ErrCorruptedKeyMetadata, err) + } + // TODO log a warning for non-16 byte IVs? + // Certain broken clients like nheko generated 32-byte IVs where only the first 16 bytes were used. + ivBytes := *(*[utils.AESCTRIVLength]byte)(rawIVBytes[:utils.AESCTRIVLength]) - var ivBytes [utils.AESCTRIVLength]byte - _, _ = base64.RawStdEncoding.Decode(ivBytes[:], []byte(strings.TrimRight(kd.IV, "="))) - - cipher := utils.XorA256CTR(make([]byte, utils.AESCTRKeyLength), aesKey, ivBytes) - - return utils.HMACSHA256B64(cipher, hmacKey) + zeroes := make([]byte, utils.AESCTRKeyLength) + encryptedZeroes := utils.XorA256CTR(zeroes, aesKey, ivBytes) + h := hmac.New(sha256.New, hmacKey[:]) + h.Write(encryptedZeroes) + return h.Sum(nil), nil } // PassphraseMetadata represents server-side metadata about a SSSS key passphrase. diff --git a/crypto/ssss/meta_test.go b/crypto/ssss/meta_test.go index 96c97282..d59809c7 100644 --- a/crypto/ssss/meta_test.go +++ b/crypto/ssss/meta_test.go @@ -8,10 +8,10 @@ package ssss_test import ( "encoding/json" - "errors" "testing" "github.com/stretchr/testify/assert" + "go.mau.fi/util/exerrors" "maunium.net/go/mautrix/crypto/ssss" ) @@ -41,12 +41,42 @@ const key2Meta = ` } ` +const key2MetaUnverified = ` +{ + "algorithm": "m.secret_storage.v1.aes-hmac-sha2" +} +` + +const key2MetaLongIV = ` +{ + "algorithm": "m.secret_storage.v1.aes-hmac-sha2", + "iv": "O0BOvTqiIAYjC+RMcyHfW2f/gdxjceTxoYtNlpPduJ8=", + "mac": "7k6OruQlWg0UmQjxGZ0ad4Q6DdwkgnoI7G6X3IjBYtI=" +} +` + +const key2MetaBrokenIV = ` +{ + "algorithm": "m.secret_storage.v1.aes-hmac-sha2", + "iv": "MeowMeowMeow", + "mac": "7k6OruQlWg0UmQjxGZ0ad4Q6DdwkgnoI7G6X3IjBYtI=" +} +` + +const key2MetaBrokenMAC = ` +{ + "algorithm": "m.secret_storage.v1.aes-hmac-sha2", + "iv": "O0BOvTqiIAYjC+RMcyHfWw==", + "mac": "7k6OruQlWg0UmQjxGZ0ad4Q6DdwkgnoI7G6X3IjBYtIMeowMeowMeow" +} +` + const key2ID = "NVe5vK6lZS9gEMQLJw0yqkzmE5Mr7dLv" const key2RecoveryKey = "EsUC xSxt XJgQ dz19 8WBZ rHdE GZo7 ybsn EFmG Y5HY MDAG GNWe" -func getKey1Meta() *ssss.KeyMetadata { +func getKeyMeta(meta string) *ssss.KeyMetadata { var km ssss.KeyMetadata - err := json.Unmarshal([]byte(key1Meta), &km) + err := json.Unmarshal([]byte(meta), &km) if err != nil { panic(err) } @@ -54,36 +84,15 @@ func getKey1Meta() *ssss.KeyMetadata { } func getKey1() *ssss.Key { - km := getKey1Meta() - key, err := km.VerifyRecoveryKey(key1ID, key1RecoveryKey) - if err != nil { - panic(err) - } - key.ID = key1ID - return key -} - -func getKey2Meta() *ssss.KeyMetadata { - var km ssss.KeyMetadata - err := json.Unmarshal([]byte(key2Meta), &km) - if err != nil { - panic(err) - } - return &km + return exerrors.Must(getKeyMeta(key1Meta).VerifyRecoveryKey(key1ID, key1RecoveryKey)) } func getKey2() *ssss.Key { - km := getKey2Meta() - key, err := km.VerifyRecoveryKey(key2ID, key2RecoveryKey) - if err != nil { - panic(err) - } - key.ID = key2ID - return key + return exerrors.Must(getKeyMeta(key2Meta).VerifyRecoveryKey(key2ID, key2RecoveryKey)) } func TestKeyMetadata_VerifyRecoveryKey_Correct(t *testing.T) { - km := getKey1Meta() + km := getKeyMeta(key1Meta) key, err := km.VerifyRecoveryKey(key1ID, key1RecoveryKey) assert.NoError(t, err) assert.NotNil(t, key) @@ -91,29 +100,45 @@ func TestKeyMetadata_VerifyRecoveryKey_Correct(t *testing.T) { } func TestKeyMetadata_VerifyRecoveryKey_Correct2(t *testing.T) { - km := getKey2Meta() + km := getKeyMeta(key2Meta) key, err := km.VerifyRecoveryKey(key2ID, key2RecoveryKey) assert.NoError(t, err) assert.NotNil(t, key) assert.Equal(t, key2RecoveryKey, key.RecoveryKey()) } +func TestKeyMetadata_VerifyRecoveryKey_NonCompliant_LongIV(t *testing.T) { + km := getKeyMeta(key2MetaLongIV) + key, err := km.VerifyRecoveryKey(key2ID, key2RecoveryKey) + assert.NoError(t, err) + assert.NotNil(t, key) + assert.Equal(t, key2RecoveryKey, key.RecoveryKey()) +} + +func TestKeyMetadata_VerifyRecoveryKey_Unverified(t *testing.T) { + km := getKeyMeta(key2MetaUnverified) + key, err := km.VerifyRecoveryKey(key2ID, key2RecoveryKey) + assert.ErrorIs(t, err, ssss.ErrUnverifiableKey) + assert.NotNil(t, key) + assert.Equal(t, key2RecoveryKey, key.RecoveryKey()) +} + func TestKeyMetadata_VerifyRecoveryKey_Invalid(t *testing.T) { - km := getKey1Meta() + km := getKeyMeta(key1Meta) key, err := km.VerifyRecoveryKey(key1ID, "foo") - assert.True(t, errors.Is(err, ssss.ErrInvalidRecoveryKey), "unexpected error: %v", err) + assert.ErrorIs(t, err, ssss.ErrInvalidRecoveryKey) assert.Nil(t, key) } func TestKeyMetadata_VerifyRecoveryKey_Incorrect(t *testing.T) { - km := getKey1Meta() + km := getKeyMeta(key1Meta) key, err := km.VerifyRecoveryKey(key2ID, key2RecoveryKey) - assert.True(t, errors.Is(err, ssss.ErrIncorrectSSSSKey), "unexpected error: %v", err) + assert.ErrorIs(t, err, ssss.ErrIncorrectSSSSKey) assert.Nil(t, key) } func TestKeyMetadata_VerifyPassphrase_Correct(t *testing.T) { - km := getKey1Meta() + km := getKeyMeta(key1Meta) key, err := km.VerifyPassphrase(key1ID, key1Passphrase) assert.NoError(t, err) assert.NotNil(t, key) @@ -121,15 +146,29 @@ func TestKeyMetadata_VerifyPassphrase_Correct(t *testing.T) { } func TestKeyMetadata_VerifyPassphrase_Incorrect(t *testing.T) { - km := getKey1Meta() + km := getKeyMeta(key1Meta) key, err := km.VerifyPassphrase(key1ID, "incorrect horse battery staple") - assert.True(t, errors.Is(err, ssss.ErrIncorrectSSSSKey), "unexpected error %v", err) + assert.ErrorIs(t, err, ssss.ErrIncorrectSSSSKey) assert.Nil(t, key) } func TestKeyMetadata_VerifyPassphrase_NotSet(t *testing.T) { - km := getKey2Meta() + km := getKeyMeta(key2Meta) key, err := km.VerifyPassphrase(key2ID, "hmm") - assert.True(t, errors.Is(err, ssss.ErrNoPassphrase), "unexpected error %v", err) + assert.ErrorIs(t, err, ssss.ErrNoPassphrase) + assert.Nil(t, key) +} + +func TestKeyMetadata_VerifyRecoveryKey_CorruptedIV(t *testing.T) { + km := getKeyMeta(key2MetaBrokenIV) + key, err := km.VerifyRecoveryKey(key2ID, key2RecoveryKey) + assert.ErrorIs(t, err, ssss.ErrCorruptedKeyMetadata) + assert.Nil(t, key) +} + +func TestKeyMetadata_VerifyRecoveryKey_CorruptedMAC(t *testing.T) { + km := getKeyMeta(key2MetaBrokenMAC) + key, err := km.VerifyRecoveryKey(key2ID, key2RecoveryKey) + assert.ErrorIs(t, err, ssss.ErrCorruptedKeyMetadata) assert.Nil(t, key) } diff --git a/crypto/ssss/types.go b/crypto/ssss/types.go index 60852c55..b7465d3e 100644 --- a/crypto/ssss/types.go +++ b/crypto/ssss/types.go @@ -26,6 +26,8 @@ var ( ErrUnsupportedPassphraseAlgorithm = errors.New("unsupported passphrase KDF algorithm") ErrIncorrectSSSSKey = errors.New("incorrect SSSS key") ErrInvalidRecoveryKey = errors.New("invalid recovery key") + ErrCorruptedKeyMetadata = errors.New("corrupted recovery key metadata") + ErrUnverifiableKey = errors.New("cannot verify recovery key: missing MAC or IV in metadata") ) // Algorithm is the identifier for an SSSS encryption algorithm. @@ -56,6 +58,7 @@ type EncryptedKeyData struct { type EncryptedAccountDataEventContent struct { Encrypted map[string]EncryptedKeyData `json:"encrypted"` + Metadata map[string]any `json:"com.beeper.metadata,omitzero"` } func (ed *EncryptedAccountDataEventContent) Decrypt(eventType string, key *Key) ([]byte, error) { diff --git a/crypto/store.go b/crypto/store.go index 9a3a4394..7620cf35 100644 --- a/crypto/store.go +++ b/crypto/store.go @@ -12,8 +12,10 @@ import ( "slices" "sort" "sync" + "time" "go.mau.fi/util/dbutil" + "go.mau.fi/util/exsync" "golang.org/x/exp/maps" "maunium.net/go/mautrix/event" @@ -43,14 +45,23 @@ type Store interface { HasSession(context.Context, id.SenderKey) bool // GetSessions returns all Olm sessions in the store with the given sender key. GetSessions(context.Context, id.SenderKey) (OlmSessionList, error) - // GetLatestSession returns the session with the highest session ID (lexiographically sorting). - // It's usually safe to return the most recently added session if sorting by session ID is too difficult. + // GetLatestSession returns the most recent session that should be used for encrypting outbound messages. + // It's usually the one with the most recent successful decryption or the highest ID lexically. GetLatestSession(context.Context, id.SenderKey) (*OlmSession, error) + // GetNewestSessionCreationTS returns the creation timestamp of the most recently created session for the given sender key. + GetNewestSessionCreationTS(context.Context, id.SenderKey) (time.Time, error) // UpdateSession updates a session that has previously been inserted with AddSession. UpdateSession(context.Context, id.SenderKey, *OlmSession) error // DeleteSession deletes the given session that has been previously inserted with AddSession. DeleteSession(context.Context, id.SenderKey, *OlmSession) error + // PutOlmHash marks a given olm message hash as handled. + PutOlmHash(context.Context, [32]byte, time.Time) error + // GetOlmHash gets the time that a given olm hash was handled. + GetOlmHash(context.Context, [32]byte) (time.Time, error) + // DeleteOldOlmHashes deletes all olm hashes that were handled before the given time. + DeleteOldOlmHashes(context.Context, time.Time) error + // PutGroupSession inserts an inbound Megolm session into the store. If an earlier withhold event has been inserted // with PutWithheldGroupSession, this call should replace that. However, PutWithheldGroupSession must not replace // sessions inserted with this call. @@ -176,6 +187,7 @@ type MemoryStore struct { KeySignatures map[id.UserID]map[id.Ed25519]map[id.UserID]map[id.Ed25519]string OutdatedUsers map[id.UserID]struct{} Secrets map[id.Secret]string + OlmHashes *exsync.Set[[32]byte] } var _ Store = (*MemoryStore)(nil) @@ -198,6 +210,7 @@ func NewMemoryStore(saveCallback func() error) *MemoryStore { KeySignatures: make(map[id.UserID]map[id.Ed25519]map[id.UserID]map[id.Ed25519]string), OutdatedUsers: make(map[id.UserID]struct{}), Secrets: make(map[id.Secret]string), + OlmHashes: exsync.NewSet[[32]byte](), } } @@ -263,6 +276,23 @@ func (gs *MemoryStore) HasSession(_ context.Context, senderKey id.SenderKey) boo return ok && len(sessions) > 0 && !sessions[0].Expired() } +func (gs *MemoryStore) PutOlmHash(_ context.Context, hash [32]byte, receivedAt time.Time) error { + gs.OlmHashes.Add(hash) + return nil +} + +func (gs *MemoryStore) GetOlmHash(_ context.Context, hash [32]byte) (time.Time, error) { + if gs.OlmHashes.Has(hash) { + // The time isn't that important, so we just return the current time + return time.Now(), nil + } + return time.Time{}, nil +} + +func (gs *MemoryStore) DeleteOldOlmHashes(_ context.Context, beforeTS time.Time) error { + return nil +} + func (gs *MemoryStore) GetLatestSession(_ context.Context, senderKey id.SenderKey) (*OlmSession, error) { gs.lock.RLock() defer gs.lock.RUnlock() @@ -270,7 +300,16 @@ func (gs *MemoryStore) GetLatestSession(_ context.Context, senderKey id.SenderKe if !ok || len(sessions) == 0 { return nil, nil } - return sessions[0], nil + return sessions[len(sessions)-1], nil +} + +func (gs *MemoryStore) GetNewestSessionCreationTS(ctx context.Context, senderKey id.SenderKey) (createdAt time.Time, err error) { + var sess *OlmSession + sess, err = gs.GetLatestSession(ctx, senderKey) + if sess != nil { + createdAt = sess.CreationTime + } + return } func (gs *MemoryStore) getGroupSessions(roomID id.RoomID) map[id.SessionID]*InboundGroupSession { @@ -486,6 +525,9 @@ func (gs *MemoryStore) ValidateMessageIndex(_ context.Context, senderKey id.Send } val, ok := gs.MessageIndices[key] if !ok { + if eventID == "" && timestamp == 0 { + return true, nil + } gs.MessageIndices[key] = messageIndexValue{ EventID: eventID, Timestamp: timestamp, diff --git a/crypto/store_test.go b/crypto/store_test.go index a7c4d75a..7a47243e 100644 --- a/crypto/store_test.go +++ b/crypto/store_test.go @@ -13,6 +13,7 @@ import ( "testing" _ "github.com/mattn/go-sqlite3" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "go.mau.fi/util/dbutil" @@ -29,22 +30,14 @@ const groupSession = "9ZbsRqJuETbjnxPpKv29n3dubP/m5PSLbr9I9CIWS2O86F/Og1JZXhqT+4 func getCryptoStores(t *testing.T) map[string]Store { rawDB, err := sql.Open("sqlite3", ":memory:?_busy_timeout=5000") - if err != nil { - t.Fatalf("Error opening db: %v", err) - } + require.NoError(t, err, "Error opening raw database") db, err := dbutil.NewWithDB(rawDB, "sqlite3") - if err != nil { - t.Fatalf("Error opening db: %v", err) - } + require.NoError(t, err, "Error creating database wrapper") sqlStore := NewSQLCryptoStore(db, nil, "accid", id.DeviceID("dev"), []byte("test")) - if err = sqlStore.DB.Upgrade(context.TODO()); err != nil { - t.Fatalf("Error creating tables: %v", err) - } + err = sqlStore.DB.Upgrade(context.TODO()) + require.NoError(t, err, "Error upgrading database") gobStore := NewMemoryStore(nil) - if err != nil { - t.Fatalf("Error creating Gob store: %v", err) - } return map[string]Store{ "sql": sqlStore, @@ -56,9 +49,10 @@ func TestPutNextBatch(t *testing.T) { stores := getCryptoStores(t) store := stores["sql"].(*SQLCryptoStore) store.PutNextBatch(context.Background(), "batch1") - if batch, _ := store.GetNextBatch(context.Background()); batch != "batch1" { - t.Errorf("Expected batch1, got %v", batch) - } + + batch, err := store.GetNextBatch(context.Background()) + require.NoError(t, err, "Error retrieving next batch") + assert.Equal(t, "batch1", batch) } func TestPutAccount(t *testing.T) { @@ -68,15 +62,9 @@ func TestPutAccount(t *testing.T) { acc := NewOlmAccount() store.PutAccount(context.TODO(), acc) retrieved, err := store.GetAccount(context.TODO()) - if err != nil { - t.Fatalf("Error retrieving account: %v", err) - } - if acc.IdentityKey() != retrieved.IdentityKey() { - t.Errorf("Stored identity key %v, got %v", acc.IdentityKey(), retrieved.IdentityKey()) - } - if acc.SigningKey() != retrieved.SigningKey() { - t.Errorf("Stored signing key %v, got %v", acc.SigningKey(), retrieved.SigningKey()) - } + require.NoError(t, err, "Error retrieving account") + assert.Equal(t, acc.IdentityKey(), retrieved.IdentityKey(), "Identity key does not match") + assert.Equal(t, acc.SigningKey(), retrieved.SigningKey(), "Signing key does not match") }) } } @@ -86,18 +74,36 @@ func TestValidateMessageIndex(t *testing.T) { for storeName, store := range stores { t.Run(storeName, func(t *testing.T) { acc := NewOlmAccount() - if ok, _ := store.ValidateMessageIndex(context.TODO(), acc.IdentityKey(), "sess1", "event1", 0, 1000); !ok { - t.Error("First message not validated successfully") - } - if ok, _ := store.ValidateMessageIndex(context.TODO(), acc.IdentityKey(), "sess1", "event1", 0, 1001); ok { - t.Error("First message validated successfully after changing timestamp") - } - if ok, _ := store.ValidateMessageIndex(context.TODO(), acc.IdentityKey(), "sess1", "event2", 0, 1000); ok { - t.Error("First message validated successfully after changing event ID") - } - if ok, _ := store.ValidateMessageIndex(context.TODO(), acc.IdentityKey(), "sess1", "event1", 0, 1000); !ok { - t.Error("First message not validated successfully for a second time") - } + + // Validating without event ID and timestamp before we have them should work + ok, err := store.ValidateMessageIndex(context.TODO(), acc.IdentityKey(), "sess1", "", 0, 0) + require.NoError(t, err, "Error validating message index") + assert.True(t, ok, "First message validation should be valid") + + // First message should validate successfully + ok, err = store.ValidateMessageIndex(context.TODO(), acc.IdentityKey(), "sess1", "event1", 0, 1000) + require.NoError(t, err, "Error validating message index") + assert.True(t, ok, "First message validation should be valid") + + // Edit the timestamp and ensure validate fails + ok, err = store.ValidateMessageIndex(context.TODO(), acc.IdentityKey(), "sess1", "event1", 0, 1001) + require.NoError(t, err, "Error validating message index after timestamp change") + assert.False(t, ok, "First message validation should fail after timestamp change") + + // Edit the event ID and ensure validate fails + ok, err = store.ValidateMessageIndex(context.TODO(), acc.IdentityKey(), "sess1", "event2", 0, 1000) + require.NoError(t, err, "Error validating message index after event ID change") + assert.False(t, ok, "First message validation should fail after event ID change") + + // Validate again with the original parameters and ensure that it still passes + ok, err = store.ValidateMessageIndex(context.TODO(), acc.IdentityKey(), "sess1", "event1", 0, 1000) + require.NoError(t, err, "Error validating message index") + assert.True(t, ok, "First message validation should be valid") + + // Validating without event ID and timestamp must fail if we already know them + ok, err = store.ValidateMessageIndex(context.TODO(), acc.IdentityKey(), "sess1", "", 0, 0) + require.NoError(t, err, "Error validating message index") + assert.False(t, ok, "First message validation should be invalid") }) } } @@ -106,43 +112,26 @@ func TestStoreOlmSession(t *testing.T) { stores := getCryptoStores(t) for storeName, store := range stores { t.Run(storeName, func(t *testing.T) { - if store.HasSession(context.TODO(), olmSessID) { - t.Error("Found Olm session before inserting it") - } + require.False(t, store.HasSession(context.TODO(), olmSessID), "Found Olm session before inserting it") + olmInternal, err := olm.SessionFromPickled([]byte(olmPickled), []byte("test")) - if err != nil { - t.Fatalf("Error creating internal Olm session: %v", err) - } + require.NoError(t, err, "Error creating internal Olm session") olmSess := OlmSession{ id: olmSessID, Internal: olmInternal, } err = store.AddSession(context.TODO(), olmSessID, &olmSess) - if err != nil { - t.Errorf("Error storing Olm session: %v", err) - } - if !store.HasSession(context.TODO(), olmSessID) { - t.Error("Not found Olm session after inserting it") - } + require.NoError(t, err, "Error storing Olm session") + assert.True(t, store.HasSession(context.TODO(), olmSessID), "Olm session not found after inserting it") retrieved, err := store.GetLatestSession(context.TODO(), olmSessID) - if err != nil { - t.Errorf("Failed retrieving Olm session: %v", err) - } - - if retrieved.ID() != olmSessID { - t.Errorf("Expected session ID to be %v, got %v", olmSessID, retrieved.ID()) - } + require.NoError(t, err, "Error retrieving Olm session") + assert.EqualValues(t, olmSessID, retrieved.ID()) pickled, err := retrieved.Internal.Pickle([]byte("test")) - if err != nil { - t.Fatalf("Error pickling Olm session: %v", err) - } - - if string(pickled) != olmPickled { - t.Error("Pickled Olm session does not match original") - } + require.NoError(t, err, "Error pickling Olm session") + assert.EqualValues(t, pickled, olmPickled, "Pickled Olm session does not match original") }) } } @@ -154,9 +143,7 @@ func TestStoreMegolmSession(t *testing.T) { acc := NewOlmAccount() internal, err := olm.InboundGroupSessionFromPickled([]byte(groupSession), []byte("test")) - if err != nil { - t.Fatalf("Error creating internal inbound group session: %v", err) - } + require.NoError(t, err, "Error creating internal inbound group session") igs := &InboundGroupSession{ Internal: internal, @@ -166,20 +153,14 @@ func TestStoreMegolmSession(t *testing.T) { } err = store.PutGroupSession(context.TODO(), igs) - if err != nil { - t.Errorf("Error storing inbound group session: %v", err) - } + require.NoError(t, err, "Error storing inbound group session") retrieved, err := store.GetGroupSession(context.TODO(), "room1", igs.ID()) - if err != nil { - t.Errorf("Error retrieving inbound group session: %v", err) - } + require.NoError(t, err, "Error retrieving inbound group session") - if pickled, err := retrieved.Internal.Pickle([]byte("test")); err != nil { - t.Fatalf("Error pickling inbound group session: %v", err) - } else if string(pickled) != groupSession { - t.Error("Pickled inbound group session does not match original") - } + pickled, err := retrieved.Internal.Pickle([]byte("test")) + require.NoError(t, err, "Error pickling inbound group session") + assert.EqualValues(t, pickled, groupSession, "Pickled inbound group session does not match original") }) } } @@ -189,40 +170,24 @@ func TestStoreOutboundMegolmSession(t *testing.T) { for storeName, store := range stores { t.Run(storeName, func(t *testing.T) { sess, err := store.GetOutboundGroupSession(context.TODO(), "room1") - if sess != nil { - t.Error("Got outbound session before inserting") - } - if err != nil { - t.Errorf("Error retrieving outbound session: %v", err) - } + require.NoError(t, err, "Error retrieving outbound session") + require.Nil(t, sess, "Got outbound session before inserting") outbound, err := NewOutboundGroupSession("room1", nil) require.NoError(t, err) err = store.AddOutboundGroupSession(context.TODO(), outbound) - if err != nil { - t.Errorf("Error inserting outbound session: %v", err) - } + require.NoError(t, err, "Error inserting outbound session") sess, err = store.GetOutboundGroupSession(context.TODO(), "room1") - if sess == nil { - t.Error("Did not get outbound session after inserting") - } - if err != nil { - t.Errorf("Error retrieving outbound session: %v", err) - } + require.NoError(t, err, "Error retrieving outbound session") + assert.NotNil(t, sess, "Did not get outbound session after inserting") err = store.RemoveOutboundGroupSession(context.TODO(), "room1") - if err != nil { - t.Errorf("Error deleting outbound session: %v", err) - } + require.NoError(t, err, "Error deleting outbound session") sess, err = store.GetOutboundGroupSession(context.TODO(), "room1") - if sess != nil { - t.Error("Got outbound session after deleting") - } - if err != nil { - t.Errorf("Error retrieving outbound session: %v", err) - } + require.NoError(t, err, "Error retrieving outbound session after deletion") + assert.Nil(t, sess, "Got outbound session after deleting") }) } } @@ -244,58 +209,41 @@ func TestStoreOutboundMegolmSessionSharing(t *testing.T) { t.Run(storeName, func(t *testing.T) { device := resetDevice() err := store.PutDevice(context.TODO(), "user1", device) - if err != nil { - t.Errorf("Error storing devices: %v", err) - } + require.NoError(t, err, "Error storing device") shared, err := store.IsOutboundGroupSessionShared(context.TODO(), device.UserID, device.IdentityKey, "session1") - if err != nil { - t.Errorf("Error checking if outbound group session is shared: %v", err) - } else if shared { - t.Errorf("Outbound group session shared when it shouldn't") - } + require.NoError(t, err, "Error checking if outbound group session is shared") + assert.False(t, shared, "Outbound group session should not be shared initially") err = store.MarkOutboundGroupSessionShared(context.TODO(), device.UserID, device.IdentityKey, "session1") - if err != nil { - t.Errorf("Error marking outbound group session as shared: %v", err) - } + require.NoError(t, err, "Error marking outbound group session as shared") shared, err = store.IsOutboundGroupSessionShared(context.TODO(), device.UserID, device.IdentityKey, "session1") - if err != nil { - t.Errorf("Error checking if outbound group session is shared: %v", err) - } else if !shared { - t.Errorf("Outbound group session not shared when it should") - } + require.NoError(t, err, "Error checking if outbound group session is shared") + assert.True(t, shared, "Outbound group session should be shared after marking it as such") device = resetDevice() err = store.PutDevice(context.TODO(), "user1", device) - if err != nil { - t.Errorf("Error storing devices: %v", err) - } + require.NoError(t, err, "Error storing device after resetting") shared, err = store.IsOutboundGroupSessionShared(context.TODO(), device.UserID, device.IdentityKey, "session1") - if err != nil { - t.Errorf("Error checking if outbound group session is shared: %v", err) - } else if shared { - t.Errorf("Outbound group session shared when it shouldn't") - } + require.NoError(t, err, "Error checking if outbound group session is shared") + assert.False(t, shared, "Outbound group session should not be shared after resetting device") }) } } func TestStoreDevices(t *testing.T) { + devicesToCreate := 17 stores := getCryptoStores(t) for storeName, store := range stores { t.Run(storeName, func(t *testing.T) { outdated, err := store.GetOutdatedTrackedUsers(context.TODO()) - if err != nil { - t.Errorf("Error filtering tracked users: %v", err) - } - if len(outdated) > 0 { - t.Errorf("Got %d outdated tracked users when expected none", len(outdated)) - } + require.NoError(t, err, "Error filtering tracked users") + assert.Empty(t, outdated, "Expected no outdated tracked users initially") + deviceMap := make(map[id.DeviceID]*id.Device) - for i := 0; i < 17; i++ { + for i := 0; i < devicesToCreate; i++ { iStr := strconv.Itoa(i) acc := NewOlmAccount() deviceMap[id.DeviceID("dev"+iStr)] = &id.Device{ @@ -306,59 +254,33 @@ func TestStoreDevices(t *testing.T) { } } err = store.PutDevices(context.TODO(), "user1", deviceMap) - if err != nil { - t.Errorf("Error storing devices: %v", err) - } + require.NoError(t, err, "Error storing devices") devs, err := store.GetDevices(context.TODO(), "user1") - if err != nil { - t.Errorf("Error getting devices: %v", err) - } - if len(devs) != 17 { - t.Errorf("Stored 17 devices, got back %v", len(devs)) - } - if devs["dev0"].IdentityKey != deviceMap["dev0"].IdentityKey { - t.Errorf("First device identity key does not match") - } - if devs["dev16"].IdentityKey != deviceMap["dev16"].IdentityKey { - t.Errorf("Last device identity key does not match") - } + require.NoError(t, err, "Error getting devices") + assert.Len(t, devs, devicesToCreate, "Expected to get %d devices back", devicesToCreate) + assert.Equal(t, deviceMap, devs, "Stored devices do not match retrieved devices") filtered, err := store.FilterTrackedUsers(context.TODO(), []id.UserID{"user0", "user1", "user2"}) - if err != nil { - t.Errorf("Error filtering tracked users: %v", err) - } else if len(filtered) != 1 || filtered[0] != "user1" { - t.Errorf("Expected to get 'user1' from filter, got %v", filtered) - } + require.NoError(t, err, "Error filtering tracked users") + assert.Equal(t, []id.UserID{"user1"}, filtered, "Expected to get 'user1' from filter") outdated, err = store.GetOutdatedTrackedUsers(context.TODO()) - if err != nil { - t.Errorf("Error filtering tracked users: %v", err) - } - if len(outdated) > 0 { - t.Errorf("Got %d outdated tracked users when expected none", len(outdated)) - } + require.NoError(t, err, "Error filtering tracked users") + assert.Empty(t, outdated, "Expected no outdated tracked users after initial storage") + err = store.MarkTrackedUsersOutdated(context.TODO(), []id.UserID{"user0", "user1"}) - if err != nil { - t.Errorf("Error marking tracked users outdated: %v", err) - } + require.NoError(t, err, "Error marking tracked users outdated") + outdated, err = store.GetOutdatedTrackedUsers(context.TODO()) - if err != nil { - t.Errorf("Error filtering tracked users: %v", err) - } - if len(outdated) != 1 || outdated[0] != id.UserID("user1") { - t.Errorf("Got outdated tracked users %v when expected 'user1'", outdated) - } + require.NoError(t, err, "Error filtering tracked users") + assert.Equal(t, []id.UserID{"user1"}, outdated, "Expected 'user1' to be marked as outdated") + err = store.PutDevices(context.TODO(), "user1", deviceMap) - if err != nil { - t.Errorf("Error storing devices: %v", err) - } + require.NoError(t, err, "Error storing devices again") + outdated, err = store.GetOutdatedTrackedUsers(context.TODO()) - if err != nil { - t.Errorf("Error filtering tracked users: %v", err) - } - if len(outdated) > 0 { - t.Errorf("Got outdated tracked users %v when expected none", outdated) - } + require.NoError(t, err, "Error filtering tracked users") + assert.Empty(t, outdated, "Expected no outdated tracked users after re-storing devices") }) } } @@ -369,16 +291,11 @@ func TestStoreSecrets(t *testing.T) { t.Run(storeName, func(t *testing.T) { storedSecret := "trustno1" err := store.PutSecret(context.TODO(), id.SecretMegolmBackupV1, storedSecret) - if err != nil { - t.Errorf("Error storing secret: %v", err) - } + require.NoError(t, err, "Error storing secret") secret, err := store.GetSecret(context.TODO(), id.SecretMegolmBackupV1) - if err != nil { - t.Errorf("Error storing secret: %v", err) - } else if secret != storedSecret { - t.Errorf("Stored secret did not match: '%s' != '%s'", secret, storedSecret) - } + require.NoError(t, err, "Error retrieving secret") + assert.Equal(t, storedSecret, secret, "Retrieved secret does not match stored secret") }) } } diff --git a/crypto/utils/utils_test.go b/crypto/utils/utils_test.go index c4f01a68..b12fd9e2 100644 --- a/crypto/utils/utils_test.go +++ b/crypto/utils/utils_test.go @@ -9,6 +9,9 @@ package utils import ( "encoding/base64" "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestAES256Ctr(t *testing.T) { @@ -16,9 +19,7 @@ func TestAES256Ctr(t *testing.T) { key, iv := GenAttachmentA256CTR() enc := XorA256CTR([]byte(expected), key, iv) dec := XorA256CTR(enc, key, iv) - if string(dec) != expected { - t.Errorf("Expected decrypted using generated key/iv to be `%v`, got %v", expected, string(dec)) - } + assert.EqualValues(t, expected, dec, "Decrypted text should match original") var key2 [AESCTRKeyLength]byte var iv2 [AESCTRIVLength]byte @@ -29,9 +30,7 @@ func TestAES256Ctr(t *testing.T) { iv2[i] = byte(i) + 32 } dec2 := XorA256CTR([]byte{0x29, 0xc3, 0xff, 0x02, 0x21, 0xaf, 0x67, 0x73, 0x6e, 0xad, 0x9d}, key2, iv2) - if string(dec2) != expected { - t.Errorf("Expected decrypted using constant key/iv to be `%v`, got %v", expected, string(dec2)) - } + assert.EqualValues(t, expected, dec2, "Decrypted text with constant key/iv should match original") } func TestPBKDF(t *testing.T) { @@ -42,9 +41,7 @@ func TestPBKDF(t *testing.T) { key := PBKDF2SHA512([]byte("Hello world"), salt, 1000, 256) expected := "ffk9YdbVE1cgqOWgDaec0lH+rJzO+MuCcxpIn3Z6D0E=" keyB64 := base64.StdEncoding.EncodeToString([]byte(key)) - if keyB64 != expected { - t.Errorf("Expected base64 of generated key to be `%v`, got `%v`", expected, keyB64) - } + assert.Equal(t, expected, keyB64) } func TestDecodeSSSSKey(t *testing.T) { @@ -53,13 +50,10 @@ func TestDecodeSSSSKey(t *testing.T) { expected := "QCFDrXZYLEFnwf4NikVm62rYGJS2mNBEmAWLC3CgNPw=" decodedB64 := base64.StdEncoding.EncodeToString(decoded[:]) - if expected != decodedB64 { - t.Errorf("Expected decoded recovery key b64 to be `%v`, got `%v`", expected, decodedB64) - } + assert.Equal(t, expected, decodedB64) - if encoded := EncodeBase58RecoveryKey(decoded); encoded != recoveryKey { - t.Errorf("Expected recovery key to be `%v`, got `%v`", recoveryKey, encoded) - } + encoded := EncodeBase58RecoveryKey(decoded) + assert.Equal(t, recoveryKey, encoded) } func TestKeyDerivationAndHMAC(t *testing.T) { @@ -69,15 +63,11 @@ func TestKeyDerivationAndHMAC(t *testing.T) { aesKey, hmacKey := DeriveKeysSHA256(decoded[:], "m.cross_signing.master") ciphertextBytes, err := base64.StdEncoding.DecodeString("Fx16KlJ9vkd3Dd6CafIq5spaH5QmK5BALMzbtFbQznG2j1VARKK+klc4/Qo=") - if err != nil { - t.Error(err) - } + require.NoError(t, err) calcMac := HMACSHA256B64(ciphertextBytes, hmacKey) expectedMac := "0DABPNIZsP9iTOh1o6EM0s7BfHHXb96dN7Eca88jq2E" - if calcMac != expectedMac { - t.Errorf("Expected MAC `%v`, got `%v`", expectedMac, calcMac) - } + assert.Equal(t, expectedMac, calcMac) var ivBytes [AESCTRIVLength]byte decodedIV, _ := base64.StdEncoding.DecodeString("zxT/W5LpZ0Q819pfju6hZw==") @@ -85,7 +75,5 @@ func TestKeyDerivationAndHMAC(t *testing.T) { decrypted := string(XorA256CTR(ciphertextBytes, aesKey, ivBytes)) expectedDec := "Ec8eZDyvVkO3EDsEG6ej5c0cCHnX7PINqFXZjnaTV2s=" - if expectedDec != decrypted { - t.Errorf("Expected decrypted text to be `%v`, got `%v`", expectedDec, decrypted) - } + assert.Equal(t, expectedDec, decrypted) } diff --git a/crypto/verificationhelper/callbacks_test.go b/crypto/verificationhelper/callbacks_test.go index 5faf2009..3b943f28 100644 --- a/crypto/verificationhelper/callbacks_test.go +++ b/crypto/verificationhelper/callbacks_test.go @@ -17,20 +17,26 @@ import ( type MockVerificationCallbacks interface { GetRequestedVerifications() map[id.UserID][]id.VerificationTransactionID GetScanQRCodeTransactions() []id.VerificationTransactionID + GetVerificationsReadyTransactions() []id.VerificationTransactionID GetQRCodeShown(id.VerificationTransactionID) *verificationhelper.QRCode } type baseVerificationCallbacks struct { scanQRCodeTransactions []id.VerificationTransactionID verificationsRequested map[id.UserID][]id.VerificationTransactionID + verificationsReady []id.VerificationTransactionID qrCodesShown map[id.VerificationTransactionID]*verificationhelper.QRCode qrCodesScanned map[id.VerificationTransactionID]struct{} doneTransactions map[id.VerificationTransactionID]struct{} verificationCancellation map[id.VerificationTransactionID]*event.VerificationCancelEventContent emojisShown map[id.VerificationTransactionID][]rune + emojiDescriptionsShown map[id.VerificationTransactionID][]string decimalsShown map[id.VerificationTransactionID][]int } +var _ verificationhelper.RequiredCallbacks = (*baseVerificationCallbacks)(nil) +var _ MockVerificationCallbacks = (*baseVerificationCallbacks)(nil) + func newBaseVerificationCallbacks() *baseVerificationCallbacks { return &baseVerificationCallbacks{ verificationsRequested: map[id.UserID][]id.VerificationTransactionID{}, @@ -39,6 +45,7 @@ func newBaseVerificationCallbacks() *baseVerificationCallbacks { doneTransactions: map[id.VerificationTransactionID]struct{}{}, verificationCancellation: map[id.VerificationTransactionID]*event.VerificationCancelEventContent{}, emojisShown: map[id.VerificationTransactionID][]rune{}, + emojiDescriptionsShown: map[id.VerificationTransactionID][]string{}, decimalsShown: map[id.VerificationTransactionID][]int{}, } } @@ -51,6 +58,10 @@ func (c *baseVerificationCallbacks) GetScanQRCodeTransactions() []id.Verificatio return c.scanQRCodeTransactions } +func (c *baseVerificationCallbacks) GetVerificationsReadyTransactions() []id.VerificationTransactionID { + return c.verificationsReady +} + func (c *baseVerificationCallbacks) GetQRCodeShown(txnID id.VerificationTransactionID) *verificationhelper.QRCode { return c.qrCodesShown[txnID] } @@ -69,8 +80,8 @@ func (c *baseVerificationCallbacks) GetVerificationCancellation(txnID id.Verific return c.verificationCancellation[txnID] } -func (c *baseVerificationCallbacks) GetEmojisShown(txnID id.VerificationTransactionID) []rune { - return c.emojisShown[txnID] +func (c *baseVerificationCallbacks) GetEmojisAndDescriptionsShown(txnID id.VerificationTransactionID) ([]rune, []string) { + return c.emojisShown[txnID], c.emojiDescriptionsShown[txnID] } func (c *baseVerificationCallbacks) GetDecimalsShown(txnID id.VerificationTransactionID) []int { @@ -81,6 +92,16 @@ func (c *baseVerificationCallbacks) VerificationRequested(ctx context.Context, t c.verificationsRequested[from] = append(c.verificationsRequested[from], txnID) } +func (c *baseVerificationCallbacks) VerificationReady(ctx context.Context, txnID id.VerificationTransactionID, otherDeviceID id.DeviceID, supportsSAS, allowScanQRCode bool, qrCode *verificationhelper.QRCode) { + c.verificationsReady = append(c.verificationsReady, txnID) + if allowScanQRCode { + c.scanQRCodeTransactions = append(c.scanQRCodeTransactions, txnID) + } + if qrCode != nil { + c.qrCodesShown[txnID] = qrCode + } +} + func (c *baseVerificationCallbacks) VerificationCancelled(ctx context.Context, txnID id.VerificationTransactionID, code event.VerificationCancelCode, reason string) { c.verificationCancellation[txnID] = &event.VerificationCancelEventContent{ Code: code, @@ -88,7 +109,7 @@ func (c *baseVerificationCallbacks) VerificationCancelled(ctx context.Context, t } } -func (c *baseVerificationCallbacks) VerificationDone(ctx context.Context, txnID id.VerificationTransactionID) { +func (c *baseVerificationCallbacks) VerificationDone(ctx context.Context, txnID id.VerificationTransactionID, method event.VerificationMethod) { c.doneTransactions[txnID] = struct{}{} } @@ -96,6 +117,8 @@ type sasVerificationCallbacks struct { *baseVerificationCallbacks } +var _ verificationhelper.ShowSASCallbacks = (*sasVerificationCallbacks)(nil) + func newSASVerificationCallbacks() *sasVerificationCallbacks { return &sasVerificationCallbacks{newBaseVerificationCallbacks()} } @@ -104,39 +127,34 @@ func newSASVerificationCallbacksWithBase(base *baseVerificationCallbacks) *sasVe return &sasVerificationCallbacks{base} } -func (c *sasVerificationCallbacks) ShowSAS(ctx context.Context, txnID id.VerificationTransactionID, emojis []rune, decimals []int) { +func (c *sasVerificationCallbacks) ShowSAS(ctx context.Context, txnID id.VerificationTransactionID, emojis []rune, emojiDescriptions []string, decimals []int) { c.emojisShown[txnID] = emojis + c.emojiDescriptionsShown[txnID] = emojiDescriptions c.decimalsShown[txnID] = decimals } -type qrCodeVerificationCallbacks struct { +type showQRCodeVerificationCallbacks struct { *baseVerificationCallbacks } -func newQRCodeVerificationCallbacks() *qrCodeVerificationCallbacks { - return &qrCodeVerificationCallbacks{newBaseVerificationCallbacks()} +var _ verificationhelper.ShowQRCodeCallbacks = (*showQRCodeVerificationCallbacks)(nil) + +func newShowQRCodeVerificationCallbacks() *showQRCodeVerificationCallbacks { + return &showQRCodeVerificationCallbacks{newBaseVerificationCallbacks()} } -func newQRCodeVerificationCallbacksWithBase(base *baseVerificationCallbacks) *qrCodeVerificationCallbacks { - return &qrCodeVerificationCallbacks{base} +func newShowQRCodeVerificationCallbacksWithBase(base *baseVerificationCallbacks) *showQRCodeVerificationCallbacks { + return &showQRCodeVerificationCallbacks{base} } -func (c *qrCodeVerificationCallbacks) ScanQRCode(ctx context.Context, txnID id.VerificationTransactionID) { - c.scanQRCodeTransactions = append(c.scanQRCodeTransactions, txnID) -} - -func (c *qrCodeVerificationCallbacks) ShowQRCode(ctx context.Context, txnID id.VerificationTransactionID, qrCode *verificationhelper.QRCode) { - c.qrCodesShown[txnID] = qrCode -} - -func (c *qrCodeVerificationCallbacks) QRCodeScanned(ctx context.Context, txnID id.VerificationTransactionID) { +func (c *showQRCodeVerificationCallbacks) QRCodeScanned(ctx context.Context, txnID id.VerificationTransactionID) { c.qrCodesScanned[txnID] = struct{}{} } type allVerificationCallbacks struct { *baseVerificationCallbacks *sasVerificationCallbacks - *qrCodeVerificationCallbacks + *showQRCodeVerificationCallbacks } func newAllVerificationCallbacks() *allVerificationCallbacks { @@ -144,6 +162,6 @@ func newAllVerificationCallbacks() *allVerificationCallbacks { return &allVerificationCallbacks{ base, newSASVerificationCallbacksWithBase(base), - newQRCodeVerificationCallbacksWithBase(base), + newShowQRCodeVerificationCallbacksWithBase(base), } } diff --git a/crypto/verificationhelper/mockserver_test.go b/crypto/verificationhelper/mockserver_test.go deleted file mode 100644 index b6bf3d2c..00000000 --- a/crypto/verificationhelper/mockserver_test.go +++ /dev/null @@ -1,255 +0,0 @@ -// Copyright (c) 2024 Sumner Evans -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. - -package verificationhelper_test - -import ( - "context" - "encoding/json" - "io" - "net/http" - "net/http/httptest" - "net/url" - "strings" - "testing" - - "github.com/gorilla/mux" - "github.com/rs/zerolog/log" // zerolog-allow-global-log - "github.com/stretchr/testify/require" - "go.mau.fi/util/random" - - "maunium.net/go/mautrix" - "maunium.net/go/mautrix/crypto" - "maunium.net/go/mautrix/crypto/cryptohelper" - "maunium.net/go/mautrix/event" - "maunium.net/go/mautrix/id" -) - -// mockServer is a mock Matrix server that wraps an [httptest.Server] to allow -// testing of the interactive verification process. -type mockServer struct { - *httptest.Server - - AccessTokenToUserID map[string]id.UserID - DeviceInbox map[id.UserID]map[id.DeviceID][]event.Event - AccountData map[id.UserID]map[event.Type]json.RawMessage - DeviceKeys map[id.UserID]map[id.DeviceID]mautrix.DeviceKeys - MasterKeys map[id.UserID]mautrix.CrossSigningKeys - SelfSigningKeys map[id.UserID]mautrix.CrossSigningKeys - UserSigningKeys map[id.UserID]mautrix.CrossSigningKeys -} - -func DecodeVarsMiddleware(next http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - vars := mux.Vars(r) - var err error - for k, v := range vars { - vars[k], err = url.PathUnescape(v) - if err != nil { - panic(err) - } - } - next.ServeHTTP(w, r) - }) -} - -func createMockServer(t *testing.T) *mockServer { - t.Helper() - - server := mockServer{ - AccessTokenToUserID: map[string]id.UserID{}, - DeviceInbox: map[id.UserID]map[id.DeviceID][]event.Event{}, - AccountData: map[id.UserID]map[event.Type]json.RawMessage{}, - DeviceKeys: map[id.UserID]map[id.DeviceID]mautrix.DeviceKeys{}, - MasterKeys: map[id.UserID]mautrix.CrossSigningKeys{}, - SelfSigningKeys: map[id.UserID]mautrix.CrossSigningKeys{}, - UserSigningKeys: map[id.UserID]mautrix.CrossSigningKeys{}, - } - - router := mux.NewRouter().SkipClean(true).StrictSlash(false).UseEncodedPath() - router.Use(DecodeVarsMiddleware) - router.HandleFunc("/_matrix/client/v3/login", server.postLogin).Methods(http.MethodPost) - router.HandleFunc("/_matrix/client/v3/keys/query", server.postKeysQuery).Methods(http.MethodPost) - router.HandleFunc("/_matrix/client/v3/sendToDevice/{type}/{txn}", server.putSendToDevice).Methods(http.MethodPut) - router.HandleFunc("/_matrix/client/v3/user/{userID}/account_data/{type}", server.putAccountData).Methods(http.MethodPut) - router.HandleFunc("/_matrix/client/v3/keys/device_signing/upload", server.postDeviceSigningUpload).Methods(http.MethodPost) - router.HandleFunc("/_matrix/client/v3/keys/signatures/upload", server.emptyResp).Methods(http.MethodPost) - router.HandleFunc("/_matrix/client/v3/keys/upload", server.postKeysUpload).Methods(http.MethodPost) - - server.Server = httptest.NewServer(router) - return &server -} - -func (ms *mockServer) getUserID(r *http.Request) id.UserID { - authHeader := r.Header.Get("Authorization") - authHeader = strings.TrimPrefix(authHeader, "Bearer ") - userID, ok := ms.AccessTokenToUserID[authHeader] - if !ok { - panic("no user ID found for access token " + authHeader) - } - return userID -} - -func (s *mockServer) emptyResp(w http.ResponseWriter, _ *http.Request) { - w.Write([]byte("{}")) -} - -func (s *mockServer) postLogin(w http.ResponseWriter, r *http.Request) { - var loginReq mautrix.ReqLogin - json.NewDecoder(r.Body).Decode(&loginReq) - - deviceID := loginReq.DeviceID - if deviceID == "" { - deviceID = id.DeviceID(random.String(10)) - } - - accessToken := random.String(30) - userID := id.UserID(loginReq.Identifier.User) - s.AccessTokenToUserID[accessToken] = userID - - json.NewEncoder(w).Encode(&mautrix.RespLogin{ - AccessToken: accessToken, - DeviceID: deviceID, - UserID: userID, - }) -} - -func (s *mockServer) putSendToDevice(w http.ResponseWriter, r *http.Request) { - vars := mux.Vars(r) - var req mautrix.ReqSendToDevice - json.NewDecoder(r.Body).Decode(&req) - evtType := event.Type{Type: vars["type"], Class: event.ToDeviceEventType} - - for user, devices := range req.Messages { - for device, content := range devices { - if _, ok := s.DeviceInbox[user]; !ok { - s.DeviceInbox[user] = map[id.DeviceID][]event.Event{} - } - content.ParseRaw(evtType) - s.DeviceInbox[user][device] = append(s.DeviceInbox[user][device], event.Event{ - Sender: s.getUserID(r), - Type: evtType, - Content: *content, - }) - } - } - s.emptyResp(w, r) -} - -func (s *mockServer) putAccountData(w http.ResponseWriter, r *http.Request) { - vars := mux.Vars(r) - userID := id.UserID(vars["userID"]) - eventType := event.Type{Type: vars["type"], Class: event.AccountDataEventType} - - jsonData, _ := io.ReadAll(r.Body) - if _, ok := s.AccountData[userID]; !ok { - s.AccountData[userID] = map[event.Type]json.RawMessage{} - } - s.AccountData[userID][eventType] = json.RawMessage(jsonData) - s.emptyResp(w, r) -} - -func (s *mockServer) postKeysQuery(w http.ResponseWriter, r *http.Request) { - var req mautrix.ReqQueryKeys - json.NewDecoder(r.Body).Decode(&req) - resp := mautrix.RespQueryKeys{ - MasterKeys: map[id.UserID]mautrix.CrossSigningKeys{}, - UserSigningKeys: map[id.UserID]mautrix.CrossSigningKeys{}, - SelfSigningKeys: map[id.UserID]mautrix.CrossSigningKeys{}, - DeviceKeys: map[id.UserID]map[id.DeviceID]mautrix.DeviceKeys{}, - } - for user := range req.DeviceKeys { - resp.MasterKeys[user] = s.MasterKeys[user] - resp.UserSigningKeys[user] = s.UserSigningKeys[user] - resp.SelfSigningKeys[user] = s.SelfSigningKeys[user] - resp.DeviceKeys[user] = s.DeviceKeys[user] - } - json.NewEncoder(w).Encode(&resp) -} - -func (s *mockServer) postKeysUpload(w http.ResponseWriter, r *http.Request) { - var req mautrix.ReqUploadKeys - json.NewDecoder(r.Body).Decode(&req) - - userID := s.getUserID(r) - if _, ok := s.DeviceKeys[userID]; !ok { - s.DeviceKeys[userID] = map[id.DeviceID]mautrix.DeviceKeys{} - } - s.DeviceKeys[userID][req.DeviceKeys.DeviceID] = *req.DeviceKeys - - json.NewEncoder(w).Encode(&mautrix.RespUploadKeys{ - OneTimeKeyCounts: mautrix.OTKCount{SignedCurve25519: 50}, - }) -} - -func (s *mockServer) postDeviceSigningUpload(w http.ResponseWriter, r *http.Request) { - var req mautrix.UploadCrossSigningKeysReq - json.NewDecoder(r.Body).Decode(&req) - - userID := s.getUserID(r) - s.MasterKeys[userID] = req.Master - s.SelfSigningKeys[userID] = req.SelfSigning - s.UserSigningKeys[userID] = req.UserSigning - - s.emptyResp(w, r) -} - -func (ms *mockServer) Login(t *testing.T, ctx context.Context, userID id.UserID, deviceID id.DeviceID) (*mautrix.Client, crypto.Store) { - t.Helper() - client, err := mautrix.NewClient(ms.URL, "", "") - require.NoError(t, err) - client.StateStore = mautrix.NewMemoryStateStore() - - _, err = client.Login(ctx, &mautrix.ReqLogin{ - Type: mautrix.AuthTypePassword, - Identifier: mautrix.UserIdentifier{ - Type: mautrix.IdentifierTypeUser, - User: userID.String(), - }, - DeviceID: deviceID, - Password: "password", - StoreCredentials: true, - }) - require.NoError(t, err) - - cryptoStore := crypto.NewMemoryStore(nil) - cryptoHelper, err := cryptohelper.NewCryptoHelper(client, []byte("test"), cryptoStore) - require.NoError(t, err) - client.Crypto = cryptoHelper - - err = cryptoHelper.Init(ctx) - require.NoError(t, err) - - machineLog := log.Logger.With(). - Stringer("my_user_id", userID). - Stringer("my_device_id", deviceID). - Logger() - cryptoHelper.Machine().Log = &machineLog - - err = cryptoHelper.Machine().ShareKeys(ctx, 50) - require.NoError(t, err) - - return client, cryptoStore -} - -func (ms *mockServer) dispatchToDevice(t *testing.T, ctx context.Context, client *mautrix.Client) { - t.Helper() - - for _, evt := range ms.DeviceInbox[client.UserID][client.DeviceID] { - client.Syncer.(*mautrix.DefaultSyncer).Dispatch(ctx, &evt) - ms.DeviceInbox[client.UserID][client.DeviceID] = ms.DeviceInbox[client.UserID][client.DeviceID][1:] - } -} - -func addDeviceID(ctx context.Context, cryptoStore crypto.Store, userID id.UserID, deviceID id.DeviceID) { - err := cryptoStore.PutDevice(ctx, userID, &id.Device{ - UserID: userID, - DeviceID: deviceID, - }) - if err != nil { - panic(err) - } -} diff --git a/crypto/verificationhelper/qrcode.go b/crypto/verificationhelper/qrcode.go index a28d8fc3..11698152 100644 --- a/crypto/verificationhelper/qrcode.go +++ b/crypto/verificationhelper/qrcode.go @@ -82,6 +82,10 @@ func NewQRCodeFromBytes(data []byte) (*QRCode, error) { // // [Section 11.12.2.4.1]: https://spec.matrix.org/v1.9/client-server-api/#qr-code-format func (q *QRCode) Bytes() []byte { + if q == nil { + return nil + } + var buf bytes.Buffer buf.WriteString("MATRIX") // Header buf.WriteByte(0x02) // Version diff --git a/crypto/verificationhelper/reciprocate.go b/crypto/verificationhelper/reciprocate.go index ef69f23c..d8827b8b 100644 --- a/crypto/verificationhelper/reciprocate.go +++ b/crypto/verificationhelper/reciprocate.go @@ -182,7 +182,7 @@ func (vh *VerificationHelper) HandleScannedQRData(ctx context.Context, data []by if err = vh.store.DeleteVerification(ctx, txn.TransactionID); err != nil { return err } - vh.verificationDone(ctx, txn.TransactionID) + vh.verificationDone(ctx, txn.TransactionID, txn.StartEventContent.Method) } else { return vh.store.SaveVerificationTransaction(ctx, txn) } @@ -212,27 +212,34 @@ func (vh *VerificationHelper) ConfirmQRCodeScanned(ctx context.Context, txnID id log.Info().Msg("Confirming QR code scanned") + // Get their device + theirDevice, err := vh.mach.GetOrFetchDevice(ctx, txn.TheirUserID, txn.TheirDeviceID) + if err != nil { + return err + } + + // Trust their device + theirDevice.Trust = id.TrustStateVerified + err = vh.mach.CryptoStore.PutDevice(ctx, txn.TheirUserID, theirDevice) + if err != nil { + return fmt.Errorf("failed to update device trust state after verifying: %w", err) + } + if txn.TheirUserID == vh.client.UserID { - // Self-signing situation. Trust their device. - - // Get their device - theirDevice, err := vh.mach.GetOrFetchDevice(ctx, txn.TheirUserID, txn.TheirDeviceID) - if err != nil { - return err - } - - // Trust their device - theirDevice.Trust = id.TrustStateVerified - err = vh.mach.CryptoStore.PutDevice(ctx, txn.TheirUserID, theirDevice) - if err != nil { - return fmt.Errorf("failed to update device trust state after verifying: %w", err) - } - - // Cross-sign their device with the self-signing key + // Self-signing situation. + // + // If we have the cross-signing keys, then we need to sign their device + // using the self-signing key. Otherwise, they have the master private + // key, so we need to trust the master public key. if vh.mach.CrossSigningKeys != nil { err = vh.mach.SignOwnDevice(ctx, theirDevice) if err != nil { - return fmt.Errorf("failed to sign their device: %w", err) + return fmt.Errorf("failed to sign our own new device: %w", err) + } + } else { + err = vh.mach.SignOwnMasterKey(ctx) + if err != nil { + return fmt.Errorf("failed to sign our own master key: %w", err) } } } else { @@ -256,35 +263,37 @@ func (vh *VerificationHelper) ConfirmQRCodeScanned(ctx context.Context, txnID id if err = vh.store.DeleteVerification(ctx, txn.TransactionID); err != nil { return err } - vh.verificationDone(ctx, txn.TransactionID) + vh.verificationDone(ctx, txn.TransactionID, txn.StartEventContent.Method) } else { return vh.store.SaveVerificationTransaction(ctx, txn) } return nil } -func (vh *VerificationHelper) generateAndShowQRCode(ctx context.Context, txn *VerificationTransaction) error { +func (vh *VerificationHelper) generateQRCode(ctx context.Context, txn *VerificationTransaction) (*QRCode, error) { log := vh.getLog(ctx).With(). Str("verification_action", "generate and show QR code"). Stringer("transaction_id", txn.TransactionID). Logger() ctx = log.WithContext(ctx) - if vh.showQRCode == nil { - log.Info().Msg("Ignoring QR code generation request as showing a QR code is not enabled on this device") - return nil + + if !slices.Contains(vh.supportedMethods, event.VerificationMethodReciprocate) || + !slices.Contains(txn.TheirSupportedMethods, event.VerificationMethodReciprocate) { + log.Info().Msg("Ignoring QR code generation request as reciprocating is not supported by both devices") + return nil, nil } else if !slices.Contains(txn.TheirSupportedMethods, event.VerificationMethodQRCodeScan) { log.Info().Msg("Ignoring QR code generation request as other device cannot scan QR codes") - return nil + return nil, nil } ownCrossSigningPublicKeys := vh.mach.GetOwnCrossSigningPublicKeys(ctx) if ownCrossSigningPublicKeys == nil || len(ownCrossSigningPublicKeys.MasterKey) == 0 { - return errors.New("failed to get own cross-signing master public key") + return nil, errors.New("failed to get own cross-signing master public key") } ownMasterKeyTrusted, err := vh.mach.CryptoStore.IsKeySignedBy(ctx, vh.client.UserID, ownCrossSigningPublicKeys.MasterKey, vh.client.UserID, vh.mach.OwnIdentity().SigningKey) if err != nil { - return err + return nil, err } mode := QRCodeModeCrossSigning if vh.client.UserID == txn.TheirUserID { @@ -297,7 +306,7 @@ func (vh *VerificationHelper) generateAndShowQRCode(ctx context.Context, txn *Ve } else { // This is a cross-signing situation. if !ownMasterKeyTrusted { - return errors.New("cannot cross-sign other device when own master key is not trusted") + return nil, errors.New("cannot cross-sign other device when own master key is not trusted") } mode = QRCodeModeCrossSigning } @@ -311,7 +320,7 @@ func (vh *VerificationHelper) generateAndShowQRCode(ctx context.Context, txn *Ve // Key 2 is the other user's master signing key. theirSigningKeys, err := vh.mach.GetCrossSigningPublicKeys(ctx, txn.TheirUserID) if err != nil { - return err + return nil, err } key2 = theirSigningKeys.MasterKey.Bytes() case QRCodeModeSelfVerifyingMasterKeyTrusted: @@ -321,7 +330,7 @@ func (vh *VerificationHelper) generateAndShowQRCode(ctx context.Context, txn *Ve // Key 2 is the other device's key. theirDevice, err := vh.mach.GetOrFetchDevice(ctx, txn.TheirUserID, txn.TheirDeviceID) if err != nil { - return err + return nil, err } key2 = theirDevice.SigningKey.Bytes() case QRCodeModeSelfVerifyingMasterKeyUntrusted: @@ -336,6 +345,5 @@ func (vh *VerificationHelper) generateAndShowQRCode(ctx context.Context, txn *Ve qrCode := NewQRCode(mode, txn.TransactionID, [32]byte(key1), [32]byte(key2)) txn.QRCodeSharedSecret = qrCode.SharedSecret - vh.showQRCode(ctx, txn.TransactionID, qrCode) - return nil + return qrCode, nil } diff --git a/crypto/verificationhelper/sas.go b/crypto/verificationhelper/sas.go index 81728bd4..e6392c79 100644 --- a/crypto/verificationhelper/sas.go +++ b/crypto/verificationhelper/sas.go @@ -13,6 +13,7 @@ import ( "crypto/hmac" "crypto/rand" "crypto/sha256" + "crypto/subtle" "encoding/base64" "encoding/json" "errors" @@ -34,7 +35,7 @@ import ( // [StartInRoomVerification] functions. func (vh *VerificationHelper) StartSAS(ctx context.Context, txnID id.VerificationTransactionID) error { log := vh.getLog(ctx).With(). - Str("verification_action", "accept verification"). + Str("verification_action", "start SAS"). Stringer("transaction_id", txnID). Logger() ctx = log.WithContext(ctx) @@ -45,7 +46,7 @@ func (vh *VerificationHelper) StartSAS(ctx context.Context, txnID id.Verificatio if err != nil { return fmt.Errorf("failed to get verification transaction %s: %w", txnID, err) } else if txn.VerificationState != VerificationStateReady { - return errors.New("transaction is not in ready state") + return fmt.Errorf("transaction is not in ready state: %s", txn.VerificationState.String()) } else if txn.StartEventContent != nil { return errors.New("start event already sent or received") } @@ -110,6 +111,7 @@ func (vh *VerificationHelper) ConfirmSAS(ctx context.Context, txnID id.Verificat keys := map[id.KeyID]jsonbytes.UnpaddedBytes{} log.Info().Msg("Signing keys") + var masterKey string // My device key myDevice := vh.mach.OwnIdentity() @@ -122,8 +124,9 @@ func (vh *VerificationHelper) ConfirmSAS(ctx context.Context, txnID id.Verificat // Master signing key crossSigningKeys := vh.mach.GetOwnCrossSigningPublicKeys(ctx) if crossSigningKeys != nil { - crossSigningKeyID := id.NewKeyID(id.KeyAlgorithmEd25519, crossSigningKeys.MasterKey.String()) - keys[crossSigningKeyID], err = vh.verificationMACHKDF(txn, vh.client.UserID, vh.client.DeviceID, txn.TheirUserID, txn.TheirDeviceID, crossSigningKeyID.String(), crossSigningKeys.MasterKey.String()) + masterKey = crossSigningKeys.MasterKey.String() + crossSigningKeyID := id.NewKeyID(id.KeyAlgorithmEd25519, masterKey) + keys[crossSigningKeyID], err = vh.verificationMACHKDF(txn, vh.client.UserID, vh.client.DeviceID, txn.TheirUserID, txn.TheirDeviceID, crossSigningKeyID.String(), masterKey) if err != nil { return err } @@ -147,10 +150,16 @@ func (vh *VerificationHelper) ConfirmSAS(ctx context.Context, txnID id.Verificat if err != nil { return err } + log.Info().Msg("Sent our MAC event") txn.SentOurMAC = true if txn.ReceivedTheirMAC { txn.VerificationState = VerificationStateSASMACExchanged + + if err := vh.trustKeysAfterMACCheck(ctx, txn, masterKey); err != nil { + return fmt.Errorf("failed to trust keys: %w", err) + } + err := vh.sendVerificationEvent(ctx, txn, event.InRoomVerificationDone, &event.VerificationDoneEventContent{}) if err != nil { return err @@ -168,7 +177,7 @@ func (vh *VerificationHelper) ConfirmSAS(ctx context.Context, txnID id.Verificat func (vh *VerificationHelper) onVerificationStartSAS(ctx context.Context, txn VerificationTransaction, evt *event.Event) error { startEvt := evt.Content.AsVerificationStart() log := vh.getLog(ctx).With(). - Str("verification_action", "start_sas"). + Str("verification_action", "start SAS"). Stringer("transaction_id", txn.TransactionID). Logger() ctx = log.WithContext(ctx) @@ -215,28 +224,29 @@ func (vh *VerificationHelper) onVerificationStartSAS(ctx context.Context, txn Ve } txn.MACMethod = macMethod txn.EphemeralKey = &ECDHPrivateKey{ephemeralKey} - txn.StartEventContent = startEvt - commitment, err := calculateCommitment(ephemeralKey.PublicKey(), startEvt) - if err != nil { - return fmt.Errorf("failed to calculate commitment: %w", err) - } + if !txn.StartedByUs { + commitment, err := calculateCommitment(ephemeralKey.PublicKey(), txn) + if err != nil { + return fmt.Errorf("failed to calculate commitment: %w", err) + } - err = vh.sendVerificationEvent(ctx, txn, event.InRoomVerificationAccept, &event.VerificationAcceptEventContent{ - Commitment: commitment, - Hash: hashAlgorithm, - KeyAgreementProtocol: keyAggreementProtocol, - MessageAuthenticationCode: macMethod, - ShortAuthenticationString: sasMethods, - }) - if err != nil { - return fmt.Errorf("failed to send accept event: %w", err) + err = vh.sendVerificationEvent(ctx, txn, event.InRoomVerificationAccept, &event.VerificationAcceptEventContent{ + Commitment: commitment, + Hash: hashAlgorithm, + KeyAgreementProtocol: keyAggreementProtocol, + MessageAuthenticationCode: macMethod, + ShortAuthenticationString: sasMethods, + }) + if err != nil { + return fmt.Errorf("failed to send accept event: %w", err) + } + txn.VerificationState = VerificationStateSASAccepted } - txn.VerificationState = VerificationStateSASAccepted return vh.store.SaveVerificationTransaction(ctx, txn) } -func calculateCommitment(ephemeralPubKey *ecdh.PublicKey, startEvt *event.VerificationStartEventContent) ([]byte, error) { +func calculateCommitment(ephemeralPubKey *ecdh.PublicKey, txn VerificationTransaction) ([]byte, error) { // The commitmentHashInput is the hash (encoded as unpadded base64) of the // concatenation of the device's ephemeral public key (encoded as // unpadded base64) and the canonical JSON representation of the @@ -246,7 +256,7 @@ func calculateCommitment(ephemeralPubKey *ecdh.PublicKey, startEvt *event.Verifi // hashing it, but we are just stuck on that. commitmentHashInput := sha256.New() commitmentHashInput.Write([]byte(base64.RawStdEncoding.EncodeToString(ephemeralPubKey.Bytes()))) - encodedStartEvt, err := json.Marshal(startEvt) + encodedStartEvt, err := json.Marshal(txn.StartEventContent) if err != nil { return nil, err } @@ -330,19 +340,13 @@ func (vh *VerificationHelper) onVerificationKey(ctx context.Context, txn Verific if txn.EphemeralPublicKeyShared { // Verify that the commitment hash is correct - commitment, err := calculateCommitment(publicKey, txn.StartEventContent) + commitment, err := calculateCommitment(publicKey, txn) if err != nil { log.Err(err).Msg("Failed to calculate commitment") return } if !bytes.Equal(commitment, txn.Commitment) { - err := vh.sendVerificationEvent(ctx, txn, event.InRoomVerificationCancel, &event.VerificationCancelEventContent{ - Code: event.VerificationCancelCodeKeyMismatch, - Reason: "The key was not the one we expected.", - }) - if err != nil { - log.Err(err).Msg("Failed to send cancellation event") - } + vh.cancelVerificationTxn(ctx, txn, event.VerificationCancelCodeKeyMismatch, "The key was not the one we expected") return } } else { @@ -365,6 +369,7 @@ func (vh *VerificationHelper) onVerificationKey(ctx context.Context, txn Verific var decimals []int var emojis []rune + var emojiDescriptions []string if slices.Contains(txn.StartEventContent.ShortAuthenticationString, event.SASMethodDecimal) { decimals = []int{ (int(sasBytes[0])<<5 | int(sasBytes[1])>>3) + 1000, @@ -380,9 +385,10 @@ func (vh *VerificationHelper) onVerificationKey(ctx context.Context, txn Verific // Right shift the number and then mask the lowest 6 bits. emojiIdx := (sasNum >> uint(48-(i+1)*6)) & 0b111111 emojis = append(emojis, allEmojis[emojiIdx]) + emojiDescriptions = append(emojiDescriptions, allEmojiDescriptions[emojiIdx]) } } - vh.showSAS(ctx, txn.TransactionID, emojis, decimals) + vh.showSAS(ctx, txn.TransactionID, emojis, emojiDescriptions, decimals) if err := vh.store.SaveVerificationTransaction(ctx, txn); err != nil { log.Err(err).Msg("failed to save verification transaction") @@ -580,6 +586,73 @@ var allEmojis = []rune{ '📌', } +var allEmojiDescriptions = []string{ + "Dog", + "Cat", + "Lion", + "Horse", + "Unicorn", + "Pig", + "Elephant", + "Rabbit", + "Panda", + "Rooster", + "Penguin", + "Turtle", + "Fish", + "Octopus", + "Butterfly", + "Flower", + "Tree", + "Cactus", + "Mushroom", + "Globe", + "Moon", + "Cloud", + "Fire", + "Banana", + "Apple", + "Strawberry", + "Corn", + "Pizza", + "Cake", + "Heart", + "Smiley", + "Robot", + "Hat", + "Glasses", + "Spanner", + "Santa", + "Thumbs Up", + "Umbrella", + "Hourglass", + "Clock", + "Gift", + "Light Bulb", + "Book", + "Pencil", + "Paperclip", + "Scissors", + "Lock", + "Key", + "Hammer", + "Telephone", + "Flag", + "Train", + "Bicycle", + "Aeroplane", + "Rocket", + "Trophy", + "Ball", + "Guitar", + "Trumpet", + "Bell", + "Anchor", + "Headphones", + "Folder", + "Pin", +} + func (vh *VerificationHelper) onVerificationMAC(ctx context.Context, txn VerificationTransaction, evt *event.Event) { log := vh.getLog(ctx).With(). Str("verification_action", "mac"). @@ -593,12 +666,15 @@ func (vh *VerificationHelper) onVerificationMAC(ctx context.Context, txn Verific // Verifying Keys MAC log.Info().Msg("Verifying MAC for all sent keys") var hasTheirDeviceKey bool + var masterKey string var keyIDs []string for keyID := range macEvt.MAC { keyIDs = append(keyIDs, keyID.String()) _, kID := keyID.Parse() if kID == txn.TheirDeviceID.String() { hasTheirDeviceKey = true + } else { + masterKey = kID } } slices.Sort(keyIDs) @@ -617,8 +693,9 @@ func (vh *VerificationHelper) onVerificationMAC(ctx context.Context, txn Verific } // Verify the MAC for each key + var theirDevice *id.Device for keyID, mac := range macEvt.MAC { - log.Info().Str("key_id", keyID.String()).Msg("Received MAC for key") + log.Info().Stringer("key_id", keyID).Msg("Received MAC for key") alg, kID := keyID.Parse() if alg != id.KeyAlgorithmEd25519 { @@ -627,8 +704,11 @@ func (vh *VerificationHelper) onVerificationMAC(ctx context.Context, txn Verific } var key string - var theirDevice *id.Device if kID == txn.TheirDeviceID.String() { + if theirDevice != nil { + vh.cancelVerificationTxn(ctx, txn, event.VerificationCancelCodeInvalidMessage, "two keys found for their device ID") + return + } theirDevice, err = vh.mach.GetOrFetchDevice(ctx, txn.TheirUserID, txn.TheirDeviceID) if err != nil { vh.cancelVerificationTxn(ctx, txn, event.VerificationCancelCodeUser, "failed to fetch their device: %w", err) @@ -653,26 +733,22 @@ func (vh *VerificationHelper) onVerificationMAC(ctx context.Context, txn Verific vh.cancelVerificationTxn(ctx, txn, event.VerificationCancelCodeUser, "failed to calculate key MAC: %w", err) return } - if !bytes.Equal(expectedMAC, mac) { + if subtle.ConstantTimeCompare(expectedMAC, mac) == 0 { vh.cancelVerificationTxn(ctx, txn, event.VerificationCancelCodeSASMismatch, "MAC mismatch for key %s", keyID) return } - - // Trust their device - if kID == txn.TheirDeviceID.String() { - theirDevice.Trust = id.TrustStateVerified - err = vh.mach.CryptoStore.PutDevice(ctx, txn.TheirUserID, theirDevice) - if err != nil { - vh.cancelVerificationTxn(ctx, txn, event.VerificationCancelCodeUser, "failed to update device trust state after verifying: %w", err) - return - } - } } log.Info().Msg("All MACs verified") txn.ReceivedTheirMAC = true if txn.SentOurMAC { txn.VerificationState = VerificationStateSASMACExchanged + + if err := vh.trustKeysAfterMACCheck(ctx, txn, masterKey); err != nil { + vh.cancelVerificationTxn(ctx, txn, event.VerificationCancelCodeUser, "failed to trust keys: %w", err) + return + } + err := vh.sendVerificationEvent(ctx, txn, event.InRoomVerificationDone, &event.VerificationDoneEventContent{}) if err != nil { vh.cancelVerificationTxn(ctx, txn, event.VerificationCancelCodeUser, "failed to send verification done event: %w", err) @@ -685,3 +761,52 @@ func (vh *VerificationHelper) onVerificationMAC(ctx context.Context, txn Verific log.Err(err).Msg("failed to save verification transaction") } } + +func (vh *VerificationHelper) trustKeysAfterMACCheck(ctx context.Context, txn VerificationTransaction, masterKey string) error { + theirDevice, err := vh.mach.GetOrFetchDevice(ctx, txn.TheirUserID, txn.TheirDeviceID) + if err != nil { + return fmt.Errorf("failed to fetch their device: %w", err) + } + // Trust their device + theirDevice.Trust = id.TrustStateVerified + err = vh.mach.CryptoStore.PutDevice(ctx, txn.TheirUserID, theirDevice) + if err != nil { + return fmt.Errorf("failed to update device trust state after verifying: %w", err) + } + + if txn.TheirUserID == vh.client.UserID { + // Self-signing situation. + // + // If we have the cross-signing keys, then we need to sign their device + // using the self-signing key. Otherwise, they have the master private + // key, so we need to trust the master public key. + if vh.mach.CrossSigningKeys != nil { + err = vh.mach.SignOwnDevice(ctx, theirDevice) + if err != nil { + return fmt.Errorf("failed to sign our own new device: %w", err) + } + } else { + err = vh.mach.SignOwnMasterKey(ctx) + if err != nil { + return fmt.Errorf("failed to sign our own master key: %w", err) + } + } + } else if masterKey != "" { + // Cross-signing situation. + // + // The master key was included in the list of keys to verify, so verify + // that it matches what we expect and sign their master key using the + // user-signing key. + theirSigningKeys, err := vh.mach.GetCrossSigningPublicKeys(ctx, txn.TheirUserID) + if err != nil { + return fmt.Errorf("couldn't get %s's cross-signing keys: %w", txn.TheirUserID, err) + } else if theirSigningKeys.MasterKey.String() != masterKey { + return fmt.Errorf("master keys do not match") + } + + if err := vh.mach.SignUser(ctx, txn.TheirUserID, theirSigningKeys.MasterKey); err != nil { + return fmt.Errorf("failed to sign %s's master key: %w", txn.TheirUserID, err) + } + } + return nil +} diff --git a/crypto/verificationhelper/verificationhelper.go b/crypto/verificationhelper/verificationhelper.go index de943976..0a781c16 100644 --- a/crypto/verificationhelper/verificationhelper.go +++ b/crypto/verificationhelper/verificationhelper.go @@ -15,6 +15,7 @@ import ( "time" "github.com/rs/zerolog" + "go.mau.fi/util/exslices" "go.mau.fi/util/jsontime" "golang.org/x/exp/maps" "golang.org/x/exp/slices" @@ -32,30 +33,26 @@ type RequiredCallbacks interface { // from another device. VerificationRequested(ctx context.Context, txnID id.VerificationTransactionID, from id.UserID, fromDevice id.DeviceID) + // VerificationReady is called when a verification request has been + // accepted by both parties. + VerificationReady(ctx context.Context, txnID id.VerificationTransactionID, otherDeviceID id.DeviceID, supportsSAS, supportsScanQRCode bool, qrCode *QRCode) + // VerificationCancelled is called when the verification is cancelled. VerificationCancelled(ctx context.Context, txnID id.VerificationTransactionID, code event.VerificationCancelCode, reason string) // VerificationDone is called when the verification is done. - VerificationDone(ctx context.Context, txnID id.VerificationTransactionID) + VerificationDone(ctx context.Context, txnID id.VerificationTransactionID, method event.VerificationMethod) } type ShowSASCallbacks interface { // ShowSAS is a callback that is called when the SAS verification has // generated a short authentication string to show. It is guaranteed that - // either the emojis list, or the decimals list, or both will be present. - ShowSAS(ctx context.Context, txnID id.VerificationTransactionID, emojis []rune, decimals []int) + // either the emojis and emoji descriptions lists, or the decimals list, or + // both will be present. + ShowSAS(ctx context.Context, txnID id.VerificationTransactionID, emojis []rune, emojiDescriptions []string, decimals []int) } type ShowQRCodeCallbacks interface { - // ScanQRCode is called when another device has sent a - // m.key.verification.ready event and indicated that they are capable of - // showing a QR code. - ScanQRCode(ctx context.Context, txnID id.VerificationTransactionID) - - // ShowQRCode is called when the verification has been accepted and a QR - // code should be shown to the user. - ShowQRCode(ctx context.Context, txnID id.VerificationTransactionID, qrCode *QRCode) - // QRCodeScanned is called when the other user has scanned the QR code and // sent the m.key.verification.start event. QRCodeScanned(ctx context.Context, txnID id.VerificationTransactionID) @@ -67,24 +64,25 @@ type VerificationHelper struct { store VerificationStore activeTransactionsLock sync.Mutex - // activeTransactions map[id.VerificationTransactionID]*verificationTransaction // supportedMethods are the methods that *we* support supportedMethods []event.VerificationMethod verificationRequested func(ctx context.Context, txnID id.VerificationTransactionID, from id.UserID, fromDevice id.DeviceID) + verificationReady func(ctx context.Context, txnID id.VerificationTransactionID, otherDeviceID id.DeviceID, supportsSAS, supportsScanQRCode bool, qrCode *QRCode) verificationCancelledCallback func(ctx context.Context, txnID id.VerificationTransactionID, code event.VerificationCancelCode, reason string) - verificationDone func(ctx context.Context, txnID id.VerificationTransactionID) + verificationDone func(ctx context.Context, txnID id.VerificationTransactionID, method event.VerificationMethod) - showSAS func(ctx context.Context, txnID id.VerificationTransactionID, emojis []rune, decimals []int) - - scanQRCode func(ctx context.Context, txnID id.VerificationTransactionID) - showQRCode func(ctx context.Context, txnID id.VerificationTransactionID, qrCode *QRCode) - qrCodeScaned func(ctx context.Context, txnID id.VerificationTransactionID) + // showSAS is a callback that will be called after the SAS verification + // dance is complete and we want the client to show the emojis/decimals + showSAS func(ctx context.Context, txnID id.VerificationTransactionID, emojis []rune, emojiDescriptions []string, decimals []int) + // qrCodeScanned is a callback that will be called when the other device + // scanned the QR code we are showing + qrCodeScanned func(ctx context.Context, txnID id.VerificationTransactionID) } var _ mautrix.VerificationHelper = (*VerificationHelper)(nil) -func NewVerificationHelper(client *mautrix.Client, mach *crypto.OlmMachine, store VerificationStore, callbacks any, supportsScan bool) *VerificationHelper { +func NewVerificationHelper(client *mautrix.Client, mach *crypto.OlmMachine, store VerificationStore, callbacks any, supportsQRShow, supportsQRScan, supportsSAS bool) *VerificationHelper { if client.Crypto == nil { panic("client.Crypto is nil") } @@ -103,28 +101,33 @@ func NewVerificationHelper(client *mautrix.Client, mach *crypto.OlmMachine, stor panic("callbacks must implement RequiredCallbacks") } else { helper.verificationRequested = c.VerificationRequested + helper.verificationReady = c.VerificationReady helper.verificationCancelledCallback = c.VerificationCancelled helper.verificationDone = c.VerificationDone } - supportedMethods := map[event.VerificationMethod]struct{}{} - if c, ok := callbacks.(ShowSASCallbacks); ok { - supportedMethods[event.VerificationMethodSAS] = struct{}{} - helper.showSAS = c.ShowSAS + if supportsSAS { + if c, ok := callbacks.(ShowSASCallbacks); !ok { + panic("callbacks must implement showSAS if supportsSAS is true") + } else { + helper.supportedMethods = append(helper.supportedMethods, event.VerificationMethodSAS) + helper.showSAS = c.ShowSAS + } } - if c, ok := callbacks.(ShowQRCodeCallbacks); ok { - supportedMethods[event.VerificationMethodQRCodeShow] = struct{}{} - supportedMethods[event.VerificationMethodReciprocate] = struct{}{} - helper.scanQRCode = c.ScanQRCode - helper.showQRCode = c.ShowQRCode - helper.qrCodeScaned = c.QRCodeScanned + if supportsQRShow { + if c, ok := callbacks.(ShowQRCodeCallbacks); !ok { + panic("callbacks must implement ShowQRCodeCallbacks if supportsQRShow is true") + } else { + helper.supportedMethods = append(helper.supportedMethods, event.VerificationMethodQRCodeShow) + helper.supportedMethods = append(helper.supportedMethods, event.VerificationMethodReciprocate) + helper.qrCodeScanned = c.QRCodeScanned + } } - if supportsScan { - supportedMethods[event.VerificationMethodQRCodeScan] = struct{}{} - supportedMethods[event.VerificationMethodReciprocate] = struct{}{} + if supportsQRScan { + helper.supportedMethods = append(helper.supportedMethods, event.VerificationMethodQRCodeScan) + helper.supportedMethods = append(helper.supportedMethods, event.VerificationMethodReciprocate) } - - helper.supportedMethods = maps.Keys(supportedMethods) + helper.supportedMethods = exslices.DeduplicateUnsorted(helper.supportedMethods) return &helper } @@ -375,6 +378,9 @@ func (vh *VerificationHelper) StartInRoomVerification(ctx context.Context, roomI // be the transaction ID of a verification request that was received via the // VerificationRequested callback in [RequiredCallbacks]. func (vh *VerificationHelper) AcceptVerification(ctx context.Context, txnID id.VerificationTransactionID) error { + vh.activeTransactionsLock.Lock() + defer vh.activeTransactionsLock.Unlock() + log := vh.getLog(ctx).With(). Str("verification_action", "accept verification"). Stringer("transaction_id", txnID). @@ -419,13 +425,19 @@ func (vh *VerificationHelper) AcceptVerification(ctx context.Context, txnID id.V } txn.VerificationState = VerificationStateReady - if vh.scanQRCode != nil && slices.Contains(txn.TheirSupportedMethods, event.VerificationMethodQRCodeShow) { - vh.scanQRCode(ctx, txn.TransactionID) - } + supportsSAS := slices.Contains(vh.supportedMethods, event.VerificationMethodSAS) && + slices.Contains(txn.TheirSupportedMethods, event.VerificationMethodSAS) + supportsReciprocate := slices.Contains(vh.supportedMethods, event.VerificationMethodReciprocate) && + slices.Contains(txn.TheirSupportedMethods, event.VerificationMethodReciprocate) + supportsScanQRCode := supportsReciprocate && + slices.Contains(vh.supportedMethods, event.VerificationMethodQRCodeScan) && + slices.Contains(txn.TheirSupportedMethods, event.VerificationMethodQRCodeShow) - if err := vh.generateAndShowQRCode(ctx, &txn); err != nil { + qrCode, err := vh.generateQRCode(ctx, &txn) + if err != nil { return err } + vh.verificationReady(ctx, txn.TransactionID, txn.TheirDeviceID, supportsSAS, supportsScanQRCode, qrCode) return vh.store.SaveVerificationTransaction(ctx, txn) } @@ -733,13 +745,23 @@ func (vh *VerificationHelper) onVerificationReady(ctx context.Context, txn Verif } } - if vh.scanQRCode != nil && slices.Contains(txn.TheirSupportedMethods, event.VerificationMethodQRCodeShow) { - vh.scanQRCode(ctx, txn.TransactionID) + supportsSAS := slices.Contains(vh.supportedMethods, event.VerificationMethodSAS) && + slices.Contains(txn.TheirSupportedMethods, event.VerificationMethodSAS) + supportsReciprocate := slices.Contains(vh.supportedMethods, event.VerificationMethodReciprocate) && + slices.Contains(txn.TheirSupportedMethods, event.VerificationMethodReciprocate) + supportsScanQRCode := supportsReciprocate && + slices.Contains(vh.supportedMethods, event.VerificationMethodQRCodeScan) && + slices.Contains(txn.TheirSupportedMethods, event.VerificationMethodQRCodeShow) + + qrCode, err := vh.generateQRCode(ctx, &txn) + if err != nil { + vh.cancelVerificationTxn(ctx, txn, event.VerificationCancelCodeInternalError, "failed to generate QR code: %w", err) + return } - if err := vh.generateAndShowQRCode(ctx, &txn); err != nil { - vh.cancelVerificationTxn(ctx, txn, event.VerificationCancelCodeInternalError, "failed to generate and show QR code: %w", err) - } else if err := vh.store.SaveVerificationTransaction(ctx, txn); err != nil { + vh.verificationReady(ctx, txn.TransactionID, txn.TheirDeviceID, supportsSAS, supportsScanQRCode, qrCode) + + if err := vh.store.SaveVerificationTransaction(ctx, txn); err != nil { vh.cancelVerificationTxn(ctx, txn, event.VerificationCancelCodeInternalError, "failed to save verification transaction: %w", err) } } @@ -799,6 +821,8 @@ func (vh *VerificationHelper) onVerificationStart(ctx context.Context, txn Verif } else if txn.VerificationState != VerificationStateReady { vh.cancelVerificationTxn(ctx, txn, event.VerificationCancelCodeUnexpectedMessage, "got start event for transaction that is not in ready state") return + } else { + txn.StartEventContent = startEvt } switch startEvt.Method { @@ -810,12 +834,12 @@ func (vh *VerificationHelper) onVerificationStart(ctx context.Context, txn Verif } case event.VerificationMethodReciprocate: log.Info().Msg("Received reciprocate start event") - if !bytes.Equal(txn.QRCodeSharedSecret, startEvt.Secret) { + if !bytes.Equal(txn.QRCodeSharedSecret, txn.StartEventContent.Secret) { vh.cancelVerificationTxn(ctx, txn, event.VerificationCancelCodeKeyMismatch, "reciprocated shared secret does not match") return } txn.VerificationState = VerificationStateOurQRScanned - vh.qrCodeScaned(ctx, txn.TransactionID) + vh.qrCodeScanned(ctx, txn.TransactionID) if err := vh.store.SaveVerificationTransaction(ctx, txn); err != nil { log.Err(err).Msg("failed to save verification transaction") } @@ -823,8 +847,8 @@ func (vh *VerificationHelper) onVerificationStart(ctx context.Context, txn Verif // Note that we should never get m.qr_code.show.v1 or m.qr_code.scan.v1 // here, since the start command for scanning and showing QR codes // should be of type m.reciprocate.v1. - log.Error().Str("method", string(startEvt.Method)).Msg("Unsupported verification method in start event") - vh.cancelVerificationTxn(ctx, txn, event.VerificationCancelCodeUnknownMethod, fmt.Sprintf("unknown method %s", startEvt.Method)) + log.Error().Str("method", string(txn.StartEventContent.Method)).Msg("Unsupported verification method in start event") + vh.cancelVerificationTxn(ctx, txn, event.VerificationCancelCodeUnknownMethod, "unknown method %s", txn.StartEventContent.Method) } } @@ -851,7 +875,7 @@ func (vh *VerificationHelper) onVerificationDone(ctx context.Context, txn Verifi if err := vh.store.DeleteVerification(ctx, txn.TransactionID); err != nil { log.Err(err).Msg("Delete verification failed") } - vh.verificationDone(ctx, txn.TransactionID) + vh.verificationDone(ctx, txn.TransactionID, txn.StartEventContent.Method) } else if err := vh.store.SaveVerificationTransaction(ctx, txn); err != nil { log.Err(err).Msg("failed to save verification transaction") } diff --git a/crypto/verificationhelper/verificationhelper_qr_crosssign_test.go b/crypto/verificationhelper/verificationhelper_qr_crosssign_test.go index aace2230..5e3f146b 100644 --- a/crypto/verificationhelper/verificationhelper_qr_crosssign_test.go +++ b/crypto/verificationhelper/verificationhelper_qr_crosssign_test.go @@ -32,7 +32,6 @@ func TestCrossSignVerification_ScanQRAndConfirmScan(t *testing.T) { for _, tc := range testCases { t.Run(fmt.Sprintf("sendingScansQR=%t", tc.sendingScansQR), func(t *testing.T) { ts, sendingClient, receivingClient, _, _, sendingMachine, receivingMachine := initServerAndLoginAliceBob(t, ctx) - defer ts.Close() sendingCallbacks, receivingCallbacks, sendingHelper, receivingHelper := initDefaultCallbacks(t, ctx, sendingClient, receivingClient, sendingMachine, receivingMachine) var err error @@ -51,10 +50,10 @@ func TestCrossSignVerification_ScanQRAndConfirmScan(t *testing.T) { // event on the sending device. txnID, err := sendingHelper.StartVerification(ctx, bobUserID) require.NoError(t, err) - ts.dispatchToDevice(t, ctx, receivingClient) + ts.DispatchToDevice(t, ctx, receivingClient) err = receivingHelper.AcceptVerification(ctx, txnID) require.NoError(t, err) - ts.dispatchToDevice(t, ctx, sendingClient) + ts.DispatchToDevice(t, ctx, sendingClient) receivingShownQRCode := receivingCallbacks.GetQRCodeShown(txnID) require.NotNil(t, receivingShownQRCode) @@ -83,7 +82,7 @@ func TestCrossSignVerification_ScanQRAndConfirmScan(t *testing.T) { // Handle the start and done events on the receiving client and // confirm the scan. - ts.dispatchToDevice(t, ctx, receivingClient) + ts.DispatchToDevice(t, ctx, receivingClient) // Ensure that the receiving device detected that its QR code // was scanned. @@ -98,7 +97,7 @@ func TestCrossSignVerification_ScanQRAndConfirmScan(t *testing.T) { doneEvt = sendingInbox[0].Content.AsVerificationDone() assert.Equal(t, txnID, doneEvt.TransactionID) - ts.dispatchToDevice(t, ctx, sendingClient) + ts.DispatchToDevice(t, ctx, sendingClient) } else { // receiving scans QR // Emulate scanning the QR code shown by the sending device on // the receiving device. @@ -121,7 +120,7 @@ func TestCrossSignVerification_ScanQRAndConfirmScan(t *testing.T) { // Handle the start and done events on the receiving client and // confirm the scan. - ts.dispatchToDevice(t, ctx, sendingClient) + ts.DispatchToDevice(t, ctx, sendingClient) // Ensure that the sending device detected that its QR code was // scanned. @@ -136,7 +135,7 @@ func TestCrossSignVerification_ScanQRAndConfirmScan(t *testing.T) { doneEvt = receivingInbox[0].Content.AsVerificationDone() assert.Equal(t, txnID, doneEvt.TransactionID) - ts.dispatchToDevice(t, ctx, receivingClient) + ts.DispatchToDevice(t, ctx, receivingClient) } // Ensure that both devices have marked the verification as done. diff --git a/crypto/verificationhelper/verificationhelper_qr_self_test.go b/crypto/verificationhelper/verificationhelper_qr_self_test.go index 937cc414..ea918cd4 100644 --- a/crypto/verificationhelper/verificationhelper_qr_self_test.go +++ b/crypto/verificationhelper/verificationhelper_qr_self_test.go @@ -36,7 +36,6 @@ func TestSelfVerification_Accept_QRContents(t *testing.T) { for _, tc := range testCases { t.Run(fmt.Sprintf("sendingGenerated=%t receivingGenerated=%t err=%s", tc.sendingGeneratedCrossSigningKeys, tc.receivingGeneratedCrossSigningKeys, tc.expectedAcceptError), func(t *testing.T) { ts, sendingClient, receivingClient, _, _, sendingMachine, receivingMachine := initServerAndLoginTwoAlice(t, ctx) - defer ts.Close() sendingCallbacks, receivingCallbacks, sendingHelper, receivingHelper := initDefaultCallbacks(t, ctx, sendingClient, receivingClient, sendingMachine, receivingMachine) var err error @@ -62,7 +61,7 @@ func TestSelfVerification_Accept_QRContents(t *testing.T) { // event on the sending device. txnID, err := sendingHelper.StartVerification(ctx, aliceUserID) require.NoError(t, err) - ts.dispatchToDevice(t, ctx, receivingClient) + ts.DispatchToDevice(t, ctx, receivingClient) err = receivingHelper.AcceptVerification(ctx, txnID) if tc.expectedAcceptError != "" { @@ -72,7 +71,7 @@ func TestSelfVerification_Accept_QRContents(t *testing.T) { require.NoError(t, err) } - ts.dispatchToDevice(t, ctx, sendingClient) + ts.DispatchToDevice(t, ctx, sendingClient) receivingShownQRCode := receivingCallbacks.GetQRCodeShown(txnID) require.NotNil(t, receivingShownQRCode) @@ -135,7 +134,6 @@ func TestSelfVerification_ScanQRAndConfirmScan(t *testing.T) { for _, tc := range testCases { t.Run(fmt.Sprintf("sendingGeneratedCrossSigningKeys=%t sendingScansQR=%t", tc.sendingGeneratedCrossSigningKeys, tc.sendingScansQR), func(t *testing.T) { ts, sendingClient, receivingClient, _, _, sendingMachine, receivingMachine := initServerAndLoginTwoAlice(t, ctx) - defer ts.Close() sendingCallbacks, receivingCallbacks, sendingHelper, receivingHelper := initDefaultCallbacks(t, ctx, sendingClient, receivingClient, sendingMachine, receivingMachine) var err error @@ -152,10 +150,10 @@ func TestSelfVerification_ScanQRAndConfirmScan(t *testing.T) { // event on the sending device. txnID, err := sendingHelper.StartVerification(ctx, aliceUserID) require.NoError(t, err) - ts.dispatchToDevice(t, ctx, receivingClient) + ts.DispatchToDevice(t, ctx, receivingClient) err = receivingHelper.AcceptVerification(ctx, txnID) require.NoError(t, err) - ts.dispatchToDevice(t, ctx, sendingClient) + ts.DispatchToDevice(t, ctx, sendingClient) receivingShownQRCode := receivingCallbacks.GetQRCodeShown(txnID) require.NotNil(t, receivingShownQRCode) @@ -184,7 +182,7 @@ func TestSelfVerification_ScanQRAndConfirmScan(t *testing.T) { // Handle the start and done events on the receiving client and // confirm the scan. - ts.dispatchToDevice(t, ctx, receivingClient) + ts.DispatchToDevice(t, ctx, receivingClient) // Ensure that the receiving device detected that its QR code // was scanned. @@ -199,7 +197,7 @@ func TestSelfVerification_ScanQRAndConfirmScan(t *testing.T) { doneEvt = sendingInbox[0].Content.AsVerificationDone() assert.Equal(t, txnID, doneEvt.TransactionID) - ts.dispatchToDevice(t, ctx, sendingClient) + ts.DispatchToDevice(t, ctx, sendingClient) } else { // receiving scans QR // Emulate scanning the QR code shown by the sending device on // the receiving device. @@ -222,7 +220,7 @@ func TestSelfVerification_ScanQRAndConfirmScan(t *testing.T) { // Handle the start and done events on the receiving client and // confirm the scan. - ts.dispatchToDevice(t, ctx, sendingClient) + ts.DispatchToDevice(t, ctx, sendingClient) // Ensure that the sending device detected that its QR code was // scanned. @@ -237,7 +235,7 @@ func TestSelfVerification_ScanQRAndConfirmScan(t *testing.T) { doneEvt = receivingInbox[0].Content.AsVerificationDone() assert.Equal(t, txnID, doneEvt.TransactionID) - ts.dispatchToDevice(t, ctx, receivingClient) + ts.DispatchToDevice(t, ctx, receivingClient) } // Ensure that both devices have marked the verification as done. @@ -251,7 +249,6 @@ func TestSelfVerification_ScanQRTransactionIDCorrupted(t *testing.T) { ctx := log.Logger.WithContext(context.TODO()) ts, sendingClient, receivingClient, _, _, sendingMachine, receivingMachine := initServerAndLoginTwoAlice(t, ctx) - defer ts.Close() sendingCallbacks, receivingCallbacks, sendingHelper, receivingHelper := initDefaultCallbacks(t, ctx, sendingClient, receivingClient, sendingMachine, receivingMachine) var err error @@ -263,10 +260,10 @@ func TestSelfVerification_ScanQRTransactionIDCorrupted(t *testing.T) { // event on the sending device. txnID, err := sendingHelper.StartVerification(ctx, aliceUserID) require.NoError(t, err) - ts.dispatchToDevice(t, ctx, receivingClient) + ts.DispatchToDevice(t, ctx, receivingClient) err = receivingHelper.AcceptVerification(ctx, txnID) require.NoError(t, err) - ts.dispatchToDevice(t, ctx, sendingClient) + ts.DispatchToDevice(t, ctx, sendingClient) receivingShownQRCodeBytes := receivingCallbacks.GetQRCodeShown(txnID).Bytes() sendingShownQRCodeBytes := sendingCallbacks.GetQRCodeShown(txnID).Bytes() @@ -310,7 +307,6 @@ func TestSelfVerification_ScanQRKeyCorrupted(t *testing.T) { for _, tc := range testCases { t.Run(fmt.Sprintf("sendingGeneratedCrossSigningKeys=%t sendingScansQR=%t corrupt=%d", tc.sendingGeneratedCrossSigningKeys, tc.sendingScansQR, tc.corruptByte), func(t *testing.T) { ts, sendingClient, receivingClient, _, _, sendingMachine, receivingMachine := initServerAndLoginTwoAlice(t, ctx) - defer ts.Close() sendingCallbacks, receivingCallbacks, sendingHelper, receivingHelper := initDefaultCallbacks(t, ctx, sendingClient, receivingClient, sendingMachine, receivingMachine) var err error @@ -327,10 +323,10 @@ func TestSelfVerification_ScanQRKeyCorrupted(t *testing.T) { // event on the sending device. txnID, err := sendingHelper.StartVerification(ctx, aliceUserID) require.NoError(t, err) - ts.dispatchToDevice(t, ctx, receivingClient) + ts.DispatchToDevice(t, ctx, receivingClient) err = receivingHelper.AcceptVerification(ctx, txnID) require.NoError(t, err) - ts.dispatchToDevice(t, ctx, sendingClient) + ts.DispatchToDevice(t, ctx, sendingClient) receivingShownQRCodeBytes := receivingCallbacks.GetQRCodeShown(txnID).Bytes() sendingShownQRCodeBytes := sendingCallbacks.GetQRCodeShown(txnID).Bytes() @@ -348,7 +344,7 @@ func TestSelfVerification_ScanQRKeyCorrupted(t *testing.T) { // Ensure that the receiving device received a cancellation. receivingInbox := ts.DeviceInbox[aliceUserID][receivingDeviceID] assert.Len(t, receivingInbox, 1) - ts.dispatchToDevice(t, ctx, receivingClient) + ts.DispatchToDevice(t, ctx, receivingClient) cancellation := receivingCallbacks.GetVerificationCancellation(txnID) require.NotNil(t, cancellation) assert.Equal(t, event.VerificationCancelCodeKeyMismatch, cancellation.Code) @@ -362,7 +358,7 @@ func TestSelfVerification_ScanQRKeyCorrupted(t *testing.T) { // Ensure that the sending device received a cancellation. sendingInbox := ts.DeviceInbox[aliceUserID][sendingDeviceID] assert.Len(t, sendingInbox, 1) - ts.dispatchToDevice(t, ctx, sendingClient) + ts.DispatchToDevice(t, ctx, sendingClient) cancellation := sendingCallbacks.GetVerificationCancellation(txnID) require.NotNil(t, cancellation) assert.Equal(t, event.VerificationCancelCodeKeyMismatch, cancellation.Code) diff --git a/crypto/verificationhelper/verificationhelper_sas_test.go b/crypto/verificationhelper/verificationhelper_sas_test.go index 20e52e0f..283eca84 100644 --- a/crypto/verificationhelper/verificationhelper_sas_test.go +++ b/crypto/verificationhelper/verificationhelper_sas_test.go @@ -36,7 +36,6 @@ func TestVerification_SAS(t *testing.T) { for _, tc := range testCases { t.Run(fmt.Sprintf("sendingGenerated=%t sendingStartsSAS=%t sendingConfirmsFirst=%t", tc.sendingGeneratedCrossSigningKeys, tc.sendingStartsSAS, tc.sendingConfirmsFirst), func(t *testing.T) { ts, sendingClient, receivingClient, _, _, sendingMachine, receivingMachine := initServerAndLoginTwoAlice(t, ctx) - defer ts.Close() sendingCallbacks, receivingCallbacks, sendingHelper, receivingHelper := initDefaultCallbacks(t, ctx, sendingClient, receivingClient, sendingMachine, receivingMachine) var err error @@ -60,10 +59,10 @@ func TestVerification_SAS(t *testing.T) { // event on the sending device. txnID, err := sendingHelper.StartVerification(ctx, aliceUserID) require.NoError(t, err) - ts.dispatchToDevice(t, ctx, receivingClient) + ts.DispatchToDevice(t, ctx, receivingClient) err = receivingHelper.AcceptVerification(ctx, txnID) require.NoError(t, err) - ts.dispatchToDevice(t, ctx, sendingClient) + ts.DispatchToDevice(t, ctx, sendingClient) // Test that the start event is correct var startEvt *event.VerificationStartEventContent @@ -102,7 +101,7 @@ func TestVerification_SAS(t *testing.T) { if tc.sendingStartsSAS { // Process the verification start event on the receiving // device. - ts.dispatchToDevice(t, ctx, receivingClient) + ts.DispatchToDevice(t, ctx, receivingClient) // Receiving device sent the accept event to the sending device sendingInbox := ts.DeviceInbox[aliceUserID][sendingDeviceID] @@ -110,7 +109,7 @@ func TestVerification_SAS(t *testing.T) { acceptEvt = sendingInbox[0].Content.AsVerificationAccept() } else { // Process the verification start event on the sending device. - ts.dispatchToDevice(t, ctx, sendingClient) + ts.DispatchToDevice(t, ctx, sendingClient) // Sending device sent the accept event to the receiving device receivingInbox := ts.DeviceInbox[aliceUserID][receivingDeviceID] @@ -129,7 +128,7 @@ func TestVerification_SAS(t *testing.T) { var firstKeyEvt *event.VerificationKeyEventContent if tc.sendingStartsSAS { // Process the verification accept event on the sending device. - ts.dispatchToDevice(t, ctx, sendingClient) + ts.DispatchToDevice(t, ctx, sendingClient) // Sending device sends first key event to the receiving // device. @@ -139,7 +138,7 @@ func TestVerification_SAS(t *testing.T) { } else { // Process the verification accept event on the receiving // device. - ts.dispatchToDevice(t, ctx, receivingClient) + ts.DispatchToDevice(t, ctx, receivingClient) // Receiving device sends first key event to the sending // device. @@ -155,7 +154,7 @@ func TestVerification_SAS(t *testing.T) { var secondKeyEvt *event.VerificationKeyEventContent if tc.sendingStartsSAS { // Process the first key event on the receiving device. - ts.dispatchToDevice(t, ctx, receivingClient) + ts.DispatchToDevice(t, ctx, receivingClient) // Receiving device sends second key event to the sending // device. @@ -165,10 +164,12 @@ func TestVerification_SAS(t *testing.T) { // Ensure that the receiving device showed emojis and SAS numbers. assert.Len(t, receivingCallbacks.GetDecimalsShown(txnID), 3) - assert.Len(t, receivingCallbacks.GetEmojisShown(txnID), 7) + emojis, descriptions := receivingCallbacks.GetEmojisAndDescriptionsShown(txnID) + assert.Len(t, emojis, 7) + assert.Len(t, descriptions, 7) } else { // Process the first key event on the sending device. - ts.dispatchToDevice(t, ctx, sendingClient) + ts.DispatchToDevice(t, ctx, sendingClient) // Sending device sends second key event to the receiving // device. @@ -178,7 +179,9 @@ func TestVerification_SAS(t *testing.T) { // Ensure that the sending device showed emojis and SAS numbers. assert.Len(t, sendingCallbacks.GetDecimalsShown(txnID), 3) - assert.Len(t, sendingCallbacks.GetEmojisShown(txnID), 7) + emojis, descriptions := sendingCallbacks.GetEmojisAndDescriptionsShown(txnID) + assert.Len(t, emojis, 7) + assert.Len(t, descriptions, 7) } assert.Equal(t, txnID, secondKeyEvt.TransactionID) assert.NotEmpty(t, secondKeyEvt.Key) @@ -187,13 +190,16 @@ func TestVerification_SAS(t *testing.T) { // Ensure that the SAS codes are the same. if tc.sendingStartsSAS { // Process the second key event on the sending device. - ts.dispatchToDevice(t, ctx, sendingClient) + ts.DispatchToDevice(t, ctx, sendingClient) } else { // Process the second key event on the receiving device. - ts.dispatchToDevice(t, ctx, receivingClient) + ts.DispatchToDevice(t, ctx, receivingClient) } assert.Equal(t, sendingCallbacks.GetDecimalsShown(txnID), receivingCallbacks.GetDecimalsShown(txnID)) - assert.Equal(t, sendingCallbacks.GetEmojisShown(txnID), receivingCallbacks.GetEmojisShown(txnID)) + sendingEmojis, sendingDescriptions := sendingCallbacks.GetEmojisAndDescriptionsShown(txnID) + receivingEmojis, receivingDescriptions := receivingCallbacks.GetEmojisAndDescriptionsShown(txnID) + assert.Equal(t, sendingEmojis, receivingEmojis) + assert.Equal(t, sendingDescriptions, receivingDescriptions) // Test that the first MAC event is correct var firstMACEvt *event.VerificationMACEventContent @@ -267,12 +273,88 @@ func TestVerification_SAS(t *testing.T) { // Test the transaction is done on both sides. We have to dispatch // twice to process and drain all of the events. - ts.dispatchToDevice(t, ctx, sendingClient) - ts.dispatchToDevice(t, ctx, receivingClient) - ts.dispatchToDevice(t, ctx, sendingClient) - ts.dispatchToDevice(t, ctx, receivingClient) + ts.DispatchToDevice(t, ctx, sendingClient) + ts.DispatchToDevice(t, ctx, receivingClient) + ts.DispatchToDevice(t, ctx, sendingClient) + ts.DispatchToDevice(t, ctx, receivingClient) assert.True(t, sendingCallbacks.IsVerificationDone(txnID)) assert.True(t, receivingCallbacks.IsVerificationDone(txnID)) }) } } + +func TestVerification_SAS_BothCallStart(t *testing.T) { + ctx := log.Logger.WithContext(context.TODO()) + + ts, sendingClient, receivingClient, _, _, sendingMachine, receivingMachine := initServerAndLoginTwoAlice(t, ctx) + sendingCallbacks, receivingCallbacks, sendingHelper, receivingHelper := initDefaultCallbacks(t, ctx, sendingClient, receivingClient, sendingMachine, receivingMachine) + var err error + + var sendingRecoveryKey string + var sendingCrossSigningKeysCache *crypto.CrossSigningKeysCache + + sendingRecoveryKey, sendingCrossSigningKeysCache, err = sendingMachine.GenerateAndUploadCrossSigningKeys(ctx, nil, "") + require.NoError(t, err) + assert.NotEmpty(t, sendingRecoveryKey) + assert.NotNil(t, sendingCrossSigningKeysCache) + + // Send the verification request from the sender device and accept + // it on the receiving device and receive the verification ready + // event on the sending device. + txnID, err := sendingHelper.StartVerification(ctx, aliceUserID) + require.NoError(t, err) + ts.DispatchToDevice(t, ctx, receivingClient) + err = receivingHelper.AcceptVerification(ctx, txnID) + require.NoError(t, err) + ts.DispatchToDevice(t, ctx, sendingClient) + + err = sendingHelper.StartSAS(ctx, txnID) + require.NoError(t, err) + + err = receivingHelper.StartSAS(ctx, txnID) + require.NoError(t, err) + + // Ensure that both devices have received the verification start event. + receivingInbox := ts.DeviceInbox[aliceUserID][receivingDeviceID] + assert.Len(t, receivingInbox, 1) + assert.Equal(t, txnID, receivingInbox[0].Content.AsVerificationStart().TransactionID) + sendingInbox := ts.DeviceInbox[aliceUserID][sendingDeviceID] + assert.Len(t, sendingInbox, 1) + assert.Equal(t, txnID, sendingInbox[0].Content.AsVerificationStart().TransactionID) + + // Process the start event from the receiving client to the sending client. + ts.DispatchToDevice(t, ctx, sendingClient) + receivingInbox = ts.DeviceInbox[aliceUserID][receivingDeviceID] + assert.Len(t, receivingInbox, 2) + assert.Equal(t, txnID, receivingInbox[0].Content.AsVerificationStart().TransactionID) + assert.Equal(t, txnID, receivingInbox[1].Content.AsVerificationAccept().TransactionID) + + // Process the rest of the events until we need to confirm the SAS. + for len(ts.DeviceInbox[aliceUserID][sendingDeviceID]) > 0 || len(ts.DeviceInbox[aliceUserID][receivingDeviceID]) > 0 { + ts.DispatchToDevice(t, ctx, receivingClient) + ts.DispatchToDevice(t, ctx, sendingClient) + } + + // Confirm the SAS only the receiving device. + receivingHelper.ConfirmSAS(ctx, txnID) + ts.DispatchToDevice(t, ctx, sendingClient) + + // Verification is not done until both devices confirm the SAS. + assert.False(t, sendingCallbacks.IsVerificationDone(txnID)) + assert.False(t, receivingCallbacks.IsVerificationDone(txnID)) + + // Now, confirm it on the sending device. + sendingHelper.ConfirmSAS(ctx, txnID) + + // Dispatching the events to the receiving device should get us to the done + // state on the receiving device. + ts.DispatchToDevice(t, ctx, receivingClient) + assert.False(t, sendingCallbacks.IsVerificationDone(txnID)) + assert.True(t, receivingCallbacks.IsVerificationDone(txnID)) + + // Dispatching the events to the sending client should get us to the done + // state on the sending device. + ts.DispatchToDevice(t, ctx, sendingClient) + assert.True(t, sendingCallbacks.IsVerificationDone(txnID)) + assert.True(t, receivingCallbacks.IsVerificationDone(txnID)) +} diff --git a/crypto/verificationhelper/verificationhelper_test.go b/crypto/verificationhelper/verificationhelper_test.go index af4a28c3..ce5ec5b4 100644 --- a/crypto/verificationhelper/verificationhelper_test.go +++ b/crypto/verificationhelper/verificationhelper_test.go @@ -19,6 +19,7 @@ import ( "maunium.net/go/mautrix/crypto/verificationhelper" "maunium.net/go/mautrix/event" "maunium.net/go/mautrix/id" + "maunium.net/go/mautrix/mockserver" ) var aliceUserID = id.UserID("@alice:example.org") @@ -31,9 +32,19 @@ func init() { zerolog.DefaultContextLogger = &log.Logger } -func initServerAndLoginTwoAlice(t *testing.T, ctx context.Context) (ts *mockServer, sendingClient, receivingClient *mautrix.Client, sendingCryptoStore, receivingCryptoStore crypto.Store, sendingMachine, receivingMachine *crypto.OlmMachine) { +func addDeviceID(ctx context.Context, cryptoStore crypto.Store, userID id.UserID, deviceID id.DeviceID) { + err := cryptoStore.PutDevice(ctx, userID, &id.Device{ + UserID: userID, + DeviceID: deviceID, + }) + if err != nil { + panic(err) + } +} + +func initServerAndLoginTwoAlice(t *testing.T, ctx context.Context) (ts *mockserver.MockServer, sendingClient, receivingClient *mautrix.Client, sendingCryptoStore, receivingCryptoStore crypto.Store, sendingMachine, receivingMachine *crypto.OlmMachine) { t.Helper() - ts = createMockServer(t) + ts = mockserver.Create(t) sendingClient, sendingCryptoStore = ts.Login(t, ctx, aliceUserID, sendingDeviceID) sendingMachine = sendingClient.Crypto.(*cryptohelper.CryptoHelper).Machine() @@ -47,9 +58,9 @@ func initServerAndLoginTwoAlice(t *testing.T, ctx context.Context) (ts *mockServ return } -func initServerAndLoginAliceBob(t *testing.T, ctx context.Context) (ts *mockServer, sendingClient, receivingClient *mautrix.Client, sendingCryptoStore, receivingCryptoStore crypto.Store, sendingMachine, receivingMachine *crypto.OlmMachine) { +func initServerAndLoginAliceBob(t *testing.T, ctx context.Context) (ts *mockserver.MockServer, sendingClient, receivingClient *mautrix.Client, sendingCryptoStore, receivingCryptoStore crypto.Store, sendingMachine, receivingMachine *crypto.OlmMachine) { t.Helper() - ts = createMockServer(t) + ts = mockserver.Create(t) sendingClient, sendingCryptoStore = ts.Login(t, ctx, aliceUserID, sendingDeviceID) sendingMachine = sendingClient.Crypto.(*cryptohelper.CryptoHelper).Machine() @@ -71,7 +82,7 @@ func initDefaultCallbacks(t *testing.T, ctx context.Context, sendingClient, rece senderVerificationStore, err := NewSQLiteVerificationStore(ctx, senderVerificationDB) require.NoError(t, err) - sendingHelper = verificationhelper.NewVerificationHelper(sendingClient, sendingMachine, senderVerificationStore, sendingCallbacks, true) + sendingHelper = verificationhelper.NewVerificationHelper(sendingClient, sendingMachine, senderVerificationStore, sendingCallbacks, true, true, true) require.NoError(t, sendingHelper.Init(ctx)) receivingCallbacks = newAllVerificationCallbacks() @@ -79,7 +90,7 @@ func initDefaultCallbacks(t *testing.T, ctx context.Context, sendingClient, rece require.NoError(t, err) receiverVerificationStore, err := NewSQLiteVerificationStore(ctx, receiverVerificationDB) require.NoError(t, err) - receivingHelper = verificationhelper.NewVerificationHelper(receivingClient, receivingMachine, receiverVerificationStore, receivingCallbacks, true) + receivingHelper = verificationhelper.NewVerificationHelper(receivingClient, receivingMachine, receiverVerificationStore, receivingCallbacks, true, true, true) require.NoError(t, receivingHelper.Init(ctx)) return } @@ -89,32 +100,41 @@ func TestVerification_Start(t *testing.T) { receivingDeviceID2 := id.DeviceID("receiving2") testCases := []struct { + supportsShow bool supportsScan bool + supportsSAS bool callbacks MockVerificationCallbacks startVerificationErrMsg string expectedVerificationMethods []event.VerificationMethod }{ - {false, newBaseVerificationCallbacks(), "no supported verification methods", nil}, - {true, newBaseVerificationCallbacks(), "", []event.VerificationMethod{event.VerificationMethodQRCodeScan, event.VerificationMethodReciprocate}}, - {false, newSASVerificationCallbacks(), "", []event.VerificationMethod{event.VerificationMethodSAS}}, - {true, newSASVerificationCallbacks(), "", []event.VerificationMethod{event.VerificationMethodQRCodeScan, event.VerificationMethodReciprocate, event.VerificationMethodSAS}}, - {true, newQRCodeVerificationCallbacks(), "", []event.VerificationMethod{event.VerificationMethodQRCodeScan, event.VerificationMethodQRCodeShow, event.VerificationMethodReciprocate}}, - {false, newQRCodeVerificationCallbacks(), "", []event.VerificationMethod{event.VerificationMethodQRCodeShow, event.VerificationMethodReciprocate}}, - {false, newAllVerificationCallbacks(), "", []event.VerificationMethod{event.VerificationMethodQRCodeShow, event.VerificationMethodReciprocate, event.VerificationMethodSAS}}, - {true, newAllVerificationCallbacks(), "", []event.VerificationMethod{event.VerificationMethodQRCodeScan, event.VerificationMethodQRCodeShow, event.VerificationMethodReciprocate, event.VerificationMethodSAS}}, + {false, false, false, newBaseVerificationCallbacks(), "no supported verification methods", nil}, + {false, true, false, newBaseVerificationCallbacks(), "", []event.VerificationMethod{event.VerificationMethodQRCodeScan, event.VerificationMethodReciprocate}}, + + {false, false, false, newShowQRCodeVerificationCallbacks(), "no supported verification methods", nil}, + {true, false, false, newShowQRCodeVerificationCallbacks(), "", []event.VerificationMethod{event.VerificationMethodQRCodeShow, event.VerificationMethodReciprocate}}, + {false, true, false, newShowQRCodeVerificationCallbacks(), "", []event.VerificationMethod{event.VerificationMethodQRCodeScan, event.VerificationMethodReciprocate}}, + {true, true, false, newShowQRCodeVerificationCallbacks(), "", []event.VerificationMethod{event.VerificationMethodQRCodeShow, event.VerificationMethodQRCodeScan, event.VerificationMethodReciprocate}}, + + {false, false, true, newSASVerificationCallbacks(), "", []event.VerificationMethod{event.VerificationMethodSAS}}, + {false, true, true, newSASVerificationCallbacks(), "", []event.VerificationMethod{event.VerificationMethodSAS, event.VerificationMethodQRCodeScan, event.VerificationMethodReciprocate}}, + + {false, false, false, newAllVerificationCallbacks(), "no supported verification methods", nil}, + {false, false, true, newAllVerificationCallbacks(), "", []event.VerificationMethod{event.VerificationMethodSAS}}, + {false, true, true, newAllVerificationCallbacks(), "", []event.VerificationMethod{event.VerificationMethodSAS, event.VerificationMethodQRCodeScan, event.VerificationMethodReciprocate}}, + {true, false, true, newAllVerificationCallbacks(), "", []event.VerificationMethod{event.VerificationMethodSAS, event.VerificationMethodQRCodeShow, event.VerificationMethodReciprocate}}, + {true, true, true, newAllVerificationCallbacks(), "", []event.VerificationMethod{event.VerificationMethodSAS, event.VerificationMethodQRCodeShow, event.VerificationMethodQRCodeScan, event.VerificationMethodReciprocate}}, } for i, tc := range testCases { t.Run(fmt.Sprintf("%d", i), func(t *testing.T) { - ts := createMockServer(t) - defer ts.Close() + ts := mockserver.Create(t) client, cryptoStore := ts.Login(t, ctx, aliceUserID, sendingDeviceID) addDeviceID(ctx, cryptoStore, aliceUserID, sendingDeviceID) addDeviceID(ctx, cryptoStore, aliceUserID, receivingDeviceID) addDeviceID(ctx, cryptoStore, aliceUserID, receivingDeviceID2) - senderHelper := verificationhelper.NewVerificationHelper(client, client.Crypto.(*cryptohelper.CryptoHelper).Machine(), nil, tc.callbacks, tc.supportsScan) + senderHelper := verificationhelper.NewVerificationHelper(client, client.Crypto.(*cryptohelper.CryptoHelper).Machine(), nil, tc.callbacks, tc.supportsShow, tc.supportsScan, tc.supportsSAS) err := senderHelper.Init(ctx) require.NoError(t, err) @@ -124,7 +144,7 @@ func TestVerification_Start(t *testing.T) { return } - assert.NoError(t, err) + require.NoError(t, err) assert.NotEmpty(t, txnID) toDeviceInbox := ts.DeviceInbox[aliceUserID] @@ -138,7 +158,7 @@ func TestVerification_Start(t *testing.T) { assert.NotEmpty(t, toDeviceInbox[receivingDeviceID]) assert.NotEmpty(t, toDeviceInbox[receivingDeviceID2]) assert.Equal(t, toDeviceInbox[receivingDeviceID], toDeviceInbox[receivingDeviceID2]) - assert.Len(t, toDeviceInbox[receivingDeviceID], 1) + require.Len(t, toDeviceInbox[receivingDeviceID], 1) // Ensure that the verification request is correct. verificationRequest := toDeviceInbox[receivingDeviceID][0].Content.AsVerificationRequest() @@ -156,12 +176,11 @@ func TestVerification_StartThenCancel(t *testing.T) { for _, sendingCancels := range []bool{true, false} { t.Run(fmt.Sprintf("sendingCancels=%t", sendingCancels), func(t *testing.T) { ts, sendingClient, receivingClient, sendingCryptoStore, receivingCryptoStore, sendingMachine, receivingMachine := initServerAndLoginTwoAlice(t, ctx) - defer ts.Close() _, _, sendingHelper, receivingHelper := initDefaultCallbacks(t, ctx, sendingClient, receivingClient, sendingMachine, receivingMachine) bystanderClient, _ := ts.Login(t, ctx, aliceUserID, bystanderDeviceID) bystanderMachine := bystanderClient.Crypto.(*cryptohelper.CryptoHelper).Machine() - bystanderHelper := verificationhelper.NewVerificationHelper(bystanderClient, bystanderMachine, nil, newAllVerificationCallbacks(), true) + bystanderHelper := verificationhelper.NewVerificationHelper(bystanderClient, bystanderMachine, nil, newAllVerificationCallbacks(), true, true, true) require.NoError(t, bystanderHelper.Init(ctx)) require.NoError(t, sendingCryptoStore.PutDevice(ctx, aliceUserID, bystanderMachine.OwnIdentity())) @@ -176,13 +195,13 @@ func TestVerification_StartThenCancel(t *testing.T) { receivingInbox := ts.DeviceInbox[aliceUserID][receivingDeviceID] assert.Len(t, receivingInbox, 1) assert.Equal(t, txnID, receivingInbox[0].Content.AsVerificationRequest().TransactionID) - ts.dispatchToDevice(t, ctx, receivingClient) + ts.DispatchToDevice(t, ctx, receivingClient) // Process the request event on the bystander device. bystanderInbox := ts.DeviceInbox[aliceUserID][bystanderDeviceID] assert.Len(t, bystanderInbox, 1) assert.Equal(t, txnID, bystanderInbox[0].Content.AsVerificationRequest().TransactionID) - ts.dispatchToDevice(t, ctx, bystanderClient) + ts.DispatchToDevice(t, ctx, bystanderClient) // Cancel the verification request. var cancelEvt *event.VerificationCancelEventContent @@ -221,7 +240,7 @@ func TestVerification_StartThenCancel(t *testing.T) { if !sendingCancels { // Process the cancellation event on the sending device. - ts.dispatchToDevice(t, ctx, sendingClient) + ts.DispatchToDevice(t, ctx, sendingClient) // Ensure that the cancellation event was sent to the bystander device. assert.Len(t, ts.DeviceInbox[aliceUserID][bystanderDeviceID], 1) @@ -237,8 +256,7 @@ func TestVerification_StartThenCancel(t *testing.T) { func TestVerification_Accept_NoSupportedMethods(t *testing.T) { ctx := log.Logger.WithContext(context.TODO()) - ts := createMockServer(t) - defer ts.Close() + ts := mockserver.Create(t) sendingClient, sendingCryptoStore := ts.Login(t, ctx, aliceUserID, sendingDeviceID) receivingClient, _ := ts.Login(t, ctx, aliceUserID, receivingDeviceID) @@ -251,12 +269,12 @@ func TestVerification_Accept_NoSupportedMethods(t *testing.T) { assert.NotEmpty(t, recoveryKey) assert.NotNil(t, cache) - sendingHelper := verificationhelper.NewVerificationHelper(sendingClient, sendingMachine, nil, newAllVerificationCallbacks(), true) + sendingHelper := verificationhelper.NewVerificationHelper(sendingClient, sendingMachine, nil, newAllVerificationCallbacks(), true, true, true) err = sendingHelper.Init(ctx) require.NoError(t, err) receivingCallbacks := newBaseVerificationCallbacks() - receivingHelper := verificationhelper.NewVerificationHelper(receivingClient, receivingClient.Crypto.(*cryptohelper.CryptoHelper).Machine(), nil, receivingCallbacks, false) + receivingHelper := verificationhelper.NewVerificationHelper(receivingClient, receivingClient.Crypto.(*cryptohelper.CryptoHelper).Machine(), nil, receivingCallbacks, false, false, false) err = receivingHelper.Init(ctx) require.NoError(t, err) @@ -264,7 +282,7 @@ func TestVerification_Accept_NoSupportedMethods(t *testing.T) { require.NoError(t, err) require.NotEmpty(t, txnID) - ts.dispatchToDevice(t, ctx, receivingClient) + ts.DispatchToDevice(t, ctx, receivingClient) // Ensure that the receiver ignored the request because it // doesn't support any of the verification methods in the @@ -277,33 +295,44 @@ func TestVerification_Accept_CorrectMethodsPresented(t *testing.T) { testCases := []struct { sendingSupportsScan bool + sendingSupportsShow bool receivingSupportsScan bool + receivingSupportsShow bool + sendingSupportsSAS bool + receivingSupportsSAS bool sendingCallbacks MockVerificationCallbacks receivingCallbacks MockVerificationCallbacks expectedVerificationMethods []event.VerificationMethod }{ - {false, false, newSASVerificationCallbacks(), newSASVerificationCallbacks(), []event.VerificationMethod{event.VerificationMethodSAS}}, - {true, false, newQRCodeVerificationCallbacks(), newQRCodeVerificationCallbacks(), []event.VerificationMethod{event.VerificationMethodQRCodeShow, event.VerificationMethodReciprocate}}, - {false, true, newQRCodeVerificationCallbacks(), newQRCodeVerificationCallbacks(), []event.VerificationMethod{event.VerificationMethodQRCodeScan, event.VerificationMethodReciprocate}}, - {true, false, newAllVerificationCallbacks(), newAllVerificationCallbacks(), []event.VerificationMethod{event.VerificationMethodQRCodeShow, event.VerificationMethodReciprocate, event.VerificationMethodSAS}}, - {true, true, newAllVerificationCallbacks(), newAllVerificationCallbacks(), []event.VerificationMethod{event.VerificationMethodQRCodeShow, event.VerificationMethodQRCodeScan, event.VerificationMethodReciprocate, event.VerificationMethodSAS}}, + // TODO + {false, false, false, false, true, true, newSASVerificationCallbacks(), newSASVerificationCallbacks(), []event.VerificationMethod{event.VerificationMethodSAS}}, + {true, false, true, false, true, true, newSASVerificationCallbacks(), newSASVerificationCallbacks(), []event.VerificationMethod{event.VerificationMethodSAS}}, + + {true, false, false, true, false, false, newShowQRCodeVerificationCallbacks(), newShowQRCodeVerificationCallbacks(), []event.VerificationMethod{event.VerificationMethodReciprocate, event.VerificationMethodQRCodeShow}}, + {false, true, true, false, false, false, newShowQRCodeVerificationCallbacks(), newShowQRCodeVerificationCallbacks(), []event.VerificationMethod{event.VerificationMethodReciprocate, event.VerificationMethodQRCodeScan}}, + {true, false, true, true, false, false, newShowQRCodeVerificationCallbacks(), newShowQRCodeVerificationCallbacks(), []event.VerificationMethod{event.VerificationMethodReciprocate, event.VerificationMethodQRCodeShow}}, + {false, true, true, true, false, false, newShowQRCodeVerificationCallbacks(), newShowQRCodeVerificationCallbacks(), []event.VerificationMethod{event.VerificationMethodReciprocate, event.VerificationMethodQRCodeScan}}, + {true, true, true, false, false, false, newShowQRCodeVerificationCallbacks(), newShowQRCodeVerificationCallbacks(), []event.VerificationMethod{event.VerificationMethodReciprocate, event.VerificationMethodQRCodeScan}}, + {true, true, false, true, false, false, newShowQRCodeVerificationCallbacks(), newShowQRCodeVerificationCallbacks(), []event.VerificationMethod{event.VerificationMethodReciprocate, event.VerificationMethodQRCodeShow}}, + {true, true, true, true, false, false, newShowQRCodeVerificationCallbacks(), newShowQRCodeVerificationCallbacks(), []event.VerificationMethod{event.VerificationMethodReciprocate, event.VerificationMethodQRCodeScan, event.VerificationMethodQRCodeShow}}, + + {true, true, true, true, true, true, newAllVerificationCallbacks(), newAllVerificationCallbacks(), []event.VerificationMethod{event.VerificationMethodSAS, event.VerificationMethodReciprocate, event.VerificationMethodQRCodeScan, event.VerificationMethodQRCodeShow}}, } for i, tc := range testCases { t.Run(fmt.Sprintf("%d", i), func(t *testing.T) { ts, sendingClient, receivingClient, _, _, sendingMachine, receivingMachine := initServerAndLoginTwoAlice(t, ctx) - defer ts.Close() recoveryKey, sendingCrossSigningKeysCache, err := sendingMachine.GenerateAndUploadCrossSigningKeys(ctx, nil, "") assert.NoError(t, err) assert.NotEmpty(t, recoveryKey) assert.NotNil(t, sendingCrossSigningKeysCache) - sendingHelper := verificationhelper.NewVerificationHelper(sendingClient, sendingMachine, nil, tc.sendingCallbacks, tc.sendingSupportsScan) + sendingHelper := verificationhelper.NewVerificationHelper(sendingClient, sendingMachine, nil, tc.sendingCallbacks, tc.sendingSupportsShow, tc.sendingSupportsScan, tc.sendingSupportsSAS) err = sendingHelper.Init(ctx) require.NoError(t, err) - receivingHelper := verificationhelper.NewVerificationHelper(receivingClient, receivingMachine, nil, tc.receivingCallbacks, tc.receivingSupportsScan) + receivingHelper := verificationhelper.NewVerificationHelper(receivingClient, receivingMachine, nil, tc.receivingCallbacks, tc.receivingSupportsShow, tc.receivingSupportsScan, tc.receivingSupportsSAS) err = receivingHelper.Init(ctx) require.NoError(t, err) @@ -311,7 +340,7 @@ func TestVerification_Accept_CorrectMethodsPresented(t *testing.T) { require.NoError(t, err) // Process the verification request on the receiving device. - ts.dispatchToDevice(t, ctx, receivingClient) + ts.DispatchToDevice(t, ctx, receivingClient) // Ensure that the receiving device received a verification // request with the correct transaction ID. @@ -321,16 +350,13 @@ func TestVerification_Accept_CorrectMethodsPresented(t *testing.T) { err = receivingHelper.AcceptVerification(ctx, txnID) require.NoError(t, err) - _, sendingIsQRCallbacks := tc.sendingCallbacks.(*qrCodeVerificationCallbacks) - _, sendingIsAllCallbacks := tc.sendingCallbacks.(*allVerificationCallbacks) - sendingCanShowQR := sendingIsQRCallbacks || sendingIsAllCallbacks - _, receivingIsQRCallbacks := tc.receivingCallbacks.(*qrCodeVerificationCallbacks) - _, receivingIsAllCallbacks := tc.receivingCallbacks.(*allVerificationCallbacks) - receivingCanShowQR := receivingIsQRCallbacks || receivingIsAllCallbacks + // Ensure that the receiving device get a notification about the + // transaction being ready. + assert.Contains(t, tc.receivingCallbacks.GetVerificationsReadyTransactions(), txnID) // Ensure that if the receiving device should show a QR code that // it has the correct content. - if tc.sendingSupportsScan && receivingCanShowQR { + if tc.sendingSupportsScan && tc.receivingSupportsShow { receivingShownQRCode := tc.receivingCallbacks.GetQRCodeShown(txnID) require.NotNil(t, receivingShownQRCode) assert.Equal(t, txnID, receivingShownQRCode.TransactionID) @@ -339,7 +365,7 @@ func TestVerification_Accept_CorrectMethodsPresented(t *testing.T) { // Check for whether the receiving device should be scanning a QR // code. - if tc.receivingSupportsScan && sendingCanShowQR { + if tc.receivingSupportsScan && tc.sendingSupportsShow { assert.Contains(t, tc.receivingCallbacks.GetScanQRCodeTransactions(), txnID) } @@ -354,11 +380,15 @@ func TestVerification_Accept_CorrectMethodsPresented(t *testing.T) { // Receive the m.key.verification.ready event on the sending // device. - ts.dispatchToDevice(t, ctx, sendingClient) + ts.DispatchToDevice(t, ctx, sendingClient) + + // Ensure that the sending device got a notification about the + // transaction being ready. + assert.Contains(t, tc.sendingCallbacks.GetVerificationsReadyTransactions(), txnID) // Ensure that if the sending device should show a QR code that it // has the correct content. - if tc.receivingSupportsScan && sendingCanShowQR { + if tc.receivingSupportsScan && tc.sendingSupportsShow { sendingShownQRCode := tc.sendingCallbacks.GetQRCodeShown(txnID) require.NotNil(t, sendingShownQRCode) assert.Equal(t, txnID, sendingShownQRCode.TransactionID) @@ -367,7 +397,7 @@ func TestVerification_Accept_CorrectMethodsPresented(t *testing.T) { // Check for whether the sending device should be scanning a QR // code. - if tc.sendingSupportsScan && receivingCanShowQR { + if tc.sendingSupportsScan && tc.receivingSupportsShow { assert.Contains(t, tc.sendingCallbacks.GetScanQRCodeTransactions(), txnID) } }) @@ -379,7 +409,6 @@ func TestVerification_Accept_CorrectMethodsPresented(t *testing.T) { func TestVerification_Accept_CancelOnNonParticipatingDevices(t *testing.T) { ctx := log.Logger.WithContext(context.TODO()) ts, sendingClient, receivingClient, sendingCryptoStore, receivingCryptoStore, sendingMachine, receivingMachine := initServerAndLoginTwoAlice(t, ctx) - defer ts.Close() _, _, sendingHelper, receivingHelper := initDefaultCallbacks(t, ctx, sendingClient, receivingClient, sendingMachine, receivingMachine) nonParticipatingDeviceID1 := id.DeviceID("non-participating1") @@ -396,12 +425,12 @@ func TestVerification_Accept_CancelOnNonParticipatingDevices(t *testing.T) { // the receiving device. txnID, err := sendingHelper.StartVerification(ctx, aliceUserID) require.NoError(t, err) - ts.dispatchToDevice(t, ctx, receivingClient) + ts.DispatchToDevice(t, ctx, receivingClient) err = receivingHelper.AcceptVerification(ctx, txnID) require.NoError(t, err) // Receive the m.key.verification.ready event on the sending device. - ts.dispatchToDevice(t, ctx, sendingClient) + ts.DispatchToDevice(t, ctx, sendingClient) // The sending and receiving devices should not have any cancellation // events in their inboxes. @@ -421,7 +450,6 @@ func TestVerification_Accept_CancelOnNonParticipatingDevices(t *testing.T) { func TestVerification_ErrorOnDoubleAccept(t *testing.T) { ctx := log.Logger.WithContext(context.TODO()) ts, sendingClient, receivingClient, _, _, sendingMachine, receivingMachine := initServerAndLoginTwoAlice(t, ctx) - defer ts.Close() _, _, sendingHelper, receivingHelper := initDefaultCallbacks(t, ctx, sendingClient, receivingClient, sendingMachine, receivingMachine) _, _, err := sendingMachine.GenerateAndUploadCrossSigningKeys(ctx, nil, "") @@ -429,7 +457,7 @@ func TestVerification_ErrorOnDoubleAccept(t *testing.T) { txnID, err := sendingHelper.StartVerification(ctx, aliceUserID) require.NoError(t, err) - ts.dispatchToDevice(t, ctx, receivingClient) + ts.DispatchToDevice(t, ctx, receivingClient) err = receivingHelper.AcceptVerification(ctx, txnID) require.NoError(t, err) err = receivingHelper.AcceptVerification(ctx, txnID) @@ -449,7 +477,6 @@ func TestVerification_ErrorOnDoubleAccept(t *testing.T) { func TestVerification_CancelOnDoubleStart(t *testing.T) { ctx := log.Logger.WithContext(context.TODO()) ts, sendingClient, receivingClient, _, _, sendingMachine, receivingMachine := initServerAndLoginTwoAlice(t, ctx) - defer ts.Close() sendingCallbacks, receivingCallbacks, sendingHelper, receivingHelper := initDefaultCallbacks(t, ctx, sendingClient, receivingClient, sendingMachine, receivingMachine) _, _, err := sendingMachine.GenerateAndUploadCrossSigningKeys(ctx, nil, "") @@ -458,15 +485,15 @@ func TestVerification_CancelOnDoubleStart(t *testing.T) { // Send and accept the first verification request. txnID1, err := sendingHelper.StartVerification(ctx, aliceUserID) require.NoError(t, err) - ts.dispatchToDevice(t, ctx, receivingClient) + ts.DispatchToDevice(t, ctx, receivingClient) err = receivingHelper.AcceptVerification(ctx, txnID1) require.NoError(t, err) - ts.dispatchToDevice(t, ctx, sendingClient) // Process the m.key.verification.ready event + ts.DispatchToDevice(t, ctx, sendingClient) // Process the m.key.verification.ready event // Send a second verification request txnID2, err := sendingHelper.StartVerification(ctx, aliceUserID) require.NoError(t, err) - ts.dispatchToDevice(t, ctx, receivingClient) + ts.DispatchToDevice(t, ctx, receivingClient) // Ensure that the sending device received a cancellation event for both of // the ongoing transactions. @@ -484,7 +511,7 @@ func TestVerification_CancelOnDoubleStart(t *testing.T) { assert.NotNil(t, receivingCallbacks.GetVerificationCancellation(txnID1)) assert.NotNil(t, receivingCallbacks.GetVerificationCancellation(txnID2)) - ts.dispatchToDevice(t, ctx, sendingClient) // Process the m.key.verification.cancel events + ts.DispatchToDevice(t, ctx, sendingClient) // Process the m.key.verification.cancel events assert.NotNil(t, sendingCallbacks.GetVerificationCancellation(txnID1)) assert.NotNil(t, sendingCallbacks.GetVerificationCancellation(txnID2)) } diff --git a/crypto/verificationhelper/verificationstore_test.go b/crypto/verificationhelper/verificationstore_test.go index a3b1895d..e64153b1 100644 --- a/crypto/verificationhelper/verificationstore_test.go +++ b/crypto/verificationhelper/verificationstore_test.go @@ -3,6 +3,7 @@ package verificationhelper_test import ( "context" "database/sql" + "errors" _ "github.com/mattn/go-sqlite3" "github.com/rs/zerolog" @@ -42,20 +43,17 @@ func NewSQLiteVerificationStore(ctx context.Context, db *sql.DB) (*SQLiteVerific func (s *SQLiteVerificationStore) GetAllVerificationTransactions(ctx context.Context) ([]verificationhelper.VerificationTransaction, error) { rows, err := s.db.QueryContext(ctx, selectVerifications) - if err != nil { - return nil, err - } - return dbutil.NewRowIter(rows, func(dbutil.Scannable) (txn verificationhelper.VerificationTransaction, err error) { + return dbutil.NewRowIterWithError(rows, func(dbutil.Scannable) (txn verificationhelper.VerificationTransaction, err error) { err = rows.Scan(&dbutil.JSON{Data: &txn}) return - }).AsList() + }, err).AsList() } func (vq *SQLiteVerificationStore) GetVerificationTransaction(ctx context.Context, txnID id.VerificationTransactionID) (txn verificationhelper.VerificationTransaction, err error) { zerolog.Ctx(ctx).Warn().Stringer("transaction_id", txnID).Msg("Getting verification transaction") row := vq.db.QueryRowContext(ctx, getVerificationByTransactionID, txnID) err = row.Scan(&dbutil.JSON{Data: &txn}) - if err == sql.ErrNoRows { + if errors.Is(err, sql.ErrNoRows) { err = verificationhelper.ErrUnknownVerificationTransaction } return @@ -64,7 +62,7 @@ func (vq *SQLiteVerificationStore) GetVerificationTransaction(ctx context.Contex func (vq *SQLiteVerificationStore) FindVerificationTransactionForUserDevice(ctx context.Context, userID id.UserID, deviceID id.DeviceID) (txn verificationhelper.VerificationTransaction, err error) { row := vq.db.QueryRowContext(ctx, getVerificationByUserDeviceID, userID, deviceID) err = row.Scan(&dbutil.JSON{Data: &txn}) - if err == sql.ErrNoRows { + if errors.Is(err, sql.ErrNoRows) { err = verificationhelper.ErrUnknownVerificationTransaction } return diff --git a/error.go b/error.go index 0133e80e..4711b3dc 100644 --- a/error.go +++ b/error.go @@ -13,6 +13,7 @@ import ( "net/http" "go.mau.fi/util/exhttp" + "go.mau.fi/util/exmaps" "golang.org/x/exp/maps" ) @@ -66,11 +67,28 @@ var ( MIncompatibleRoomVersion = RespError{ErrCode: "M_INCOMPATIBLE_ROOM_VERSION"} // The client specified a parameter that has the wrong value. MInvalidParam = RespError{ErrCode: "M_INVALID_PARAM", StatusCode: http.StatusBadRequest} + // The client specified a room key backup version that is not the current room key backup version for the user. + MWrongRoomKeysVersion = RespError{ErrCode: "M_WRONG_ROOM_KEYS_VERSION", StatusCode: http.StatusForbidden} MURLNotSet = RespError{ErrCode: "M_URL_NOT_SET"} MBadStatus = RespError{ErrCode: "M_BAD_STATUS"} MConnectionTimeout = RespError{ErrCode: "M_CONNECTION_TIMEOUT"} MConnectionFailed = RespError{ErrCode: "M_CONNECTION_FAILED"} + + MUnredactedContentDeleted = RespError{ErrCode: "FI.MAU.MSC2815_UNREDACTED_CONTENT_DELETED"} + MUnredactedContentNotReceived = RespError{ErrCode: "FI.MAU.MSC2815_UNREDACTED_CONTENT_NOT_RECEIVED"} +) + +var ( + ErrClientIsNil = errors.New("client is nil") + ErrClientHasNoHomeserver = errors.New("client has no homeserver set") + + ErrResponseTooLong = errors.New("response content length too long") + ErrBodyReadReachedLimit = errors.New("reached response size limit while reading body") + + // Special error that indicates we should retry canceled contexts. Note that on it's own this + // is useless, the context itself must also be replaced. + ErrContextCancelRetry = errors.New("retry canceled context") ) // HTTPError An HTTP Error response, which may wrap an underlying native Go Error. @@ -96,10 +114,9 @@ func (e HTTPError) Error() string { if e.WrappedError != nil { return fmt.Sprintf("%s: %v", e.Message, e.WrappedError) } else if e.RespError != nil { - return fmt.Sprintf("failed to %s %s: %s (HTTP %d): %s", e.Request.Method, e.Request.URL.Path, - e.RespError.ErrCode, e.Response.StatusCode, e.RespError.Err) + return fmt.Sprintf("%s (HTTP %d): %s", e.RespError.ErrCode, e.Response.StatusCode, e.RespError.Err) } else { - msg := fmt.Sprintf("failed to %s %s: HTTP %d", e.Request.Method, e.Request.URL.Path, e.Response.StatusCode) + msg := fmt.Sprintf("HTTP %d", e.Response.StatusCode) if len(e.ResponseBody) > 0 { msg = fmt.Sprintf("%s: %s", msg, e.ResponseBody) } @@ -123,7 +140,10 @@ type RespError struct { Err string ExtraData map[string]any - StatusCode int + StatusCode int + ExtraHeader map[string]string + + CanRetry bool } func (e *RespError) UnmarshalJSON(data []byte) error { @@ -133,16 +153,17 @@ func (e *RespError) UnmarshalJSON(data []byte) error { } e.ErrCode, _ = e.ExtraData["errcode"].(string) e.Err, _ = e.ExtraData["error"].(string) + e.CanRetry, _ = e.ExtraData["com.beeper.can_retry"].(bool) return nil } func (e *RespError) MarshalJSON() ([]byte, error) { - data := maps.Clone(e.ExtraData) - if data == nil { - data = make(map[string]any) - } + data := exmaps.NonNilClone(e.ExtraData) data["errcode"] = e.ErrCode data["error"] = e.Err + if e.CanRetry { + data["com.beeper.can_retry"] = e.CanRetry + } return json.Marshal(data) } @@ -154,6 +175,9 @@ func (e RespError) Write(w http.ResponseWriter) { if statusCode == 0 { statusCode = http.StatusInternalServerError } + for key, value := range e.ExtraHeader { + w.Header().Set(key, value) + } exhttp.WriteJSONResponse(w, statusCode, &e) } @@ -170,6 +194,29 @@ func (e RespError) WithStatus(status int) RespError { return e } +func (e RespError) WithCanRetry(canRetry bool) RespError { + e.CanRetry = canRetry + return e +} + +func (e RespError) WithExtraData(extraData map[string]any) RespError { + e.ExtraData = exmaps.NonNilClone(e.ExtraData) + maps.Copy(e.ExtraData, extraData) + return e +} + +func (e RespError) WithExtraHeader(key, value string) RespError { + e.ExtraHeader = exmaps.NonNilClone(e.ExtraHeader) + e.ExtraHeader[key] = value + return e +} + +func (e RespError) WithExtraHeaders(headers map[string]string) RespError { + e.ExtraHeader = exmaps.NonNilClone(e.ExtraHeader) + maps.Copy(e.ExtraHeader, headers) + return e +} + // Error returns the errcode and error message. func (e RespError) Error() string { return e.ErrCode + ": " + e.Err diff --git a/event/accountdata.go b/event/accountdata.go index 30ca35a2..223919a1 100644 --- a/event/accountdata.go +++ b/event/accountdata.go @@ -105,3 +105,15 @@ func (bmec *BeeperMuteEventContent) GetMutedUntilTime() time.Time { } return time.Time{} } + +func (bmec *BeeperMuteEventContent) GetMuteDuration() time.Duration { + ts := bmec.GetMutedUntilTime() + now := time.Now() + if ts.Before(now) { + return 0 + } else if ts == MutedForever { + return -1 + } else { + return ts.Sub(now) + } +} diff --git a/event/beeper.go b/event/beeper.go index 7ea0d068..a1a60b35 100644 --- a/event/beeper.go +++ b/event/beeper.go @@ -9,7 +9,12 @@ package event import ( "encoding/base32" "encoding/binary" + "encoding/json" "fmt" + "html" + "regexp" + "strconv" + "strings" "maunium.net/go/mautrix/id" ) @@ -48,6 +53,8 @@ type BeeperMessageStatusEventContent struct { LastRetry id.EventID `json:"last_retry,omitempty"` + TargetTxnID string `json:"relates_to_txn_id,omitempty"` + MutateEventKey string `json:"mutate_event_key,omitempty"` // Indicates the set of users to whom the event was delivered. If nil, then @@ -57,6 +64,18 @@ type BeeperMessageStatusEventContent struct { DeliveredToUsers *[]id.UserID `json:"delivered_to_users,omitempty"` } +type BeeperRelatesTo struct { + EventID id.EventID `json:"event_id,omitempty"` + RoomID id.RoomID `json:"room_id,omitempty"` + Type RelationType `json:"rel_type,omitempty"` +} + +type BeeperTranscriptionEventContent struct { + Text []ExtensibleText `json:"m.text,omitempty"` + Model string `json:"com.beeper.transcription.model,omitempty"` + RelatesTo BeeperRelatesTo `json:"com.beeper.relates_to,omitempty"` +} + type BeeperRetryMetadata struct { OriginalEventID id.EventID `json:"original_event_id"` RetryCount int `json:"retry_count"` @@ -69,18 +88,54 @@ type BeeperRoomKeyAckEventContent struct { FirstMessageIndex int `json:"first_message_index"` } +type BeeperChatDeleteEventContent struct { + DeleteForEveryone bool `json:"delete_for_everyone,omitempty"` + FromMessageRequest bool `json:"from_message_request,omitempty"` +} + +type BeeperAcceptMessageRequestEventContent struct { + // Whether this was triggered by a message rather than an explicit event + IsImplicit bool `json:"-"` +} + +type BeeperSendStateEventContent struct { + Type string `json:"type"` + StateKey string `json:"state_key"` + Content Content `json:"content"` +} + +type IntOrString int + +func (ios *IntOrString) UnmarshalJSON(data []byte) error { + if len(data) > 0 && data[0] == '"' { + var str string + err := json.Unmarshal(data, &str) + if err != nil { + return err + } + intVal, err := strconv.Atoi(str) + if err != nil { + return err + } + *ios = IntOrString(intVal) + return nil + } + return json.Unmarshal(data, (*int)(ios)) +} + type LinkPreview struct { CanonicalURL string `json:"og:url,omitempty"` Title string `json:"og:title,omitempty"` Type string `json:"og:type,omitempty"` Description string `json:"og:description,omitempty"` + SiteName string `json:"og:site_name,omitempty"` ImageURL id.ContentURIString `json:"og:image,omitempty"` - ImageSize int `json:"matrix:image:size,omitempty"` - ImageWidth int `json:"og:image:width,omitempty"` - ImageHeight int `json:"og:image:height,omitempty"` - ImageType string `json:"og:image:type,omitempty"` + ImageSize IntOrString `json:"matrix:image:size,omitempty"` + ImageWidth IntOrString `json:"og:image:width,omitempty"` + ImageHeight IntOrString `json:"og:image:height,omitempty"` + ImageType string `json:"og:image:type,omitempty"` } // BeeperLinkPreview contains the data for a bundled URL preview as specified in MSC4095 @@ -91,6 +146,7 @@ type BeeperLinkPreview struct { MatchedURL string `json:"matched_url,omitempty"` ImageEncryption *EncryptedFileInfo `json:"beeper:image:encryption,omitempty"` + ImageBlurhash string `json:"matrix:image:blurhash,omitempty"` } type BeeperProfileExtra struct { @@ -107,6 +163,64 @@ type BeeperPerMessageProfile struct { Displayname string `json:"displayname,omitempty"` AvatarURL *id.ContentURIString `json:"avatar_url,omitempty"` AvatarFile *EncryptedFileInfo `json:"avatar_file,omitempty"` + HasFallback bool `json:"has_fallback,omitempty"` +} + +type BeeperActionMessageType string + +const ( + BeeperActionMessageCall BeeperActionMessageType = "call" +) + +type BeeperActionMessageCallType string + +const ( + BeeperActionMessageCallTypeVoice BeeperActionMessageCallType = "voice" + BeeperActionMessageCallTypeVideo BeeperActionMessageCallType = "video" +) + +type BeeperActionMessage struct { + Type BeeperActionMessageType `json:"type"` + CallType BeeperActionMessageCallType `json:"call_type,omitempty"` +} + +func (content *MessageEventContent) AddPerMessageProfileFallback() { + if content.BeeperPerMessageProfile == nil || content.BeeperPerMessageProfile.HasFallback || content.BeeperPerMessageProfile.Displayname == "" { + return + } + content.BeeperPerMessageProfile.HasFallback = true + content.EnsureHasHTML() + content.Body = fmt.Sprintf("%s: %s", content.BeeperPerMessageProfile.Displayname, content.Body) + content.FormattedBody = fmt.Sprintf( + "%s: %s", + html.EscapeString(content.BeeperPerMessageProfile.Displayname), + content.FormattedBody, + ) +} + +var HTMLProfileFallbackRegex = regexp.MustCompile(`([^<]+): `) + +func (content *MessageEventContent) RemovePerMessageProfileFallback() { + if content.NewContent != nil && content.NewContent != content { + content.NewContent.RemovePerMessageProfileFallback() + } + if content == nil || content.BeeperPerMessageProfile == nil || !content.BeeperPerMessageProfile.HasFallback || content.BeeperPerMessageProfile.Displayname == "" { + return + } + content.BeeperPerMessageProfile.HasFallback = false + content.Body = strings.TrimPrefix(content.Body, content.BeeperPerMessageProfile.Displayname+": ") + if content.Format == FormatHTML { + content.FormattedBody = HTMLProfileFallbackRegex.ReplaceAllLiteralString(content.FormattedBody, "") + } +} + +type BeeperAIStreamEventContent struct { + TurnID string `json:"turn_id"` + Seq int `json:"seq"` + Part map[string]any `json:"part"` + TargetEvent id.EventID `json:"target_event,omitempty"` + AgentID string `json:"agent_id,omitempty"` + RelatesTo *RelatesTo `json:"m.relates_to,omitempty"` } type BeeperEncodedOrder struct { diff --git a/event/capabilities.d.ts b/event/capabilities.d.ts new file mode 100644 index 00000000..26aeb347 --- /dev/null +++ b/event/capabilities.d.ts @@ -0,0 +1,225 @@ +/** + * The content of the `com.beeper.room_features` state event. + */ +export interface RoomFeatures { + /** + * Supported formatting features. If omitted, no formatting is supported. + * + * Capability level 0 means the corresponding HTML tags/attributes are ignored + * and will be treated as if they don't exist, which means that children will + * be rendered, but attributes will be dropped. + */ + formatting?: Record + /** + * Supported file message types and their features. + * + * If a message type isn't listed here, it should be treated as support level -2 (will be rejected). + */ + file?: Record + /** + * Supported state event types and their parameters. Currently, there are no parameters, + * but it is likely there will be some in the future (like max name/topic length, avatar mime types, etc.). + * + * Events that are not listed or have a support level of zero or below should be treated as unsupported. + * + * Clients should at least check `m.room.name`, `m.room.topic`, and `m.room.avatar` here. + * `m.room.member` will not be listed here, as it's controlled by the member_actions field. + * `com.beeper.disappearing_timer` should be listed here, but the parameters are in the disappearing_timer field for now. + */ + state?: Record + /** + * Supported member actions and their support levels. + * + * Actions that are not listed or have a support level of zero or below should be treated as unsupported. + */ + member_actions?: Record + + /** Maximum length of normal text messages. */ + max_text_length?: integer + + /** Whether location messages (`m.location`) are supported. */ + location_message?: CapabilitySupportLevel + /** Whether polls are supported. */ + poll?: CapabilitySupportLevel + /** Whether replying in a thread is supported. */ + thread?: CapabilitySupportLevel + /** Whether replying to a specific message is supported. */ + reply?: CapabilitySupportLevel + + /** Whether edits are supported. */ + edit?: CapabilitySupportLevel + /** How many times can an individual message be edited. */ + edit_max_count?: integer + /** How old messages can be edited, in seconds. */ + edit_max_age?: seconds + /** Whether deleting messages for everyone is supported */ + delete?: CapabilitySupportLevel + /** How old messages can be deleted for everyone, in seconds. */ + delete_max_age?: seconds + /** Whether deleting messages just for yourself is supported. No message age limit. */ + delete_for_me?: boolean + /** Allowed configuration options for disappearing timers. */ + disappearing_timer?: DisappearingTimerCapability + + /** Whether reactions are supported. */ + reaction?: CapabilitySupportLevel + /** How many reactions can be added to a single message. */ + reaction_count?: integer + /** + * The Unicode emojis allowed for reactions. If omitted, all emojis are allowed. + * Emojis in this list must include variation selector 16 if allowed in the Unicode spec. + */ + allowed_reactions?: string[] + /** Whether custom emoji reactions are allowed. */ + custom_emoji_reactions?: boolean + + /** Whether deleting the chat for yourself is supported. */ + delete_chat?: boolean + /** Whether deleting the chat for all participants is supported. */ + delete_chat_for_everyone?: boolean + /** What can be done with message requests? */ + message_request?: { + accept_with_message?: CapabilitySupportLevel + accept_with_button?: CapabilitySupportLevel + } +} + +declare type integer = number +declare type seconds = integer +declare type milliseconds = integer +declare type MIMEClass = "image" | "audio" | "video" | "text" | "font" | "model" | "application" +declare type MIMETypeOrPattern = + "*/*" + | `${MIMEClass}/*` + | `${MIMEClass}/${string}` + | `${MIMEClass}/${string}; ${string}` + +export enum MemberAction { + Ban = "ban", + Kick = "kick", + Leave = "leave", + RevokeInvite = "revoke_invite", + Invite = "invite", +} + +declare type EventType = string + +// This is an object for future extensibility (e.g. max name/topic length) +export interface StateFeatures { + level: CapabilitySupportLevel +} + +export enum CapabilityMsgType { + // Real message types used in the `msgtype` field + Image = "m.image", + File = "m.file", + Audio = "m.audio", + Video = "m.video", + + // Pseudo types only used in capabilities + /** An `m.audio` message that has `"org.matrix.msc3245.voice": {}` */ + Voice = "org.matrix.msc3245.voice", + /** An `m.video` message that has `"info": {"fi.mau.gif": true}`, or an `m.image` message of type `image/gif` */ + GIF = "fi.mau.gif", + /** An `m.sticker` event, no `msgtype` field */ + Sticker = "m.sticker", +} + +export interface FileFeatures { + /** + * The supported MIME types or type patterns and their support levels. + * + * If a mime type doesn't match any pattern provided, + * it should be treated as support level -2 (will be rejected). + */ + mime_types: Record + + /** The support level for captions within this file message type */ + caption?: CapabilitySupportLevel + /** The maximum length for captions (only applicable if captions are supported). */ + max_caption_length?: integer + /** The maximum file size as bytes. */ + max_size?: integer + /** For images and videos, the maximum width as pixels. */ + max_width?: integer + /** For images and videos, the maximum height as pixels. */ + max_height?: integer + /** For videos and audio files, the maximum duration as seconds. */ + max_duration?: seconds + + /** Can this type of file be sent as view-once media? */ + view_once?: boolean +} + +export enum DisappearingType { + None = "", + AfterRead = "after_read", + AfterSend = "after_send", +} + +export interface DisappearingTimerCapability { + types: DisappearingType[] + /** Allowed timer values. If omitted, any timer is allowed. */ + timers?: milliseconds[] + /** + * Whether clients should omit the empty disappearing_timer object in messages that they don't want to disappear + * + * Generally, bridged rooms will want the object to be always present, while native Matrix rooms don't, + * so the hardcoded features for Matrix rooms should set this to true, while bridges will not. + */ + omit_empty_timer?: true +} + +/** + * The support level for a feature. These are integers rather than booleans + * to accurately represent what the bridge is doing and hopefully make the + * state event more generally useful. Our clients should check for > 0 to + * determine if the feature should be allowed. + */ +export enum CapabilitySupportLevel { + /** The feature is unsupported and messages using it will be rejected. */ + Rejected = -2, + /** The feature is unsupported and has no fallback. The message will go through, but data may be lost. */ + Dropped = -1, + /** The feature is unsupported, but may have a fallback. The nature of the fallback depends on the context. */ + Unsupported = 0, + /** The feature is partially supported (e.g. it may be converted to a different format). */ + PartialSupport = 1, + /** The feature is fully supported and can be safely used. */ + FullySupported = 2, +} + +/** + * A formatting feature that consists of specific HTML tags and/or attributes. + */ +export enum FormattingFeature { + Bold = "bold", // strong, b + Italic = "italic", // em, i + Underline = "underline", // u + Strikethrough = "strikethrough", // del, s + InlineCode = "inline_code", // code + CodeBlock = "code_block", // pre + code + SyntaxHighlighting = "code_block.syntax_highlighting", //

+	Blockquote = "blockquote", // blockquote
+	InlineLink = "inline_link", // a
+	UserLink = "user_link", // 
+	RoomLink = "room_link", // 
+	EventLink = "event_link", // 
+	AtRoomMention = "at_room_mention", // @room (no html tag)
+	UnorderedList = "unordered_list", // ul + li
+	OrderedList = "ordered_list", // ol + li
+	ListStart = "ordered_list.start", // 
    + ListJumpValue = "ordered_list.jump_value", //
  1. + CustomEmoji = "custom_emoji", // + Spoiler = "spoiler", // + SpoilerReason = "spoiler.reason", // + TextForegroundColor = "color.foreground", // + TextBackgroundColor = "color.background", // + HorizontalLine = "horizontal_line", // hr + Headers = "headers", // h1, h2, h3, h4, h5, h6 + Superscript = "superscript", // sup + Subscript = "subscript", // sub + Math = "math", // + DetailsSummary = "details_summary", //
    ......
    + Table = "table", // table, thead, tbody, tr, th, td +} diff --git a/event/capabilities.go b/event/capabilities.go new file mode 100644 index 00000000..a86c726b --- /dev/null +++ b/event/capabilities.go @@ -0,0 +1,414 @@ +// Copyright (c) 2024 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package event + +import ( + "crypto/sha256" + "encoding/base64" + "encoding/binary" + "fmt" + "io" + "mime" + "slices" + "strings" + + "go.mau.fi/util/exerrors" + "go.mau.fi/util/jsontime" + "go.mau.fi/util/ptr" + "golang.org/x/exp/constraints" + "golang.org/x/exp/maps" +) + +type RoomFeatures struct { + ID string `json:"id,omitempty"` + + // N.B. New fields need to be added to the Hash function to be included in the deduplication hash. + + Formatting FormattingFeatureMap `json:"formatting,omitempty"` + File FileFeatureMap `json:"file,omitempty"` + State StateFeatureMap `json:"state,omitempty"` + MemberActions MemberFeatureMap `json:"member_actions,omitempty"` + + MaxTextLength int `json:"max_text_length,omitempty"` + + LocationMessage CapabilitySupportLevel `json:"location_message,omitempty"` + Poll CapabilitySupportLevel `json:"poll,omitempty"` + Thread CapabilitySupportLevel `json:"thread,omitempty"` + Reply CapabilitySupportLevel `json:"reply,omitempty"` + + Edit CapabilitySupportLevel `json:"edit,omitempty"` + EditMaxCount int `json:"edit_max_count,omitempty"` + EditMaxAge *jsontime.Seconds `json:"edit_max_age,omitempty"` + Delete CapabilitySupportLevel `json:"delete,omitempty"` + DeleteForMe bool `json:"delete_for_me,omitempty"` + DeleteMaxAge *jsontime.Seconds `json:"delete_max_age,omitempty"` + + DisappearingTimer *DisappearingTimerCapability `json:"disappearing_timer,omitempty"` + + Reaction CapabilitySupportLevel `json:"reaction,omitempty"` + ReactionCount int `json:"reaction_count,omitempty"` + AllowedReactions []string `json:"allowed_reactions,omitempty"` + CustomEmojiReactions bool `json:"custom_emoji_reactions,omitempty"` + + ReadReceipts bool `json:"read_receipts,omitempty"` + TypingNotifications bool `json:"typing_notifications,omitempty"` + Archive bool `json:"archive,omitempty"` + MarkAsUnread bool `json:"mark_as_unread,omitempty"` + DeleteChat bool `json:"delete_chat,omitempty"` + DeleteChatForEveryone bool `json:"delete_chat_for_everyone,omitempty"` + + MessageRequest *MessageRequestFeatures `json:"message_request,omitempty"` + + PerMessageProfileRelay bool `json:"-"` +} + +func (rf *RoomFeatures) GetID() string { + if rf.ID != "" { + return rf.ID + } + return base64.RawURLEncoding.EncodeToString(rf.Hash()) +} + +func (rf *RoomFeatures) Clone() *RoomFeatures { + if rf == nil { + return nil + } + clone := *rf + clone.File = clone.File.Clone() + clone.Formatting = maps.Clone(clone.Formatting) + clone.State = clone.State.Clone() + clone.MemberActions = clone.MemberActions.Clone() + clone.EditMaxAge = ptr.Clone(clone.EditMaxAge) + clone.DeleteMaxAge = ptr.Clone(clone.DeleteMaxAge) + clone.DisappearingTimer = clone.DisappearingTimer.Clone() + clone.AllowedReactions = slices.Clone(clone.AllowedReactions) + clone.MessageRequest = clone.MessageRequest.Clone() + return &clone +} + +type MemberFeatureMap map[MemberAction]CapabilitySupportLevel + +func (mfm MemberFeatureMap) Clone() MemberFeatureMap { + return maps.Clone(mfm) +} + +type MemberAction string + +const ( + MemberActionBan MemberAction = "ban" + MemberActionKick MemberAction = "kick" + MemberActionLeave MemberAction = "leave" + MemberActionRevokeInvite MemberAction = "revoke_invite" + MemberActionInvite MemberAction = "invite" +) + +type StateFeatureMap map[string]*StateFeatures + +func (sfm StateFeatureMap) Clone() StateFeatureMap { + dup := maps.Clone(sfm) + for key, value := range dup { + dup[key] = value.Clone() + } + return dup +} + +type StateFeatures struct { + Level CapabilitySupportLevel `json:"level"` +} + +func (sf *StateFeatures) Clone() *StateFeatures { + if sf == nil { + return nil + } + clone := *sf + return &clone +} + +func (sf *StateFeatures) Hash() []byte { + return sf.Level.Hash() +} + +type FormattingFeatureMap map[FormattingFeature]CapabilitySupportLevel + +type FileFeatureMap map[CapabilityMsgType]*FileFeatures + +func (ffm FileFeatureMap) Clone() FileFeatureMap { + dup := maps.Clone(ffm) + for key, value := range dup { + dup[key] = value.Clone() + } + return dup +} + +type DisappearingTimerCapability struct { + Types []DisappearingType `json:"types"` + Timers []jsontime.Milliseconds `json:"timers,omitempty"` + + OmitEmptyTimer bool `json:"omit_empty_timer,omitempty"` +} + +func (dtc *DisappearingTimerCapability) Clone() *DisappearingTimerCapability { + if dtc == nil { + return nil + } + clone := *dtc + clone.Types = slices.Clone(clone.Types) + clone.Timers = slices.Clone(clone.Timers) + return &clone +} + +func (dtc *DisappearingTimerCapability) Supports(content *BeeperDisappearingTimer) bool { + if dtc == nil || content == nil || content.Type == DisappearingTypeNone { + return true + } + return slices.Contains(dtc.Types, content.Type) && (dtc.Timers == nil || slices.Contains(dtc.Timers, content.Timer)) +} + +type MessageRequestFeatures struct { + AcceptWithMessage CapabilitySupportLevel `json:"accept_with_message,omitempty"` + AcceptWithButton CapabilitySupportLevel `json:"accept_with_button,omitempty"` +} + +func (mrf *MessageRequestFeatures) Clone() *MessageRequestFeatures { + return ptr.Clone(mrf) +} + +func (mrf *MessageRequestFeatures) Hash() []byte { + if mrf == nil { + return nil + } + hasher := sha256.New() + hashValue(hasher, "accept_with_message", mrf.AcceptWithMessage) + hashValue(hasher, "accept_with_button", mrf.AcceptWithButton) + return hasher.Sum(nil) +} + +type CapabilityMsgType = MessageType + +// Message types which are used for event capability signaling, but aren't real values for the msgtype field. +const ( + CapMsgVoice CapabilityMsgType = "org.matrix.msc3245.voice" + CapMsgGIF CapabilityMsgType = "fi.mau.gif" + CapMsgSticker CapabilityMsgType = "m.sticker" +) + +type CapabilitySupportLevel int + +func (csl CapabilitySupportLevel) Partial() bool { + return csl >= CapLevelPartialSupport +} + +func (csl CapabilitySupportLevel) Full() bool { + return csl >= CapLevelFullySupported +} + +func (csl CapabilitySupportLevel) Reject() bool { + return csl <= CapLevelRejected +} + +const ( + CapLevelRejected CapabilitySupportLevel = -2 // The feature is unsupported and messages using it will be rejected. + CapLevelDropped CapabilitySupportLevel = -1 // The feature is unsupported and has no fallback. The message will go through, but data may be lost. + CapLevelUnsupported CapabilitySupportLevel = 0 // The feature is unsupported, but may have a fallback. + CapLevelPartialSupport CapabilitySupportLevel = 1 // The feature is partially supported (e.g. it may be converted to a different format). + CapLevelFullySupported CapabilitySupportLevel = 2 // The feature is fully supported and can be safely used. +) + +type FormattingFeature string + +const ( + FmtBold FormattingFeature = "bold" // strong, b + FmtItalic FormattingFeature = "italic" // em, i + FmtUnderline FormattingFeature = "underline" // u + FmtStrikethrough FormattingFeature = "strikethrough" // del, s + FmtInlineCode FormattingFeature = "inline_code" // code + FmtCodeBlock FormattingFeature = "code_block" // pre + code + FmtSyntaxHighlighting FormattingFeature = "code_block.syntax_highlighting" //
    
    +	FmtBlockquote          FormattingFeature = "blockquote"                     // blockquote
    +	FmtInlineLink          FormattingFeature = "inline_link"                    // a
    +	FmtUserLink            FormattingFeature = "user_link"                      // 
    +	FmtRoomLink            FormattingFeature = "room_link"                      // 
    +	FmtEventLink           FormattingFeature = "event_link"                     // 
    +	FmtAtRoomMention       FormattingFeature = "at_room_mention"                // @room (no html tag)
    +	FmtUnorderedList       FormattingFeature = "unordered_list"                 // ul + li
    +	FmtOrderedList         FormattingFeature = "ordered_list"                   // ol + li
    +	FmtListStart           FormattingFeature = "ordered_list.start"             // 
      + FmtListJumpValue FormattingFeature = "ordered_list.jump_value" //
    1. + FmtCustomEmoji FormattingFeature = "custom_emoji" // + FmtSpoiler FormattingFeature = "spoiler" // + FmtSpoilerReason FormattingFeature = "spoiler.reason" // + FmtTextForegroundColor FormattingFeature = "color.foreground" // + FmtTextBackgroundColor FormattingFeature = "color.background" // + FmtHorizontalLine FormattingFeature = "horizontal_line" // hr + FmtHeaders FormattingFeature = "headers" // h1, h2, h3, h4, h5, h6 + FmtSuperscript FormattingFeature = "superscript" // sup + FmtSubscript FormattingFeature = "subscript" // sub + FmtMath FormattingFeature = "math" // + FmtDetailsSummary FormattingFeature = "details_summary" //
      ......
      + FmtTable FormattingFeature = "table" // table, thead, tbody, tr, th, td +) + +type FileFeatures struct { + // N.B. New fields need to be added to the Hash function to be included in the deduplication hash. + + MimeTypes map[string]CapabilitySupportLevel `json:"mime_types"` + + Caption CapabilitySupportLevel `json:"caption,omitempty"` + MaxCaptionLength int `json:"max_caption_length,omitempty"` + + MaxSize int64 `json:"max_size,omitempty"` + MaxWidth int `json:"max_width,omitempty"` + MaxHeight int `json:"max_height,omitempty"` + MaxDuration *jsontime.Seconds `json:"max_duration,omitempty"` + + ViewOnce bool `json:"view_once,omitempty"` +} + +func (ff *FileFeatures) GetMimeSupport(inputType string) CapabilitySupportLevel { + match, ok := ff.MimeTypes[inputType] + if ok { + return match + } + if strings.IndexByte(inputType, ';') != -1 { + plainMime, _, _ := mime.ParseMediaType(inputType) + if plainMime != "" { + if match, ok = ff.MimeTypes[plainMime]; ok { + return match + } + } + } + if slash := strings.IndexByte(inputType, '/'); slash > 0 { + generalType := fmt.Sprintf("%s/*", inputType[:slash]) + if match, ok = ff.MimeTypes[generalType]; ok { + return match + } + } + match, ok = ff.MimeTypes["*/*"] + if ok { + return match + } + return CapLevelRejected +} + +type hashable interface { + Hash() []byte +} + +func hashMap[Key ~string, Value hashable](w io.Writer, name string, data map[Key]Value) { + keys := maps.Keys(data) + slices.Sort(keys) + exerrors.Must(w.Write([]byte(name))) + for _, key := range keys { + exerrors.Must(w.Write([]byte(key))) + exerrors.Must(w.Write(data[key].Hash())) + exerrors.Must(w.Write([]byte{0})) + } +} + +func hashValue(w io.Writer, name string, data hashable) { + exerrors.Must(w.Write([]byte(name))) + exerrors.Must(w.Write(data.Hash())) +} + +func hashInt[T constraints.Integer](w io.Writer, name string, data T) { + exerrors.Must(w.Write(binary.BigEndian.AppendUint64([]byte(name), uint64(data)))) +} + +func hashBool[T ~bool](w io.Writer, name string, data T) { + exerrors.Must(w.Write([]byte(name))) + if data { + exerrors.Must(w.Write([]byte{1})) + } else { + exerrors.Must(w.Write([]byte{0})) + } +} + +func (csl CapabilitySupportLevel) Hash() []byte { + return []byte{byte(csl + 128)} +} + +func (rf *RoomFeatures) Hash() []byte { + hasher := sha256.New() + + hashMap(hasher, "formatting", rf.Formatting) + hashMap(hasher, "file", rf.File) + hashMap(hasher, "state", rf.State) + hashMap(hasher, "member_actions", rf.MemberActions) + + hashInt(hasher, "max_text_length", rf.MaxTextLength) + + hashValue(hasher, "location_message", rf.LocationMessage) + hashValue(hasher, "poll", rf.Poll) + hashValue(hasher, "thread", rf.Thread) + hashValue(hasher, "reply", rf.Reply) + + hashValue(hasher, "edit", rf.Edit) + hashInt(hasher, "edit_max_count", rf.EditMaxCount) + hashInt(hasher, "edit_max_age", rf.EditMaxAge.Get()) + + hashValue(hasher, "delete", rf.Delete) + hashBool(hasher, "delete_for_me", rf.DeleteForMe) + hashInt(hasher, "delete_max_age", rf.DeleteMaxAge.Get()) + hashValue(hasher, "disappearing_timer", rf.DisappearingTimer) + + hashValue(hasher, "reaction", rf.Reaction) + hashInt(hasher, "reaction_count", rf.ReactionCount) + hasher.Write([]byte("allowed_reactions")) + for _, reaction := range rf.AllowedReactions { + hasher.Write([]byte(reaction)) + } + hashBool(hasher, "custom_emoji_reactions", rf.CustomEmojiReactions) + + hashBool(hasher, "read_receipts", rf.ReadReceipts) + hashBool(hasher, "typing_notifications", rf.TypingNotifications) + hashBool(hasher, "archive", rf.Archive) + hashBool(hasher, "mark_as_unread", rf.MarkAsUnread) + hashBool(hasher, "delete_chat", rf.DeleteChat) + hashBool(hasher, "delete_chat_for_everyone", rf.DeleteChatForEveryone) + hashValue(hasher, "message_request", rf.MessageRequest) + + return hasher.Sum(nil) +} + +func (dtc *DisappearingTimerCapability) Hash() []byte { + if dtc == nil { + return nil + } + hasher := sha256.New() + hasher.Write([]byte("types")) + for _, t := range dtc.Types { + hasher.Write([]byte(t)) + } + hasher.Write([]byte("timers")) + for _, timer := range dtc.Timers { + hashInt(hasher, "", timer.Milliseconds()) + } + return hasher.Sum(nil) +} + +func (ff *FileFeatures) Hash() []byte { + hasher := sha256.New() + hashMap(hasher, "mime_types", ff.MimeTypes) + hashValue(hasher, "caption", ff.Caption) + hashInt(hasher, "max_caption_length", ff.MaxCaptionLength) + hashInt(hasher, "max_size", ff.MaxSize) + hashInt(hasher, "max_width", ff.MaxWidth) + hashInt(hasher, "max_height", ff.MaxHeight) + hashInt(hasher, "max_duration", ff.MaxDuration.Get()) + hashBool(hasher, "view_once", ff.ViewOnce) + return hasher.Sum(nil) +} + +func (ff *FileFeatures) Clone() *FileFeatures { + if ff == nil { + return nil + } + clone := *ff + clone.MimeTypes = maps.Clone(clone.MimeTypes) + clone.MaxDuration = ptr.Clone(clone.MaxDuration) + return &clone +} diff --git a/event/cmdschema/content.go b/event/cmdschema/content.go new file mode 100644 index 00000000..ce07c4c0 --- /dev/null +++ b/event/cmdschema/content.go @@ -0,0 +1,78 @@ +// Copyright (c) 2026 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package cmdschema + +import ( + "crypto/sha256" + "encoding/base64" + "fmt" + "reflect" + "slices" + + "go.mau.fi/util/exsync" + "go.mau.fi/util/ptr" + + "maunium.net/go/mautrix/event" + "maunium.net/go/mautrix/id" +) + +type EventContent struct { + Command string `json:"command"` + Aliases []string `json:"aliases,omitempty"` + Parameters []*Parameter `json:"parameters,omitempty"` + Description *event.ExtensibleTextContainer `json:"description,omitempty"` + TailParam string `json:"fi.mau.tail_parameter,omitempty"` +} + +func (ec *EventContent) Validate() error { + if ec == nil { + return fmt.Errorf("event content is nil") + } else if ec.Command == "" { + return fmt.Errorf("command is empty") + } + var tailFound bool + dupMap := exsync.NewSet[string]() + for i, p := range ec.Parameters { + if err := p.Validate(); err != nil { + return fmt.Errorf("parameter %q (#%d) is invalid: %w", ptr.Val(p).Key, i+1, err) + } else if !dupMap.Add(p.Key) { + return fmt.Errorf("duplicate parameter key %q at #%d", p.Key, i+1) + } else if p.Key == ec.TailParam { + tailFound = true + } else if tailFound && !p.Optional { + return fmt.Errorf("required parameter %q (#%d) is after tail parameter %q", p.Key, i+1, ec.TailParam) + } + } + if ec.TailParam != "" && !tailFound { + return fmt.Errorf("tail parameter %q not found in parameters", ec.TailParam) + } + return nil +} + +func (ec *EventContent) IsValid() bool { + return ec.Validate() == nil +} + +func (ec *EventContent) StateKey(owner id.UserID) string { + hash := sha256.Sum256([]byte(ec.Command + owner.String())) + return base64.StdEncoding.EncodeToString(hash[:]) +} + +func (ec *EventContent) Equals(other *EventContent) bool { + if ec == nil || other == nil { + return ec == other + } + return ec.Command == other.Command && + slices.Equal(ec.Aliases, other.Aliases) && + slices.EqualFunc(ec.Parameters, other.Parameters, (*Parameter).Equals) && + ec.Description.Equals(other.Description) && + ec.TailParam == other.TailParam +} + +func init() { + event.TypeMap[event.StateMSC4391BotCommand] = reflect.TypeOf(EventContent{}) +} diff --git a/event/cmdschema/parameter.go b/event/cmdschema/parameter.go new file mode 100644 index 00000000..4193b297 --- /dev/null +++ b/event/cmdschema/parameter.go @@ -0,0 +1,286 @@ +// Copyright (c) 2026 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package cmdschema + +import ( + "encoding/json" + "fmt" + "slices" + + "go.mau.fi/util/exslices" + + "maunium.net/go/mautrix/event" +) + +type Parameter struct { + Key string `json:"key"` + Schema *ParameterSchema `json:"schema"` + Optional bool `json:"optional,omitempty"` + Description *event.ExtensibleTextContainer `json:"description,omitempty"` + DefaultValue any `json:"fi.mau.default_value,omitempty"` +} + +func (p *Parameter) Equals(other *Parameter) bool { + if p == nil || other == nil { + return p == other + } + return p.Key == other.Key && + p.Schema.Equals(other.Schema) && + p.Optional == other.Optional && + p.Description.Equals(other.Description) && + p.DefaultValue == other.DefaultValue // TODO this won't work for room/event ID values +} + +func (p *Parameter) Validate() error { + if p == nil { + return fmt.Errorf("parameter is nil") + } else if p.Key == "" { + return fmt.Errorf("key is empty") + } + return p.Schema.Validate() +} + +func (p *Parameter) IsValid() bool { + return p.Validate() == nil +} + +func (p *Parameter) GetDefaultValue() any { + if p != nil && p.DefaultValue != nil { + return p.DefaultValue + } else if p == nil || p.Optional { + return nil + } + return p.Schema.GetDefaultValue() +} + +type PrimitiveType string + +const ( + PrimitiveTypeString PrimitiveType = "string" + PrimitiveTypeInteger PrimitiveType = "integer" + PrimitiveTypeBoolean PrimitiveType = "boolean" + PrimitiveTypeServerName PrimitiveType = "server_name" + PrimitiveTypeUserID PrimitiveType = "user_id" + PrimitiveTypeRoomID PrimitiveType = "room_id" + PrimitiveTypeRoomAlias PrimitiveType = "room_alias" + PrimitiveTypeEventID PrimitiveType = "event_id" +) + +func (pt PrimitiveType) Schema() *ParameterSchema { + return &ParameterSchema{ + SchemaType: SchemaTypePrimitive, + Type: pt, + } +} + +func (pt PrimitiveType) IsValid() bool { + switch pt { + case PrimitiveTypeString, + PrimitiveTypeInteger, + PrimitiveTypeBoolean, + PrimitiveTypeServerName, + PrimitiveTypeUserID, + PrimitiveTypeRoomID, + PrimitiveTypeRoomAlias, + PrimitiveTypeEventID: + return true + default: + return false + } +} + +type SchemaType string + +const ( + SchemaTypePrimitive SchemaType = "primitive" + SchemaTypeArray SchemaType = "array" + SchemaTypeUnion SchemaType = "union" + SchemaTypeLiteral SchemaType = "literal" +) + +type ParameterSchema struct { + SchemaType SchemaType `json:"schema_type"` + Type PrimitiveType `json:"type,omitempty"` // Only for primitive + Items *ParameterSchema `json:"items,omitempty"` // Only for array + Variants []*ParameterSchema `json:"variants,omitempty"` // Only for union + Value any `json:"value,omitempty"` // Only for literal +} + +func Literal(value any) *ParameterSchema { + return &ParameterSchema{ + SchemaType: SchemaTypeLiteral, + Value: value, + } +} + +func Enum(values ...any) *ParameterSchema { + return Union(exslices.CastFunc(values, Literal)...) +} + +func flattenUnion(variants []*ParameterSchema) []*ParameterSchema { + var flattened []*ParameterSchema + for _, variant := range variants { + switch variant.SchemaType { + case SchemaTypeArray: + panic(fmt.Errorf("illegal array schema in union")) + case SchemaTypeUnion: + flattened = append(flattened, flattenUnion(variant.Variants)...) + default: + flattened = append(flattened, variant) + } + } + return flattened +} + +func Union(variants ...*ParameterSchema) *ParameterSchema { + needsFlattening := false + for _, variant := range variants { + if variant.SchemaType == SchemaTypeArray { + panic(fmt.Errorf("illegal array schema in union")) + } else if variant.SchemaType == SchemaTypeUnion { + needsFlattening = true + } + } + if needsFlattening { + variants = flattenUnion(variants) + } + return &ParameterSchema{ + SchemaType: SchemaTypeUnion, + Variants: variants, + } +} + +func Array(items *ParameterSchema) *ParameterSchema { + if items.SchemaType == SchemaTypeArray { + panic(fmt.Errorf("illegal array schema in array")) + } + return &ParameterSchema{ + SchemaType: SchemaTypeArray, + Items: items, + } +} + +func (ps *ParameterSchema) GetDefaultValue() any { + if ps == nil { + return nil + } + switch ps.SchemaType { + case SchemaTypePrimitive: + switch ps.Type { + case PrimitiveTypeInteger: + return 0 + case PrimitiveTypeBoolean: + return false + default: + return "" + } + case SchemaTypeArray: + return []any{} + case SchemaTypeUnion: + if len(ps.Variants) > 0 { + return ps.Variants[0].GetDefaultValue() + } + return nil + case SchemaTypeLiteral: + return ps.Value + default: + return nil + } +} + +func (ps *ParameterSchema) IsValid() bool { + return ps.validate("") == nil +} + +func (ps *ParameterSchema) Validate() error { + return ps.validate("") +} + +func (ps *ParameterSchema) validate(parent SchemaType) error { + if ps == nil { + return fmt.Errorf("schema is nil") + } + switch ps.SchemaType { + case SchemaTypePrimitive: + if !ps.Type.IsValid() { + return fmt.Errorf("invalid primitive type %s", ps.Type) + } else if ps.Items != nil || ps.Variants != nil || ps.Value != nil { + return fmt.Errorf("primitive schema has extra fields") + } + return nil + case SchemaTypeArray: + if parent != "" { + return fmt.Errorf("arrays can't be nested in other types") + } else if err := ps.Items.validate(ps.SchemaType); err != nil { + return fmt.Errorf("item schema is invalid: %w", err) + } else if ps.Type != "" || ps.Variants != nil || ps.Value != nil { + return fmt.Errorf("array schema has extra fields") + } + return nil + case SchemaTypeUnion: + if len(ps.Variants) == 0 { + return fmt.Errorf("no variants specified for union") + } else if parent != "" && parent != SchemaTypeArray { + return fmt.Errorf("unions can't be nested in anything other than arrays") + } + for i, v := range ps.Variants { + if err := v.validate(ps.SchemaType); err != nil { + return fmt.Errorf("variant #%d is invalid: %w", i+1, err) + } + } + if ps.Type != "" || ps.Items != nil || ps.Value != nil { + return fmt.Errorf("union schema has extra fields") + } + return nil + case SchemaTypeLiteral: + switch typedVal := ps.Value.(type) { + case string, float64, int, int64, json.Number, bool, RoomIDValue, *RoomIDValue: + // ok + case map[string]any: + if typedVal["type"] != "event_id" && typedVal["type"] != "room_id" { + return fmt.Errorf("literal value has invalid map data") + } + default: + return fmt.Errorf("literal value has unsupported type %T", ps.Value) + } + if ps.Type != "" || ps.Items != nil || ps.Variants != nil { + return fmt.Errorf("literal schema has extra fields") + } + return nil + default: + return fmt.Errorf("invalid schema type %s", ps.SchemaType) + } +} + +func (ps *ParameterSchema) Equals(other *ParameterSchema) bool { + if ps == nil || other == nil { + return ps == other + } + return ps.SchemaType == other.SchemaType && + ps.Type == other.Type && + ps.Items.Equals(other.Items) && + slices.EqualFunc(ps.Variants, other.Variants, (*ParameterSchema).Equals) && + ps.Value == other.Value // TODO this won't work for room/event ID values +} + +func (ps *ParameterSchema) AllowsPrimitive(prim PrimitiveType) bool { + switch ps.SchemaType { + case SchemaTypePrimitive: + return ps.Type == prim + case SchemaTypeUnion: + for _, variant := range ps.Variants { + if variant.AllowsPrimitive(prim) { + return true + } + } + return false + case SchemaTypeArray: + return ps.Items.AllowsPrimitive(prim) + default: + return false + } +} diff --git a/event/cmdschema/parse.go b/event/cmdschema/parse.go new file mode 100644 index 00000000..92e69b60 --- /dev/null +++ b/event/cmdschema/parse.go @@ -0,0 +1,478 @@ +// Copyright (c) 2026 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package cmdschema + +import ( + "encoding/json" + "errors" + "fmt" + "regexp" + "strconv" + "strings" + + "maunium.net/go/mautrix/event" + "maunium.net/go/mautrix/id" +) + +const botArrayOpener = "<" +const botArrayCloser = ">" + +func parseQuoted(val string) (parsed, remaining string, quoted bool) { + if len(val) == 0 { + return + } + if !strings.HasPrefix(val, `"`) { + spaceIdx := strings.IndexByte(val, ' ') + if spaceIdx == -1 { + parsed = val + } else { + parsed = val[:spaceIdx] + remaining = strings.TrimLeft(val[spaceIdx+1:], " ") + } + return + } + val = val[1:] + var buf strings.Builder + for { + quoteIdx := strings.IndexByte(val, '"') + var valUntilQuote string + if quoteIdx == -1 { + valUntilQuote = val + } else { + valUntilQuote = val[:quoteIdx] + } + escapeIdx := strings.IndexByte(valUntilQuote, '\\') + if escapeIdx >= 0 { + buf.WriteString(val[:escapeIdx]) + if len(val) > escapeIdx+1 { + buf.WriteByte(val[escapeIdx+1]) + } + val = val[min(escapeIdx+2, len(val)):] + } else if quoteIdx >= 0 { + buf.WriteString(val[:quoteIdx]) + val = val[quoteIdx+1:] + break + } else if buf.Len() == 0 { + // Unterminated quote, no escape characters, val is the whole input + return val, "", true + } else { + // Unterminated quote, but there were escape characters previously + buf.WriteString(val) + val = "" + break + } + } + return buf.String(), strings.TrimLeft(val, " "), true +} + +// ParseInput tries to parse the given text into a bot command event matching this command definition. +// +// If the prefix doesn't match, this will return a nil content and nil error. +// If the prefix does match, some content is always returned, but there may still be an error if parsing failed. +func (ec *EventContent) ParseInput(owner id.UserID, sigils []string, input string) (content *event.MessageEventContent, err error) { + prefix := ec.parsePrefix(input, sigils, owner.String()) + if prefix == "" { + return nil, nil + } + content = &event.MessageEventContent{ + MsgType: event.MsgText, + Body: input, + Mentions: &event.Mentions{UserIDs: []id.UserID{owner}}, + MSC4391BotCommand: &event.MSC4391BotCommandInput{ + Command: ec.Command, + }, + } + content.MSC4391BotCommand.Arguments, err = ec.ParseArguments(input[len(prefix):]) + return content, err +} + +func (ec *EventContent) ParseArguments(input string) (json.RawMessage, error) { + args := make(map[string]any) + var retErr error + setError := func(err error) { + if err != nil && retErr == nil { + retErr = err + } + } + processParameter := func(param *Parameter, isLast, isTail, isNamed bool) { + origInput := input + var nextVal string + var wasQuoted bool + if param.Schema.SchemaType == SchemaTypeArray { + hasOpener := strings.HasPrefix(input, botArrayOpener) + arrayClosed := false + if hasOpener { + input = input[len(botArrayOpener):] + if strings.HasPrefix(input, botArrayCloser) { + input = strings.TrimLeft(input[len(botArrayCloser):], " ") + arrayClosed = true + } + } + var collector []any + for len(input) > 0 && !arrayClosed { + //origInput = input + nextVal, input, wasQuoted = parseQuoted(input) + if !wasQuoted && hasOpener && strings.HasSuffix(nextVal, botArrayCloser) { + // The value wasn't quoted and has the array delimiter at the end, close the array + nextVal = strings.TrimRight(nextVal, botArrayCloser) + arrayClosed = true + } else if hasOpener && strings.HasPrefix(input, botArrayCloser) { + // The value was quoted or there was a space, and the next character is the + // array delimiter, close the array + input = strings.TrimLeft(input[len(botArrayCloser):], " ") + arrayClosed = true + } else if !hasOpener && !isLast { + // For array arguments in the middle without the <> delimiters, stop after the first item + arrayClosed = true + } + parsedVal, err := param.Schema.Items.ParseString(nextVal) + if err == nil { + collector = append(collector, parsedVal) + } else if hasOpener || isLast { + setError(fmt.Errorf("failed to parse item #%d of array %s: %w", len(collector)+1, param.Key, err)) + } else { + //input = origInput + } + } + args[param.Key] = collector + } else { + nextVal, input, wasQuoted = parseQuoted(input) + if (isLast || isTail) && !wasQuoted && len(input) > 0 { + // If the last argument is not quoted, just treat the rest of the string + // as the argument without escapes (arguments with escapes should be quoted). + nextVal += " " + input + input = "" + } + // Special case for named boolean parameters: if no value is given, treat it as true + if nextVal == "" && !wasQuoted && isNamed && param.Schema.AllowsPrimitive(PrimitiveTypeBoolean) { + args[param.Key] = true + return + } + if nextVal == "" && !wasQuoted && !isNamed && !param.Optional { + setError(fmt.Errorf("missing value for required parameter %s", param.Key)) + } + parsedVal, err := param.Schema.ParseString(nextVal) + if err != nil { + args[param.Key] = param.GetDefaultValue() + // For optional parameters that fail to parse, restore the input and try passing it as the next parameter + if param.Optional && !isLast && !isNamed { + input = strings.TrimLeft(origInput, " ") + } else if !param.Optional || isNamed { + setError(fmt.Errorf("failed to parse %s: %w", param.Key, err)) + } + } else { + args[param.Key] = parsedVal + } + } + } + skipParams := make([]bool, len(ec.Parameters)) + for i, param := range ec.Parameters { + for strings.HasPrefix(input, "--") { + nameEndIdx := strings.IndexAny(input, " =") + if nameEndIdx == -1 { + nameEndIdx = len(input) + } + overrideParam, paramIdx := ec.parameterByName(input[2:nameEndIdx]) + if overrideParam != nil { + // Trim the equals sign, but leave spaces alone to let parseQuoted treat it as empty input + input = strings.TrimPrefix(input[nameEndIdx:], "=") + skipParams[paramIdx] = true + processParameter(overrideParam, false, false, true) + } else { + break + } + } + isTail := param.Key == ec.TailParam + if skipParams[i] || (param.Optional && !isTail) { + continue + } + processParameter(param, i == len(ec.Parameters)-1, isTail, false) + } + jsonArgs, marshalErr := json.Marshal(args) + if marshalErr != nil { + return nil, fmt.Errorf("failed to marshal arguments: %w", marshalErr) + } + return jsonArgs, retErr +} + +func (ec *EventContent) parameterByName(name string) (*Parameter, int) { + for i, param := range ec.Parameters { + if strings.EqualFold(param.Key, name) { + return param, i + } + } + return nil, -1 +} + +func (ec *EventContent) parsePrefix(origInput string, sigils []string, owner string) (prefix string) { + input := origInput + var chosenSigil string + for _, sigil := range sigils { + if strings.HasPrefix(input, sigil) { + chosenSigil = sigil + break + } + } + if chosenSigil == "" { + return "" + } + input = input[len(chosenSigil):] + var chosenAlias string + if !strings.HasPrefix(input, ec.Command) { + for _, alias := range ec.Aliases { + if strings.HasPrefix(input, alias) { + chosenAlias = alias + break + } + } + if chosenAlias == "" { + return "" + } + } else { + chosenAlias = ec.Command + } + input = strings.TrimPrefix(input[len(chosenAlias):], owner) + if input == "" || input[0] == ' ' { + input = strings.TrimLeft(input, " ") + return origInput[:len(origInput)-len(input)] + } + return "" +} + +func (pt PrimitiveType) ValidateValue(value any) bool { + _, err := pt.NormalizeValue(value) + return err == nil +} + +func normalizeNumber(value any) (int, error) { + switch typedValue := value.(type) { + case int: + return typedValue, nil + case int64: + return int(typedValue), nil + case float64: + return int(typedValue), nil + case json.Number: + if i, err := typedValue.Int64(); err != nil { + return 0, fmt.Errorf("failed to parse json.Number: %w", err) + } else { + return int(i), nil + } + default: + return 0, fmt.Errorf("unsupported type %T for integer", value) + } +} + +func (pt PrimitiveType) NormalizeValue(value any) (any, error) { + switch pt { + case PrimitiveTypeInteger: + return normalizeNumber(value) + case PrimitiveTypeBoolean: + bv, ok := value.(bool) + if !ok { + return nil, fmt.Errorf("unsupported type %T for boolean", value) + } + return bv, nil + case PrimitiveTypeString, PrimitiveTypeServerName: + str, ok := value.(string) + if !ok { + return nil, fmt.Errorf("unsupported type %T for string", value) + } + return str, pt.validateStringValue(str) + case PrimitiveTypeUserID, PrimitiveTypeRoomAlias: + str, ok := value.(string) + if !ok { + return nil, fmt.Errorf("unsupported type %T for user ID or room alias", value) + } else if plainErr := pt.validateStringValue(str); plainErr == nil { + return str, nil + } else if parsed, err := id.ParseMatrixURIOrMatrixToURL(str); err != nil { + return nil, fmt.Errorf("couldn't parse %q as plain ID nor matrix URI: %w / %w", value, plainErr, err) + } else if parsed.Sigil1 == '@' && pt == PrimitiveTypeUserID { + return parsed.UserID(), nil + } else if parsed.Sigil1 == '#' && pt == PrimitiveTypeRoomAlias { + return parsed.RoomAlias(), nil + } else { + return nil, fmt.Errorf("unexpected sigil %c for user ID or room alias", parsed.Sigil1) + } + case PrimitiveTypeRoomID, PrimitiveTypeEventID: + riv, err := NormalizeRoomIDValue(value) + if err != nil { + return nil, err + } + return riv, riv.Validate() + default: + return nil, fmt.Errorf("cannot normalize value for argument type %s", pt) + } +} + +func (pt PrimitiveType) validateStringValue(value string) error { + switch pt { + case PrimitiveTypeString: + return nil + case PrimitiveTypeServerName: + if !id.ValidateServerName(value) { + return fmt.Errorf("invalid server name: %q", value) + } + return nil + case PrimitiveTypeUserID: + _, _, err := id.UserID(value).ParseAndValidateRelaxed() + return err + case PrimitiveTypeRoomAlias: + sigil, localpart, serverName := id.ParseCommonIdentifier(value) + if sigil != '#' || localpart == "" || serverName == "" { + return fmt.Errorf("invalid room alias: %q", value) + } else if !id.ValidateServerName(serverName) { + return fmt.Errorf("invalid server name in room alias: %q", serverName) + } + return nil + default: + panic(fmt.Errorf("validateStringValue called with invalid type %s", pt)) + } +} + +func parseBoolean(val string) (bool, error) { + if len(val) == 0 { + return false, fmt.Errorf("cannot parse empty string as boolean") + } + switch strings.ToLower(val) { + case "t", "true", "y", "yes", "1": + return true, nil + case "f", "false", "n", "no", "0": + return false, nil + default: + return false, fmt.Errorf("invalid boolean string: %q", val) + } +} + +var markdownLinkRegex = regexp.MustCompile(`^\[.+]\(([^)]+)\)$`) + +func parseRoomOrEventID(value string) (*RoomIDValue, error) { + if strings.HasPrefix(value, "[") && strings.Contains(value, "](") && strings.HasSuffix(value, ")") { + matches := markdownLinkRegex.FindStringSubmatch(value) + if len(matches) == 2 { + value = matches[1] + } + } + parsed, err := id.ParseMatrixURIOrMatrixToURL(value) + if err != nil && strings.HasPrefix(value, "!") { + return &RoomIDValue{ + Type: PrimitiveTypeRoomID, + RoomID: id.RoomID(value), + }, nil + } + if err != nil { + return nil, err + } else if parsed.Sigil1 != '!' { + return nil, fmt.Errorf("unexpected sigil %c for room ID", parsed.Sigil1) + } else if parsed.MXID2 != "" && parsed.Sigil2 != '$' { + return nil, fmt.Errorf("unexpected sigil %c for event ID", parsed.Sigil2) + } + valType := PrimitiveTypeRoomID + if parsed.MXID2 != "" { + valType = PrimitiveTypeEventID + } + return &RoomIDValue{ + Type: valType, + RoomID: parsed.RoomID(), + Via: parsed.Via, + EventID: parsed.EventID(), + }, nil +} + +func (pt PrimitiveType) ParseString(value string) (any, error) { + switch pt { + case PrimitiveTypeInteger: + return strconv.Atoi(value) + case PrimitiveTypeBoolean: + return parseBoolean(value) + case PrimitiveTypeString, PrimitiveTypeServerName, PrimitiveTypeUserID: + return value, pt.validateStringValue(value) + case PrimitiveTypeRoomAlias: + plainErr := pt.validateStringValue(value) + if plainErr == nil { + return value, nil + } + parsed, err := id.ParseMatrixURIOrMatrixToURL(value) + if err != nil { + return nil, fmt.Errorf("couldn't parse %q as plain room alias nor matrix URI: %w / %w", value, plainErr, err) + } else if parsed.Sigil1 != '#' { + return nil, fmt.Errorf("unexpected sigil %c for room alias", parsed.Sigil1) + } + return parsed.RoomAlias(), nil + case PrimitiveTypeRoomID, PrimitiveTypeEventID: + parsed, err := parseRoomOrEventID(value) + if err != nil { + return nil, err + } else if pt != parsed.Type { + return nil, fmt.Errorf("mismatching argument type: expected %s but got %s", pt, parsed.Type) + } + return parsed, nil + default: + return nil, fmt.Errorf("cannot parse string for argument type %s", pt) + } +} + +func (ps *ParameterSchema) ParseString(value string) (any, error) { + if ps == nil { + return nil, fmt.Errorf("parameter schema is nil") + } + switch ps.SchemaType { + case SchemaTypePrimitive: + return ps.Type.ParseString(value) + case SchemaTypeLiteral: + switch typedValue := ps.Value.(type) { + case string: + if value == typedValue { + return typedValue, nil + } else { + return nil, fmt.Errorf("literal value %q does not match %q", typedValue, value) + } + case int, int64, float64, json.Number: + expectedVal, _ := normalizeNumber(typedValue) + intVal, err := strconv.Atoi(value) + if err != nil { + return nil, fmt.Errorf("failed to parse integer literal: %w", err) + } else if intVal != expectedVal { + return nil, fmt.Errorf("literal value %d does not match %d", expectedVal, intVal) + } + return intVal, nil + case bool: + boolVal, err := parseBoolean(value) + if err != nil { + return nil, fmt.Errorf("failed to parse boolean literal: %w", err) + } else if boolVal != typedValue { + return nil, fmt.Errorf("literal value %t does not match %t", typedValue, boolVal) + } + return boolVal, nil + case RoomIDValue, *RoomIDValue, map[string]any, json.RawMessage: + expectedVal, _ := NormalizeRoomIDValue(typedValue) + parsed, err := parseRoomOrEventID(value) + if err != nil { + return nil, fmt.Errorf("failed to parse room or event ID literal: %w", err) + } else if !parsed.Equals(expectedVal) { + return nil, fmt.Errorf("literal value %s does not match %s", expectedVal, parsed) + } + return parsed, nil + default: + return nil, fmt.Errorf("unsupported literal type %T", ps.Value) + } + case SchemaTypeUnion: + var errs []error + for _, variant := range ps.Variants { + if parsed, err := variant.ParseString(value); err == nil { + return parsed, nil + } else { + errs = append(errs, err) + } + } + return nil, fmt.Errorf("no union variant matched: %w", errors.Join(errs...)) + case SchemaTypeArray: + return nil, fmt.Errorf("cannot parse string for array schema type") + default: + return nil, fmt.Errorf("unknown schema type %s", ps.SchemaType) + } +} diff --git a/event/cmdschema/parse_test.go b/event/cmdschema/parse_test.go new file mode 100644 index 00000000..1e0d1817 --- /dev/null +++ b/event/cmdschema/parse_test.go @@ -0,0 +1,118 @@ +// Copyright (c) 2026 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package cmdschema + +import ( + "bytes" + "encoding/json" + "strings" + "testing" + + "github.com/stretchr/testify/assert" + "go.mau.fi/util/exbytes" + "go.mau.fi/util/exerrors" + + "maunium.net/go/mautrix/event/cmdschema/testdata" +) + +type QuoteParseOutput struct { + Parsed string + Remaining string + Quoted bool +} + +func (qpo *QuoteParseOutput) UnmarshalJSON(data []byte) error { + var arr []any + if err := json.Unmarshal(data, &arr); err != nil { + return err + } + qpo.Parsed = arr[0].(string) + qpo.Remaining = arr[1].(string) + qpo.Quoted = arr[2].(bool) + return nil +} + +type QuoteParseTestData struct { + Name string `json:"name"` + Input string `json:"input"` + Output QuoteParseOutput `json:"output"` +} + +func loadFile[T any](name string) (into T) { + quoteData := exerrors.Must(testdata.FS.ReadFile(name)) + exerrors.PanicIfNotNil(json.Unmarshal(quoteData, &into)) + return +} + +func TestParseQuoted(t *testing.T) { + qptd := loadFile[[]QuoteParseTestData]("parse_quote.json") + for _, test := range qptd { + t.Run(test.Name, func(t *testing.T) { + parsed, remaining, quoted := parseQuoted(test.Input) + assert.Equalf(t, test.Output, QuoteParseOutput{ + Parsed: parsed, + Remaining: remaining, + Quoted: quoted, + }, "Failed with input `%s`", test.Input) + // Note: can't just test that requoted == input, because some inputs + // have unnecessary escapes which won't survive roundtripping + t.Run("roundtrip", func(t *testing.T) { + requoted := quoteString(parsed) + " " + remaining + reparsed, newRemaining, _ := parseQuoted(requoted) + assert.Equal(t, parsed, reparsed) + assert.Equal(t, remaining, newRemaining) + }) + }) + } +} + +type CommandTestData struct { + Spec *EventContent + Tests []*CommandTestUnit +} + +type CommandTestUnit struct { + Name string `json:"name"` + Input string `json:"input"` + Broken string `json:"broken,omitempty"` + Error bool `json:"error"` + Output json.RawMessage `json:"output"` +} + +func compactJSON(input json.RawMessage) json.RawMessage { + var buf bytes.Buffer + exerrors.PanicIfNotNil(json.Compact(&buf, input)) + return buf.Bytes() +} + +func TestMSC4391BotCommandEventContent_ParseInput(t *testing.T) { + for _, cmd := range exerrors.Must(testdata.FS.ReadDir("commands")) { + t.Run(strings.TrimSuffix(cmd.Name(), ".json"), func(t *testing.T) { + ctd := loadFile[CommandTestData]("commands/" + cmd.Name()) + for _, test := range ctd.Tests { + outputStr := exbytes.UnsafeString(compactJSON(test.Output)) + t.Run(test.Name, func(t *testing.T) { + if test.Broken != "" { + t.Skip(test.Broken) + } + output, err := ctd.Spec.ParseInput("@testbot", []string{"/"}, test.Input) + if test.Error { + assert.Error(t, err) + } else { + assert.NoError(t, err) + } + if outputStr == "null" { + assert.Nil(t, output) + } else { + assert.Equal(t, ctd.Spec.Command, output.MSC4391BotCommand.Command) + assert.Equalf(t, outputStr, exbytes.UnsafeString(output.MSC4391BotCommand.Arguments), "Input: %s", test.Input) + } + }) + } + }) + } +} diff --git a/event/cmdschema/roomid.go b/event/cmdschema/roomid.go new file mode 100644 index 00000000..98c421fc --- /dev/null +++ b/event/cmdschema/roomid.go @@ -0,0 +1,135 @@ +// Copyright (c) 2026 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package cmdschema + +import ( + "encoding/json" + "fmt" + "slices" + "strings" + + "maunium.net/go/mautrix/id" +) + +var ParameterSchemaJoinableRoom = Union( + PrimitiveTypeRoomID.Schema(), + PrimitiveTypeRoomAlias.Schema(), +) + +type RoomIDValue struct { + Type PrimitiveType `json:"type"` + RoomID id.RoomID `json:"id"` + Via []string `json:"via,omitempty"` + EventID id.EventID `json:"event_id,omitempty"` +} + +func NormalizeRoomIDValue(input any) (riv *RoomIDValue, err error) { + switch typedValue := input.(type) { + case map[string]any, json.RawMessage: + var raw json.RawMessage + if raw, err = json.Marshal(input); err != nil { + err = fmt.Errorf("failed to roundtrip room ID value: %w", err) + } else if err = json.Unmarshal(raw, &riv); err != nil { + err = fmt.Errorf("failed to roundtrip room ID value: %w", err) + } + case *RoomIDValue: + riv = typedValue + case RoomIDValue: + riv = &typedValue + default: + err = fmt.Errorf("unsupported type %T for room or event ID", input) + } + return +} + +func (riv *RoomIDValue) String() string { + return riv.URI().String() +} + +func (riv *RoomIDValue) URI() *id.MatrixURI { + if riv == nil { + return nil + } + switch riv.Type { + case PrimitiveTypeRoomID: + return riv.RoomID.URI(riv.Via...) + case PrimitiveTypeEventID: + return riv.RoomID.EventURI(riv.EventID, riv.Via...) + default: + return nil + } +} + +func (riv *RoomIDValue) Equals(other *RoomIDValue) bool { + if riv == nil || other == nil { + return riv == other + } + return riv.Type == other.Type && + riv.RoomID == other.RoomID && + riv.EventID == other.EventID && + slices.Equal(riv.Via, other.Via) +} + +func (riv *RoomIDValue) Validate() error { + if riv == nil { + return fmt.Errorf("value is nil") + } + switch riv.Type { + case PrimitiveTypeRoomID: + if riv.EventID != "" { + return fmt.Errorf("event ID must be empty for room ID type") + } + case PrimitiveTypeEventID: + if !strings.HasPrefix(riv.EventID.String(), "$") { + return fmt.Errorf("event ID not valid: %q", riv.EventID) + } + default: + return fmt.Errorf("unexpected type %s for room/event ID value", riv.Type) + } + for _, via := range riv.Via { + if !id.ValidateServerName(via) { + return fmt.Errorf("invalid server name %q in vias", via) + } + } + sigil, localpart, serverName := id.ParseCommonIdentifier(riv.RoomID) + if sigil != '!' { + return fmt.Errorf("room ID does not start with !: %q", riv.RoomID) + } else if localpart == "" && serverName == "" { + return fmt.Errorf("room ID has empty localpart and server name: %q", riv.RoomID) + } else if serverName != "" && !id.ValidateServerName(serverName) { + return fmt.Errorf("invalid server name %q in room ID", serverName) + } + return nil +} + +func (riv *RoomIDValue) IsValid() bool { + return riv.Validate() == nil +} + +type RoomIDOrString string + +func (ros *RoomIDOrString) UnmarshalJSON(data []byte) error { + if len(data) == 0 { + return fmt.Errorf("empty data for room ID or string") + } + if data[0] == '"' { + var str string + if err := json.Unmarshal(data, &str); err != nil { + return err + } + *ros = RoomIDOrString(str) + return nil + } + var riv RoomIDValue + if err := json.Unmarshal(data, &riv); err != nil { + return err + } else if err = riv.Validate(); err != nil { + return err + } + *ros = RoomIDOrString(riv.String()) + return nil +} diff --git a/event/cmdschema/stringify.go b/event/cmdschema/stringify.go new file mode 100644 index 00000000..c5c57c53 --- /dev/null +++ b/event/cmdschema/stringify.go @@ -0,0 +1,122 @@ +// Copyright (c) 2026 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package cmdschema + +import ( + "encoding/json" + "strconv" + "strings" +) + +var quoteEscaper = strings.NewReplacer( + `"`, `\"`, + `\`, `\\`, +) + +const charsToQuote = ` \` + botArrayOpener + botArrayCloser + +func quoteString(val string) string { + if val == "" { + return `""` + } + val = quoteEscaper.Replace(val) + if strings.ContainsAny(val, charsToQuote) { + return `"` + val + `"` + } + return val +} + +func (ec *EventContent) StringifyArgs(args any) string { + var argMap map[string]any + switch typedArgs := args.(type) { + case json.RawMessage: + err := json.Unmarshal(typedArgs, &argMap) + if err != nil { + return "" + } + case map[string]any: + argMap = typedArgs + default: + if b, err := json.Marshal(args); err != nil { + return "" + } else if err = json.Unmarshal(b, &argMap); err != nil { + return "" + } + } + parts := make([]string, 0, len(ec.Parameters)) + for i, param := range ec.Parameters { + isLast := i == len(ec.Parameters)-1 + val := argMap[param.Key] + if val == nil { + val = param.DefaultValue + if val == nil && !param.Optional { + val = param.Schema.GetDefaultValue() + } + } + if val == nil { + continue + } + var stringified string + if param.Schema.SchemaType == SchemaTypeArray { + stringified = arrayArgumentToString(val, isLast) + } else { + stringified = singleArgumentToString(val) + } + if stringified != "" { + parts = append(parts, stringified) + } + } + return strings.Join(parts, " ") +} + +func arrayArgumentToString(val any, isLast bool) string { + valArr, ok := val.([]any) + if !ok { + return "" + } + parts := make([]string, 0, len(valArr)) + for _, elem := range valArr { + stringified := singleArgumentToString(elem) + if stringified != "" { + parts = append(parts, stringified) + } + } + joinedParts := strings.Join(parts, " ") + if isLast && len(parts) > 0 { + return joinedParts + } + return botArrayOpener + joinedParts + botArrayCloser +} + +func singleArgumentToString(val any) string { + switch typedVal := val.(type) { + case string: + return quoteString(typedVal) + case json.Number: + return typedVal.String() + case bool: + return strconv.FormatBool(typedVal) + case int: + return strconv.Itoa(typedVal) + case int64: + return strconv.FormatInt(typedVal, 10) + case float64: + return strconv.FormatInt(int64(typedVal), 10) + case map[string]any, json.RawMessage, RoomIDValue, *RoomIDValue: + normalized, err := NormalizeRoomIDValue(typedVal) + if err != nil { + return "" + } + uri := normalized.URI() + if uri == nil { + return "" + } + return quoteString(uri.String()) + default: + return "" + } +} diff --git a/event/cmdschema/testdata/commands.schema.json b/event/cmdschema/testdata/commands.schema.json new file mode 100644 index 00000000..e53382db --- /dev/null +++ b/event/cmdschema/testdata/commands.schema.json @@ -0,0 +1,281 @@ +{ + "$schema": "https://json-schema.org/draft/2020-12/schema#", + "$id": "commands.schema.json", + "title": "ParseInput test cases", + "description": "JSON schema for test case files containing command specifications and test cases", + "type": "object", + "required": [ + "spec", + "tests" + ], + "additionalProperties": false, + "properties": { + "spec": { + "title": "MSC4391 Command Description", + "description": "JSON schema defining the structure of a bot command event content", + "type": "object", + "required": [ + "command" + ], + "additionalProperties": false, + "properties": { + "command": { + "type": "string", + "description": "The command name that triggers this bot command" + }, + "aliases": { + "type": "array", + "description": "Alternative names/aliases for this command", + "items": { + "type": "string" + } + }, + "parameters": { + "type": "array", + "description": "List of parameters accepted by this command", + "items": { + "$ref": "#/$defs/Parameter" + } + }, + "description": { + "$ref": "#/$defs/ExtensibleTextContainer", + "description": "Human-readable description of the command" + }, + "fi.mau.tail_parameter": { + "type": "string", + "description": "The key of the parameter that accepts remaining arguments as tail text" + }, + "source": { + "type": "string", + "description": "The user ID of the bot that responds to this command" + } + } + }, + "tests": { + "type": "array", + "description": "Array of test cases for the command", + "items": { + "type": "object", + "description": "A single test case for command parsing", + "required": [ + "name", + "input" + ], + "additionalProperties": false, + "properties": { + "name": { + "type": "string", + "description": "The name of the test case" + }, + "input": { + "type": "string", + "description": "The command input string to parse" + }, + "output": { + "description": "The expected parsed parameter values, or null if the parsing is expected to fail", + "oneOf": [ + { + "type": "object", + "additionalProperties": true + }, + { + "type": "null" + } + ] + }, + "error": { + "type": "boolean", + "description": "Whether parsing should result in an error. May still produce output.", + "default": false + } + } + } + } + }, + "$defs": { + "ExtensibleTextContainer": { + "type": "object", + "description": "Container for text that can have multiple representations", + "required": [ + "m.text" + ], + "properties": { + "m.text": { + "type": "array", + "description": "Array of text representations in different formats", + "items": { + "$ref": "#/$defs/ExtensibleText" + } + } + } + }, + "ExtensibleText": { + "type": "object", + "description": "A text representation with a specific MIME type", + "required": [ + "body" + ], + "properties": { + "body": { + "type": "string", + "description": "The text content" + }, + "mimetype": { + "type": "string", + "description": "The MIME type of the text (e.g., text/plain, text/html)", + "default": "text/plain", + "examples": [ + "text/plain", + "text/html" + ] + } + } + }, + "Parameter": { + "type": "object", + "description": "A parameter definition for a command", + "required": [ + "key", + "schema" + ], + "additionalProperties": false, + "properties": { + "key": { + "type": "string", + "description": "The identifier for this parameter" + }, + "schema": { + "$ref": "#/$defs/ParameterSchema", + "description": "The schema defining the type and structure of this parameter" + }, + "optional": { + "type": "boolean", + "description": "Whether this parameter is optional", + "default": false + }, + "description": { + "$ref": "#/$defs/ExtensibleTextContainer", + "description": "Human-readable description of this parameter" + }, + "fi.mau.default_value": { + "description": "Default value for this parameter if not provided" + } + } + }, + "ParameterSchema": { + "type": "object", + "description": "Schema definition for a parameter value", + "required": [ + "schema_type" + ], + "additionalProperties": false, + "properties": { + "schema_type": { + "type": "string", + "enum": [ + "primitive", + "array", + "union", + "literal" + ], + "description": "The type of schema" + } + }, + "allOf": [ + { + "if": { + "properties": { + "schema_type": { + "const": "primitive" + } + } + }, + "then": { + "required": [ + "type" + ], + "properties": { + "type": { + "type": "string", + "enum": [ + "string", + "integer", + "boolean", + "server_name", + "user_id", + "room_id", + "room_alias", + "event_id" + ], + "description": "The primitive type (only for schema_type: primitive)" + } + } + } + }, + { + "if": { + "properties": { + "schema_type": { + "const": "array" + } + } + }, + "then": { + "required": [ + "items" + ], + "properties": { + "items": { + "$ref": "#/$defs/ParameterSchema", + "description": "The schema for array items (only for schema_type: array)" + } + } + } + }, + { + "if": { + "properties": { + "schema_type": { + "const": "union" + } + } + }, + "then": { + "required": [ + "variants" + ], + "properties": { + "variants": { + "type": "array", + "description": "The possible variants (only for schema_type: union)", + "items": { + "$ref": "#/$defs/ParameterSchema" + }, + "minItems": 1 + } + } + } + }, + { + "if": { + "properties": { + "schema_type": { + "const": "literal" + } + } + }, + "then": { + "required": [ + "value" + ], + "properties": { + "value": { + "description": "The literal value (only for schema_type: literal)" + } + } + } + } + ] + } + } +} diff --git a/event/cmdschema/testdata/commands/flags.json b/event/cmdschema/testdata/commands/flags.json new file mode 100644 index 00000000..6ce1f4da --- /dev/null +++ b/event/cmdschema/testdata/commands/flags.json @@ -0,0 +1,126 @@ +{ + "$schema": "../commands.schema.json#", + "spec": { + "command": "flag", + "source": "@testbot", + "parameters": [ + { + "key": "meow", + "schema": { + "schema_type": "primitive", + "type": "string" + } + }, + { + "key": "user", + "schema": { + "schema_type": "primitive", + "type": "user_id" + }, + "optional": true + }, + { + "key": "woof", + "schema": { + "schema_type": "primitive", + "type": "boolean" + }, + "optional": true, + "fi.mau.default_value": false + } + ], + "fi.mau.tail_parameter": "user" + }, + "tests": [ + { + "name": "no flags", + "input": "/flag mrrp", + "output": { + "meow": "mrrp", + "user": null + } + }, + { + "name": "no flags, has tail", + "input": "/flag mrrp @user:example.com", + "output": { + "meow": "mrrp", + "user": "@user:example.com" + } + }, + { + "name": "named flag at start", + "input": "/flag --woof=yes mrrp @user:example.com", + "output": { + "meow": "mrrp", + "user": "@user:example.com", + "woof": true + } + }, + { + "name": "boolean flag without value", + "input": "/flag --woof mrrp @user:example.com", + "output": { + "meow": "mrrp", + "user": "@user:example.com", + "woof": true + } + }, + { + "name": "user id flag without value", + "input": "/flag --user --woof mrrp", + "error": true, + "output": { + "meow": "mrrp", + "user": null, + "woof": true + } + }, + { + "name": "named flag in the middle", + "input": "/flag mrrp --woof=yes @user:example.com", + "output": { + "meow": "mrrp", + "user": "@user:example.com", + "woof": true + } + }, + { + "name": "named flag in the middle with different value", + "input": "/flag mrrp --woof=no @user:example.com", + "output": { + "meow": "mrrp", + "user": "@user:example.com", + "woof": false + } + }, + { + "name": "all variables named", + "input": "/flag --woof=no --meow=mrrp --user=@user:example.com", + "output": { + "meow": "mrrp", + "user": "@user:example.com", + "woof": false + } + }, + { + "name": "all variables named with quotes", + "input": "/flag --woof --meow=\"meow meow mrrp\" --user=\"@user:example.com\"", + "output": { + "meow": "meow meow mrrp", + "user": "@user:example.com", + "woof": true + } + }, + { + "name": "invalid value for named parameter", + "input": "/flag --user=meowings mrrp --woof", + "error": true, + "output": { + "meow": "mrrp", + "user": null, + "woof": true + } + } + ] +} diff --git a/event/cmdschema/testdata/commands/room_id_or_alias.json b/event/cmdschema/testdata/commands/room_id_or_alias.json new file mode 100644 index 00000000..1351c292 --- /dev/null +++ b/event/cmdschema/testdata/commands/room_id_or_alias.json @@ -0,0 +1,85 @@ +{ + "$schema": "../commands.schema.json#", + "spec": { + "command": "test room reference", + "source": "@testbot", + "parameters": [ + { + "key": "room", + "schema": { + "schema_type": "union", + "variants": [ + { + "schema_type": "primitive", + "type": "room_id" + }, + { + "schema_type": "primitive", + "type": "room_alias" + } + ] + } + } + ] + }, + "tests": [ + { + "name": "room alias", + "input": "/test room reference #test:matrix.org", + "output": { + "room": "#test:matrix.org" + } + }, + { + "name": "room id", + "input": "/test room reference !aiwVrNhPwbGBNjqlNu:matrix.org", + "output": { + "room": { + "type": "room_id", + "id": "!aiwVrNhPwbGBNjqlNu:matrix.org" + } + } + }, + { + "name": "room id matrix.to link", + "input": "/test room reference https://matrix.to/#/!aiwVrNhPwbGBNjqlNu:matrix.org?via=example.com", + "output": { + "room": { + "type": "room_id", + "id": "!aiwVrNhPwbGBNjqlNu:matrix.org", + "via": [ + "example.com" + ] + } + } + }, + { + "name": "room id matrix.to link with url encoding", + "input": "/test room reference https://matrix.to/#/!%23test%2Froom%0Aversion%20%3Cu%3E11%3C%2Fu%3E%2C%20with%20%40%F0%9F%90%88%EF%B8%8F%3Amaunium.net?via=maunium.net", + "broken": "Go's url.URL does url decoding on the fragment, which breaks splitting the path segments properly", + "output": { + "room": { + "type": "room_id", + "id": "!#test/room\nversion 11, with @🐈️:maunium.net", + "via": [ + "maunium.net" + ] + } + } + }, + { + "name": "room id matrix: URI", + "input": "/test room reference matrix:roomid/mauT12AzsoqxV7Abvy_ApA-HNPK1LcT4GbP70_AOPyQ?via=maunium.net&via=matrix.org", + "output": { + "room": { + "type": "room_id", + "id": "!mauT12AzsoqxV7Abvy_ApA-HNPK1LcT4GbP70_AOPyQ", + "via": [ + "maunium.net", + "matrix.org" + ] + } + } + } + ] +} diff --git a/event/cmdschema/testdata/commands/room_reference_list.json b/event/cmdschema/testdata/commands/room_reference_list.json new file mode 100644 index 00000000..aa266054 --- /dev/null +++ b/event/cmdschema/testdata/commands/room_reference_list.json @@ -0,0 +1,106 @@ +{ + "$schema": "../commands.schema.json#", + "spec": { + "command": "test room reference", + "source": "@testbot", + "parameters": [ + { + "key": "rooms", + "schema": { + "schema_type": "array", + "items": { + "schema_type": "union", + "variants": [ + { + "schema_type": "primitive", + "type": "room_id" + }, + { + "schema_type": "primitive", + "type": "room_alias" + } + ] + } + } + } + ] + }, + "tests": [ + { + "name": "room alias", + "input": "/test room reference #test:matrix.org", + "output": { + "rooms": [ + "#test:matrix.org" + ] + } + }, + { + "name": "room id", + "input": "/test room reference !aiwVrNhPwbGBNjqlNu:matrix.org", + "output": { + "rooms": [ + { + "type": "room_id", + "id": "!aiwVrNhPwbGBNjqlNu:matrix.org" + } + ] + } + }, + { + "name": "two room ids", + "input": "/test room reference !mauT12AzsoqxV7Abvy_ApA-HNPK1LcT4GbP70_AOPyQ !aiwVrNhPwbGBNjqlNu:matrix.org", + "output": { + "rooms": [ + { + "type": "room_id", + "id": "!mauT12AzsoqxV7Abvy_ApA-HNPK1LcT4GbP70_AOPyQ" + }, + { + "type": "room_id", + "id": "!aiwVrNhPwbGBNjqlNu:matrix.org" + } + ] + } + }, + { + "name": "room id matrix: URI", + "input": "/test room reference matrix:roomid/mauT12AzsoqxV7Abvy_ApA-HNPK1LcT4GbP70_AOPyQ?via=maunium.net&via=matrix.org", + "output": { + "rooms": [ + { + "type": "room_id", + "id": "!mauT12AzsoqxV7Abvy_ApA-HNPK1LcT4GbP70_AOPyQ", + "via": [ + "maunium.net", + "matrix.org" + ] + } + ] + } + }, + { + "name": "room id matrix: URI and matrix.to URL", + "input": "/test room reference https://matrix.to/#/!aiwVrNhPwbGBNjqlNu:matrix.org?via=example.com matrix:roomid/mauT12AzsoqxV7Abvy_ApA-HNPK1LcT4GbP70_AOPyQ?via=maunium.net&via=matrix.org", + "output": { + "rooms": [ + { + "type": "room_id", + "id": "!aiwVrNhPwbGBNjqlNu:matrix.org", + "via": [ + "example.com" + ] + }, + { + "type": "room_id", + "id": "!mauT12AzsoqxV7Abvy_ApA-HNPK1LcT4GbP70_AOPyQ", + "via": [ + "maunium.net", + "matrix.org" + ] + } + ] + } + } + ] +} diff --git a/event/cmdschema/testdata/commands/simple.json b/event/cmdschema/testdata/commands/simple.json new file mode 100644 index 00000000..94667323 --- /dev/null +++ b/event/cmdschema/testdata/commands/simple.json @@ -0,0 +1,46 @@ +{ + "$schema": "../commands.schema.json#", + "spec": { + "command": "test simple", + "source": "@testbot", + "parameters": [ + { + "key": "meow", + "schema": { + "schema_type": "primitive", + "type": "string" + } + } + ] + }, + "tests": [ + { + "name": "success", + "input": "/test simple mrrp", + "output": { + "meow": "mrrp" + } + }, + { + "name": "directed success", + "input": "/test simple@testbot mrrp", + "output": { + "meow": "mrrp" + } + }, + { + "name": "missing parameter", + "input": "/test simple", + "error": true, + "output": { + "meow": "" + } + }, + { + "name": "directed at another bot", + "input": "/test simple@anotherbot mrrp", + "error": false, + "output": null + } + ] +} diff --git a/event/cmdschema/testdata/commands/tail.json b/event/cmdschema/testdata/commands/tail.json new file mode 100644 index 00000000..9782f8ec --- /dev/null +++ b/event/cmdschema/testdata/commands/tail.json @@ -0,0 +1,60 @@ +{ + "$schema": "../commands.schema.json#", + "spec": { + "command": "tail", + "source": "@testbot", + "parameters": [ + { + "key": "meow", + "schema": { + "schema_type": "primitive", + "type": "string" + } + }, + { + "key": "reason", + "schema": { + "schema_type": "primitive", + "type": "string" + }, + "optional": true + }, + { + "key": "woof", + "schema": { + "schema_type": "primitive", + "type": "boolean" + }, + "optional": true + } + ], + "fi.mau.tail_parameter": "reason" + }, + "tests": [ + { + "name": "no tail or flag", + "input": "/tail mrrp", + "output": { + "meow": "mrrp", + "reason": "" + } + }, + { + "name": "tail, no flag", + "input": "/tail mrrp meow meow", + "output": { + "meow": "mrrp", + "reason": "meow meow" + } + }, + { + "name": "flag before tail", + "input": "/tail mrrp --woof meow meow", + "output": { + "meow": "mrrp", + "reason": "meow meow", + "woof": true + } + } + ] +} diff --git a/event/cmdschema/testdata/data.go b/event/cmdschema/testdata/data.go new file mode 100644 index 00000000..eceea3d2 --- /dev/null +++ b/event/cmdschema/testdata/data.go @@ -0,0 +1,14 @@ +// Copyright (c) 2026 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package testdata + +import ( + "embed" +) + +//go:embed * +var FS embed.FS diff --git a/event/cmdschema/testdata/parse_quote.json b/event/cmdschema/testdata/parse_quote.json new file mode 100644 index 00000000..8f52b7f5 --- /dev/null +++ b/event/cmdschema/testdata/parse_quote.json @@ -0,0 +1,30 @@ +[ + {"name": "empty string", "input": "", "output": ["", "", false]}, + {"name": "single word", "input": "meow", "output": ["meow", "", false]}, + {"name": "two words", "input": "meow woof", "output": ["meow", "woof", false]}, + {"name": "many words", "input": "meow meow mrrp", "output": ["meow", "meow mrrp", false]}, + {"name": "extra spaces", "input": "meow meow mrrp", "output": ["meow", "meow mrrp", false]}, + {"name": "trailing space", "input": "meow ", "output": ["meow", "", false]}, + {"name": "only spaces", "input": " ", "output": ["", "", false]}, + {"name": "leading spaces", "input": " meow woof", "output": ["", "meow woof", false]}, + {"name": "backslash at end unquoted", "input": "meow\\ woof", "output": ["meow\\", "woof", false]}, + {"name": "quoted word", "input": "\"meow\" meow mrrp", "output": ["meow", "meow mrrp", true]}, + {"name": "quoted words", "input": "\"meow meow\" mrrp", "output": ["meow meow", "mrrp", true]}, + {"name": "spaces in quotes", "input": "\" meow meow \" mrrp", "output": [" meow meow ", "mrrp", true]}, + {"name": "empty quoted string", "input": "\"\"", "output": ["", "", true]}, + {"name": "empty quoted with trailing", "input": "\"\" meow", "output": ["", "meow", true]}, + {"name": "quote no space before next", "input": "\"meow\"woof", "output": ["meow", "woof", true]}, + {"name": "just opening quote", "input": "\"", "output": ["", "", true]}, + {"name": "quote then space then text", "input": "\" meow", "output": [" meow", "", true]}, + {"name": "quotes after word", "input": "meow \" meow mrrp \"", "output": ["meow", "\" meow mrrp \"", false]}, + {"name": "escaped quote", "input": "\"meow\\\" meow\" mrrp", "output": ["meow\" meow", "mrrp", true]}, + {"name": "missing end quote", "input": "\"meow meow mrrp", "output": ["meow meow mrrp", "", true]}, + {"name": "missing end quote with escaped quote", "input": "\"meow\\\" meow mrrp", "output": ["meow\" meow mrrp", "", true]}, + {"name": "quote in the middle", "input": "me\"ow meow mrrp", "output": ["me\"ow", "meow mrrp", false]}, + {"name": "backslash in the middle", "input": "me\\ow meow mrrp", "output": ["me\\ow", "meow mrrp", false]}, + {"name": "other escaped character", "input": "\"m\\eow\" meow mrrp", "output": ["meow", "meow mrrp", true]}, + {"name": "escaped backslashes", "input": "\"m\\\\e\\\"ow\\\\\" meow mrrp", "output": ["m\\e\"ow\\", "meow mrrp", true]}, + {"name": "just quotes", "input": "\"\\\"\\\"\\\\\\\"\" meow", "output": ["\"\"\\\"", "meow", true]}, + {"name": "escape at eof", "input": "\"meow\\", "output": ["meow", "", true]}, + {"name": "escaped backslash at eof", "input": "\"meow\\\\", "output": ["meow\\", "", true]} +] diff --git a/event/cmdschema/testdata/parse_quote.schema.json b/event/cmdschema/testdata/parse_quote.schema.json new file mode 100644 index 00000000..9f249116 --- /dev/null +++ b/event/cmdschema/testdata/parse_quote.schema.json @@ -0,0 +1,46 @@ +{ + "$schema": "https://json-schema.org/draft/2020-12/schema#", + "$id": "parse_quote.schema.json", + "title": "parseQuote test cases", + "description": "Test cases for the parseQuoted function", + "type": "array", + "items": { + "type": "object", + "required": [ + "name", + "input", + "output" + ], + "properties": { + "name": { + "type": "string", + "description": "Name of the test case" + }, + "input": { + "type": "string", + "description": "Input string to be parsed" + }, + "output": { + "type": "array", + "description": "Expected output of parsing: [first word, remaining text, was quoted]", + "minItems": 3, + "maxItems": 3, + "prefixItems": [ + { + "type": "string", + "description": "First parsed word" + }, + { + "type": "string", + "description": "Remaining text after the first word" + }, + { + "type": "boolean", + "description": "Whether the first word was quoted" + } + ] + } + }, + "additionalProperties": false + } +} diff --git a/event/content.go b/event/content.go index ab57c658..814aeec4 100644 --- a/event/content.go +++ b/event/content.go @@ -18,6 +18,7 @@ import ( // This is used by Content.ParseRaw() for creating the correct type of struct. var TypeMap = map[Type]reflect.Type{ StateMember: reflect.TypeOf(MemberEventContent{}), + StateThirdPartyInvite: reflect.TypeOf(ThirdPartyInviteEventContent{}), StatePowerLevels: reflect.TypeOf(PowerLevelsEventContent{}), StateCanonicalAlias: reflect.TypeOf(CanonicalAliasEventContent{}), StateRoomName: reflect.TypeOf(RoomNameEventContent{}), @@ -38,7 +39,9 @@ var TypeMap = map[Type]reflect.Type{ StateHalfShotBridge: reflect.TypeOf(BridgeEventContent{}), StateSpaceParent: reflect.TypeOf(SpaceParentEventContent{}), StateSpaceChild: reflect.TypeOf(SpaceChildEventContent{}), - StateInsertionMarker: reflect.TypeOf(InsertionMarkerContent{}), + + StateRoomPolicy: reflect.TypeOf(RoomPolicyEventContent{}), + StateUnstableRoomPolicy: reflect.TypeOf(RoomPolicyEventContent{}), StateLegacyPolicyRoom: reflect.TypeOf(ModPolicyContent{}), StateLegacyPolicyServer: reflect.TypeOf(ModPolicyContent{}), @@ -48,6 +51,8 @@ var TypeMap = map[Type]reflect.Type{ StateUnstablePolicyUser: reflect.TypeOf(ModPolicyContent{}), StateElementFunctionalMembers: reflect.TypeOf(ElementFunctionalMembersContent{}), + StateBeeperRoomFeatures: reflect.TypeOf(RoomFeatures{}), + StateBeeperDisappearingTimer: reflect.TypeOf(BeeperDisappearingTimer{}), EventMessage: reflect.TypeOf(MessageEventContent{}), EventSticker: reflect.TypeOf(MessageEventContent{}), @@ -58,7 +63,11 @@ var TypeMap = map[Type]reflect.Type{ EventUnstablePollStart: reflect.TypeOf(PollStartEventContent{}), EventUnstablePollResponse: reflect.TypeOf(PollResponseEventContent{}), - BeeperMessageStatus: reflect.TypeOf(BeeperMessageStatusEventContent{}), + BeeperMessageStatus: reflect.TypeOf(BeeperMessageStatusEventContent{}), + BeeperTranscription: reflect.TypeOf(BeeperTranscriptionEventContent{}), + BeeperDeleteChat: reflect.TypeOf(BeeperChatDeleteEventContent{}), + BeeperAcceptMessageRequest: reflect.TypeOf(BeeperAcceptMessageRequestEventContent{}), + BeeperSendState: reflect.TypeOf(BeeperSendStateEventContent{}), AccountDataRoomTags: reflect.TypeOf(TagEventContent{}), AccountDataDirectChats: reflect.TypeOf(DirectChatsEventContent{}), @@ -67,9 +76,11 @@ var TypeMap = map[Type]reflect.Type{ AccountDataMarkedUnread: reflect.TypeOf(MarkedUnreadEventContent{}), AccountDataBeeperMute: reflect.TypeOf(BeeperMuteEventContent{}), - EphemeralEventTyping: reflect.TypeOf(TypingEventContent{}), - EphemeralEventReceipt: reflect.TypeOf(ReceiptEventContent{}), - EphemeralEventPresence: reflect.TypeOf(PresenceEventContent{}), + EphemeralEventTyping: reflect.TypeOf(TypingEventContent{}), + EphemeralEventReceipt: reflect.TypeOf(ReceiptEventContent{}), + EphemeralEventPresence: reflect.TypeOf(PresenceEventContent{}), + EphemeralEventEncrypted: reflect.TypeOf(EncryptedEventContent{}), + BeeperEphemeralEventAIStream: reflect.TypeOf(BeeperAIStreamEventContent{}), InRoomVerificationReady: reflect.TypeOf(VerificationReadyEventContent{}), InRoomVerificationStart: reflect.TypeOf(VerificationStartEventContent{}), @@ -121,7 +132,7 @@ var TypeMap = map[Type]reflect.Type{ // When being marshaled into JSON, the data in Parsed will be marshaled first and then recursively merged // with the data in Raw. Values in Raw are preferred, but nested objects will be recursed into before merging, // rather than overriding the whole object with the one in Raw). -// If one of them is nil, the only the other is used. If both (Parsed and Raw) are nil, VeryRaw is used instead. +// If one of them is nil, then only the other is used. If both (Parsed and Raw) are nil, VeryRaw is used instead. type Content struct { VeryRaw json.RawMessage Raw map[string]interface{} diff --git a/event/delayed.go b/event/delayed.go new file mode 100644 index 00000000..fefb62af --- /dev/null +++ b/event/delayed.go @@ -0,0 +1,70 @@ +package event + +import ( + "encoding/json" + + "go.mau.fi/util/jsontime" + + "maunium.net/go/mautrix/id" +) + +type ScheduledDelayedEvent struct { + DelayID id.DelayID `json:"delay_id"` + RoomID id.RoomID `json:"room_id"` + Type Type `json:"type"` + StateKey *string `json:"state_key,omitempty"` + Delay int64 `json:"delay"` + RunningSince jsontime.UnixMilli `json:"running_since"` + Content Content `json:"content"` +} + +func (e ScheduledDelayedEvent) AsEvent(eventID id.EventID, ts jsontime.UnixMilli) (*Event, error) { + evt := &Event{ + ID: eventID, + RoomID: e.RoomID, + Type: e.Type, + StateKey: e.StateKey, + Content: e.Content, + Timestamp: ts.UnixMilli(), + } + return evt, evt.Content.ParseRaw(evt.Type) +} + +type FinalisedDelayedEvent struct { + DelayedEvent *ScheduledDelayedEvent `json:"scheduled_event"` + Outcome DelayOutcome `json:"outcome"` + Reason DelayReason `json:"reason"` + Error json.RawMessage `json:"error,omitempty"` + EventID id.EventID `json:"event_id,omitempty"` + Timestamp jsontime.UnixMilli `json:"origin_server_ts"` +} + +type DelayStatus string + +var ( + DelayStatusScheduled DelayStatus = "scheduled" + DelayStatusFinalised DelayStatus = "finalised" +) + +type DelayAction string + +var ( + DelayActionSend DelayAction = "send" + DelayActionCancel DelayAction = "cancel" + DelayActionRestart DelayAction = "restart" +) + +type DelayOutcome string + +var ( + DelayOutcomeSend DelayOutcome = "send" + DelayOutcomeCancel DelayOutcome = "cancel" +) + +type DelayReason string + +var ( + DelayReasonAction DelayReason = "action" + DelayReasonError DelayReason = "error" + DelayReasonDelay DelayReason = "delay" +) diff --git a/event/encryption.go b/event/encryption.go index cf9c2814..c60cb91a 100644 --- a/event/encryption.go +++ b/event/encryption.go @@ -63,7 +63,7 @@ func (content *EncryptedEventContent) UnmarshalJSON(data []byte) error { return json.Unmarshal(content.Ciphertext, &content.OlmCiphertext) case id.AlgorithmMegolmV1: if len(content.Ciphertext) == 0 || content.Ciphertext[0] != '"' || content.Ciphertext[len(content.Ciphertext)-1] != '"' { - return id.InputNotJSONString + return fmt.Errorf("ciphertext %w", id.ErrInputNotJSONString) } content.MegolmCiphertext = content.Ciphertext[1 : len(content.Ciphertext)-1] } @@ -132,8 +132,9 @@ type RoomKeyRequestEventContent struct { type RequestedKeyInfo struct { Algorithm id.Algorithm `json:"algorithm"` RoomID id.RoomID `json:"room_id"` - SenderKey id.SenderKey `json:"sender_key"` SessionID id.SessionID `json:"session_id"` + // Deprecated: Matrix v1.3 + SenderKey id.SenderKey `json:"sender_key"` } type RoomKeyWithheldCode string diff --git a/event/events.go b/event/events.go index 1c173351..72c1e161 100644 --- a/event/events.go +++ b/event/events.go @@ -118,6 +118,9 @@ type MautrixInfo struct { DecryptionDuration time.Duration CheckpointSent bool + // When using MSC4222 and the state_after field, this field is set + // for timeline events to indicate they shouldn't update room state. + IgnoreState bool } func (evt *Event) GetStateKey() string { @@ -127,31 +130,29 @@ func (evt *Event) GetStateKey() string { return "" } -type StrippedState struct { - Content Content `json:"content"` - Type Type `json:"type"` - StateKey string `json:"state_key"` - Sender id.UserID `json:"sender"` -} - type Unsigned struct { - PrevContent *Content `json:"prev_content,omitempty"` - PrevSender id.UserID `json:"prev_sender,omitempty"` - ReplacesState id.EventID `json:"replaces_state,omitempty"` - Age int64 `json:"age,omitempty"` - TransactionID string `json:"transaction_id,omitempty"` - Relations *Relations `json:"m.relations,omitempty"` - RedactedBecause *Event `json:"redacted_because,omitempty"` - InviteRoomState []StrippedState `json:"invite_room_state,omitempty"` + PrevContent *Content `json:"prev_content,omitempty"` + PrevSender id.UserID `json:"prev_sender,omitempty"` + Membership Membership `json:"membership,omitempty"` + ReplacesState id.EventID `json:"replaces_state,omitempty"` + Age int64 `json:"age,omitempty"` + TransactionID string `json:"transaction_id,omitempty"` + Relations *Relations `json:"m.relations,omitempty"` + RedactedBecause *Event `json:"redacted_because,omitempty"` + InviteRoomState []*Event `json:"invite_room_state,omitempty"` BeeperHSOrder int64 `json:"com.beeper.hs.order,omitempty"` BeeperHSSuborder int16 `json:"com.beeper.hs.suborder,omitempty"` BeeperHSOrderString *BeeperEncodedOrder `json:"com.beeper.hs.order_string,omitempty"` BeeperFromBackup bool `json:"com.beeper.from_backup,omitempty"` + + ElementSoftFailed bool `json:"io.element.synapse.soft_failed,omitempty"` + ElementPolicyServerSpammy bool `json:"io.element.synapse.policy_server_spammy,omitempty"` } func (us *Unsigned) IsEmpty() bool { - return us.PrevContent == nil && us.PrevSender == "" && us.ReplacesState == "" && us.Age == 0 && + return us.PrevContent == nil && us.PrevSender == "" && us.ReplacesState == "" && us.Age == 0 && us.Membership == "" && us.TransactionID == "" && us.RedactedBecause == nil && us.InviteRoomState == nil && us.Relations == nil && - us.BeeperHSOrder == 0 && us.BeeperHSSuborder == 0 && us.BeeperHSOrderString.IsZero() + us.BeeperHSOrder == 0 && us.BeeperHSSuborder == 0 && us.BeeperHSOrderString.IsZero() && + !us.ElementSoftFailed } diff --git a/event/member.go b/event/member.go index d0ff2a7c..9956a36b 100644 --- a/event/member.go +++ b/event/member.go @@ -7,8 +7,6 @@ package event import ( - "encoding/json" - "maunium.net/go/mautrix/id" ) @@ -35,20 +33,37 @@ const ( // MemberEventContent represents the content of a m.room.member state event. // https://spec.matrix.org/v1.2/client-server-api/#mroommember type MemberEventContent struct { - Membership Membership `json:"membership"` - AvatarURL id.ContentURIString `json:"avatar_url,omitempty"` - Displayname string `json:"displayname,omitempty"` - IsDirect bool `json:"is_direct,omitempty"` - ThirdPartyInvite *ThirdPartyInvite `json:"third_party_invite,omitempty"` - Reason string `json:"reason,omitempty"` - MSC3414File *EncryptedFileInfo `json:"org.matrix.msc3414.file,omitempty"` + Membership Membership `json:"membership"` + AvatarURL id.ContentURIString `json:"avatar_url,omitempty"` + Displayname string `json:"displayname,omitempty"` + IsDirect bool `json:"is_direct,omitempty"` + ThirdPartyInvite *ThirdPartyInvite `json:"third_party_invite,omitempty"` + Reason string `json:"reason,omitempty"` + JoinAuthorisedViaUsersServer id.UserID `json:"join_authorised_via_users_server,omitempty"` + MSC3414File *EncryptedFileInfo `json:"org.matrix.msc3414.file,omitempty"` + + MSC4293RedactEvents bool `json:"org.matrix.msc4293.redact_events,omitempty"` +} + +type SignedThirdPartyInvite struct { + Token string `json:"token"` + Signatures map[string]map[id.KeyID]string `json:"signatures,omitempty"` + MXID string `json:"mxid"` } type ThirdPartyInvite struct { - DisplayName string `json:"display_name"` - Signed struct { - Token string `json:"token"` - Signatures json.RawMessage `json:"signatures"` - MXID string `json:"mxid"` - } + DisplayName string `json:"display_name"` + Signed SignedThirdPartyInvite `json:"signed"` +} + +type ThirdPartyInviteEventContent struct { + DisplayName string `json:"display_name"` + KeyValidityURL string `json:"key_validity_url"` + PublicKey id.Ed25519 `json:"public_key"` + PublicKeys []ThirdPartyInviteKey `json:"public_keys,omitempty"` +} + +type ThirdPartyInviteKey struct { + KeyValidityURL string `json:"key_validity_url,omitempty"` + PublicKey id.Ed25519 `json:"public_key"` } diff --git a/event/message.go b/event/message.go index 92bdcf07..3fb3dc82 100644 --- a/event/message.go +++ b/event/message.go @@ -32,7 +32,7 @@ func (mt MessageType) IsText() bool { func (mt MessageType) IsMedia() bool { switch mt { - case MsgImage, MsgVideo, MsgAudio, MsgFile, MessageType(EventSticker.Type): + case MsgImage, MsgVideo, MsgAudio, MsgFile, CapMsgSticker: return true default: return false @@ -135,11 +135,42 @@ type MessageEventContent struct { BeeperGalleryCaption string `json:"com.beeper.gallery.caption,omitempty"` BeeperGalleryCaptionHTML string `json:"com.beeper.gallery.caption_html,omitempty"` BeeperPerMessageProfile *BeeperPerMessageProfile `json:"com.beeper.per_message_profile,omitempty"` + BeeperActionMessage *BeeperActionMessage `json:"com.beeper.action_message,omitempty"` BeeperLinkPreviews []*BeeperLinkPreview `json:"com.beeper.linkpreviews,omitempty"` + BeeperDisappearingTimer *BeeperDisappearingTimer `json:"com.beeper.disappearing_timer,omitempty"` + MSC1767Audio *MSC1767Audio `json:"org.matrix.msc1767.audio,omitempty"` MSC3245Voice *MSC3245Voice `json:"org.matrix.msc3245.voice,omitempty"` + + MSC4391BotCommand *MSC4391BotCommandInput `json:"org.matrix.msc4391.command,omitempty"` +} + +func (content *MessageEventContent) GetCapMsgType() CapabilityMsgType { + switch content.MsgType { + case CapMsgSticker: + return CapMsgSticker + case "": + if content.URL != "" || content.File != nil { + return CapMsgSticker + } + case MsgImage: + return MsgImage + case MsgAudio: + if content.MSC3245Voice != nil { + return CapMsgVoice + } + return MsgAudio + case MsgVideo: + if content.Info != nil && content.Info.MauGIF { + return CapMsgGIF + } + return MsgVideo + case MsgFile: + return MsgFile + } + return "" } func (content *MessageEventContent) GetFileName() string { @@ -184,6 +215,7 @@ func (content *MessageEventContent) SetEdit(original id.EventID) { content.RelatesTo = (&RelatesTo{}).SetReplace(original) if content.MsgType == MsgText || content.MsgType == MsgNotice { content.Body = "* " + content.Body + content.Mentions = &Mentions{} if content.Format == FormatHTML && len(content.FormattedBody) > 0 { content.FormattedBody = "* " + content.FormattedBody } @@ -244,24 +276,46 @@ func (m *Mentions) Has(userID id.UserID) bool { return m != nil && slices.Contains(m.UserIDs, userID) } +func (m *Mentions) Merge(other *Mentions) *Mentions { + if m == nil { + return other + } else if other == nil { + return m + } + return &Mentions{ + UserIDs: slices.Concat(m.UserIDs, other.UserIDs), + Room: m.Room || other.Room, + } +} + +type MSC4391BotCommandInputCustom[T any] struct { + Command string `json:"command"` + Arguments T `json:"arguments,omitempty"` +} + +type MSC4391BotCommandInput = MSC4391BotCommandInputCustom[json.RawMessage] + type EncryptedFileInfo struct { attachment.EncryptedFile URL id.ContentURIString `json:"url"` } type FileInfo struct { - MimeType string `json:"mimetype,omitempty"` - ThumbnailInfo *FileInfo `json:"thumbnail_info,omitempty"` - ThumbnailURL id.ContentURIString `json:"thumbnail_url,omitempty"` - ThumbnailFile *EncryptedFileInfo `json:"thumbnail_file,omitempty"` + MimeType string + ThumbnailInfo *FileInfo + ThumbnailURL id.ContentURIString + ThumbnailFile *EncryptedFileInfo - Blurhash string `json:"blurhash,omitempty"` - AnoaBlurhash string `json:"xyz.amorgan.blurhash,omitempty"` + Blurhash string + AnoaBlurhash string - Width int `json:"-"` - Height int `json:"-"` - Duration int `json:"-"` - Size int `json:"-"` + MauGIF bool + IsAnimated bool + + Width int + Height int + Duration int + Size int } type serializableFileInfo struct { @@ -273,6 +327,9 @@ type serializableFileInfo struct { Blurhash string `json:"blurhash,omitempty"` AnoaBlurhash string `json:"xyz.amorgan.blurhash,omitempty"` + MauGIF bool `json:"fi.mau.gif,omitempty"` + IsAnimated bool `json:"is_animated,omitempty"` + Width json.Number `json:"w,omitempty"` Height json.Number `json:"h,omitempty"` Duration json.Number `json:"duration,omitempty"` @@ -289,6 +346,9 @@ func (sfi *serializableFileInfo) CopyFrom(fileInfo *FileInfo) *serializableFileI ThumbnailInfo: (&serializableFileInfo{}).CopyFrom(fileInfo.ThumbnailInfo), ThumbnailFile: fileInfo.ThumbnailFile, + MauGIF: fileInfo.MauGIF, + IsAnimated: fileInfo.IsAnimated, + Blurhash: fileInfo.Blurhash, AnoaBlurhash: fileInfo.AnoaBlurhash, } @@ -317,6 +377,8 @@ func (sfi *serializableFileInfo) CopyTo(fileInfo *FileInfo) { MimeType: sfi.MimeType, ThumbnailURL: sfi.ThumbnailURL, ThumbnailFile: sfi.ThumbnailFile, + MauGIF: sfi.MauGIF, + IsAnimated: sfi.IsAnimated, Blurhash: sfi.Blurhash, AnoaBlurhash: sfi.AnoaBlurhash, } diff --git a/event/message_test.go b/event/message_test.go index 562a6622..c721df35 100644 --- a/event/message_test.go +++ b/event/message_test.go @@ -33,7 +33,7 @@ const invalidMessageEvent = `{ func TestMessageEventContent__ParseInvalid(t *testing.T) { var evt *event.Event err := json.Unmarshal([]byte(invalidMessageEvent), &evt) - assert.Nil(t, err) + assert.NoError(t, err) assert.Equal(t, id.UserID("@tulir:maunium.net"), evt.Sender) assert.Equal(t, event.EventMessage, evt.Type) @@ -42,7 +42,7 @@ func TestMessageEventContent__ParseInvalid(t *testing.T) { assert.Equal(t, id.RoomID("!bar"), evt.RoomID) err = evt.Content.ParseRaw(evt.Type) - assert.NotNil(t, err) + assert.Error(t, err) } const messageEvent = `{ @@ -68,7 +68,7 @@ const messageEvent = `{ func TestMessageEventContent__ParseEdit(t *testing.T) { var evt *event.Event err := json.Unmarshal([]byte(messageEvent), &evt) - assert.Nil(t, err) + assert.NoError(t, err) assert.Equal(t, id.UserID("@tulir:maunium.net"), evt.Sender) assert.Equal(t, event.EventMessage, evt.Type) @@ -110,7 +110,7 @@ const imageMessageEvent = `{ func TestMessageEventContent__ParseMedia(t *testing.T) { var evt *event.Event err := json.Unmarshal([]byte(imageMessageEvent), &evt) - assert.Nil(t, err) + assert.NoError(t, err) assert.Equal(t, id.UserID("@tulir:maunium.net"), evt.Sender) assert.Equal(t, event.EventMessage, evt.Type) @@ -125,7 +125,7 @@ func TestMessageEventContent__ParseMedia(t *testing.T) { content := evt.Content.Parsed.(*event.MessageEventContent) assert.Equal(t, event.MsgImage, content.MsgType) parsedURL, err := content.URL.Parse() - assert.Nil(t, err) + assert.NoError(t, err) assert.Equal(t, id.ContentURI{Homeserver: "example.com", FileID: "image"}, parsedURL) assert.Nil(t, content.NewContent) assert.Equal(t, "image/png", content.GetInfo().MimeType) @@ -145,7 +145,7 @@ const expectedMarshalResult = `{"msgtype":"m.text","body":"test"}` func TestMessageEventContent__Marshal(t *testing.T) { data, err := json.Marshal(parsedMessage) - assert.Nil(t, err) + assert.NoError(t, err) assert.Equal(t, expectedMarshalResult, string(data)) } @@ -163,6 +163,6 @@ const expectedCustomMarshalResult = `{"body":"test","msgtype":"m.text","net.maun func TestMessageEventContent__Marshal_Custom(t *testing.T) { data, err := json.Marshal(customParsedMessage) - assert.Nil(t, err) + assert.NoError(t, err) assert.Equal(t, expectedCustomMarshalResult, string(data)) } diff --git a/event/poll.go b/event/poll.go index 37333015..9082f65e 100644 --- a/event/poll.go +++ b/event/poll.go @@ -29,16 +29,13 @@ func (content *PollResponseEventContent) SetRelatesTo(rel *RelatesTo) { } type MSC1767Message struct { - Text string `json:"org.matrix.msc1767.text,omitempty"` - HTML string `json:"org.matrix.msc1767.html,omitempty"` - Message []struct { - MimeType string `json:"mimetype"` - Body string `json:"body"` - } `json:"org.matrix.msc1767.message,omitempty"` + Text string `json:"org.matrix.msc1767.text,omitempty"` + HTML string `json:"org.matrix.msc1767.html,omitempty"` + Message []ExtensibleText `json:"org.matrix.msc1767.message,omitempty"` } type PollStartEventContent struct { - RelatesTo *RelatesTo `json:"m.relates_to"` + RelatesTo *RelatesTo `json:"m.relates_to,omitempty"` Mentions *Mentions `json:"m.mentions,omitempty"` PollStart struct { Kind string `json:"kind"` diff --git a/event/powerlevels.go b/event/powerlevels.go index 2f4d4573..668eb6d3 100644 --- a/event/powerlevels.go +++ b/event/powerlevels.go @@ -7,6 +7,8 @@ package event import ( + "math" + "slices" "sync" "go.mau.fi/util/ptr" @@ -26,6 +28,9 @@ type PowerLevelsEventContent struct { Events map[string]int `json:"events,omitempty"` EventsDefault int `json:"events_default,omitempty"` + beeperEphemeralLock sync.RWMutex + BeeperEphemeral map[string]int `json:"com.beeper.ephemeral,omitempty"` + Notifications *NotificationPowerLevels `json:"notifications,omitempty"` StateDefaultPtr *int `json:"state_default,omitempty"` @@ -34,6 +39,12 @@ type PowerLevelsEventContent struct { KickPtr *int `json:"kick,omitempty"` BanPtr *int `json:"ban,omitempty"` RedactPtr *int `json:"redact,omitempty"` + + BeeperEphemeralDefaultPtr *int `json:"com.beeper.ephemeral_default,omitempty"` + + // This is not a part of power levels, it's added by mautrix-go internally in certain places + // in order to detect creator power accurately. + CreateEvent *Event `json:"-"` } func (pl *PowerLevelsEventContent) Clone() *PowerLevelsEventContent { @@ -45,6 +56,7 @@ func (pl *PowerLevelsEventContent) Clone() *PowerLevelsEventContent { UsersDefault: pl.UsersDefault, Events: maps.Clone(pl.Events), EventsDefault: pl.EventsDefault, + BeeperEphemeral: maps.Clone(pl.BeeperEphemeral), StateDefaultPtr: ptr.Clone(pl.StateDefaultPtr), Notifications: pl.Notifications.Clone(), @@ -53,6 +65,10 @@ func (pl *PowerLevelsEventContent) Clone() *PowerLevelsEventContent { KickPtr: ptr.Clone(pl.KickPtr), BanPtr: ptr.Clone(pl.BanPtr), RedactPtr: ptr.Clone(pl.RedactPtr), + + BeeperEphemeralDefaultPtr: ptr.Clone(pl.BeeperEphemeralDefaultPtr), + + CreateEvent: pl.CreateEvent, } } @@ -111,7 +127,17 @@ func (pl *PowerLevelsEventContent) StateDefault() int { return 50 } +func (pl *PowerLevelsEventContent) BeeperEphemeralDefault() int { + if pl.BeeperEphemeralDefaultPtr != nil { + return *pl.BeeperEphemeralDefaultPtr + } + return pl.EventsDefault +} + func (pl *PowerLevelsEventContent) GetUserLevel(userID id.UserID) int { + if pl.isCreator(userID) { + return math.MaxInt + } pl.usersLock.RLock() defer pl.usersLock.RUnlock() level, ok := pl.Users[userID] @@ -121,9 +147,19 @@ func (pl *PowerLevelsEventContent) GetUserLevel(userID id.UserID) int { return level } +const maxPL = 1<<53 - 1 + func (pl *PowerLevelsEventContent) SetUserLevel(userID id.UserID, level int) { pl.usersLock.Lock() defer pl.usersLock.Unlock() + if pl.isCreator(userID) { + return + } + if level == math.MaxInt && maxPL < math.MaxInt { + // Hack to avoid breaking on 32-bit systems (they're only slightly supported) + x := int64(maxPL) + level = int(x) + } if level == pl.UsersDefault { delete(pl.Users, userID) } else { @@ -138,9 +174,24 @@ func (pl *PowerLevelsEventContent) EnsureUserLevel(target id.UserID, level int) return pl.EnsureUserLevelAs("", target, level) } +func (pl *PowerLevelsEventContent) createContent() *CreateEventContent { + if pl.CreateEvent == nil { + return &CreateEventContent{} + } + return pl.CreateEvent.Content.AsCreate() +} + +func (pl *PowerLevelsEventContent) isCreator(userID id.UserID) bool { + cc := pl.createContent() + return cc.SupportsCreatorPower() && (userID == pl.CreateEvent.Sender || slices.Contains(cc.AdditionalCreators, userID)) +} + func (pl *PowerLevelsEventContent) EnsureUserLevelAs(actor, target id.UserID, level int) bool { + if pl.isCreator(target) { + return false + } existingLevel := pl.GetUserLevel(target) - if actor != "" { + if actor != "" && !pl.isCreator(actor) { actorLevel := pl.GetUserLevel(actor) if actorLevel <= existingLevel || actorLevel < level { return false @@ -166,6 +217,29 @@ func (pl *PowerLevelsEventContent) GetEventLevel(eventType Type) int { return level } +func (pl *PowerLevelsEventContent) GetBeeperEphemeralLevel(eventType Type) int { + pl.beeperEphemeralLock.RLock() + defer pl.beeperEphemeralLock.RUnlock() + level, ok := pl.BeeperEphemeral[eventType.String()] + if !ok { + return pl.BeeperEphemeralDefault() + } + return level +} + +func (pl *PowerLevelsEventContent) SetBeeperEphemeralLevel(eventType Type, level int) { + pl.beeperEphemeralLock.Lock() + defer pl.beeperEphemeralLock.Unlock() + if level == pl.BeeperEphemeralDefault() { + delete(pl.BeeperEphemeral, eventType.String()) + } else { + if pl.BeeperEphemeral == nil { + pl.BeeperEphemeral = make(map[string]int) + } + pl.BeeperEphemeral[eventType.String()] = level + } +} + func (pl *PowerLevelsEventContent) SetEventLevel(eventType Type, level int) { pl.eventsLock.Lock() defer pl.eventsLock.Unlock() @@ -185,7 +259,7 @@ func (pl *PowerLevelsEventContent) EnsureEventLevel(eventType Type, level int) b func (pl *PowerLevelsEventContent) EnsureEventLevelAs(actor id.UserID, eventType Type, level int) bool { existingLevel := pl.GetEventLevel(eventType) - if actor != "" { + if actor != "" && !pl.isCreator(actor) { actorLevel := pl.GetUserLevel(actor) if existingLevel > actorLevel || level > actorLevel { return false diff --git a/event/powerlevels_ephemeral_test.go b/event/powerlevels_ephemeral_test.go new file mode 100644 index 00000000..f5861583 --- /dev/null +++ b/event/powerlevels_ephemeral_test.go @@ -0,0 +1,67 @@ +// Copyright (c) 2026 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package event_test + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "maunium.net/go/mautrix/event" +) + +func TestPowerLevelsEventContent_BeeperEphemeralDefaultFallsBackToEventsDefault(t *testing.T) { + pl := &event.PowerLevelsEventContent{ + EventsDefault: 45, + } + + assert.Equal(t, 45, pl.BeeperEphemeralDefault()) + + override := 60 + pl.BeeperEphemeralDefaultPtr = &override + assert.Equal(t, 60, pl.BeeperEphemeralDefault()) +} + +func TestPowerLevelsEventContent_GetSetBeeperEphemeralLevel(t *testing.T) { + pl := &event.PowerLevelsEventContent{ + EventsDefault: 25, + } + evtType := event.Type{Type: "com.example.ephemeral", Class: event.EphemeralEventType} + + assert.Equal(t, 25, pl.GetBeeperEphemeralLevel(evtType)) + + pl.SetBeeperEphemeralLevel(evtType, 50) + assert.Equal(t, 50, pl.GetBeeperEphemeralLevel(evtType)) + require.NotNil(t, pl.BeeperEphemeral) + assert.Equal(t, 50, pl.BeeperEphemeral[evtType.String()]) + + pl.SetBeeperEphemeralLevel(evtType, 25) + _, exists := pl.BeeperEphemeral[evtType.String()] + assert.False(t, exists) +} + +func TestPowerLevelsEventContent_CloneCopiesBeeperEphemeralFields(t *testing.T) { + override := 70 + pl := &event.PowerLevelsEventContent{ + EventsDefault: 35, + BeeperEphemeral: map[string]int{"com.example.ephemeral": 90}, + BeeperEphemeralDefaultPtr: &override, + } + + cloned := pl.Clone() + require.NotNil(t, cloned) + require.NotNil(t, cloned.BeeperEphemeralDefaultPtr) + assert.Equal(t, 70, *cloned.BeeperEphemeralDefaultPtr) + assert.Equal(t, 90, cloned.BeeperEphemeral["com.example.ephemeral"]) + + cloned.BeeperEphemeral["com.example.ephemeral"] = 99 + *cloned.BeeperEphemeralDefaultPtr = 71 + + assert.Equal(t, 90, pl.BeeperEphemeral["com.example.ephemeral"]) + assert.Equal(t, 70, *pl.BeeperEphemeralDefaultPtr) +} diff --git a/event/relations.go b/event/relations.go index ea40cc06..2316cbc7 100644 --- a/event/relations.go +++ b/event/relations.go @@ -15,10 +15,11 @@ import ( type RelationType string const ( - RelReplace RelationType = "m.replace" - RelReference RelationType = "m.reference" - RelAnnotation RelationType = "m.annotation" - RelThread RelationType = "m.thread" + RelReplace RelationType = "m.replace" + RelReference RelationType = "m.reference" + RelAnnotation RelationType = "m.annotation" + RelThread RelationType = "m.thread" + RelBeeperTranscription RelationType = "com.beeper.transcription" ) type RelatesTo struct { @@ -33,7 +34,7 @@ type RelatesTo struct { type InReplyTo struct { EventID id.EventID `json:"event_id,omitempty"` - UnstableRoomID id.RoomID `json:"room_id,omitempty"` + UnstableRoomID id.RoomID `json:"com.beeper.cross_room_id,omitempty"` } func (rel *RelatesTo) Copy() *RelatesTo { @@ -100,6 +101,10 @@ func (rel *RelatesTo) SetReplace(mxid id.EventID) *RelatesTo { } func (rel *RelatesTo) SetReplyTo(mxid id.EventID) *RelatesTo { + if rel.Type != RelThread { + rel.Type = "" + rel.EventID = "" + } rel.InReplyTo = &InReplyTo{EventID: mxid} rel.IsFallingBack = false return rel diff --git a/event/reply.go b/event/reply.go index 1a88c619..5f55bb80 100644 --- a/event/reply.go +++ b/event/reply.go @@ -32,12 +32,13 @@ func TrimReplyFallbackText(text string) string { } func (content *MessageEventContent) RemoveReplyFallback() { - if len(content.RelatesTo.GetReplyTo()) > 0 && !content.replyFallbackRemoved { - if content.Format == FormatHTML { - content.FormattedBody = TrimReplyFallbackHTML(content.FormattedBody) + if len(content.RelatesTo.GetReplyTo()) > 0 && !content.replyFallbackRemoved && content.Format == FormatHTML { + origHTML := content.FormattedBody + content.FormattedBody = TrimReplyFallbackHTML(content.FormattedBody) + if content.FormattedBody != origHTML { + content.Body = TrimReplyFallbackText(content.Body) + content.replyFallbackRemoved = true } - content.Body = TrimReplyFallbackText(content.Body) - content.replyFallbackRemoved = true } } @@ -47,5 +48,27 @@ func (content *MessageEventContent) GetReplyTo() id.EventID { } func (content *MessageEventContent) SetReply(inReplyTo *Event) { - content.RelatesTo = (&RelatesTo{}).SetReplyTo(inReplyTo.ID) + if content.RelatesTo == nil { + content.RelatesTo = &RelatesTo{} + } + content.RelatesTo.SetReplyTo(inReplyTo.ID) + if content.Mentions == nil { + content.Mentions = &Mentions{} + } + content.Mentions.Add(inReplyTo.Sender) +} + +func (content *MessageEventContent) SetThread(inReplyTo *Event) { + root := inReplyTo.ID + relatable, ok := inReplyTo.Content.Parsed.(Relatable) + if ok { + targetRoot := relatable.OptionalGetRelatesTo().GetThreadParent() + if targetRoot != "" { + root = targetRoot + } + } + if content.RelatesTo == nil { + content.RelatesTo = &RelatesTo{} + } + content.RelatesTo.SetThread(root, inReplyTo.ID) } diff --git a/event/state.go b/event/state.go index 15972892..ace170a5 100644 --- a/event/state.go +++ b/event/state.go @@ -7,6 +7,12 @@ package event import ( + "encoding/base64" + "encoding/json" + "slices" + + "go.mau.fi/util/jsontime" + "maunium.net/go/mautrix/id" ) @@ -42,7 +48,52 @@ type ServerACLEventContent struct { // TopicEventContent represents the content of a m.room.topic state event. // https://spec.matrix.org/v1.2/client-server-api/#mroomtopic type TopicEventContent struct { - Topic string `json:"topic"` + Topic string `json:"topic"` + ExtensibleTopic *ExtensibleTopic `json:"m.topic,omitempty"` +} + +// ExtensibleTopic represents the contents of the m.topic field within the +// m.room.topic state event as described in [MSC3765]. +// +// [MSC3765]: https://github.com/matrix-org/matrix-spec-proposals/pull/3765 +type ExtensibleTopic = ExtensibleTextContainer + +type ExtensibleTextContainer struct { + Text []ExtensibleText `json:"m.text"` +} + +func (c *ExtensibleTextContainer) Equals(description *ExtensibleTextContainer) bool { + if c == nil || description == nil { + return c == description + } + return slices.Equal(c.Text, description.Text) +} + +func MakeExtensibleText(text string) *ExtensibleTextContainer { + return &ExtensibleTextContainer{ + Text: []ExtensibleText{{ + Body: text, + MimeType: "text/plain", + }}, + } +} + +func MakeExtensibleFormattedText(plaintext, html string) *ExtensibleTextContainer { + return &ExtensibleTextContainer{ + Text: []ExtensibleText{{ + Body: plaintext, + MimeType: "text/plain", + }, { + Body: html, + MimeType: "text/html", + }}, + } +} + +// ExtensibleText represents the contents of an m.text field. +type ExtensibleText struct { + MimeType string `json:"mimetype,omitempty"` + Body string `json:"body"` } // TombstoneEventContent represents the content of a m.room.tombstone state event. @@ -52,35 +103,64 @@ type TombstoneEventContent struct { ReplacementRoom id.RoomID `json:"replacement_room"` } +func (tec *TombstoneEventContent) GetReplacementRoom() id.RoomID { + if tec == nil { + return "" + } + return tec.ReplacementRoom +} + type Predecessor struct { RoomID id.RoomID `json:"room_id"` EventID id.EventID `json:"event_id"` } -type RoomVersion string +// Deprecated: use id.RoomVersion instead +type RoomVersion = id.RoomVersion +// Deprecated: use id.RoomVX constants instead const ( - RoomV1 RoomVersion = "1" - RoomV2 RoomVersion = "2" - RoomV3 RoomVersion = "3" - RoomV4 RoomVersion = "4" - RoomV5 RoomVersion = "5" - RoomV6 RoomVersion = "6" - RoomV7 RoomVersion = "7" - RoomV8 RoomVersion = "8" - RoomV9 RoomVersion = "9" - RoomV10 RoomVersion = "10" - RoomV11 RoomVersion = "11" + RoomV1 = id.RoomV1 + RoomV2 = id.RoomV2 + RoomV3 = id.RoomV3 + RoomV4 = id.RoomV4 + RoomV5 = id.RoomV5 + RoomV6 = id.RoomV6 + RoomV7 = id.RoomV7 + RoomV8 = id.RoomV8 + RoomV9 = id.RoomV9 + RoomV10 = id.RoomV10 + RoomV11 = id.RoomV11 + RoomV12 = id.RoomV12 ) // CreateEventContent represents the content of a m.room.create state event. // https://spec.matrix.org/v1.2/client-server-api/#mroomcreate type CreateEventContent struct { - Type RoomType `json:"type,omitempty"` - Creator id.UserID `json:"creator,omitempty"` - Federate bool `json:"m.federate,omitempty"` - RoomVersion RoomVersion `json:"room_version,omitempty"` - Predecessor *Predecessor `json:"predecessor,omitempty"` + Type RoomType `json:"type,omitempty"` + Federate *bool `json:"m.federate,omitempty"` + RoomVersion id.RoomVersion `json:"room_version,omitempty"` + Predecessor *Predecessor `json:"predecessor,omitempty"` + + // Room v12+ only + AdditionalCreators []id.UserID `json:"additional_creators,omitempty"` + + // Deprecated: use the event sender instead + Creator id.UserID `json:"creator,omitempty"` +} + +func (cec *CreateEventContent) GetPredecessor() (p Predecessor) { + if cec != nil && cec.Predecessor != nil { + p = *cec.Predecessor + } + return +} + +func (cec *CreateEventContent) SupportsCreatorPower() bool { + if cec == nil { + return false + } + return cec.RoomVersion.PrivilegedRoomCreators() } // JoinRule specifies how open a room is to new members. @@ -158,7 +238,8 @@ type BridgeInfoSection struct { AvatarURL id.ContentURIString `json:"avatar_url,omitempty"` ExternalURL string `json:"external_url,omitempty"` - Receiver string `json:"fi.mau.receiver,omitempty"` + Receiver string `json:"fi.mau.receiver,omitempty"` + MessageRequest bool `json:"com.beeper.message_request,omitempty"` } // BridgeEventContent represents the content of a m.bridge state event. @@ -172,6 +253,32 @@ type BridgeEventContent struct { BeeperRoomType string `json:"com.beeper.room_type,omitempty"` BeeperRoomTypeV2 string `json:"com.beeper.room_type.v2,omitempty"` + + TempSlackRemoteIDMigratedFlag bool `json:"com.beeper.slack_remote_id_migrated,omitempty"` + TempSlackRemoteIDMigratedFlag2 bool `json:"com.beeper.slack_remote_id_really_migrated,omitempty"` +} + +// DisappearingType represents the type of a disappearing message timer. +type DisappearingType string + +const ( + DisappearingTypeNone DisappearingType = "" + DisappearingTypeAfterRead DisappearingType = "after_read" + DisappearingTypeAfterSend DisappearingType = "after_send" +) + +type BeeperDisappearingTimer struct { + Type DisappearingType `json:"type"` + Timer jsontime.Milliseconds `json:"timer"` +} + +type marshalableBeeperDisappearingTimer BeeperDisappearingTimer + +func (bdt *BeeperDisappearingTimer) MarshalJSON() ([]byte, error) { + if bdt == nil || bdt.Type == DisappearingTypeNone { + return []byte("{}"), nil + } + return json.Marshal((*marshalableBeeperDisappearingTimer)(bdt)) } type SpaceChildEventContent struct { @@ -188,25 +295,63 @@ type SpaceParentEventContent struct { type PolicyRecommendation string const ( - PolicyRecommendationBan PolicyRecommendation = "m.ban" - PolicyRecommendationUnstableBan PolicyRecommendation = "org.matrix.mjolnir.ban" - PolicyRecommendationUnban PolicyRecommendation = "fi.mau.meowlnir.unban" + PolicyRecommendationBan PolicyRecommendation = "m.ban" + PolicyRecommendationUnstableTakedown PolicyRecommendation = "org.matrix.msc4204.takedown" + PolicyRecommendationUnstableBan PolicyRecommendation = "org.matrix.mjolnir.ban" + PolicyRecommendationUnban PolicyRecommendation = "fi.mau.meowlnir.unban" ) +type PolicyHashes struct { + SHA256 string `json:"sha256"` +} + +func (ph *PolicyHashes) DecodeSHA256() *[32]byte { + if ph == nil || ph.SHA256 == "" { + return nil + } + decoded, _ := base64.StdEncoding.DecodeString(ph.SHA256) + if len(decoded) == 32 { + return (*[32]byte)(decoded) + } + return nil +} + // ModPolicyContent represents the content of a m.room.rule.user, m.room.rule.room, and m.room.rule.server state event. // https://spec.matrix.org/v1.2/client-server-api/#moderation-policy-lists type ModPolicyContent struct { - Entity string `json:"entity"` + Entity string `json:"entity,omitempty"` Reason string `json:"reason"` Recommendation PolicyRecommendation `json:"recommendation"` + UnstableHashes *PolicyHashes `json:"org.matrix.msc4205.hashes,omitempty"` } -// Deprecated: MSC2716 has been abandoned -type InsertionMarkerContent struct { - InsertionID id.EventID `json:"org.matrix.msc2716.marker.insertion"` - Timestamp int64 `json:"com.beeper.timestamp,omitempty"` +func (mpc *ModPolicyContent) EntityOrHash() string { + if mpc.UnstableHashes != nil && mpc.UnstableHashes.SHA256 != "" { + return mpc.UnstableHashes.SHA256 + } + return mpc.Entity } type ElementFunctionalMembersContent struct { ServiceMembers []id.UserID `json:"service_members"` } + +func (efmc *ElementFunctionalMembersContent) Add(mxid id.UserID) bool { + if slices.Contains(efmc.ServiceMembers, mxid) { + return false + } + efmc.ServiceMembers = append(efmc.ServiceMembers, mxid) + return true +} + +type PolicyServerPublicKeys struct { + Ed25519 id.Ed25519 `json:"ed25519,omitempty"` +} + +type RoomPolicyEventContent struct { + Via string `json:"via,omitempty"` + PublicKeys *PolicyServerPublicKeys `json:"public_keys,omitempty"` + + // Deprecated, only for legacy use + PublicKey id.Ed25519 `json:"public_key,omitempty"` +} diff --git a/event/type.go b/event/type.go index f2b841ad..80b86728 100644 --- a/event/type.go +++ b/event/type.go @@ -108,13 +108,14 @@ func (et *Type) IsCustom() bool { func (et *Type) GuessClass() TypeClass { switch et.Type { - case StateAliases.Type, StateCanonicalAlias.Type, StateCreate.Type, StateJoinRules.Type, StateMember.Type, + case StateAliases.Type, StateCanonicalAlias.Type, StateCreate.Type, StateJoinRules.Type, StateMember.Type, StateThirdPartyInvite.Type, StatePowerLevels.Type, StateRoomName.Type, StateRoomAvatar.Type, StateServerACL.Type, StateTopic.Type, StatePinnedEvents.Type, StateTombstone.Type, StateEncryption.Type, StateBridge.Type, StateHalfShotBridge.Type, StateSpaceParent.Type, StateSpaceChild.Type, StatePolicyRoom.Type, StatePolicyServer.Type, StatePolicyUser.Type, - StateInsertionMarker.Type, StateElementFunctionalMembers.Type: + StateElementFunctionalMembers.Type, StateBeeperRoomFeatures.Type, StateBeeperDisappearingTimer.Type, + StateMSC4391BotCommand.Type, StateRoomPolicy.Type, StateUnstableRoomPolicy.Type: return StateEventType - case EphemeralEventReceipt.Type, EphemeralEventTyping.Type, EphemeralEventPresence.Type: + case EphemeralEventReceipt.Type, EphemeralEventTyping.Type, EphemeralEventPresence.Type, BeeperEphemeralEventAIStream.Type: return EphemeralEventType case AccountDataDirectChats.Type, AccountDataPushRules.Type, AccountDataRoomTags.Type, AccountDataFullyRead.Type, AccountDataIgnoredUserList.Type, AccountDataMarkedUnread.Type, @@ -126,7 +127,8 @@ func (et *Type) GuessClass() TypeClass { InRoomVerificationStart.Type, InRoomVerificationReady.Type, InRoomVerificationAccept.Type, InRoomVerificationKey.Type, InRoomVerificationMAC.Type, InRoomVerificationCancel.Type, CallInvite.Type, CallCandidates.Type, CallAnswer.Type, CallReject.Type, CallSelectAnswer.Type, - CallNegotiate.Type, CallHangup.Type, BeeperMessageStatus.Type, EventUnstablePollStart.Type, EventUnstablePollResponse.Type: + CallNegotiate.Type, CallHangup.Type, BeeperMessageStatus.Type, EventUnstablePollStart.Type, EventUnstablePollResponse.Type, + EventUnstablePollEnd.Type, BeeperTranscription.Type, BeeperDeleteChat.Type, BeeperAcceptMessageRequest.Type: return MessageEventType case ToDeviceRoomKey.Type, ToDeviceRoomKeyRequest.Type, ToDeviceForwardedRoomKey.Type, ToDeviceRoomKeyWithheld.Type, ToDeviceBeeperRoomKeyAck.Type: @@ -176,6 +178,7 @@ var ( StateHistoryVisibility = Type{"m.room.history_visibility", StateEventType} StateGuestAccess = Type{"m.room.guest_access", StateEventType} StateMember = Type{"m.room.member", StateEventType} + StateThirdPartyInvite = Type{"m.room.third_party_invite", StateEventType} StatePowerLevels = Type{"m.room.power_levels", StateEventType} StateRoomName = Type{"m.room.name", StateEventType} StateTopic = Type{"m.room.topic", StateEventType} @@ -192,6 +195,9 @@ var ( StateSpaceChild = Type{"m.space.child", StateEventType} StateSpaceParent = Type{"m.space.parent", StateEventType} + StateRoomPolicy = Type{"m.room.policy", StateEventType} + StateUnstableRoomPolicy = Type{"org.matrix.msc4284.policy", StateEventType} + StateLegacyPolicyRoom = Type{"m.room.rule.room", StateEventType} StateLegacyPolicyServer = Type{"m.room.rule.server", StateEventType} StateLegacyPolicyUser = Type{"m.room.rule.user", StateEventType} @@ -199,10 +205,10 @@ var ( StateUnstablePolicyServer = Type{"org.matrix.mjolnir.rule.server", StateEventType} StateUnstablePolicyUser = Type{"org.matrix.mjolnir.rule.user", StateEventType} - // Deprecated: MSC2716 has been abandoned - StateInsertionMarker = Type{"org.matrix.msc2716.marker", StateEventType} - StateElementFunctionalMembers = Type{"io.element.functional_members", StateEventType} + StateBeeperRoomFeatures = Type{"com.beeper.room_features", StateEventType} + StateBeeperDisappearingTimer = Type{"com.beeper.disappearing_timer", StateEventType} + StateMSC4391BotCommand = Type{"org.matrix.msc4391.command_description", StateEventType} ) // Message events @@ -231,17 +237,24 @@ var ( CallNegotiate = Type{"m.call.negotiate", MessageEventType} CallHangup = Type{"m.call.hangup", MessageEventType} - BeeperMessageStatus = Type{"com.beeper.message_send_status", MessageEventType} + BeeperMessageStatus = Type{"com.beeper.message_send_status", MessageEventType} + BeeperTranscription = Type{"com.beeper.transcription", MessageEventType} + BeeperDeleteChat = Type{"com.beeper.delete_chat", MessageEventType} + BeeperAcceptMessageRequest = Type{"com.beeper.accept_message_request", MessageEventType} + BeeperSendState = Type{"com.beeper.send_state", MessageEventType} EventUnstablePollStart = Type{Type: "org.matrix.msc3381.poll.start", Class: MessageEventType} EventUnstablePollResponse = Type{Type: "org.matrix.msc3381.poll.response", Class: MessageEventType} + EventUnstablePollEnd = Type{Type: "org.matrix.msc3381.poll.end", Class: MessageEventType} ) // Ephemeral events var ( - EphemeralEventReceipt = Type{"m.receipt", EphemeralEventType} - EphemeralEventTyping = Type{"m.typing", EphemeralEventType} - EphemeralEventPresence = Type{"m.presence", EphemeralEventType} + EphemeralEventReceipt = Type{"m.receipt", EphemeralEventType} + EphemeralEventTyping = Type{"m.typing", EphemeralEventType} + EphemeralEventPresence = Type{"m.presence", EphemeralEventType} + EphemeralEventEncrypted = Type{"m.room.encrypted", EphemeralEventType} + BeeperEphemeralEventAIStream = Type{"com.beeper.ai.stream_event", EphemeralEventType} ) // Account data events diff --git a/event/voip.go b/event/voip.go index 28f56c95..cd8364a1 100644 --- a/event/voip.go +++ b/event/voip.go @@ -76,7 +76,7 @@ func (cv *CallVersion) Int() (int, error) { type BaseCallEventContent struct { CallID string `json:"call_id"` PartyID string `json:"party_id"` - Version CallVersion `json:"version"` + Version CallVersion `json:"version,omitempty"` } type CallInviteEventContent struct { diff --git a/example/main.go b/example/main.go index d8006d46..2bf4bef3 100644 --- a/example/main.go +++ b/example/main.go @@ -143,7 +143,7 @@ func main() { if err != nil { log.Error().Err(err).Msg("Failed to send event") } else { - log.Info().Str("event_id", resp.EventID.String()).Msg("Event sent") + log.Info().Stringer("event_id", resp.EventID).Msg("Event sent") } } cancelSync() diff --git a/federation/cache.go b/federation/cache.go new file mode 100644 index 00000000..24154974 --- /dev/null +++ b/federation/cache.go @@ -0,0 +1,153 @@ +// Copyright (c) 2025 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package federation + +import ( + "errors" + "fmt" + "math" + "sync" + "time" +) + +// ResolutionCache is an interface for caching resolved server names. +type ResolutionCache interface { + StoreResolution(*ResolvedServerName) + // LoadResolution loads a resolved server name from the cache. + // Expired entries MUST NOT be returned. + LoadResolution(serverName string) (*ResolvedServerName, error) +} + +type KeyCache interface { + StoreKeys(*ServerKeyResponse) + StoreFetchError(serverName string, err error) + ShouldReQuery(serverName string) bool + LoadKeys(serverName string) (*ServerKeyResponse, error) +} + +type InMemoryCache struct { + MinKeyRefetchDelay time.Duration + + resolutions map[string]*ResolvedServerName + resolutionsLock sync.RWMutex + keys map[string]*ServerKeyResponse + lastReQueryAt map[string]time.Time + lastError map[string]*resolutionErrorCache + keysLock sync.RWMutex +} + +var ( + _ ResolutionCache = (*InMemoryCache)(nil) + _ KeyCache = (*InMemoryCache)(nil) +) + +func NewInMemoryCache() *InMemoryCache { + return &InMemoryCache{ + resolutions: make(map[string]*ResolvedServerName), + keys: make(map[string]*ServerKeyResponse), + lastReQueryAt: make(map[string]time.Time), + lastError: make(map[string]*resolutionErrorCache), + MinKeyRefetchDelay: 1 * time.Hour, + } +} + +func (c *InMemoryCache) StoreResolution(resolution *ResolvedServerName) { + c.resolutionsLock.Lock() + defer c.resolutionsLock.Unlock() + c.resolutions[resolution.ServerName] = resolution +} + +func (c *InMemoryCache) LoadResolution(serverName string) (*ResolvedServerName, error) { + c.resolutionsLock.RLock() + defer c.resolutionsLock.RUnlock() + resolution, ok := c.resolutions[serverName] + if !ok || time.Until(resolution.Expires) < 0 { + return nil, nil + } + return resolution, nil +} + +func (c *InMemoryCache) StoreKeys(keys *ServerKeyResponse) { + c.keysLock.Lock() + defer c.keysLock.Unlock() + c.keys[keys.ServerName] = keys + delete(c.lastError, keys.ServerName) +} + +type resolutionErrorCache struct { + Error error + Time time.Time + Count int +} + +const MaxBackoff = 7 * 24 * time.Hour + +func (rec *resolutionErrorCache) ShouldRetry() bool { + backoff := time.Duration(math.Exp(float64(rec.Count))) * time.Second + return time.Since(rec.Time) > backoff +} + +var ErrRecentKeyQueryFailed = errors.New("last retry was too recent") + +func (c *InMemoryCache) LoadKeys(serverName string) (*ServerKeyResponse, error) { + c.keysLock.RLock() + defer c.keysLock.RUnlock() + keys, ok := c.keys[serverName] + if !ok || time.Until(keys.ValidUntilTS.Time) < 0 { + err, ok := c.lastError[serverName] + if ok && !err.ShouldRetry() { + return nil, fmt.Errorf( + "%w (%s ago) and failed with %w", + ErrRecentKeyQueryFailed, + time.Since(err.Time).String(), + err.Error, + ) + } + return nil, nil + } + return keys, nil +} + +func (c *InMemoryCache) StoreFetchError(serverName string, err error) { + c.keysLock.Lock() + defer c.keysLock.Unlock() + errorCache, ok := c.lastError[serverName] + if ok { + errorCache.Time = time.Now() + errorCache.Error = err + errorCache.Count++ + } else { + c.lastError[serverName] = &resolutionErrorCache{Error: err, Time: time.Now(), Count: 1} + } +} + +func (c *InMemoryCache) ShouldReQuery(serverName string) bool { + c.keysLock.Lock() + defer c.keysLock.Unlock() + lastQuery, ok := c.lastReQueryAt[serverName] + if ok && time.Since(lastQuery) < c.MinKeyRefetchDelay { + return false + } + c.lastReQueryAt[serverName] = time.Now() + return true +} + +type noopCache struct{} + +func (*noopCache) StoreKeys(_ *ServerKeyResponse) {} +func (*noopCache) LoadKeys(_ string) (*ServerKeyResponse, error) { return nil, nil } +func (*noopCache) StoreFetchError(_ string, _ error) {} +func (*noopCache) ShouldReQuery(_ string) bool { return true } +func (*noopCache) StoreResolution(_ *ResolvedServerName) {} +func (*noopCache) LoadResolution(_ string) (*ResolvedServerName, error) { return nil, nil } + +var ( + _ ResolutionCache = (*noopCache)(nil) + _ KeyCache = (*noopCache)(nil) +) + +var NoopCache *noopCache diff --git a/federation/client.go b/federation/client.go index 098df095..183fb5d1 100644 --- a/federation/client.go +++ b/federation/client.go @@ -9,7 +9,6 @@ package federation import ( "bytes" "context" - "encoding/base64" "encoding/json" "fmt" "io" @@ -22,6 +21,7 @@ import ( "go.mau.fi/util/jsontime" "maunium.net/go/mautrix" + "maunium.net/go/mautrix/federation/signutil" "maunium.net/go/mautrix/id" ) @@ -30,17 +30,25 @@ type Client struct { ServerName string UserAgent string Key *SigningKey + + ResponseSizeLimit int64 } -func NewClient(serverName string, key *SigningKey) *Client { +func NewClient(serverName string, key *SigningKey, cache ResolutionCache) *Client { return &Client{ HTTP: &http.Client{ - Transport: NewServerResolvingTransport(), + Transport: NewServerResolvingTransport(cache), Timeout: 120 * time.Second, + CheckRedirect: func(req *http.Request, via []*http.Request) error { + // Federation requests do not allow redirects. + return http.ErrUseLastResponse + }, }, UserAgent: mautrix.DefaultUserAgent, ServerName: serverName, Key: key, + + ResponseSizeLimit: mautrix.DefaultResponseSizeLimit, } } @@ -54,7 +62,7 @@ func (c *Client) ServerKeys(ctx context.Context, serverName string) (resp *Serve return } -func (c *Client) QueryKeys(ctx context.Context, serverName string, req *ReqQueryKeys) (resp *ServerKeyResponse, err error) { +func (c *Client) QueryKeys(ctx context.Context, serverName string, req *ReqQueryKeys) (resp *QueryKeysResponse, err error) { err = c.MakeRequest(ctx, serverName, false, http.MethodPost, KeyURLPath{"v2", "query"}, req, &resp) return } @@ -81,7 +89,7 @@ type RespSendTransaction struct { } func (c *Client) SendTransaction(ctx context.Context, req *ReqSendTransaction) (resp *RespSendTransaction, err error) { - err = c.MakeRequest(ctx, req.Destination, true, http.MethodPost, URLPath{"v1", "send", req.TxnID}, req, &resp) + err = c.MakeRequest(ctx, req.Destination, true, http.MethodPut, URLPath{"v1", "send", req.TxnID}, req, &resp) return } @@ -220,6 +228,26 @@ func (c *Client) Query(ctx context.Context, serverName, queryType string, queryP return } +func queryToValues(query map[string]string) url.Values { + values := make(url.Values, len(query)) + for k, v := range query { + values[k] = []string{v} + } + return values +} + +func (c *Client) PublicRooms(ctx context.Context, serverName string, req *mautrix.ReqPublicRooms) (resp *mautrix.RespPublicRooms, err error) { + _, _, err = c.MakeFullRequest(ctx, RequestParams{ + ServerName: serverName, + Method: http.MethodGet, + Path: URLPath{"v1", "publicRooms"}, + Query: queryToValues(req.Query()), + Authenticate: true, + ResponseJSON: &resp, + }) + return +} + type RespOpenIDUserInfo struct { Sub id.UserID `json:"sub"` } @@ -235,6 +263,169 @@ func (c *Client) GetOpenIDUserInfo(ctx context.Context, serverName, accessToken return } +type ReqMakeJoin struct { + RoomID id.RoomID + UserID id.UserID + Via string + SupportedVersions []id.RoomVersion +} + +type RespMakeJoin struct { + RoomVersion id.RoomVersion `json:"room_version"` + Event PDU `json:"event"` +} + +type ReqSendJoin struct { + RoomID id.RoomID + EventID id.EventID + OmitMembers bool + Event PDU + Via string +} + +type ReqSendKnock struct { + RoomID id.RoomID + EventID id.EventID + Event PDU + Via string +} + +type RespSendJoin struct { + AuthChain []PDU `json:"auth_chain"` + Event PDU `json:"event"` + MembersOmitted bool `json:"members_omitted"` + ServersInRoom []string `json:"servers_in_room"` + State []PDU `json:"state"` +} + +type RespSendKnock struct { + KnockRoomState []PDU `json:"knock_room_state"` +} + +type ReqSendInvite struct { + RoomID id.RoomID `json:"-"` + UserID id.UserID `json:"-"` + Event PDU `json:"event"` + InviteRoomState []PDU `json:"invite_room_state"` + RoomVersion id.RoomVersion `json:"room_version"` +} + +type RespSendInvite struct { + Event PDU `json:"event"` +} + +type ReqMakeLeave struct { + RoomID id.RoomID + UserID id.UserID + Via string +} + +type ReqSendLeave struct { + RoomID id.RoomID + EventID id.EventID + Event PDU + Via string +} + +type ( + ReqMakeKnock = ReqMakeJoin + RespMakeKnock = RespMakeJoin + RespMakeLeave = RespMakeJoin +) + +func (c *Client) MakeJoin(ctx context.Context, req *ReqMakeJoin) (resp *RespMakeJoin, err error) { + versions := make([]string, len(req.SupportedVersions)) + for i, v := range req.SupportedVersions { + versions[i] = string(v) + } + _, _, err = c.MakeFullRequest(ctx, RequestParams{ + ServerName: req.Via, + Method: http.MethodGet, + Path: URLPath{"v1", "make_join", req.RoomID, req.UserID}, + Query: url.Values{"ver": versions}, + Authenticate: true, + ResponseJSON: &resp, + }) + return +} + +func (c *Client) MakeKnock(ctx context.Context, req *ReqMakeKnock) (resp *RespMakeKnock, err error) { + versions := make([]string, len(req.SupportedVersions)) + for i, v := range req.SupportedVersions { + versions[i] = string(v) + } + _, _, err = c.MakeFullRequest(ctx, RequestParams{ + ServerName: req.Via, + Method: http.MethodGet, + Path: URLPath{"v1", "make_knock", req.RoomID, req.UserID}, + Query: url.Values{"ver": versions}, + Authenticate: true, + ResponseJSON: &resp, + }) + return +} + +func (c *Client) SendJoin(ctx context.Context, req *ReqSendJoin) (resp *RespSendJoin, err error) { + _, _, err = c.MakeFullRequest(ctx, RequestParams{ + ServerName: req.Via, + Method: http.MethodPut, + Path: URLPath{"v2", "send_join", req.RoomID, req.EventID}, + Query: url.Values{ + "omit_members": {strconv.FormatBool(req.OmitMembers)}, + }, + Authenticate: true, + RequestJSON: req.Event, + ResponseJSON: &resp, + }) + return +} + +func (c *Client) SendKnock(ctx context.Context, req *ReqSendKnock) (resp *RespSendKnock, err error) { + _, _, err = c.MakeFullRequest(ctx, RequestParams{ + ServerName: req.Via, + Method: http.MethodPut, + Path: URLPath{"v1", "send_knock", req.RoomID, req.EventID}, + Authenticate: true, + RequestJSON: req.Event, + ResponseJSON: &resp, + }) + return +} + +func (c *Client) SendInvite(ctx context.Context, req *ReqSendInvite) (resp *RespSendInvite, err error) { + _, _, err = c.MakeFullRequest(ctx, RequestParams{ + ServerName: req.UserID.Homeserver(), + Method: http.MethodPut, + Path: URLPath{"v2", "invite", req.RoomID, req.UserID}, + Authenticate: true, + RequestJSON: req, + ResponseJSON: &resp, + }) + return +} + +func (c *Client) MakeLeave(ctx context.Context, req *ReqMakeLeave) (resp *RespMakeLeave, err error) { + _, _, err = c.MakeFullRequest(ctx, RequestParams{ + ServerName: req.Via, + Method: http.MethodGet, + Path: URLPath{"v1", "make_leave", req.RoomID, req.UserID}, + Authenticate: true, + ResponseJSON: &resp, + }) + return +} + +func (c *Client) SendLeave(ctx context.Context, req *ReqSendLeave) (err error) { + _, _, err = c.MakeFullRequest(ctx, RequestParams{ + ServerName: req.Via, + Method: http.MethodPut, + Path: URLPath{"v2", "send_leave", req.RoomID, req.EventID}, + Authenticate: true, + RequestJSON: req.Event, + }) + return +} + type URLPath []any func (fup URLPath) FullPath() []any { @@ -286,15 +477,27 @@ func (c *Client) MakeFullRequest(ctx context.Context, params RequestParams) ([]b WrappedError: err, } } - defer func() { - _ = resp.Body.Close() - }() + if !params.DontReadBody { + defer resp.Body.Close() + } var body []byte - if resp.StatusCode >= 400 { + if resp.StatusCode >= 300 { body, err = mautrix.ParseErrorResponse(req, resp) return body, resp, err } else if params.ResponseJSON != nil || !params.DontReadBody { - body, err = io.ReadAll(resp.Body) + if resp.ContentLength > c.ResponseSizeLimit { + return body, resp, mautrix.HTTPError{ + Request: req, + Response: resp, + + Message: "not reading response", + WrappedError: fmt.Errorf("%w (%.2f MiB)", mautrix.ErrResponseTooLong, float64(resp.ContentLength)/1024/1024), + } + } + body, err = io.ReadAll(io.LimitReader(resp.Body, c.ResponseSizeLimit+1)) + if err == nil && len(body) > int(c.ResponseSizeLimit) { + err = mautrix.ErrBodyReadReachedLimit + } if err != nil { return body, resp, mautrix.HTTPError{ Request: req, @@ -354,16 +557,12 @@ func (c *Client) compileRequest(ctx context.Context, params RequestParams) (*htt Message: "client not configured for authentication", } } - var contentAny any - if reqJSON != nil { - contentAny = reqJSON - } auth, err := (&signableRequest{ Method: req.Method, URI: reqURL.RequestURI(), Origin: c.ServerName, Destination: params.ServerName, - Content: contentAny, + Content: reqJSON, }).Sign(c.Key) if err != nil { return nil, mautrix.HTTPError{ @@ -377,11 +576,19 @@ func (c *Client) compileRequest(ctx context.Context, params RequestParams) (*htt } type signableRequest struct { - Method string `json:"method"` - URI string `json:"uri"` - Origin string `json:"origin"` - Destination string `json:"destination"` - Content any `json:"content,omitempty"` + Method string `json:"method"` + URI string `json:"uri"` + Origin string `json:"origin"` + Destination string `json:"destination"` + Content json.RawMessage `json:"content,omitempty"` +} + +func (r *signableRequest) Verify(key id.SigningKey, sig string) error { + message, err := json.Marshal(r) + if err != nil { + return fmt.Errorf("failed to marshal data: %w", err) + } + return signutil.VerifyJSONRaw(key, sig, message) } func (r *signableRequest) Sign(key *SigningKey) (string, error) { @@ -389,11 +596,10 @@ func (r *signableRequest) Sign(key *SigningKey) (string, error) { if err != nil { return "", err } - return fmt.Sprintf( - `X-Matrix origin="%s",destination="%s",key="%s",sig="%s"`, - r.Origin, - r.Destination, - key.ID, - base64.RawURLEncoding.EncodeToString(sig), - ), nil + return XMatrixAuth{ + Origin: r.Origin, + Destination: r.Destination, + KeyID: key.ID, + Signature: sig, + }.String(), nil } diff --git a/federation/client_test.go b/federation/client_test.go index ba3c3ed4..ece399ea 100644 --- a/federation/client_test.go +++ b/federation/client_test.go @@ -16,7 +16,7 @@ import ( ) func TestClient_Version(t *testing.T) { - cli := federation.NewClient("", nil) + cli := federation.NewClient("", nil, nil) resp, err := cli.Version(context.TODO(), "maunium.net") require.NoError(t, err) require.Equal(t, "Synapse", resp.Server.Name) diff --git a/federation/context.go b/federation/context.go new file mode 100644 index 00000000..eedb2dc1 --- /dev/null +++ b/federation/context.go @@ -0,0 +1,42 @@ +// Copyright (c) 2025 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package federation + +import ( + "context" + "net/http" +) + +type contextKey int + +const ( + contextKeyIPPort contextKey = iota + contextKeyDestinationServer + contextKeyOriginServer +) + +func DestinationServerNameFromRequest(r *http.Request) string { + return DestinationServerName(r.Context()) +} + +func DestinationServerName(ctx context.Context) string { + if dest, ok := ctx.Value(contextKeyDestinationServer).(string); ok { + return dest + } + return "" +} + +func OriginServerNameFromRequest(r *http.Request) string { + return OriginServerName(r.Context()) +} + +func OriginServerName(ctx context.Context) string { + if origin, ok := ctx.Value(contextKeyOriginServer).(string); ok { + return origin + } + return "" +} diff --git a/federation/eventauth/eventauth.go b/federation/eventauth/eventauth.go new file mode 100644 index 00000000..c72933c2 --- /dev/null +++ b/federation/eventauth/eventauth.go @@ -0,0 +1,851 @@ +// Copyright (c) 2025 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +//go:build goexperiment.jsonv2 + +package eventauth + +import ( + "encoding/json" + "encoding/json/jsontext" + "errors" + "fmt" + "slices" + "strconv" + "strings" + + "github.com/tidwall/gjson" + "go.mau.fi/util/exgjson" + "go.mau.fi/util/exstrings" + "go.mau.fi/util/ptr" + + "maunium.net/go/mautrix/event" + "maunium.net/go/mautrix/federation/pdu" + "maunium.net/go/mautrix/federation/signutil" + "maunium.net/go/mautrix/id" +) + +type AuthFailError struct { + Index string + Message string + Wrapped error +} + +func (afe AuthFailError) Error() string { + if afe.Message != "" { + return fmt.Sprintf("fail %s: %s", afe.Index, afe.Message) + } else if afe.Wrapped != nil { + return fmt.Sprintf("fail %s: %s", afe.Index, afe.Wrapped.Error()) + } + return fmt.Sprintf("fail %s", afe.Index) +} + +func (afe AuthFailError) Unwrap() error { + return afe.Wrapped +} + +var mFederatePath = exgjson.Path("m.federate") + +var ( + ErrCreateHasPrevEvents = AuthFailError{Index: "1.1", Message: "m.room.create event has prev_events"} + ErrCreateHasRoomID = AuthFailError{Index: "1.2", Message: "m.room.create event has room_id set"} + ErrRoomIDDoesntMatchSender = AuthFailError{Index: "1.2", Message: "room ID server doesn't match sender server"} + ErrUnknownRoomVersion = AuthFailError{Index: "1.3", Wrapped: id.ErrUnknownRoomVersion} + ErrInvalidAdditionalCreators = AuthFailError{Index: "1.4", Message: "m.room.create event has invalid additional_creators"} + ErrMissingCreator = AuthFailError{Index: "1.4", Message: "m.room.create event is missing creator field"} + + ErrInvalidRoomIDLength = AuthFailError{Index: "2", Message: "room ID length is invalid"} + ErrFailedToGetCreateEvent = AuthFailError{Index: "2", Message: "failed to get m.room.create event"} + ErrCreateEventNotFound = AuthFailError{Index: "2", Message: "m.room.create event not found using room ID as event ID"} + ErrRejectedCreateEvent = AuthFailError{Index: "2", Message: "m.room.create event was rejected"} + + ErrFailedToGetAuthEvents = AuthFailError{Index: "3", Message: "failed to get auth events"} + ErrFailedToParsePowerLevels = AuthFailError{Index: "?", Message: "failed to parse power levels"} + ErrDuplicateAuthEvent = AuthFailError{Index: "3.1", Message: "duplicate type/state key pair in auth events"} + ErrNonStateAuthEvent = AuthFailError{Index: "3.2", Message: "non-state event in auth events"} + ErrMissingAuthEvent = AuthFailError{Index: "3.2", Message: "missing auth event"} + ErrUnexpectedAuthEvent = AuthFailError{Index: "3.2", Message: "unexpected type/state key pair in auth events"} + ErrNoCreateEvent = AuthFailError{Index: "3.2", Message: "no m.room.create event found in auth events"} + ErrRejectedAuthEvent = AuthFailError{Index: "3.3", Message: "auth event was rejected"} + ErrMismatchingRoomIDInAuthEvent = AuthFailError{Index: "3.4", Message: "auth event room ID does not match event room ID"} + + ErrFederationDisabled = AuthFailError{Index: "4", Message: "federation is disabled for this room"} + + ErrMemberNotState = AuthFailError{Index: "5.1", Message: "m.room.member event is not a state event"} + ErrNotSignedByAuthoriser = AuthFailError{Index: "5.2", Message: "m.room.member event is not signed by server of join_authorised_via_users_server"} + ErrCantJoinOtherUser = AuthFailError{Index: "5.3.2", Message: "can't send join event with different state key"} + ErrCantJoinBanned = AuthFailError{Index: "5.3.3", Message: "user is banned from the room"} + ErrAuthoriserCantInvite = AuthFailError{Index: "5.3.5.2", Message: "authoriser doesn't have sufficient power level to invite"} + ErrAuthoriserNotInRoom = AuthFailError{Index: "5.3.5.2", Message: "authoriser isn't a member of the room"} + ErrCantJoinWithoutInvite = AuthFailError{Index: "5.3.7", Message: "can't join invite-only room without invite"} + ErrInvalidJoinRule = AuthFailError{Index: "5.3.7", Message: "invalid join rule in room"} + ErrThirdPartyInviteBanned = AuthFailError{Index: "5.4.1.1", Message: "third party invite target user is banned"} + ErrThirdPartyInviteMissingFields = AuthFailError{Index: "5.4.1.3", Message: "third party invite is missing mxid or token fields"} + ErrThirdPartyInviteMXIDMismatch = AuthFailError{Index: "5.4.1.4", Message: "mxid in signed third party invite doesn't match event state key"} + ErrThirdPartyInviteNotFound = AuthFailError{Index: "5.4.1.5", Message: "matching m.room.third_party_invite event not found in auth events"} + ErrThirdPartyInviteSenderMismatch = AuthFailError{Index: "5.4.1.6", Message: "sender of third party invite doesn't match sender of member event"} + ErrThirdPartyInviteNotSigned = AuthFailError{Index: "5.4.1.8", Message: "no valid signatures found for third party invite"} + ErrInviterNotInRoom = AuthFailError{Index: "5.4.2", Message: "inviter's membership is not join"} + ErrInviteTargetAlreadyInRoom = AuthFailError{Index: "5.4.3", Message: "invite target user is already in the room"} + ErrInviteTargetBanned = AuthFailError{Index: "5.4.3", Message: "invite target user is banned"} + ErrInsufficientPermissionForInvite = AuthFailError{Index: "5.4.5", Message: "inviter does not have sufficient permission to send invites"} + ErrCantLeaveWithoutBeingInRoom = AuthFailError{Index: "5.5.1", Message: "can't leave room without being in it"} + ErrCantKickWithoutBeingInRoom = AuthFailError{Index: "5.5.2", Message: "can't kick another user without being in the room"} + ErrInsufficientPermissionForUnban = AuthFailError{Index: "5.5.3", Message: "sender does not have sufficient permission to unban users"} + ErrInsufficientPermissionForKick = AuthFailError{Index: "5.5.5", Message: "sender does not have sufficient permission to kick the user"} + ErrCantBanWithoutBeingInRoom = AuthFailError{Index: "5.6.1", Message: "can't ban another user without being in the room"} + ErrInsufficientPermissionForBan = AuthFailError{Index: "5.6.3", Message: "sender does not have sufficient permission to ban the user"} + ErrNotKnockableRoom = AuthFailError{Index: "5.7.1", Message: "join rule doesn't allow knocking"} + ErrCantKnockOtherUser = AuthFailError{Index: "5.7.1", Message: "can't send knock event with different state key"} + ErrCantKnockWhileInRoom = AuthFailError{Index: "5.7.2", Message: "can't knock while joined, invited or banned"} + ErrUnknownMembership = AuthFailError{Index: "5.8", Message: "unknown membership in m.room.member event"} + + ErrNotInRoom = AuthFailError{Index: "6", Message: "sender is not a member of the room"} + + ErrInsufficientPowerForThirdPartyInvite = AuthFailError{Index: "7.1", Message: "sender does not have sufficient power level to send third party invite"} + + ErrInsufficientPowerLevel = AuthFailError{Index: "8", Message: "sender does not have sufficient power level to send event"} + + ErrMismatchingPrivateStateKey = AuthFailError{Index: "9", Message: "state keys starting with @ must match sender user ID"} + + ErrTopLevelPLNotInteger = AuthFailError{Index: "10.1", Message: "invalid type for top-level power level field"} + ErrPLNotInteger = AuthFailError{Index: "10.2", Message: "invalid type for power level"} + ErrInvalidUserIDInPL = AuthFailError{Index: "10.3", Message: "invalid user ID in power levels"} + ErrUserPLNotInteger = AuthFailError{Index: "10.3", Message: "invalid type for user power level"} + ErrCreatorInPowerLevels = AuthFailError{Index: "10.4", Message: "room creators must not be specified in power levels"} + ErrInvalidPowerChange = AuthFailError{Index: "10.x", Message: "illegal power level change"} + ErrInvalidUserPowerChange = AuthFailError{Index: "10.9", Message: "illegal power level change"} +) + +func isRejected(evt *pdu.PDU) bool { + return evt.InternalMeta.Rejected +} + +type GetEventsFunc = func(ids []id.EventID) ([]*pdu.PDU, error) + +func Authorize(roomVersion id.RoomVersion, evt *pdu.PDU, getEvents GetEventsFunc, getKey pdu.GetKeyFunc) error { + if evt.Type == event.StateCreate.Type { + // 1. If type is m.room.create: + return authorizeCreate(roomVersion, evt) + } + var createEvt *pdu.PDU + if roomVersion.RoomIDIsCreateEventID() { + // 2. If the event’s room_id is not an event ID for an accepted (not rejected) m.room.create event, + // with the sigil ! instead of $, reject. + if len(evt.RoomID) != 44 { + return fmt.Errorf("%w (%d)", ErrInvalidRoomIDLength, len(evt.RoomID)) + } else if createEvts, err := getEvents([]id.EventID{id.EventID("$" + evt.RoomID[1:])}); err != nil { + return fmt.Errorf("%w: %w", ErrFailedToGetCreateEvent, err) + } else if len(createEvts) != 1 { + return fmt.Errorf("%w (%s)", ErrCreateEventNotFound, evt.RoomID) + } else if isRejected(createEvts[0]) { + return ErrRejectedCreateEvent + } else { + createEvt = createEvts[0] + } + } + authEvents, err := getEvents(evt.AuthEvents) + if err != nil { + return fmt.Errorf("%w: %w", ErrFailedToGetAuthEvents, err) + } + expectedAuthEvents := evt.AuthEventSelection(roomVersion) + deduplicator := make(map[pdu.StateKey]id.EventID, len(expectedAuthEvents)) + // 3. Considering the event’s auth_events: + for i, ae := range authEvents { + authEvtID := evt.AuthEvents[i] + if ae == nil { + return fmt.Errorf("%w (%s)", ErrMissingAuthEvent, authEvtID) + } else if ae.StateKey == nil { + // This approximately falls under rule 3.2. + return fmt.Errorf("%w (%s)", ErrNonStateAuthEvent, authEvtID) + } + key := pdu.StateKey{Type: ae.Type, StateKey: *ae.StateKey} + if prevEvtID, alreadyFound := deduplicator[key]; alreadyFound { + // 3.1. If there are duplicate entries for a given type and state_key pair, reject. + return fmt.Errorf("%w for %s/%s: found %s and %s", ErrDuplicateAuthEvent, ae.Type, *ae.StateKey, prevEvtID, authEvtID) + } else if !expectedAuthEvents.Has(key) { + // 3.2. If there are entries whose type and state_key don’t match those specified by + // the auth events selection algorithm described in the server specification, reject. + return fmt.Errorf("%w: found %s with key %s/%s", ErrUnexpectedAuthEvent, authEvtID, ae.Type, *ae.StateKey) + } else if isRejected(ae) { + // 3.3. If there are entries which were themselves rejected under the checks performed on receipt of a PDU, reject. + return fmt.Errorf("%w (%s)", ErrRejectedAuthEvent, authEvtID) + } else if ae.RoomID != evt.RoomID { + // 3.4. If any event in auth_events has a room_id which does not match that of the event being authorised, reject. + return fmt.Errorf("%w (%s)", ErrMismatchingRoomIDInAuthEvent, authEvtID) + } else { + deduplicator[key] = authEvtID + } + if ae.Type == event.StateCreate.Type { + if createEvt == nil { + createEvt = ae + } else { + // Duplicates are prevented by deduplicator, AuthEventSelection also won't allow a create event at all for v12+ + panic(fmt.Errorf("impossible case: multiple create events found in auth events")) + } + } + } + if createEvt == nil { + // This comes either from auth_events or room_id depending on the room version. + // The checks above make sure it's from the right source. + return ErrNoCreateEvent + } + if federateVal := gjson.GetBytes(createEvt.Content, mFederatePath); federateVal.Type == gjson.False && createEvt.Sender.Homeserver() != evt.Sender.Homeserver() { + // 4. If the content of the m.room.create event in the room state has the property m.federate set to false, + // and the sender domain of the event does not match the sender domain of the create event, reject. + return ErrFederationDisabled + } + if evt.Type == event.StateMember.Type { + // 5. If type is m.room.member: + return authorizeMember(roomVersion, evt, createEvt, authEvents, getKey) + } + senderMembership := event.Membership(findEventAndReadString(authEvents, event.StateMember.Type, evt.Sender.String(), "membership", "leave")) + if senderMembership != event.MembershipJoin { + // 6. If the sender’s current membership state is not join, reject. + return ErrNotInRoom + } + powerLevels, err := getPowerLevels(roomVersion, authEvents, createEvt) + if err != nil { + return err + } + senderPL := powerLevels.GetUserLevel(evt.Sender) + if evt.Type == event.StateThirdPartyInvite.Type { + // 7.1. Allow if and only if sender’s current power level is greater than or equal to the invite level. + if senderPL >= powerLevels.Invite() { + return nil + } + return ErrInsufficientPowerForThirdPartyInvite + } + typeClass := event.MessageEventType + if evt.StateKey != nil { + typeClass = event.StateEventType + } + evtLevel := powerLevels.GetEventLevel(event.Type{Type: evt.Type, Class: typeClass}) + if evtLevel > senderPL { + // 8. If the event type’s required power level is greater than the sender’s power level, reject. + return fmt.Errorf("%w (%d > %d)", ErrInsufficientPowerLevel, evtLevel, senderPL) + } + + if evt.StateKey != nil && strings.HasPrefix(*evt.StateKey, "@") && *evt.StateKey != evt.Sender.String() { + // 9. If the event has a state_key that starts with an @ and does not match the sender, reject. + return ErrMismatchingPrivateStateKey + } + + if evt.Type == event.StatePowerLevels.Type { + // 10. If type is m.room.power_levels: + return authorizePowerLevels(roomVersion, evt, createEvt, authEvents) + } + + // 11. Otherwise, allow. + return nil +} + +var ErrUserIDNotAString = errors.New("not a string") +var ErrUserIDNotValid = errors.New("not a valid user ID") + +func isValidUserID(roomVersion id.RoomVersion, userID gjson.Result) error { + if userID.Type != gjson.String { + return ErrUserIDNotAString + } + // In a future room version, user IDs will have stricter validation + _, _, err := id.UserID(userID.Str).Parse() + if err != nil { + return ErrUserIDNotValid + } + return nil +} + +func authorizeCreate(roomVersion id.RoomVersion, evt *pdu.PDU) error { + if len(evt.PrevEvents) > 0 { + // 1.1. If it has any prev_events, reject. + return ErrCreateHasPrevEvents + } + if roomVersion.RoomIDIsCreateEventID() { + if evt.RoomID != "" { + // 1.2. If the event has a room_id, reject. + return ErrCreateHasRoomID + } + } else { + _, _, server := id.ParseCommonIdentifier(evt.RoomID) + if server == "" || server != evt.Sender.Homeserver() { + // 1.2. (v11 and below) If the domain of the room_id does not match the domain of the sender, reject. + return ErrRoomIDDoesntMatchSender + } + } + if !roomVersion.IsKnown() { + // 1.3. If content.room_version is present and is not a recognised version, reject. + return fmt.Errorf("%w %s", ErrUnknownRoomVersion, roomVersion) + } + if roomVersion.PrivilegedRoomCreators() { + additionalCreators := gjson.GetBytes(evt.Content, "additional_creators") + if additionalCreators.Exists() { + if !additionalCreators.IsArray() { + return fmt.Errorf("%w: not an array", ErrInvalidAdditionalCreators) + } + for i, item := range additionalCreators.Array() { + // 1.4. If additional_creators is present in content and is not an array of strings + // where each string passes the same user ID validation applied to sender, reject. + if err := isValidUserID(roomVersion, item); err != nil { + return fmt.Errorf("%w: item #%d %w", ErrInvalidAdditionalCreators, i+1, err) + } + } + } + } + if roomVersion.CreatorInContent() { + // 1.4. (v10 and below) If content has no creator property, reject. + if !gjson.GetBytes(evt.Content, "creator").Exists() { + return ErrMissingCreator + } + } + // 1.5. Otherwise, allow. + return nil +} + +func authorizeMember(roomVersion id.RoomVersion, evt, createEvt *pdu.PDU, authEvents []*pdu.PDU, getKey pdu.GetKeyFunc) error { + membership := event.Membership(gjson.GetBytes(evt.Content, "membership").Str) + if evt.StateKey == nil { + // 5.1. If there is no state_key property, or no membership property in content, reject. + return ErrMemberNotState + } + authorizedVia := id.UserID(gjson.GetBytes(evt.Content, "authorised_via_users_server").Str) + if authorizedVia != "" { + homeserver := authorizedVia.Homeserver() + err := evt.VerifySignature(roomVersion, homeserver, getKey) + if err != nil { + // 5.2. If content has a join_authorised_via_users_server key: + // 5.2.1. If the event is not validly signed by the homeserver of the user ID denoted by the key, reject. + return fmt.Errorf("%w: %w", ErrNotSignedByAuthoriser, err) + } + } + targetPrevMembership := event.Membership(findEventAndReadString(authEvents, event.StateMember.Type, *evt.StateKey, "membership", "leave")) + senderMembership := event.Membership(findEventAndReadString(authEvents, event.StateMember.Type, evt.Sender.String(), "membership", "leave")) + switch membership { + case event.MembershipJoin: + createEvtID, err := createEvt.GetEventID(roomVersion) + if err != nil { + return fmt.Errorf("failed to get create event ID: %w", err) + } + creator := createEvt.Sender.String() + if roomVersion.CreatorInContent() { + creator = gjson.GetBytes(evt.Content, "creator").Str + } + if len(evt.PrevEvents) == 1 && + len(evt.AuthEvents) <= 1 && + evt.PrevEvents[0] == createEvtID && + *evt.StateKey == creator { + // 5.3.1. If the only previous event is an m.room.create and the state_key is the sender of the m.room.create, allow. + return nil + } + // Spec wart: this would make more sense before the check above. + // Now you can set anyone as the sender of the first join. + if evt.Sender.String() != *evt.StateKey { + // 5.3.2. If the sender does not match state_key, reject. + return ErrCantJoinOtherUser + } + + if senderMembership == event.MembershipBan { + // 5.3.3. If the sender is banned, reject. + return ErrCantJoinBanned + } + + joinRule := event.JoinRule(findEventAndReadString(authEvents, event.StateJoinRules.Type, "", "join_rule", "invite")) + switch joinRule { + case event.JoinRuleKnock: + if !roomVersion.Knocks() { + return ErrInvalidJoinRule + } + fallthrough + case event.JoinRuleInvite: + // 5.3.4. If the join_rule is invite or knock then allow if membership state is invite or join. + if targetPrevMembership == event.MembershipJoin || targetPrevMembership == event.MembershipInvite { + return nil + } + return ErrCantJoinWithoutInvite + case event.JoinRuleKnockRestricted: + if !roomVersion.KnockRestricted() { + return ErrInvalidJoinRule + } + fallthrough + case event.JoinRuleRestricted: + if joinRule == event.JoinRuleRestricted && !roomVersion.RestrictedJoins() { + return ErrInvalidJoinRule + } + if targetPrevMembership == event.MembershipJoin || targetPrevMembership == event.MembershipInvite { + // 5.3.5.1. If membership state is join or invite, allow. + return nil + } + powerLevels, err := getPowerLevels(roomVersion, authEvents, createEvt) + if err != nil { + return err + } + if powerLevels.GetUserLevel(authorizedVia) < powerLevels.Invite() { + // 5.3.5.2. If the join_authorised_via_users_server key in content is not a user with sufficient permission to invite other users, reject. + return ErrAuthoriserCantInvite + } + authorizerMembership := event.Membership(findEventAndReadString(authEvents, event.StateMember.Type, authorizedVia.String(), "membership", string(event.MembershipLeave))) + if authorizerMembership != event.MembershipJoin { + return ErrAuthoriserNotInRoom + } + // 5.3.5.3. Otherwise, allow. + return nil + case event.JoinRulePublic: + // 5.3.6. If the join_rule is public, allow. + return nil + default: + // 5.3.7. Otherwise, reject. + return ErrInvalidJoinRule + } + case event.MembershipInvite: + tpiVal := gjson.GetBytes(evt.Content, "third_party_invite") + if tpiVal.Exists() { + if targetPrevMembership == event.MembershipBan { + return ErrThirdPartyInviteBanned + } + signed := tpiVal.Get("signed") + mxid := signed.Get("mxid").Str + token := signed.Get("token").Str + if mxid == "" || token == "" { + // 5.4.1.2. If content.third_party_invite does not have a signed property, reject. + // 5.4.1.3. If signed does not have mxid and token properties, reject. + return ErrThirdPartyInviteMissingFields + } + if mxid != *evt.StateKey { + // 5.4.1.4. If mxid does not match state_key, reject. + return ErrThirdPartyInviteMXIDMismatch + } + tpiEvt := findEvent(authEvents, event.StateThirdPartyInvite.Type, token) + if tpiEvt == nil { + // 5.4.1.5. If there is no m.room.third_party_invite event in the current room state with state_key matching token, reject. + return ErrThirdPartyInviteNotFound + } + if tpiEvt.Sender != evt.Sender { + // 5.4.1.6. If sender does not match sender of the m.room.third_party_invite, reject. + return ErrThirdPartyInviteSenderMismatch + } + var keys []id.Ed25519 + const ed25519Base64Len = 43 + oldPubKey := gjson.GetBytes(evt.Content, "public_key.token") + if oldPubKey.Type == gjson.String && len(oldPubKey.Str) == ed25519Base64Len { + keys = append(keys, id.Ed25519(oldPubKey.Str)) + } + gjson.GetBytes(evt.Content, "public_keys").ForEach(func(key, value gjson.Result) bool { + if key.Type != gjson.Number { + return false + } + if value.Type == gjson.String && len(value.Str) == ed25519Base64Len { + keys = append(keys, id.Ed25519(value.Str)) + } + return true + }) + rawSigned := jsontext.Value(exstrings.UnsafeBytes(signed.Str)) + var validated bool + for _, key := range keys { + if signutil.VerifyJSONAny(key, rawSigned) == nil { + validated = true + } + } + if validated { + // 4.4.1.7. If any signature in signed matches any public key in the m.room.third_party_invite event, allow. + return nil + } + // 4.4.1.8. Otherwise, reject. + return ErrThirdPartyInviteNotSigned + } + if senderMembership != event.MembershipJoin { + // 5.4.2. If the sender’s current membership state is not join, reject. + return ErrInviterNotInRoom + } + // 5.4.3. If target user’s current membership state is join or ban, reject. + if targetPrevMembership == event.MembershipJoin { + return ErrInviteTargetAlreadyInRoom + } else if targetPrevMembership == event.MembershipBan { + return ErrInviteTargetBanned + } + powerLevels, err := getPowerLevels(roomVersion, authEvents, createEvt) + if err != nil { + return err + } + if powerLevels.GetUserLevel(evt.Sender) >= powerLevels.Invite() { + // 5.4.4. If the sender’s power level is greater than or equal to the invite level, allow. + return nil + } + // 5.4.5. Otherwise, reject. + return ErrInsufficientPermissionForInvite + case event.MembershipLeave: + if evt.Sender.String() == *evt.StateKey { + // 5.5.1. If the sender matches state_key, allow if and only if that user’s current membership state is invite, join, or knock. + if senderMembership == event.MembershipInvite || + senderMembership == event.MembershipJoin || + (senderMembership == event.MembershipKnock && roomVersion.Knocks()) { + return nil + } + return ErrCantLeaveWithoutBeingInRoom + } + if senderMembership != event.MembershipJoin { + // 5.5.2. If the sender’s current membership state is not join, reject. + return ErrCantKickWithoutBeingInRoom + } + powerLevels, err := getPowerLevels(roomVersion, authEvents, createEvt) + if err != nil { + return err + } + senderLevel := powerLevels.GetUserLevel(evt.Sender) + if targetPrevMembership == event.MembershipBan && senderLevel < powerLevels.Ban() { + // 5.5.3. If the target user’s current membership state is ban, and the sender’s power level is less than the ban level, reject. + return ErrInsufficientPermissionForUnban + } + if senderLevel >= powerLevels.Kick() && powerLevels.GetUserLevel(id.UserID(*evt.StateKey)) < senderLevel { + // 5.5.4. If the sender’s power level is greater than or equal to the kick level, and the target user’s power level is less than the sender’s power level, allow. + return nil + } + // TODO separate errors for < kick and < target user level? + // 5.5.5. Otherwise, reject. + return ErrInsufficientPermissionForKick + case event.MembershipBan: + if senderMembership != event.MembershipJoin { + // 5.6.1. If the sender’s current membership state is not join, reject. + return ErrCantBanWithoutBeingInRoom + } + powerLevels, err := getPowerLevels(roomVersion, authEvents, createEvt) + if err != nil { + return err + } + senderLevel := powerLevels.GetUserLevel(evt.Sender) + if senderLevel >= powerLevels.Ban() && powerLevels.GetUserLevel(id.UserID(*evt.StateKey)) < senderLevel { + // 5.6.2. If the sender’s power level is greater than or equal to the ban level, and the target user’s power level is less than the sender’s power level, allow. + return nil + } + // 5.6.3. Otherwise, reject. + return ErrInsufficientPermissionForBan + case event.MembershipKnock: + joinRule := event.JoinRule(findEventAndReadString(authEvents, event.StateJoinRules.Type, "", "join_rule", "invite")) + validKnockRule := roomVersion.Knocks() && joinRule == event.JoinRuleKnock + validKnockRestrictedRule := roomVersion.KnockRestricted() && joinRule == event.JoinRuleKnockRestricted + if !validKnockRule && !validKnockRestrictedRule { + // 5.7.1. If the join_rule is anything other than knock or knock_restricted, reject. + return ErrNotKnockableRoom + } + if evt.Sender.String() != *evt.StateKey { + // 5.7.2. If the sender does not match state_key, reject. + return ErrCantKnockOtherUser + } + if senderMembership != event.MembershipBan && senderMembership != event.MembershipInvite && senderMembership != event.MembershipJoin { + // 5.7.3. If the sender’s current membership is not ban, invite, or join, allow. + return nil + } + // 5.7.4. Otherwise, reject. + return ErrCantKnockWhileInRoom + default: + // 5.8. Otherwise, the membership is unknown. Reject. + return ErrUnknownMembership + } +} + +func authorizePowerLevels(roomVersion id.RoomVersion, evt, createEvt *pdu.PDU, authEvents []*pdu.PDU) error { + if roomVersion.ValidatePowerLevelInts() { + for _, key := range []string{"users_default", "events_default", "state_default", "ban", "redact", "kick", "invite"} { + res := gjson.GetBytes(evt.Content, key) + if !res.Exists() { + continue + } + if parseIntWithVersion(roomVersion, res) == nil { + // 10.1. If any of the properties users_default, events_default, state_default, ban, redact, kick, or invite in content are present and not an integer, reject. + return fmt.Errorf("%w %s", ErrTopLevelPLNotInteger, key) + } + } + for _, key := range []string{"events", "notifications"} { + obj := gjson.GetBytes(evt.Content, key) + if !obj.Exists() { + continue + } + // 10.2. If either of the properties events or notifications in content are present and not an object [...], reject. + if !obj.IsObject() { + return fmt.Errorf("%w %s", ErrTopLevelPLNotInteger, key) + } + var err error + // 10.2. [...] are not an object with values that are integers, reject. + obj.ForEach(func(innerKey, value gjson.Result) bool { + if parseIntWithVersion(roomVersion, value) == nil { + err = fmt.Errorf("%w %s.%s", ErrPLNotInteger, key, innerKey.Str) + return false + } + return true + }) + if err != nil { + return err + } + } + } + var creators []id.UserID + if roomVersion.PrivilegedRoomCreators() { + creators = append(creators, createEvt.Sender) + gjson.GetBytes(createEvt.Content, "additional_creators").ForEach(func(key, value gjson.Result) bool { + creators = append(creators, id.UserID(value.Str)) + return true + }) + } + users := gjson.GetBytes(evt.Content, "users") + if users.Exists() { + if !users.IsObject() { + // 10.3. If the users property in content is not an object [...], reject. + return fmt.Errorf("%w users", ErrTopLevelPLNotInteger) + } + var err error + users.ForEach(func(key, value gjson.Result) bool { + if validatorErr := isValidUserID(roomVersion, key); validatorErr != nil { + // 10.3. [...] is not an object with keys that are valid user IDs [...], reject. + err = fmt.Errorf("%w: %q %w", ErrInvalidUserIDInPL, key.Str, validatorErr) + return false + } + if parseIntWithVersion(roomVersion, value) == nil { + // 10.3. [...] is not an object [...] with values that are integers, reject. + err = fmt.Errorf("%w %q", ErrUserPLNotInteger, key.Str) + return false + } + // creators is only filled if the room version has privileged room creators + if slices.Contains(creators, id.UserID(key.Str)) { + // 10.4. If the users property in content contains the sender of the m.room.create event or any of + // the additional_creators array (if present) from the content of the m.room.create event, reject. + err = fmt.Errorf("%w: %q", ErrCreatorInPowerLevels, key.Str) + return false + } + return true + }) + if err != nil { + return err + } + } + oldPL := findEvent(authEvents, event.StatePowerLevels.Type, "") + if oldPL == nil { + // 10.5. If there is no previous m.room.power_levels event in the room, allow. + return nil + } + if slices.Contains(creators, evt.Sender) { + // Skip remaining checks for creators + return nil + } + senderPLPtr := parsePythonInt(gjson.GetBytes(oldPL.Content, exgjson.Path("users", evt.Sender.String()))) + if senderPLPtr == nil { + senderPLPtr = parsePythonInt(gjson.GetBytes(oldPL.Content, "users_default")) + if senderPLPtr == nil { + senderPLPtr = ptr.Ptr(0) + } + } + for _, key := range []string{"users_default", "events_default", "state_default", "ban", "redact", "kick", "invite"} { + oldVal := gjson.GetBytes(oldPL.Content, key) + newVal := gjson.GetBytes(evt.Content, key) + if err := allowPowerChange(roomVersion, *senderPLPtr, key, oldVal, newVal); err != nil { + return err + } + } + if err := allowPowerChangeMap( + roomVersion, *senderPLPtr, "events", "", + gjson.GetBytes(oldPL.Content, "events"), + gjson.GetBytes(evt.Content, "events"), + ); err != nil { + return err + } + if err := allowPowerChangeMap( + roomVersion, *senderPLPtr, "notifications", "", + gjson.GetBytes(oldPL.Content, "notifications"), + gjson.GetBytes(evt.Content, "notifications"), + ); err != nil { + return err + } + if err := allowPowerChangeMap( + roomVersion, *senderPLPtr, "users", evt.Sender.String(), + gjson.GetBytes(oldPL.Content, "users"), + gjson.GetBytes(evt.Content, "users"), + ); err != nil { + return err + } + return nil +} + +func allowPowerChangeMap(roomVersion id.RoomVersion, maxVal int, path, ownID string, old, new gjson.Result) (err error) { + old.ForEach(func(key, value gjson.Result) bool { + newVal := new.Get(exgjson.Path(key.Str)) + err = allowPowerChange(roomVersion, maxVal, path+"."+key.Str, value, newVal) + if err == nil && ownID != "" && key.Str != ownID { + parsedOldVal := parseIntWithVersion(roomVersion, value) + parsedNewVal := parseIntWithVersion(roomVersion, newVal) + if *parsedOldVal >= maxVal && *parsedOldVal != *parsedNewVal { + err = fmt.Errorf("%w: can't change users.%s from %s to %s with sender level %d", ErrInvalidUserPowerChange, key.Str, stringifyForError(value), stringifyForError(newVal), maxVal) + } + } + return err == nil + }) + if err != nil { + return + } + new.ForEach(func(key, value gjson.Result) bool { + err = allowPowerChange(roomVersion, maxVal, path+"."+key.Str, old.Get(exgjson.Path(key.Str)), value) + return err == nil + }) + return +} + +func allowPowerChange(roomVersion id.RoomVersion, maxVal int, path string, old, new gjson.Result) error { + oldVal := parseIntWithVersion(roomVersion, old) + newVal := parseIntWithVersion(roomVersion, new) + if oldVal == nil { + if newVal == nil || *newVal <= maxVal { + return nil + } + } else if newVal == nil { + if *oldVal <= maxVal { + return nil + } + } else if *oldVal == *newVal || (*oldVal <= maxVal && *newVal <= maxVal) { + return nil + } + return fmt.Errorf("%w can't change %s from %s to %s with sender level %d", ErrInvalidPowerChange, path, stringifyForError(old), stringifyForError(new), maxVal) +} + +func stringifyForError(val gjson.Result) string { + if !val.Exists() { + return "null" + } + return val.Raw +} + +func findEvent(events []*pdu.PDU, evtType, stateKey string) *pdu.PDU { + for _, evt := range events { + if evt.Type == evtType && *evt.StateKey == stateKey { + return evt + } + } + return nil +} + +func findEventAndReadData[T any](events []*pdu.PDU, evtType, stateKey string, reader func(evt *pdu.PDU) T) T { + return reader(findEvent(events, evtType, stateKey)) +} + +func findEventAndReadString(events []*pdu.PDU, evtType, stateKey, fieldPath, defVal string) string { + return findEventAndReadData(events, evtType, stateKey, func(evt *pdu.PDU) string { + if evt == nil { + return defVal + } + res := gjson.GetBytes(evt.Content, fieldPath) + if res.Type != gjson.String { + return defVal + } + return res.Str + }) +} + +func getPowerLevels(roomVersion id.RoomVersion, authEvents []*pdu.PDU, createEvt *pdu.PDU) (*event.PowerLevelsEventContent, error) { + var err error + powerLevels := findEventAndReadData(authEvents, event.StatePowerLevels.Type, "", func(evt *pdu.PDU) *event.PowerLevelsEventContent { + if evt == nil { + return nil + } + content := evt.Content + out := &event.PowerLevelsEventContent{} + if !roomVersion.ValidatePowerLevelInts() { + safeParsePowerLevels(content, out) + } else { + err = json.Unmarshal(content, out) + } + return out + }) + if err != nil { + // This should never happen thanks to safeParsePowerLevels for v1-9 and strict validation in v10+ + return nil, fmt.Errorf("%w: %w", ErrFailedToParsePowerLevels, err) + } + if roomVersion.PrivilegedRoomCreators() { + if powerLevels == nil { + powerLevels = &event.PowerLevelsEventContent{} + } + powerLevels.CreateEvent, err = createEvt.ToClientEvent(roomVersion) + if err != nil { + return nil, fmt.Errorf("%w: %w", ErrFailedToParsePowerLevels, err) + } + err = powerLevels.CreateEvent.Content.ParseRaw(powerLevels.CreateEvent.Type) + if err != nil { + return nil, fmt.Errorf("%w: %w", ErrFailedToParsePowerLevels, err) + } + } else if powerLevels == nil { + powerLevels = &event.PowerLevelsEventContent{ + Users: map[id.UserID]int{ + createEvt.Sender: 100, + }, + } + } + return powerLevels, nil +} + +func parseIntWithVersion(roomVersion id.RoomVersion, val gjson.Result) *int { + if roomVersion.ValidatePowerLevelInts() { + if val.Type != gjson.Number { + return nil + } + return ptr.Ptr(int(val.Int())) + } + return parsePythonInt(val) +} + +func parsePythonInt(val gjson.Result) *int { + switch val.Type { + case gjson.True: + return ptr.Ptr(1) + case gjson.False: + return ptr.Ptr(0) + case gjson.Number: + return ptr.Ptr(int(val.Int())) + case gjson.String: + // strconv.Atoi accepts signs as well as leading zeroes, so we just need to trim spaces beforehand + num, err := strconv.Atoi(strings.TrimSpace(val.Str)) + if err != nil { + return nil + } + return &num + default: + // Python int() doesn't accept nulls, arrays or dicts + return nil + } +} + +func safeParsePowerLevels(content jsontext.Value, into *event.PowerLevelsEventContent) { + *into = event.PowerLevelsEventContent{ + Users: make(map[id.UserID]int), + UsersDefault: ptr.Val(parsePythonInt(gjson.GetBytes(content, "users_default"))), + Events: make(map[string]int), + EventsDefault: ptr.Val(parsePythonInt(gjson.GetBytes(content, "events_default"))), + Notifications: nil, // irrelevant for event auth + StateDefaultPtr: parsePythonInt(gjson.GetBytes(content, "state_default")), + InvitePtr: parsePythonInt(gjson.GetBytes(content, "invite")), + KickPtr: parsePythonInt(gjson.GetBytes(content, "kick")), + BanPtr: parsePythonInt(gjson.GetBytes(content, "ban")), + RedactPtr: parsePythonInt(gjson.GetBytes(content, "redact")), + } + gjson.GetBytes(content, "events").ForEach(func(key, value gjson.Result) bool { + if key.Type != gjson.String { + return false + } + val := parsePythonInt(value) + if val != nil { + into.Events[key.Str] = *val + } + return true + }) + gjson.GetBytes(content, "users").ForEach(func(key, value gjson.Result) bool { + if key.Type != gjson.String { + return false + } + val := parsePythonInt(value) + if val == nil { + return false + } + userID := id.UserID(key.Str) + if _, _, err := userID.Parse(); err != nil { + return false + } + into.Users[userID] = *val + return true + }) +} diff --git a/federation/eventauth/eventauth_internal_test.go b/federation/eventauth/eventauth_internal_test.go new file mode 100644 index 00000000..d316f3c8 --- /dev/null +++ b/federation/eventauth/eventauth_internal_test.go @@ -0,0 +1,66 @@ +// Copyright (c) 2026 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +//go:build goexperiment.jsonv2 + +package eventauth + +import ( + "strings" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/tidwall/gjson" +) + +type pythonIntTest struct { + Name string + Input string + Expected int64 +} + +var pythonIntTests = []pythonIntTest{ + {"True", `true`, 1}, + {"False", `false`, 0}, + {"SmallFloat", `3.1415`, 3}, + {"SmallFloatRoundDown", `10.999999999999999`, 10}, + {"SmallFloatRoundUp", `10.9999999999999999`, 11}, + {"BigFloatRoundDown", `1000000.9999999999`, 1000000}, + {"BigFloatRoundUp", `1000000.99999999999`, 1000001}, + {"BigFloatPrecisionError", `9007199254740993.0`, 9007199254740992}, + {"BigFloatPrecisionError2", `9007199254740993.123`, 9007199254740994}, + {"Int64", `9223372036854775807`, 9223372036854775807}, + {"Int64String", `"9223372036854775807"`, 9223372036854775807}, + {"String", `"123"`, 123}, + {"InvalidFloatInString", `"123.456"`, 0}, + {"StringWithPlusSign", `"+123"`, 123}, + {"StringWithMinusSign", `"-123"`, -123}, + {"StringWithSpaces", `" 123 "`, 123}, + {"StringWithSpacesAndSign", `" -123 "`, -123}, + //{"StringWithUnderscores", `"123_456"`, 123456}, + //{"StringWithUnderscores", `"123_456"`, 123456}, + {"InvalidStringWithTrailingUnderscore", `"123_456_"`, 0}, + {"InvalidStringWithMultipleUnderscores", `"123__456"`, 0}, + {"InvalidStringWithLeadingUnderscore", `"_123_456"`, 0}, + {"InvalidStringWithUnderscoreAfterSign", `"+_123_456"`, 0}, + {"InvalidStringWithUnderscoreAfterSpace", `" _123_456"`, 0}, + //{"StringWithUnderscoresAndSpaces", `" +1_2_3_4_5_6 "`, 123456}, +} + +func TestParsePythonInt(t *testing.T) { + for _, test := range pythonIntTests { + t.Run(test.Name, func(t *testing.T) { + output := parsePythonInt(gjson.Parse(test.Input)) + if strings.HasPrefix(test.Name, "Invalid") { + assert.Nil(t, output) + } else { + require.NotNil(t, output) + assert.Equal(t, int(test.Expected), *output) + } + }) + } +} diff --git a/federation/eventauth/eventauth_test.go b/federation/eventauth/eventauth_test.go new file mode 100644 index 00000000..e3c5cd76 --- /dev/null +++ b/federation/eventauth/eventauth_test.go @@ -0,0 +1,85 @@ +// Copyright (c) 2025 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +//go:build goexperiment.jsonv2 + +package eventauth_test + +import ( + "embed" + "encoding/json/jsontext" + "encoding/json/v2" + "errors" + "io" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/tidwall/gjson" + "go.mau.fi/util/exerrors" + "go.mau.fi/util/ptr" + + "maunium.net/go/mautrix/federation/eventauth" + "maunium.net/go/mautrix/federation/pdu" + "maunium.net/go/mautrix/id" +) + +//go:embed *.jsonl +var data embed.FS + +type eventMap map[id.EventID]*pdu.PDU + +func (em eventMap) Get(ids []id.EventID) ([]*pdu.PDU, error) { + output := make([]*pdu.PDU, len(ids)) + for i, evtID := range ids { + output[i] = em[evtID] + } + return output, nil +} + +func GetKey(serverName string, keyID id.KeyID, validUntilTS time.Time) (id.SigningKey, time.Time, error) { + return "", time.Time{}, nil +} + +func TestAuthorize(t *testing.T) { + files := exerrors.Must(data.ReadDir(".")) + for _, file := range files { + t.Run(file.Name(), func(t *testing.T) { + decoder := jsontext.NewDecoder(exerrors.Must(data.Open(file.Name()))) + events := make(eventMap) + var roomVersion *id.RoomVersion + for i := 1; ; i++ { + var evt *pdu.PDU + err := json.UnmarshalDecode(decoder, &evt) + if errors.Is(err, io.EOF) { + break + } + require.NoError(t, err) + if roomVersion == nil { + require.Equal(t, evt.Type, "m.room.create") + roomVersion = ptr.Ptr(id.RoomVersion(gjson.GetBytes(evt.Content, "room_version").Str)) + } + expectedEventID := gjson.GetBytes(evt.Unsigned, "event_id").Str + evtID, err := evt.GetEventID(*roomVersion) + require.NoError(t, err) + require.Equalf(t, id.EventID(expectedEventID), evtID, "Event ID mismatch for event #%d", i) + + // TODO allow redacted events + assert.True(t, evt.VerifyContentHash(), i) + + events[evtID] = evt + err = eventauth.Authorize(*roomVersion, evt, events.Get, GetKey) + if err != nil { + evt.InternalMeta.Rejected = true + } + // TODO allow testing intentionally rejected events + assert.NoErrorf(t, err, "Failed to authorize event #%d / %s of type %s", i, evtID, evt.Type) + } + }) + } + +} diff --git a/federation/eventauth/testroom-v12-success.jsonl b/federation/eventauth/testroom-v12-success.jsonl new file mode 100644 index 00000000..2b751de3 --- /dev/null +++ b/federation/eventauth/testroom-v12-success.jsonl @@ -0,0 +1,21 @@ +{"auth_events":[],"content":{"room_version":"12"},"depth":1,"hashes":{"sha256":"qJYytb+EqWPiiZ0ogDODcLeA8XYw/2hVTaLHihcVBZQ"},"origin_server_ts":1756071567186,"prev_events":[],"sender":"@tulir:maunium.net","signatures":{"maunium.net":{"ed25519:a_xxeS":"/9pp+2tkLo6XcZ3opqLeIpa3D96fh3QLpR2PQrZ6Z6j7wyRAvBrcgCpAeMtuyDCzW8Wh1QFEPG4FSsGvVaEFBg"}},"state_key":"","type":"m.room.create","unsigned":{"age_ts":1756071567186,"event_id":"$lVEL38waGAf4ggmWC3OVk_bbx8kZx-iOcTBKXTBnM54"}} +{"auth_events":[],"content":{"avatar_url":"mxc://maunium.net/jdlSfvudiMSmcRrleeiYjjFO","displayname":"tulir","membership":"join"},"depth":2,"hashes":{"sha256":"MXmgq0e4J9CdIP0IVKVvueFhOb+ndlsXpeyI+6l/2FI"},"origin_server_ts":1756071567259,"prev_events":["$lVEL38waGAf4ggmWC3OVk_bbx8kZx-iOcTBKXTBnM54"],"room_id":"!lVEL38waGAf4ggmWC3OVk_bbx8kZx-iOcTBKXTBnM54","sender":"@tulir:maunium.net","signatures":{"maunium.net":{"ed25519:a_xxeS":"xMgRzyRg9VM9XCKpfFJA+MrYoI68b8PIddKpMTcxz/fDzmGSHEy6Ta2b59VxiX3NoJe2CigkDZ3+jVsQoZYIBA"}},"state_key":"@tulir:maunium.net","type":"m.room.member","unsigned":{"age_ts":1756071567259,"event_id":"$mmqm2KS4UExkNL65c6CIhKofn_L9fzF2OhghVqajksU"}} +{"auth_events":["$mmqm2KS4UExkNL65c6CIhKofn_L9fzF2OhghVqajksU"],"content":{"ban":50,"events":{"m.room.avatar":50,"m.room.canonical_alias":50,"m.room.encryption":100,"m.room.history_visibility":100,"m.room.name":50,"m.room.power_levels":100,"m.room.server_acl":100,"m.room.tombstone":150},"events_default":0,"historical":100,"invite":0,"kick":50,"redact":50,"state_default":50,"users":{"@tulir:envs.net":9001},"users_default":0},"depth":3,"hashes":{"sha256":"/JzQNBNqJ/i8vwj6xESDaD5EDdOqB4l/LmKlvAVl5jY"},"origin_server_ts":1756071567319,"prev_events":["$mmqm2KS4UExkNL65c6CIhKofn_L9fzF2OhghVqajksU"],"room_id":"!lVEL38waGAf4ggmWC3OVk_bbx8kZx-iOcTBKXTBnM54","sender":"@tulir:maunium.net","signatures":{"maunium.net":{"ed25519:a_xxeS":"W3N3X/enja+lumXw3uz66/wT9oczoxrmHbAD5/RF069cX4wkCtqtDd61VWPkSGmKxdV1jurgbCqSX6+Q9/t3AA"}},"state_key":"","type":"m.room.power_levels","unsigned":{"age_ts":1756071567319,"event_id":"$v3gylw64IK4PohOe0M8XO1PZthibpBCKVBI3x_8xiUU"}} +{"auth_events":["$v3gylw64IK4PohOe0M8XO1PZthibpBCKVBI3x_8xiUU","$mmqm2KS4UExkNL65c6CIhKofn_L9fzF2OhghVqajksU"],"content":{"join_rule":"invite"},"depth":4,"hashes":{"sha256":"GBu5AySj75ZXlOLd65mB03KueFKOHNgvtg2o/LUnLyI"},"origin_server_ts":1756071567320,"prev_events":["$v3gylw64IK4PohOe0M8XO1PZthibpBCKVBI3x_8xiUU"],"room_id":"!lVEL38waGAf4ggmWC3OVk_bbx8kZx-iOcTBKXTBnM54","sender":"@tulir:maunium.net","signatures":{"maunium.net":{"ed25519:a_xxeS":"XqWEnFREo2PhRnaebGjNzdHdtD691BtCQKkLnpKd8P3lVDewDt8OkCbDSk/Uzh9rDtzwWEsbsIoKSYuOm+G6CA"}},"state_key":"","type":"m.room.join_rules","unsigned":{"age_ts":1756071567320,"event_id":"$deNVGs6Ef7OKVrvewhtPv7DCCqSip112cEJYp-jkP6M"}} +{"auth_events":["$v3gylw64IK4PohOe0M8XO1PZthibpBCKVBI3x_8xiUU","$mmqm2KS4UExkNL65c6CIhKofn_L9fzF2OhghVqajksU"],"content":{"history_visibility":"shared"},"depth":5,"hashes":{"sha256":"niDi5vG2akQm0f5pm0aoCYXqmWjXRfmP1ulr/ZEPm/k"},"origin_server_ts":1756071567320,"prev_events":["$deNVGs6Ef7OKVrvewhtPv7DCCqSip112cEJYp-jkP6M"],"room_id":"!lVEL38waGAf4ggmWC3OVk_bbx8kZx-iOcTBKXTBnM54","sender":"@tulir:maunium.net","signatures":{"maunium.net":{"ed25519:a_xxeS":"PTIrNke/fc9+ObKAl/K0PGZfmpe8dwREyoA5rXffOXWdRHSaBifn9UIiJUqd68Bzvrv4RcADTR/ci7lUquFBBw"}},"state_key":"","type":"m.room.history_visibility","unsigned":{"age_ts":1756071567320,"event_id":"$Wmy3G9yxl9ArVg5ZsdeIDPxBsNAdgseuvHoqHTZ2vug"}} +{"auth_events":["$v3gylw64IK4PohOe0M8XO1PZthibpBCKVBI3x_8xiUU","$mmqm2KS4UExkNL65c6CIhKofn_L9fzF2OhghVqajksU"],"content":{"guest_access":"can_join"},"depth":6,"hashes":{"sha256":"sZ9QqsId4oarFF724esTohXuRxDNnaXPl+QmTDG60dw"},"origin_server_ts":1756071567321,"prev_events":["$Wmy3G9yxl9ArVg5ZsdeIDPxBsNAdgseuvHoqHTZ2vug"],"room_id":"!lVEL38waGAf4ggmWC3OVk_bbx8kZx-iOcTBKXTBnM54","sender":"@tulir:maunium.net","signatures":{"maunium.net":{"ed25519:a_xxeS":"Eh2P9/hl38wfZx2AQbeS5VCD4wldXPfeP2sQsJsLtfmdwFV74jrlGVBaKIkaYcXY4eA08iDp8HW5jqttZqKKDg"}},"state_key":"","type":"m.room.guest_access","unsigned":{"age_ts":1756071567321,"event_id":"$hYVRH7F4P5mB5IqvBDDU5aXY7pYGG0ApstrryiVPKmQ"}} +{"auth_events":["$v3gylw64IK4PohOe0M8XO1PZthibpBCKVBI3x_8xiUU","$mmqm2KS4UExkNL65c6CIhKofn_L9fzF2OhghVqajksU"],"content":{"name":"event auth test v12"},"depth":7,"hashes":{"sha256":"tjwPo38yR+23Was6SbxLvPMhNx44DaXLhF3rKgngepU"},"origin_server_ts":1756071567321,"prev_events":["$hYVRH7F4P5mB5IqvBDDU5aXY7pYGG0ApstrryiVPKmQ"],"room_id":"!lVEL38waGAf4ggmWC3OVk_bbx8kZx-iOcTBKXTBnM54","sender":"@tulir:maunium.net","signatures":{"maunium.net":{"ed25519:a_xxeS":"q1rk0c5m8TJYE9tePsMaLeaigatNNbvaLRom0X8KiZY0EH+itujfA+/UnksvmPmMmThfAXWlFLx5u8tcuSVyCQ"}},"state_key":"","type":"m.room.name","unsigned":{"age_ts":1756071567321,"event_id":"$fFDwIavLTEIfcnggWuryB6JwfS-L2KT6vP1ap3P6ctE"}} +{"auth_events":["$v3gylw64IK4PohOe0M8XO1PZthibpBCKVBI3x_8xiUU","$mmqm2KS4UExkNL65c6CIhKofn_L9fzF2OhghVqajksU","$deNVGs6Ef7OKVrvewhtPv7DCCqSip112cEJYp-jkP6M"],"content":{"avatar_url":"mxc://envs.net/000cf1510b7c61018f9c72ca4cc63668370782c81725865933316030464","displayname":"tulir[e]","membership":"invite"},"depth":8,"hashes":{"sha256":"r5EBUZN/4LbVcMYwuffDcVV9G4OMHzAQuNbnjigL+OE"},"origin_server_ts":1756071567548,"prev_events":["$fFDwIavLTEIfcnggWuryB6JwfS-L2KT6vP1ap3P6ctE"],"room_id":"!lVEL38waGAf4ggmWC3OVk_bbx8kZx-iOcTBKXTBnM54","sender":"@tulir:maunium.net","signatures":{"envs.net":{"ed25519:wuJyKT":"svB+uW4Tsj8/I+SYbLl+LPPjBlqxGNXE4wGyAxlP7vfyJtFf7Kn/19jx65wT9ebeCq5sTGlEDV4Fabwma9LhDA"},"maunium.net":{"ed25519:a_xxeS":"LBYMcdJVSNsLd6SmOgx5oOU/0xOeCl03o4g83VwJfHWlRuTT5l9+qlpNED28wY07uxoU9MgLgXXICJ0EezMBCg"}},"state_key":"@tulir:envs.net","type":"m.room.member","unsigned":{"age_ts":1756071567548,"event_id":"$qYZqSKiKMCNjzH6Trhr6nBSvbfuwr8Sh2bC4USSAxok","invite_room_state":[{"content":{"join_rule":"invite"},"sender":"@tulir:maunium.net","state_key":"","type":"m.room.join_rules"},{"content":{"name":"event auth test v12"},"sender":"@tulir:maunium.net","state_key":"","type":"m.room.name"},{"auth_events":[],"content":{"room_version":"12"},"depth":1,"hashes":{"sha256":"qJYytb+EqWPiiZ0ogDODcLeA8XYw/2hVTaLHihcVBZQ"},"origin_server_ts":1756071567186,"prev_events":[],"sender":"@tulir:maunium.net","signatures":{"maunium.net":{"ed25519:a_xxeS":"/9pp+2tkLo6XcZ3opqLeIpa3D96fh3QLpR2PQrZ6Z6j7wyRAvBrcgCpAeMtuyDCzW8Wh1QFEPG4FSsGvVaEFBg"}},"state_key":"","type":"m.room.create","unsigned":{"age_ts":1756071567186}},{"content":{"avatar_url":"mxc://maunium.net/jdlSfvudiMSmcRrleeiYjjFO","displayname":"tulir","membership":"join"},"sender":"@tulir:maunium.net","state_key":"@tulir:maunium.net","type":"m.room.member"}]}} +{"auth_events":["$v3gylw64IK4PohOe0M8XO1PZthibpBCKVBI3x_8xiUU","$mmqm2KS4UExkNL65c6CIhKofn_L9fzF2OhghVqajksU"],"content":{"body":"meow","com.beeper.linkpreviews":[],"m.mentions":{},"msgtype":"m.text"},"depth":9,"hashes":{"sha256":"23rgMf7EGJcYt3Aj0qAFnmBWCxuU9Uk+ReidqtIJDKQ"},"origin_server_ts":1756071575986,"prev_events":["$qYZqSKiKMCNjzH6Trhr6nBSvbfuwr8Sh2bC4USSAxok"],"room_id":"!lVEL38waGAf4ggmWC3OVk_bbx8kZx-iOcTBKXTBnM54","sender":"@tulir:maunium.net","signatures":{"maunium.net":{"ed25519:a_xxeS":"p+Fm/uWO8VXJdCYvN/dVb8HF8W3t1sssNCBiOWbzAeuS3QqYjoMKHyixLuN1mOdnCyATv7SsHHmA4+cELRGdAA"}},"type":"m.room.message","unsigned":{"age_ts":1756071576002,"event_id":"$eZDCydRWSRnR5od0c7ahz2qSZQDHbl5g5PITT0OMC3E"}} +{"auth_events":["$qYZqSKiKMCNjzH6Trhr6nBSvbfuwr8Sh2bC4USSAxok","$v3gylw64IK4PohOe0M8XO1PZthibpBCKVBI3x_8xiUU","$deNVGs6Ef7OKVrvewhtPv7DCCqSip112cEJYp-jkP6M"],"content":{"avatar_url":"mxc://envs.net/000cf1510b7c61018f9c72ca4cc63668370782c81725865933316030464","displayname":"tulir[e]","membership":"join"},"depth":10,"hashes":{"sha256":"2kJPx2UsysNzTH8QGYHUKTO/05yetxKRlI0nKFeGbts"},"origin_server_ts":1756071578631,"prev_events":["$eZDCydRWSRnR5od0c7ahz2qSZQDHbl5g5PITT0OMC3E"],"room_id":"!lVEL38waGAf4ggmWC3OVk_bbx8kZx-iOcTBKXTBnM54","sender":"@tulir:envs.net","signatures":{"envs.net":{"ed25519:wuJyKT":"Wuzxkh8nEEX6mdJzph6Bt5ku+odFkEg2RIpFAAirOqxgcrwRaz42PsJni3YbfzH1qneF+iWQ/neA+up6jLXFBw"}},"state_key":"@tulir:envs.net","type":"m.room.member","unsigned":{"age":6,"event_id":"$Bz2lxsbUYkeBDE7eMAsOm_TK_iuSuHNvQdrHnc-T1PE","replaces_state":"$qYZqSKiKMCNjzH6Trhr6nBSvbfuwr8Sh2bC4USSAxok"}} +{"auth_events":["$Bz2lxsbUYkeBDE7eMAsOm_TK_iuSuHNvQdrHnc-T1PE","$deNVGs6Ef7OKVrvewhtPv7DCCqSip112cEJYp-jkP6M","$v3gylw64IK4PohOe0M8XO1PZthibpBCKVBI3x_8xiUU"],"content":{"avatar_url":"mxc://matrix.org/BDYVQFSLvZHMaKHDGiRkvhVg","displayname":"tulir[m]","membership":"invite"},"depth":11,"hashes":{"sha256":"dRE11R2hBfFalQ5tIJdyaElUIiSE5aCKMddjek4wR3c"},"origin_server_ts":1756071591449,"prev_events":["$Bz2lxsbUYkeBDE7eMAsOm_TK_iuSuHNvQdrHnc-T1PE"],"room_id":"!lVEL38waGAf4ggmWC3OVk_bbx8kZx-iOcTBKXTBnM54","sender":"@tulir:envs.net","signatures":{"envs.net":{"ed25519:wuJyKT":"/Mi4kX40fbR+V3DCJJGI/9L3Uuf8y5Un8LHlCQv1T0O5gnFZGQ3qN6rRNaZ1Kdh3QJBU6H4NTfnd+SVj3wt3CQ"},"matrix.org":{"ed25519:a_RXGa":"ZeLm/oxP3/Cds/uCL2FaZpgjUp0vTDBlGG6YVFNl76yIVlyIKKQKR6BSVw2u5KC5Mu9M1f+0lDmLGQujR5NkBg"}},"state_key":"@tulir:matrix.org","type":"m.room.member","unsigned":{"event_id":"$g4eBtA9EFNGLkHOofvQ4U87GNt4W8NmfmNRyR0wOUO4","invite_room_state":[{"content":{"join_rule":"invite"},"sender":"@tulir:maunium.net","state_key":"","type":"m.room.join_rules"},{"content":{"name":"event auth test v12"},"sender":"@tulir:maunium.net","state_key":"","type":"m.room.name"},{"auth_events":[],"content":{"room_version":"12"},"depth":1,"hashes":{"sha256":"qJYytb+EqWPiiZ0ogDODcLeA8XYw/2hVTaLHihcVBZQ"},"origin_server_ts":1756071567186,"prev_events":[],"sender":"@tulir:maunium.net","signatures":{"maunium.net":{"ed25519:a_xxeS":"/9pp+2tkLo6XcZ3opqLeIpa3D96fh3QLpR2PQrZ6Z6j7wyRAvBrcgCpAeMtuyDCzW8Wh1QFEPG4FSsGvVaEFBg"}},"state_key":"","type":"m.room.create","unsigned":{"age":11553}},{"content":{"avatar_url":"mxc://envs.net/000cf1510b7c61018f9c72ca4cc63668370782c81725865933316030464","displayname":"tulir[e]","membership":"join"},"sender":"@tulir:envs.net","state_key":"@tulir:envs.net","type":"m.room.member"}]}} +{"auth_events":["$g4eBtA9EFNGLkHOofvQ4U87GNt4W8NmfmNRyR0wOUO4","$deNVGs6Ef7OKVrvewhtPv7DCCqSip112cEJYp-jkP6M","$v3gylw64IK4PohOe0M8XO1PZthibpBCKVBI3x_8xiUU"],"content":{"avatar_url":"mxc://matrix.org/BDYVQFSLvZHMaKHDGiRkvhVg","displayname":"tulir[m]","membership":"join"},"depth":12,"hashes":{"sha256":"hR/fRIyFkxKnA1XNxIB+NKC0VR0vHs82EDgydhmmZXU"},"origin_server_ts":1756071609205,"prev_events":["$g4eBtA9EFNGLkHOofvQ4U87GNt4W8NmfmNRyR0wOUO4"],"room_id":"!lVEL38waGAf4ggmWC3OVk_bbx8kZx-iOcTBKXTBnM54","sender":"@tulir:matrix.org","signatures":{"matrix.org":{"ed25519:a_RXGa":"keWbZHm+LPW22XWxb14Att4Ae4GVc6XAKAnxFRr3hxhrgEhsnMcxUx7fjqlA1dk3As6kjLKdekcyCef+AQCXCA"}},"state_key":"@tulir:matrix.org","type":"m.room.member","unsigned":{"age":19,"event_id":"$_gYjNODWJdo5-S1IN0bmAk3rzIeXzr5W5cmXZSmUsNw","replaces_state":"$g4eBtA9EFNGLkHOofvQ4U87GNt4W8NmfmNRyR0wOUO4"}} +{"auth_events":["$Bz2lxsbUYkeBDE7eMAsOm_TK_iuSuHNvQdrHnc-T1PE","$v3gylw64IK4PohOe0M8XO1PZthibpBCKVBI3x_8xiUU"],"content":{"ban":50,"events":{"m.room.avatar":50,"m.room.canonical_alias":50,"m.room.encryption":100,"m.room.history_visibility":100,"m.room.name":50,"m.room.power_levels":100,"m.room.server_acl":100,"m.room.tombstone":150},"events_default":0,"historical":100,"invite":0,"kick":50,"redact":50,"state_default":50,"users":{"@tulir:envs.net":9001,"@tulir:matrix.org":9000},"users_default":0},"depth":13,"hashes":{"sha256":"30Wuw3xIbA8+eXQBa4nFDKcyHtMbKPBYhLW1zft9/fE"},"origin_server_ts":1756071643928,"prev_events":["$_gYjNODWJdo5-S1IN0bmAk3rzIeXzr5W5cmXZSmUsNw"],"room_id":"!lVEL38waGAf4ggmWC3OVk_bbx8kZx-iOcTBKXTBnM54","sender":"@tulir:envs.net","signatures":{"envs.net":{"ed25519:wuJyKT":"x6Y4uViq4nK8LVPqtMLdCuvNET2bnjxYTgiKuEe1JYfwB4jPBnPuqvrt1O9oaanMpcRWbnuiZjckq4bUlRZ7Cw"}},"state_key":"","type":"m.room.power_levels","unsigned":{"event_id":"$Qg1xRB8nL8lGykGvt9_agu_WCWq8Y3rl_p_LKa6D2Hg","replaces_state":"$v3gylw64IK4PohOe0M8XO1PZthibpBCKVBI3x_8xiUU"}} +{"auth_events":["$Qg1xRB8nL8lGykGvt9_agu_WCWq8Y3rl_p_LKa6D2Hg","$_gYjNODWJdo5-S1IN0bmAk3rzIeXzr5W5cmXZSmUsNw"],"content":{"name":"event auth test v12!"},"depth":14,"hashes":{"sha256":"WT0gz7KYXvbdNruRavqIi9Hhul3rxCdZ+YY9yMGN+Fw"},"origin_server_ts":1756071656988,"prev_events":["$Qg1xRB8nL8lGykGvt9_agu_WCWq8Y3rl_p_LKa6D2Hg"],"room_id":"!lVEL38waGAf4ggmWC3OVk_bbx8kZx-iOcTBKXTBnM54","sender":"@tulir:matrix.org","signatures":{"matrix.org":{"ed25519:a_RXGa":"bSplmqtXVhO2Z3hJ8JMQ/u7G2Wmg6yt7SwhYXObRQJfthekddJN152ME4YJIwy7YD8WFq7EkyB/NMyQoliYyCg"}},"state_key":"","type":"m.room.name","unsigned":{"event_id":"$p4xvOczrhzQMtRW3-Tf86LYUb5aqpGFIgjwHBuxWIcI","replaces_state":"$fFDwIavLTEIfcnggWuryB6JwfS-L2KT6vP1ap3P6ctE"}} +{"auth_events":["$Bz2lxsbUYkeBDE7eMAsOm_TK_iuSuHNvQdrHnc-T1PE","$Qg1xRB8nL8lGykGvt9_agu_WCWq8Y3rl_p_LKa6D2Hg"],"content":{"ban":50,"events":{"m.room.avatar":50,"m.room.canonical_alias":50,"m.room.encryption":100,"m.room.history_visibility":100,"m.room.name":50,"m.room.power_levels":100,"m.room.server_acl":100,"m.room.tombstone":9001},"events_default":0,"historical":12345,"invite":0,"kick":50,"redact":50,"state_default":50,"users":{"@tulir:envs.net":9001,"@tulir:matrix.org":9000},"users_default":0},"depth":15,"hashes":{"sha256":"FnGzbcXc8YOiB1TY33QunGA17Axoyuu3sdVOj5Z408o"},"origin_server_ts":1756071804931,"prev_events":["$p4xvOczrhzQMtRW3-Tf86LYUb5aqpGFIgjwHBuxWIcI"],"room_id":"!lVEL38waGAf4ggmWC3OVk_bbx8kZx-iOcTBKXTBnM54","sender":"@tulir:envs.net","signatures":{"envs.net":{"ed25519:wuJyKT":"uyTUsPR+CzCtlevzB5+sNXvmfbPSp6u7RZC4E4TLVsj45+pjmMRswAvuHP9PT2+Tkl6Hu8ZPigsXgbKZtR35Aw"}},"state_key":"","type":"m.room.power_levels","unsigned":{"event_id":"$uZ4OOtkM8RcbEkhjNp-YlEH0zBqgsRx1eI8b2YP7ovw","replaces_state":"$Qg1xRB8nL8lGykGvt9_agu_WCWq8Y3rl_p_LKa6D2Hg"}} +{"auth_events":["$Bz2lxsbUYkeBDE7eMAsOm_TK_iuSuHNvQdrHnc-T1PE","$uZ4OOtkM8RcbEkhjNp-YlEH0zBqgsRx1eI8b2YP7ovw"],"content":{"ban":50,"events":{"m.room.avatar":50,"m.room.canonical_alias":50,"m.room.encryption":100,"m.room.history_visibility":100,"m.room.name":50,"m.room.power_levels":100,"m.room.server_acl":100,"m.room.tombstone":100},"events_default":0,"historical":12345,"invite":0,"kick":50,"redact":50,"state_default":50,"users":{"@tulir:envs.net":9001,"@tulir:matrix.org":9000},"users_default":0},"depth":16,"hashes":{"sha256":"KcivsiLesdnUnKX23Akk3OJEJFGRSY0g4H+p7XIThnw"},"origin_server_ts":1756071812688,"prev_events":["$uZ4OOtkM8RcbEkhjNp-YlEH0zBqgsRx1eI8b2YP7ovw"],"room_id":"!lVEL38waGAf4ggmWC3OVk_bbx8kZx-iOcTBKXTBnM54","sender":"@tulir:envs.net","signatures":{"envs.net":{"ed25519:wuJyKT":"cAK8dO2AVZklY9te5aVKbF1jR/eB5rzeNOXfYPjBLf+aSAS4Z6R2aMKW6hJB9PqRS4S+UZc24DTrjUjnvMzeBA"}},"state_key":"","type":"m.room.power_levels","unsigned":{"event_id":"$iwqRXQc2cx8K4AclTjU1Se-BMJpUl4DxrLm3nfUgeQU","replaces_state":"$uZ4OOtkM8RcbEkhjNp-YlEH0zBqgsRx1eI8b2YP7ovw"}} +{"auth_events":["$Bz2lxsbUYkeBDE7eMAsOm_TK_iuSuHNvQdrHnc-T1PE","$iwqRXQc2cx8K4AclTjU1Se-BMJpUl4DxrLm3nfUgeQU"],"content":{"body":"meow #2","com.beeper.linkpreviews":[],"m.mentions":{},"msgtype":"m.text"},"depth":17,"hashes":{"sha256":"SgH9fOXGdbdqpRfYmoz1t29+gX8Ze4ThSoj6klZs3Og"},"origin_server_ts":1756247476706,"prev_events":["$iwqRXQc2cx8K4AclTjU1Se-BMJpUl4DxrLm3nfUgeQU"],"room_id":"!lVEL38waGAf4ggmWC3OVk_bbx8kZx-iOcTBKXTBnM54","sender":"@tulir:envs.net","signatures":{"envs.net":{"ed25519:wuJyKT":"SMYK7zP3SaQOKhzZUKUBVCKwffYqi3PFAlPM34kRJtmfGU3KZXNBT0zi+veXDMmxkMunqhF2RTHBD6joa0kBAQ"}},"type":"m.room.message","unsigned":{"event_id":"$KFHLO0-ENYOGQXogp84C-ISSu1xtKUzIMaZ6LiBcR_w"}} +{"auth_events":["$_gYjNODWJdo5-S1IN0bmAk3rzIeXzr5W5cmXZSmUsNw","$iwqRXQc2cx8K4AclTjU1Se-BMJpUl4DxrLm3nfUgeQU"],"content":{"ban":50,"events":{"m.room.avatar":50,"m.room.canonical_alias":50,"m.room.encryption":100,"m.room.history_visibility":100,"m.room.name":50,"m.room.power_levels":100,"m.room.server_acl":100,"m.room.tombstone":100},"events_default":0,"historical":12345,"invite":0,"kick":50,"redact":50,"state_default":50,"users":{"@tulir:beeper.com":8999,"@tulir:envs.net":9001,"@tulir:matrix.org":9000},"users_default":0},"depth":18,"hashes":{"sha256":"l8Mw3VKn/Bvntg7bZ8uh5J8M2IBZM93Xg7hsdaSci8s"},"origin_server_ts":1758918656341,"prev_events":["$KFHLO0-ENYOGQXogp84C-ISSu1xtKUzIMaZ6LiBcR_w"],"room_id":"!lVEL38waGAf4ggmWC3OVk_bbx8kZx-iOcTBKXTBnM54","sender":"@tulir:matrix.org","signatures":{"matrix.org":{"ed25519:a_RXGa":"cg5LP0WuTnVB5jFhNERLLU5b+EhmyACiOq6cp3gKJnZsTAb1yajcgJybLWKrc8QQqxPa7hPnskRBgt4OBTFNAA"}},"state_key":"","type":"m.room.power_levels","unsigned":{"event_id":"$x-CCUewbWOHQXqfcUsywOmbHvNnOSwNM1RyOu-c8SB0","replaces_state":"$iwqRXQc2cx8K4AclTjU1Se-BMJpUl4DxrLm3nfUgeQU"}} +{"auth_events":["$deNVGs6Ef7OKVrvewhtPv7DCCqSip112cEJYp-jkP6M","$_gYjNODWJdo5-S1IN0bmAk3rzIeXzr5W5cmXZSmUsNw","$x-CCUewbWOHQXqfcUsywOmbHvNnOSwNM1RyOu-c8SB0"],"content":{"avatar_url":"mxc://beeper.com/eBdwbHbllONoAySQkXLjbfFM","displayname":"tulir[b]","membership":"invite"},"depth":19,"hashes":{"sha256":"KpmaRUQnJju8TIDMPzakitUIKOWJxTvULpFB3a1CGgc"},"origin_server_ts":1758918665952,"prev_events":["$x-CCUewbWOHQXqfcUsywOmbHvNnOSwNM1RyOu-c8SB0"],"room_id":"!lVEL38waGAf4ggmWC3OVk_bbx8kZx-iOcTBKXTBnM54","sender":"@tulir:matrix.org","signatures":{"beeper.com":{"ed25519:a_zgvp":"mzI9rPkQ1xHl2/G5Yrn0qmIRt5OyjPNqRwilPfH4jmr1tP+vv3vC0m4mph/MCOq8S1c/DQaCWSpdOX1uWfchBQ"},"matrix.org":{"ed25519:a_RXGa":"kEdfr8DjxC/bdvGYxnniFI/pxDWeyG73OjG/Gu1uoHLhjdtAT/vEQ6lotJJs214/KX5eAaQWobE9qtMvtPwMDw"}},"state_key":"@tulir:beeper.com","type":"m.room.member","unsigned":{"event_id":"$PZJZoUwNySl0jY16DkHBHR0HyAppLdxc0rkSuYp5Mro","invite_room_state":[{"auth_events":[],"content":{"room_version":"12"},"depth":1,"hashes":{"sha256":"qJYytb+EqWPiiZ0ogDODcLeA8XYw/2hVTaLHihcVBZQ"},"origin_server_ts":1756071567186,"prev_events":[],"sender":"@tulir:maunium.net","signatures":{"maunium.net":{"ed25519:a_xxeS":"/9pp+2tkLo6XcZ3opqLeIpa3D96fh3QLpR2PQrZ6Z6j7wyRAvBrcgCpAeMtuyDCzW8Wh1QFEPG4FSsGvVaEFBg"}},"state_key":"","type":"m.room.create","unsigned":{"age":11553}},{"content":{"avatar_url":"mxc://matrix.org/BDYVQFSLvZHMaKHDGiRkvhVg","displayname":"tulir[m]","membership":"join"},"sender":"@tulir:matrix.org","state_key":"@tulir:matrix.org","type":"m.room.member"},{"content":{"name":"event auth test v12!"},"sender":"@tulir:matrix.org","state_key":"","type":"m.room.name"},{"content":{"join_rule":"invite"},"sender":"@tulir:maunium.net","state_key":"","type":"m.room.join_rules"}]}} +{"auth_events":["$deNVGs6Ef7OKVrvewhtPv7DCCqSip112cEJYp-jkP6M","$PZJZoUwNySl0jY16DkHBHR0HyAppLdxc0rkSuYp5Mro","$x-CCUewbWOHQXqfcUsywOmbHvNnOSwNM1RyOu-c8SB0"],"content":{"avatar_url":"mxc://beeper.com/eBdwbHbllONoAySQkXLjbfFM","displayname":"tulir[b]","membership":"join"},"depth":20,"hashes":{"sha256":"bmaHSm4mYPNBNlUfFsauSTxLrUH4CUSAKYvr1v76qkk"},"origin_server_ts":1758918670276,"prev_events":["$PZJZoUwNySl0jY16DkHBHR0HyAppLdxc0rkSuYp5Mro"],"room_id":"!lVEL38waGAf4ggmWC3OVk_bbx8kZx-iOcTBKXTBnM54","sender":"@tulir:beeper.com","signatures":{"beeper.com":{"ed25519:a_zgvp":"D3cz3m15m89a3G4c5yWOBCjhtSeI5IxBfQKt5XOr9a44QHyc3nwjjvIJaRrKNcS5tLUJwZ2IpVzjlrpbPHpxDA"}},"state_key":"@tulir:beeper.com","type":"m.room.member","unsigned":{"age":6,"event_id":"$_hayW1Y0HRWp3VEGZZbsMf0Ncg9x6n0ikveD0lbCwMw","replaces_state":"$PZJZoUwNySl0jY16DkHBHR0HyAppLdxc0rkSuYp5Mro"}} +{"auth_events":["$x-CCUewbWOHQXqfcUsywOmbHvNnOSwNM1RyOu-c8SB0","$Bz2lxsbUYkeBDE7eMAsOm_TK_iuSuHNvQdrHnc-T1PE"],"content":{"ban":50,"events":{"m.room.avatar":50,"m.room.canonical_alias":50,"m.room.encryption":100,"m.room.history_visibility":100,"m.room.name":50,"m.room.power_levels":100,"m.room.server_acl":100,"m.room.tombstone":100},"events_default":0,"historical":12345,"invite":0,"kick":50,"redact":50,"state_default":50,"users":{"@tulir:beeper.com":9000,"@tulir:envs.net":9001,"@tulir:matrix.org":8999},"users_default":0},"depth":21,"hashes":{"sha256":"xCj9vszChHiXba9DaPzhtF79Tphek3pRViMp36DOurU"},"origin_server_ts":1758918689485,"prev_events":["$_hayW1Y0HRWp3VEGZZbsMf0Ncg9x6n0ikveD0lbCwMw"],"room_id":"!lVEL38waGAf4ggmWC3OVk_bbx8kZx-iOcTBKXTBnM54","sender":"@tulir:envs.net","signatures":{"envs.net":{"ed25519:wuJyKT":"odkrWD30+ObeYtagULtECB/QmGae7qNy66nmJMWYXiQMYUJw/GMzSmgAiLAWfVYlfD3aEvMb/CBdrhL07tfSBw"}},"state_key":"","type":"m.room.power_levels","unsigned":{"event_id":"$di6cI89-GxX8-Wbx-0T69l4wg6TUWITRkjWXzG7EBqo","replaces_state":"$x-CCUewbWOHQXqfcUsywOmbHvNnOSwNM1RyOu-c8SB0"}} diff --git a/federation/httpclient.go b/federation/httpclient.go index d6d97280..2f8dbb4f 100644 --- a/federation/httpclient.go +++ b/federation/httpclient.go @@ -12,7 +12,6 @@ import ( "net" "net/http" "sync" - "time" ) // ServerResolvingTransport is an http.RoundTripper that resolves Matrix server names before sending requests. @@ -22,17 +21,20 @@ type ServerResolvingTransport struct { Transport *http.Transport Dialer *net.Dialer - cache map[string]*ResolvedServerName - resolveLocks map[string]*sync.Mutex - cacheLock sync.Mutex + cache ResolutionCache + + resolveLocks map[string]*sync.Mutex + resolveLocksLock sync.Mutex } -func NewServerResolvingTransport() *ServerResolvingTransport { +func NewServerResolvingTransport(cache ResolutionCache) *ServerResolvingTransport { + if cache == nil { + cache = NewInMemoryCache() + } srt := &ServerResolvingTransport{ - cache: make(map[string]*ResolvedServerName), resolveLocks: make(map[string]*sync.Mutex), - - Dialer: &net.Dialer{}, + cache: cache, + Dialer: &net.Dialer{}, } srt.Transport = &http.Transport{ DialContext: srt.DialContext, @@ -50,12 +52,6 @@ func (srt *ServerResolvingTransport) DialContext(ctx context.Context, network, a return srt.Dialer.DialContext(ctx, network, addrs[0]) } -type contextKey int - -const ( - contextKeyIPPort contextKey = iota -) - func (srt *ServerResolvingTransport) RoundTrip(request *http.Request) (*http.Response, error) { if request.URL.Scheme != "matrix-federation" { return nil, fmt.Errorf("unsupported scheme: %s", request.URL.Scheme) @@ -72,37 +68,25 @@ func (srt *ServerResolvingTransport) RoundTrip(request *http.Request) (*http.Res } func (srt *ServerResolvingTransport) resolve(ctx context.Context, serverName string) (*ResolvedServerName, error) { - res, lock := srt.getResolveCache(serverName) - if res != nil { - return res, nil + srt.resolveLocksLock.Lock() + lock, ok := srt.resolveLocks[serverName] + if !ok { + lock = &sync.Mutex{} + srt.resolveLocks[serverName] = lock } + srt.resolveLocksLock.Unlock() + lock.Lock() defer lock.Unlock() - res, _ = srt.getResolveCache(serverName) - if res != nil { + res, err := srt.cache.LoadResolution(serverName) + if err != nil { + return nil, fmt.Errorf("failed to read cache: %w", err) + } else if res != nil { + return res, nil + } else if res, err = ResolveServerName(ctx, serverName, srt.ResolveOpts); err != nil { + return nil, err + } else { + srt.cache.StoreResolution(res) return res, nil } - var err error - res, err = ResolveServerName(ctx, serverName, srt.ResolveOpts) - if err != nil { - return nil, err - } - srt.cacheLock.Lock() - srt.cache[serverName] = res - srt.cacheLock.Unlock() - return res, nil -} - -func (srt *ServerResolvingTransport) getResolveCache(serverName string) (*ResolvedServerName, *sync.Mutex) { - srt.cacheLock.Lock() - defer srt.cacheLock.Unlock() - if val, ok := srt.cache[serverName]; ok && time.Until(val.Expires) > 0 { - return val, nil - } - rl, ok := srt.resolveLocks[serverName] - if !ok { - rl = &sync.Mutex{} - srt.resolveLocks[serverName] = rl - } - return nil, rl } diff --git a/federation/keyserver.go b/federation/keyserver.go index 3e74bfdf..d32ba5cf 100644 --- a/federation/keyserver.go +++ b/federation/keyserver.go @@ -1,4 +1,4 @@ -// Copyright (c) 2024 Tulir Asokan +// Copyright (c) 2025 Tulir Asokan // // This Source Code Form is subject to the terms of the Mozilla Public // License, v. 2.0. If a copy of the MPL was not distributed with this @@ -8,13 +8,17 @@ package federation import ( "encoding/json" - "fmt" "net/http" "strconv" "time" - "github.com/gorilla/mux" + "github.com/rs/zerolog" + "github.com/rs/zerolog/hlog" + "go.mau.fi/util/exerrors" + "go.mau.fi/util/exhttp" "go.mau.fi/util/jsontime" + "go.mau.fi/util/ptr" + "go.mau.fi/util/requestlog" "maunium.net/go/mautrix" "maunium.net/go/mautrix/id" @@ -47,34 +51,29 @@ type KeyServer struct { KeyProvider ServerKeyProvider Version ServerVersion WellKnownTarget string + OtherKeys KeyCache } // Register registers the key server endpoints to the given router. -func (ks *KeyServer) Register(r *mux.Router) { - r.HandleFunc("/.well-known/matrix/server", ks.GetWellKnown).Methods(http.MethodGet) - r.HandleFunc("/_matrix/federation/v1/version", ks.GetServerVersion).Methods(http.MethodGet) - keyRouter := r.PathPrefix("/_matrix/key").Subrouter() - keyRouter.HandleFunc("/v2/server", ks.GetServerKey).Methods(http.MethodGet) - keyRouter.HandleFunc("/v2/query/{serverName}", ks.GetQueryKeys).Methods(http.MethodGet) - keyRouter.HandleFunc("/v2/query", ks.PostQueryKeys).Methods(http.MethodPost) - keyRouter.NotFoundHandler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - jsonResponse(w, http.StatusNotFound, &mautrix.RespError{ - ErrCode: mautrix.MUnrecognized.ErrCode, - Err: "Unrecognized endpoint", - }) - }) - keyRouter.MethodNotAllowedHandler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - jsonResponse(w, http.StatusMethodNotAllowed, &mautrix.RespError{ - ErrCode: mautrix.MUnrecognized.ErrCode, - Err: "Invalid method for endpoint", - }) - }) -} - -func jsonResponse(w http.ResponseWriter, code int, data any) { - w.Header().Add("Content-Type", "application/json") - w.WriteHeader(code) - _ = json.NewEncoder(w).Encode(data) +func (ks *KeyServer) Register(r *http.ServeMux, log zerolog.Logger) { + r.HandleFunc("GET /.well-known/matrix/server", ks.GetWellKnown) + r.HandleFunc("GET /_matrix/federation/v1/version", ks.GetServerVersion) + keyRouter := http.NewServeMux() + keyRouter.HandleFunc("GET /v2/server", ks.GetServerKey) + keyRouter.HandleFunc("GET /v2/query/{serverName}", ks.GetQueryKeys) + keyRouter.HandleFunc("POST /v2/query", ks.PostQueryKeys) + errorBodies := exhttp.ErrorBodies{ + NotFound: exerrors.Must(ptr.Ptr(mautrix.MUnrecognized.WithMessage("Unrecognized endpoint")).MarshalJSON()), + MethodNotAllowed: exerrors.Must(ptr.Ptr(mautrix.MUnrecognized.WithMessage("Invalid method for endpoint")).MarshalJSON()), + } + r.Handle("/_matrix/key/", exhttp.ApplyMiddleware( + keyRouter, + exhttp.StripPrefix("/_matrix/key"), + hlog.NewHandler(log), + hlog.RequestIDHandler("request_id", "Request-Id"), + requestlog.AccessLogger(requestlog.Options{TrustXForwardedFor: true}), + exhttp.HandleErrors(errorBodies), + )) } // RespWellKnown is the response body for the `GET /.well-known/matrix/server` endpoint. @@ -87,12 +86,9 @@ type RespWellKnown struct { // https://spec.matrix.org/v1.9/server-server-api/#get_well-knownmatrixserver func (ks *KeyServer) GetWellKnown(w http.ResponseWriter, r *http.Request) { if ks.WellKnownTarget == "" { - jsonResponse(w, http.StatusNotFound, &mautrix.RespError{ - ErrCode: mautrix.MNotFound.ErrCode, - Err: "No well-known target set", - }) + mautrix.MNotFound.WithMessage("No well-known target set").Write(w) } else { - jsonResponse(w, http.StatusOK, &RespWellKnown{Server: ks.WellKnownTarget}) + exhttp.WriteJSONResponse(w, http.StatusOK, &RespWellKnown{Server: ks.WellKnownTarget}) } } @@ -105,7 +101,7 @@ type RespServerVersion struct { // // https://spec.matrix.org/v1.9/server-server-api/#get_matrixfederationv1version func (ks *KeyServer) GetServerVersion(w http.ResponseWriter, r *http.Request) { - jsonResponse(w, http.StatusOK, &RespServerVersion{Server: ks.Version}) + exhttp.WriteJSONResponse(w, http.StatusOK, &RespServerVersion{Server: ks.Version}) } // GetServerKey implements the `GET /_matrix/key/v2/server` endpoint. @@ -114,12 +110,9 @@ func (ks *KeyServer) GetServerVersion(w http.ResponseWriter, r *http.Request) { func (ks *KeyServer) GetServerKey(w http.ResponseWriter, r *http.Request) { domain, key := ks.KeyProvider.Get(r) if key == nil { - jsonResponse(w, http.StatusNotFound, &mautrix.RespError{ - ErrCode: mautrix.MNotFound.ErrCode, - Err: fmt.Sprintf("No signing key found for %q", r.Host), - }) + mautrix.MNotFound.WithMessage("No signing key found for %q", r.Host).Write(w) } else { - jsonResponse(w, http.StatusOK, key.GenerateKeyResponse(domain, nil)) + exhttp.WriteJSONResponse(w, http.StatusOK, key.GenerateKeyResponse(domain, nil)) } } @@ -144,10 +137,7 @@ func (ks *KeyServer) PostQueryKeys(w http.ResponseWriter, r *http.Request) { var req ReqQueryKeys err := json.NewDecoder(r.Body).Decode(&req) if err != nil { - jsonResponse(w, http.StatusBadRequest, &mautrix.RespError{ - ErrCode: mautrix.MBadJSON.ErrCode, - Err: fmt.Sprintf("failed to parse request: %v", err), - }) + mautrix.MBadJSON.WithMessage("failed to parse request: %v", err).Write(w) return } @@ -165,7 +155,7 @@ func (ks *KeyServer) PostQueryKeys(w http.ResponseWriter, r *http.Request) { } } } - jsonResponse(w, http.StatusOK, resp) + exhttp.WriteJSONResponse(w, http.StatusOK, resp) } // GetQueryKeysResponse is the response body for the `GET /_matrix/key/v2/query/{serverName}` endpoint @@ -177,27 +167,39 @@ type GetQueryKeysResponse struct { // // https://spec.matrix.org/v1.9/server-server-api/#get_matrixkeyv2queryservername func (ks *KeyServer) GetQueryKeys(w http.ResponseWriter, r *http.Request) { - serverName := mux.Vars(r)["serverName"] + serverName := r.PathValue("serverName") minimumValidUntilTSString := r.URL.Query().Get("minimum_valid_until_ts") minimumValidUntilTS, err := strconv.ParseInt(minimumValidUntilTSString, 10, 64) if err != nil && minimumValidUntilTSString != "" { - jsonResponse(w, http.StatusBadRequest, &mautrix.RespError{ - ErrCode: mautrix.MInvalidParam.ErrCode, - Err: fmt.Sprintf("failed to parse ?minimum_valid_until_ts: %v", err), - }) + mautrix.MInvalidParam.WithMessage("failed to parse ?minimum_valid_until_ts: %v", err).Write(w) return } else if time.UnixMilli(minimumValidUntilTS).After(time.Now().Add(24 * time.Hour)) { - jsonResponse(w, http.StatusBadRequest, &mautrix.RespError{ - ErrCode: mautrix.MInvalidParam.ErrCode, - Err: "minimum_valid_until_ts may not be more than 24 hours in the future", - }) + mautrix.MInvalidParam.WithMessage("minimum_valid_until_ts may not be more than 24 hours in the future").Write(w) return } resp := &GetQueryKeysResponse{ ServerKeys: []*ServerKeyResponse{}, } - if domain, key := ks.KeyProvider.Get(r); key != nil && domain == serverName { - resp.ServerKeys = append(resp.ServerKeys, key.GenerateKeyResponse(serverName, nil)) + domain, key := ks.KeyProvider.Get(r) + if domain == serverName { + if key != nil { + resp.ServerKeys = append(resp.ServerKeys, key.GenerateKeyResponse(serverName, nil)) + } + } else if ks.OtherKeys != nil { + otherKey, err := ks.OtherKeys.LoadKeys(serverName) + if err != nil { + mautrix.MUnknown.WithMessage("Failed to load keys from cache").Write(w) + return + } + if key != nil && domain != "" { + signature, err := key.SignJSON(otherKey) + if err == nil { + otherKey.Signatures[domain] = map[id.KeyID]string{ + key.ID: signature, + } + } + } + resp.ServerKeys = append(resp.ServerKeys, otherKey) } - jsonResponse(w, http.StatusOK, resp) + exhttp.WriteJSONResponse(w, http.StatusOK, resp) } diff --git a/federation/pdu/auth.go b/federation/pdu/auth.go new file mode 100644 index 00000000..16706fe5 --- /dev/null +++ b/federation/pdu/auth.go @@ -0,0 +1,71 @@ +// Copyright (c) 2025 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +//go:build goexperiment.jsonv2 + +package pdu + +import ( + "slices" + + "github.com/tidwall/gjson" + "go.mau.fi/util/exgjson" + + "maunium.net/go/mautrix/event" + "maunium.net/go/mautrix/id" +) + +type StateKey struct { + Type string + StateKey string +} + +var thirdPartyInviteTokenPath = exgjson.Path("third_party_invite", "signed", "token") + +type AuthEventSelection []StateKey + +func (aes *AuthEventSelection) Add(evtType, stateKey string) { + key := StateKey{Type: evtType, StateKey: stateKey} + if !aes.Has(key) { + *aes = append(*aes, key) + } +} + +func (aes *AuthEventSelection) Has(key StateKey) bool { + return slices.Contains(*aes, key) +} + +func (pdu *PDU) AuthEventSelection(roomVersion id.RoomVersion) (keys AuthEventSelection) { + if pdu.Type == event.StateCreate.Type && pdu.StateKey != nil { + return AuthEventSelection{} + } + keys = make(AuthEventSelection, 0, 3) + if !roomVersion.RoomIDIsCreateEventID() { + keys.Add(event.StateCreate.Type, "") + } + keys.Add(event.StatePowerLevels.Type, "") + keys.Add(event.StateMember.Type, pdu.Sender.String()) + if pdu.Type == event.StateMember.Type && pdu.StateKey != nil { + keys.Add(event.StateMember.Type, *pdu.StateKey) + membership := event.Membership(gjson.GetBytes(pdu.Content, "membership").Str) + if membership == event.MembershipJoin || membership == event.MembershipInvite || membership == event.MembershipKnock { + keys.Add(event.StateJoinRules.Type, "") + } + if membership == event.MembershipInvite { + thirdPartyInviteToken := gjson.GetBytes(pdu.Content, thirdPartyInviteTokenPath).Str + if thirdPartyInviteToken != "" { + keys.Add(event.StateThirdPartyInvite.Type, thirdPartyInviteToken) + } + } + if membership == event.MembershipJoin && roomVersion.RestrictedJoins() { + authorizedVia := gjson.GetBytes(pdu.Content, "authorised_via_users_server").Str + if authorizedVia != "" { + keys.Add(event.StateMember.Type, authorizedVia) + } + } + } + return +} diff --git a/federation/pdu/hash.go b/federation/pdu/hash.go new file mode 100644 index 00000000..38ef83e9 --- /dev/null +++ b/federation/pdu/hash.go @@ -0,0 +1,118 @@ +// Copyright (c) 2025 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +//go:build goexperiment.jsonv2 + +package pdu + +import ( + "crypto/hmac" + "crypto/sha256" + "encoding/base64" + "fmt" + + "github.com/tidwall/gjson" + + "maunium.net/go/mautrix/id" +) + +func (pdu *PDU) CalculateContentHash() ([32]byte, error) { + if pdu == nil { + return [32]byte{}, ErrPDUIsNil + } + pduClone := pdu.Clone() + pduClone.Signatures = nil + pduClone.Unsigned = nil + pduClone.Hashes = nil + rawJSON, err := marshalCanonical(pduClone) + if err != nil { + return [32]byte{}, fmt.Errorf("failed to marshal PDU to calculate content hash: %w", err) + } + return sha256.Sum256(rawJSON), nil +} + +func (pdu *PDU) FillContentHash() error { + if pdu == nil { + return ErrPDUIsNil + } else if pdu.Hashes != nil { + return nil + } else if hash, err := pdu.CalculateContentHash(); err != nil { + return err + } else { + pdu.Hashes = &Hashes{SHA256: hash[:]} + return nil + } +} + +func (pdu *PDU) VerifyContentHash() bool { + if pdu == nil || pdu.Hashes == nil { + return false + } + calculatedHash, err := pdu.CalculateContentHash() + if err != nil { + return false + } + return hmac.Equal(calculatedHash[:], pdu.Hashes.SHA256) +} + +func (pdu *PDU) GetRoomID() (id.RoomID, error) { + if pdu == nil { + return "", ErrPDUIsNil + } else if pdu.Type != "m.room.create" { + return "", fmt.Errorf("room ID can only be calculated for m.room.create events") + } else if roomVersion := id.RoomVersion(gjson.GetBytes(pdu.Content, "room_version").Str); !roomVersion.RoomIDIsCreateEventID() { + return "", fmt.Errorf("room version %s does not use m.room.create event ID as room ID", roomVersion) + } else if evtID, err := pdu.calculateEventID(roomVersion, '!'); err != nil { + return "", fmt.Errorf("failed to calculate event ID: %w", err) + } else { + return id.RoomID(evtID), nil + } +} + +var UseInternalMetaForGetEventID = false + +func (pdu *PDU) GetEventID(roomVersion id.RoomVersion) (id.EventID, error) { + if UseInternalMetaForGetEventID && pdu.InternalMeta.EventID != "" { + return pdu.InternalMeta.EventID, nil + } + return pdu.calculateEventID(roomVersion, '$') +} + +func (pdu *PDU) GetReferenceHash(roomVersion id.RoomVersion) ([32]byte, error) { + if pdu == nil { + return [32]byte{}, ErrPDUIsNil + } + if pdu.Hashes == nil || pdu.Hashes.SHA256 == nil { + if err := pdu.FillContentHash(); err != nil { + return [32]byte{}, err + } + } + rawJSON, err := marshalCanonical(pdu.Clone().RedactForSignature(roomVersion)) + if err != nil { + return [32]byte{}, fmt.Errorf("failed to marshal redacted PDU to calculate event ID: %w", err) + } + return sha256.Sum256(rawJSON), nil +} + +func (pdu *PDU) calculateEventID(roomVersion id.RoomVersion, prefix byte) (id.EventID, error) { + referenceHash, err := pdu.GetReferenceHash(roomVersion) + if err != nil { + return "", err + } + eventID := make([]byte, 44) + eventID[0] = prefix + switch roomVersion.EventIDFormat() { + case id.EventIDFormatCustom: + return "", fmt.Errorf("*pdu.PDU can only be used for room v3+") + case id.EventIDFormatBase64: + base64.RawStdEncoding.Encode(eventID[1:], referenceHash[:]) + case id.EventIDFormatURLSafeBase64: + base64.RawURLEncoding.Encode(eventID[1:], referenceHash[:]) + default: + return "", fmt.Errorf("unknown event ID format %v", roomVersion.EventIDFormat()) + } + return id.EventID(eventID), nil +} diff --git a/federation/pdu/hash_test.go b/federation/pdu/hash_test.go new file mode 100644 index 00000000..17417e12 --- /dev/null +++ b/federation/pdu/hash_test.go @@ -0,0 +1,55 @@ +// Copyright (c) 2025 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +//go:build goexperiment.jsonv2 + +package pdu_test + +import ( + "encoding/base64" + "testing" + + "github.com/stretchr/testify/assert" + "go.mau.fi/util/exerrors" +) + +func TestPDU_CalculateContentHash(t *testing.T) { + for _, test := range testPDUs { + if test.redacted { + continue + } + t.Run(test.name, func(t *testing.T) { + parsed := parsePDU(test.pdu) + contentHash := exerrors.Must(parsed.CalculateContentHash()) + assert.Equal( + t, + base64.RawStdEncoding.EncodeToString(parsed.Hashes.SHA256), + base64.RawStdEncoding.EncodeToString(contentHash[:]), + ) + }) + } +} + +func TestPDU_VerifyContentHash(t *testing.T) { + for _, test := range testPDUs { + if test.redacted { + continue + } + t.Run(test.name, func(t *testing.T) { + parsed := parsePDU(test.pdu) + assert.True(t, parsed.VerifyContentHash()) + }) + } +} + +func TestPDU_GetEventID(t *testing.T) { + for _, test := range testPDUs { + t.Run(test.name, func(t *testing.T) { + gotEventID := exerrors.Must(parsePDU(test.pdu).GetEventID(test.roomVersion)) + assert.Equal(t, test.eventID, gotEventID) + }) + } +} diff --git a/federation/pdu/pdu.go b/federation/pdu/pdu.go new file mode 100644 index 00000000..17db6995 --- /dev/null +++ b/federation/pdu/pdu.go @@ -0,0 +1,156 @@ +// Copyright (c) 2025 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +//go:build goexperiment.jsonv2 + +package pdu + +import ( + "bytes" + "crypto/ed25519" + "encoding/json/jsontext" + "encoding/json/v2" + "errors" + "fmt" + "strings" + "time" + + "github.com/tidwall/gjson" + "go.mau.fi/util/jsonbytes" + "go.mau.fi/util/ptr" + + "maunium.net/go/mautrix/crypto/canonicaljson" + "maunium.net/go/mautrix/event" + "maunium.net/go/mautrix/id" +) + +// GetKeyFunc is a callback for retrieving the key corresponding to a given key ID when verifying the signature of a PDU. +// +// The input time is the timestamp of the event. The function should attempt to fetch a key that is +// valid at or after this time, but if that is not possible, the latest available key should be +// returned without an error. The verify function will do its own validity checking based on the +// returned valid until timestamp. +type GetKeyFunc = func(serverName string, keyID id.KeyID, minValidUntil time.Time) (key id.SigningKey, validUntil time.Time, err error) + +type AnyPDU interface { + GetRoomID() (id.RoomID, error) + GetEventID(roomVersion id.RoomVersion) (id.EventID, error) + GetReferenceHash(roomVersion id.RoomVersion) ([32]byte, error) + CalculateContentHash() ([32]byte, error) + FillContentHash() error + VerifyContentHash() bool + Sign(roomVersion id.RoomVersion, serverName string, keyID id.KeyID, privateKey ed25519.PrivateKey) error + VerifySignature(roomVersion id.RoomVersion, serverName string, getKey GetKeyFunc) error + ToClientEvent(roomVersion id.RoomVersion) (*event.Event, error) + AuthEventSelection(roomVersion id.RoomVersion) (keys AuthEventSelection) +} + +var ( + _ AnyPDU = (*PDU)(nil) + _ AnyPDU = (*RoomV1PDU)(nil) +) + +type InternalMeta struct { + EventID id.EventID `json:"event_id,omitempty"` + Rejected bool `json:"rejected,omitempty"` + Extra map[string]any `json:",unknown"` +} + +type PDU struct { + AuthEvents []id.EventID `json:"auth_events"` + Content jsontext.Value `json:"content"` + Depth int64 `json:"depth"` + Hashes *Hashes `json:"hashes,omitzero"` + OriginServerTS int64 `json:"origin_server_ts"` + PrevEvents []id.EventID `json:"prev_events"` + Redacts *id.EventID `json:"redacts,omitzero"` + RoomID id.RoomID `json:"room_id,omitzero"` // not present for room v12+ create events + Sender id.UserID `json:"sender"` + Signatures map[string]map[id.KeyID]string `json:"signatures,omitzero"` + StateKey *string `json:"state_key,omitzero"` + Type string `json:"type"` + Unsigned jsontext.Value `json:"unsigned,omitzero"` + InternalMeta InternalMeta `json:"-"` + + Unknown jsontext.Value `json:",unknown"` + + // Deprecated legacy fields + DeprecatedPrevState jsontext.Value `json:"prev_state,omitzero"` + DeprecatedOrigin jsontext.Value `json:"origin,omitzero"` + DeprecatedMembership jsontext.Value `json:"membership,omitzero"` +} + +var ErrPDUIsNil = errors.New("PDU is nil") + +type Hashes struct { + SHA256 jsonbytes.UnpaddedBytes `json:"sha256"` + + Unknown jsontext.Value `json:",unknown"` +} + +func (pdu *PDU) ToClientEvent(roomVersion id.RoomVersion) (*event.Event, error) { + if pdu.Type == "m.room.create" && roomVersion == "" { + roomVersion = id.RoomVersion(gjson.GetBytes(pdu.Content, "room_version").Str) + } + evtType := event.Type{Type: pdu.Type, Class: event.MessageEventType} + if pdu.StateKey != nil { + evtType.Class = event.StateEventType + } + eventID, err := pdu.GetEventID(roomVersion) + if err != nil { + return nil, err + } + roomID := pdu.RoomID + if pdu.Type == "m.room.create" && roomVersion.RoomIDIsCreateEventID() { + roomID = id.RoomID(strings.Replace(string(eventID), "$", "!", 1)) + } + evt := &event.Event{ + StateKey: pdu.StateKey, + Sender: pdu.Sender, + Type: evtType, + Timestamp: pdu.OriginServerTS, + ID: eventID, + RoomID: roomID, + Redacts: ptr.Val(pdu.Redacts), + } + err = json.Unmarshal(pdu.Content, &evt.Content) + if err != nil { + return nil, fmt.Errorf("failed to unmarshal content: %w", err) + } + return evt, nil +} + +func (pdu *PDU) AddSignature(serverName string, keyID id.KeyID, signature string) { + if signature == "" { + return + } + if pdu.Signatures == nil { + pdu.Signatures = make(map[string]map[id.KeyID]string) + } + if _, ok := pdu.Signatures[serverName]; !ok { + pdu.Signatures[serverName] = make(map[id.KeyID]string) + } + pdu.Signatures[serverName][keyID] = signature +} + +func marshalCanonical(data any) (jsontext.Value, error) { + marshaledBytes, err := json.Marshal(data) + if err != nil { + return nil, err + } + marshaled := jsontext.Value(marshaledBytes) + err = marshaled.Canonicalize() + if err != nil { + return nil, err + } + check := canonicaljson.CanonicalJSONAssumeValid(marshaled) + if !bytes.Equal(marshaled, check) { + fmt.Println(string(marshaled)) + fmt.Println(string(check)) + return nil, fmt.Errorf("canonical JSON mismatch for %s", string(marshaled)) + } + return marshaled, nil +} diff --git a/federation/pdu/pdu_test.go b/federation/pdu/pdu_test.go new file mode 100644 index 00000000..59d7c3a6 --- /dev/null +++ b/federation/pdu/pdu_test.go @@ -0,0 +1,193 @@ +// Copyright (c) 2025 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +//go:build goexperiment.jsonv2 + +package pdu_test + +import ( + "encoding/json/v2" + "time" + + "go.mau.fi/util/exerrors" + + "maunium.net/go/mautrix/federation/pdu" + "maunium.net/go/mautrix/id" +) + +type serverKey struct { + key id.SigningKey + validUntilTS time.Time +} + +type serverDetails struct { + serverName string + keys map[id.KeyID]serverKey +} + +func (sd serverDetails) getKey(serverName string, keyID id.KeyID, _ time.Time) (id.SigningKey, time.Time, error) { + if serverName != sd.serverName { + return "", time.Time{}, nil + } + key, ok := sd.keys[keyID] + if ok { + return key.key, key.validUntilTS, nil + } + return "", time.Time{}, nil +} + +var mauniumNet = serverDetails{ + serverName: "maunium.net", + keys: map[id.KeyID]serverKey{ + "ed25519:a_xxeS": { + key: "lVt/CC3tv74OH6xTph2JrUmeRj/j+1q0HVa0Xf4QlCg", + validUntilTS: time.Now(), + }, + }, +} +var envsNet = serverDetails{ + serverName: "envs.net", + keys: map[id.KeyID]serverKey{ + "ed25519:a_zIqy": { + key: "vCUcZpt9hUn0aabfh/9GP/6sZvXcydww8DUstPHdJm0", + validUntilTS: time.UnixMilli(1722360538068), + }, + "ed25519:wuJyKT": { + key: "xbE1QssgomL4wCSlyMYF5/7KxVyM4HPwAbNa+nFFnx0", + validUntilTS: time.Now(), + }, + }, +} +var matrixOrg = serverDetails{ + serverName: "matrix.org", + keys: map[id.KeyID]serverKey{ + "ed25519:auto": { + key: "Noi6WqcDj0QmPxCNQqgezwTlBKrfqehY1u2FyWP9uYw", + validUntilTS: time.UnixMilli(1576767829750), + }, + "ed25519:a_RXGa": { + key: "l8Hft5qXKn1vfHrg3p4+W8gELQVo8N13JkluMfmn2sQ", + validUntilTS: time.Now(), + }, + }, +} +var continuwuityOrg = serverDetails{ + serverName: "continuwuity.org", + keys: map[id.KeyID]serverKey{ + "ed25519:PwHlNsFu": { + key: "8eNx2s0zWW+heKAmOH5zKv/nCPkEpraDJfGHxDu6hFI", + validUntilTS: time.Now(), + }, + }, +} +var novaAstraltechOrg = serverDetails{ + serverName: "nova.astraltech.org", + keys: map[id.KeyID]serverKey{ + "ed25519:a_afpo": { + key: "O1Y9GWuKo9xkuzuQef6gROxtTgxxAbS3WPNghPYXF3o", + validUntilTS: time.Now(), + }, + }, +} + +type testPDU struct { + name string + pdu string + eventID id.EventID + roomVersion id.RoomVersion + redacted bool + serverDetails +} + +var roomV4MessageTestPDU = testPDU{ + name: "m.room.message in v4 room", + pdu: `{"auth_events":["$OB87jNemaIVDHAfu0-pa_cP7OPFXUXCbFpjYVi8gll4","$RaWbTF9wQfGQgUpe1S13wzICtGTB2PNKRHUNHu9IO1c","$ZmEWOXw6cC4Rd1wTdY5OzeLJVzjhrkxFPwwKE4gguGk"],"content":{"body":"the last one is saying it shouldn't have effects","com.beeper.linkpreviews":[],"m.mentions":{},"msgtype":"m.text"},"depth":13103,"hashes":{"sha256":"c2wb8qMlvzIPCP1Wd+eYZ4BRgnGYxS97dR1UlJjVMeg"},"origin_server_ts":1752875275263,"prev_events":["$-7_BMI3BXwj3ayoxiJvraJxYWTKwjiQ6sh7CW_Brvj0"],"room_id":"!JiiOHXrIUCtcOJsZCa:matrix.org","sender":"@tulir:maunium.net","type":"m.room.message","signatures":{"maunium.net":{"ed25519:a_xxeS":"99TAqHpBkUEtgCraXsVXogmf/hnijPbgbG9eACtA+mbix3Y6gURI4QGQgcX/NhcE3pJQZ/YDjmbuvCnKvEccAA"}},"unsigned":{"age_ts":1752875275281}}`, + eventID: "$Jo_lmFR-e6lzrimzCA7DevIn2OwhuQYmd9xkcJBoqAA", + roomVersion: id.RoomV4, + serverDetails: mauniumNet, +} + +var roomV12MessageTestPDU = testPDU{ + name: "m.room.message in v12 room", + pdu: `{"auth_events":["$gCzdJUVV93Qory0x7p_PLG5UUiDjPJNe1H12qbHTuFA","$hyeL_nU_L3tsZ2dtZZpAHk0Skv-PqFQIipuII_By584"],"content":{"body":"meow","com.beeper.linkpreviews":[],"m.mentions":{},"msgtype":"m.text"},"depth":122,"hashes":{"sha256":"IQ0zlc+PXeEs6R3JvRkW3xTPV3zlGKSSd3x07KXGjzs"},"origin_server_ts":1755384351627,"prev_events":["$gCzdJUVV93Qory0x7p_PLG5UUiDjPJNe1H12qbHTuFA"],"room_id":"!mauT12AzsoqxV7Abvy_ApA-HNPK1LcT4GbP70_AOPyQ","sender":"@tulir_test:maunium.net","type":"m.room.message","signatures":{"maunium.net":{"ed25519:a_xxeS":"0GDMddL2k7gF4V1VU8sL3wTfhAIzAu5iVH5jeavZ2VEg3J9/tHLWXAOn2tzkLaMRWl0/XpINT2YlH/rd2U21Ag"}},"unsigned":{"age_ts":1755384351627}}`, + eventID: "$xmP-wZfpannuHG-Akogi6c4YvqxChMtdyYbUMGOrMWc", + roomVersion: id.RoomV12, + serverDetails: mauniumNet, +} + +var testPDUs = []testPDU{roomV4MessageTestPDU, { + name: "m.room.message in v5 room", + pdu: `{"auth_events":["$hp0ImHqYgHTRbLeWKPeTeFmxdb5SdMJN9cfmTrTk7d0","$KAj7X7tnJbR9qYYMWJSw-1g414_KlPptbbkZm7_kUtg","$V-2ShOwZYhA_nxMijaf3lqFgIJgzE2UMeFPtOLnoBYM"],"content":{"body":"meow","com.beeper.linkpreviews":[],"m.mentions":{},"msgtype":"m.text"},"depth":2248,"hashes":{"sha256":"kV+JuLbWXJ2r6PjHT3wt8bFc/TfI1nTaSN3Lamg/xHs"},"origin_server_ts":1755422945654,"prev_events":["$49lFLem2Nk4dxHk9RDXxTdaq9InIJpmkHpzVnjKcYwg"],"room_id":"!vzBgJsjNzgHSdWsmki:mozilla.org","sender":"@tulir:maunium.net","type":"m.room.message","signatures":{"maunium.net":{"ed25519:a_xxeS":"JIl60uVgfCLBZLPoSiE7wVkJ9U5cNEPVPuv1sCCYUOq5yOW56WD1adgpBUdX2UFpYkCHvkRnyQGxU0+6HBp5BA"}},"unsigned":{"age_ts":1755422945673}}`, + eventID: "$Qn4tHfuAe6PlnKXPZnygAU9wd6RXqMKtt_ZzstHTSgA", + roomVersion: id.RoomV5, + serverDetails: mauniumNet, +}, { + name: "m.room.message in v10 room", + pdu: `{"auth_events":["$--ilpwnsHaEdHrwiMrZNu5xHP6TthWG0FIXMHnlHCcs","$tn1FZUI_YUpfTr_a3Y_r8kC3inliIZZratzg0UsNdCQ","$Z-qMWmiMvm-aIEffcfSO6lN7TyjyTOsIcHIymfzoo20"],"content":{"body":"meow","com.beeper.linkpreviews":[],"m.mentions":{},"msgtype":"m.text"},"depth":100885,"hashes":{"sha256":"jc9272JPpPIVreJC3UEAm3BNVnLX8sm3U/TZs23wsHo"},"origin_server_ts":1755422792518,"prev_events":["$HDtbzpSys36Hk-F2NsiXfp9slsGXBH0b58qyddj_q5E"],"room_id":"!UzZHbJYcgggctGnlzr:envs.net","sender":"@tulir:maunium.net","type":"m.room.message","signatures":{"maunium.net":{"ed25519:a_xxeS":"sAMLo9jPtNB0Jq67IQm06siEBx82qZa2edu56IDQ4tDylEV4Mq7iFO23gCghqXA7B/MqBsjXotGBxv6AvlJ2Dw"}},"unsigned":{"age_ts":1755422792540}}`, + eventID: "$4ZFr_ypfp4DyZQP4zyxM_cvuOMFkl07doJmwi106YFY", + roomVersion: id.RoomV10, + serverDetails: mauniumNet, +}, { + name: "m.room.message in v11 room", + pdu: `{"auth_events":["$L8Ak6A939llTRIsZrytMlLDXQhI4uLEjx-wb1zSg-Bw","$QJmr7mmGeXGD4Tof0ZYSPW2oRGklseyHTKtZXnF-YNM","$7bkKK_Z-cGQ6Ae4HXWGBwXyZi3YjC6rIcQzGfVyl3Eo"],"content":{"body":"meow","com.beeper.linkpreviews":[],"m.mentions":{},"msgtype":"m.text"},"depth":3212,"hashes":{"sha256":"K549YdTnv62Jn84Y7sS5ZN3+AdmhleZHbenbhUpR2R8"},"origin_server_ts":1754242687127,"prev_events":["$DAhJg4jVsqk5FRatE2hbT1dSA8D2ASy5DbjEHIMSHwY"],"room_id":"!offtopic-2:continuwuity.org","sender":"@tulir:maunium.net","type":"m.room.message","signatures":{"maunium.net":{"ed25519:a_xxeS":"SkzZdZ+rH22kzCBBIAErTdB0Vg6vkFmzvwjlOarGul72EnufgtE/tJcd3a8szAdK7f1ZovRyQxDgVm/Ib2u0Aw"}},"unsigned":{"age_ts":1754242687146}}`, + eventID: `$qkWfTL7_l3oRZO2CItW8-Q0yAmi_l_1ua629ZDqponE`, + roomVersion: id.RoomV11, + serverDetails: mauniumNet, +}, roomV12MessageTestPDU, { + name: "m.room.create in v4 room", + pdu: `{"auth_events": [], "prev_events": [], "type": "m.room.create", "room_id": "!jxlRxnrZCsjpjDubDX:matrix.org", "sender": "@neilj:matrix.org", "content": {"room_version": "4", "predecessor": {"room_id": "!DYgXKezaHgMbiPMzjX:matrix.org", "event_id": "$156171636353XwPJT:matrix.org"}, "creator": "@neilj:matrix.org"}, "depth": 1, "prev_state": [], "state_key": "", "origin": "matrix.org", "origin_server_ts": 1561716363993, "hashes": {"sha256": "9tj8GpXjTAJvdNAbnuKLemZZk+Tjv2LAbGodSX6nJAo"}, "signatures": {"matrix.org": {"ed25519:auto": "2+sNt8uJUhzU4GPxnFVYtU2ZRgFdtVLT1vEZGUdJYN40zBpwYEGJy+kyb5matA+8/yLeYD9gu1O98lhleH0aCA"}}, "unsigned": {"age": 104769}}`, + eventID: "$ay_9_nPilrTpb3UxIwHHBBfFjTJb6hBAE_JzQwSjqeY", + roomVersion: id.RoomV4, + serverDetails: matrixOrg, +}, { + name: "m.room.create in v10 room", + pdu: `{"auth_events":[],"content":{"creator":"@creme:envs.net","predecessor":{"event_id":"$BxYNisKcyBDhPLiVC06t18qhv7wsT72MzMCqn5vRhfY","room_id":"!tEyFYiMHhwJlDXTxwf:envs.net"},"room_version":"10"},"depth":1,"hashes":{"sha256":"us3TrsIjBWpwbm+k3F9fUVnz9GIuhnb+LcaY47fWwUI"},"origin":"envs.net","origin_server_ts":1664394769527,"prev_events":[],"room_id":"!UzZHbJYcgggctGnlzr:envs.net","sender":"@creme:envs.net","state_key":"","type":"m.room.create","signatures":{"envs.net":{"ed25519:a_zIqy":"0g3FDaD1e5BekJYW2sR7dgxuKoZshrf8P067c9+jmH6frsWr2Ua86Ax08CFa/n46L8uvV2SGofP8iiVYgXCRBg"}},"unsigned":{"age":2060}}`, + eventID: "$tn1FZUI_YUpfTr_a3Y_r8kC3inliIZZratzg0UsNdCQ", + roomVersion: id.RoomV10, + serverDetails: envsNet, +}, { + name: "m.room.create in v12 room", + pdu: `{"auth_events":[],"content":{"fi.mau.randomness":"AAXZ6aIc","predecessor":{"room_id":"!#test/room\nversion 11, with @\ud83d\udc08\ufe0f:maunium.net"},"room_version":"12"},"depth":1,"hashes":{"sha256":"d3L1M3KUdyIKWcShyW6grUoJ8GOjCdSIEvQrDVHSpE8"},"origin_server_ts":1754940000000,"prev_events":[],"sender":"@tulir:maunium.net","state_key":"","type":"m.room.create","signatures":{"maunium.net":{"ed25519:a_xxeS":"ebjIRpzToc82cjb/RGY+VUzZic0yeRZrjctgx0SUTJxkprXn3/i1KdiYULfl/aD0cUJ5eL8gLakOSk2glm+sBw"}},"unsigned":{"age_ts":1754939139045}}`, + eventID: "$mauT12AzsoqxV7Abvy_ApA-HNPK1LcT4GbP70_AOPyQ", + roomVersion: id.RoomV12, + serverDetails: mauniumNet, +}, { + name: "m.room.member in v4 room", + pdu: `{"auth_events":["$ay_9_nPilrTpb3UxIwHHBBfFjTJb6hBAE_JzQwSjqeY","$jg2AgCfnwnjR-osoyM0lVYS21QrtfmZxhGO90PRkmO4","$wMGMP4Ucij2_d4h_fVDgIT2xooLZAgMcBruT9oo3Jio","$yyDgV8w0_e8qslmn0nh9OeSq_fO0zjpjTjSEdKFxDso"],"prev_events":["$zSjNuTXhUe3Rq6NpKD3sNyl8a_asMnBhGC5IbacHlJ4"],"type":"m.room.member","room_id":"!jxlRxnrZCsjpjDubDX:matrix.org","sender":"@tulir:maunium.net","content":{"membership":"join","displayname":"tulir","avatar_url":"mxc://maunium.net/jdlSfvudiMSmcRrleeiYjjFO","clicked \"send membership event with no changes\"":true},"depth":14370,"prev_state":[],"state_key":"@tulir:maunium.net","origin":"maunium.net","origin_server_ts":1600871136259,"hashes":{"sha256":"Ga6bG9Mk0887ruzM9TAAfa1O3DbNssb+qSFtE9oeRL4"},"signatures":{"maunium.net":{"ed25519:a_xxeS":"fzOyDG3G3pEzixtWPttkRA1DfnHETiKbiG8SEBQe2qycQbZWPky7xX8WujSrUJH/+bxTABpQwEH49d+RakxtBw"}},"unsigned":{"age_ts":1600871136259,"replaces_state":"$jg2AgCfnwnjR-osoyM0lVYS21QrtfmZxhGO90PRkmO4"}}`, + eventID: "$VtuCNOfAWGow-cxy0ajeK3fvONcC8QzF2yWa43g0Gwo", + roomVersion: id.RoomV4, + serverDetails: mauniumNet, +}, { + name: "m.room.member in v10 room", + pdu: `{"auth_events":["$HQC4hWaioLKVbMH94qKbfb3UnL4ocql2vi-VdUYI48I","$R9FUDgNAp9ms7b6ASunZOIkpqmsIRq_ROrNEznu62fs","$kEPF8Aj87EzRmFPriu2zdyEY0rY15XSqywTYVLUUlCA","$tn1FZUI_YUpfTr_a3Y_r8kC3inliIZZratzg0UsNdCQ"],"content":{"avatar_url":"mxc://maunium.net/jdlSfvudiMSmcRrleeiYjjFO","displayname":"tulir","membership":"join"},"depth":182,"hashes":{"sha256":"0HscBc921QV2dxK2qY7qrnyoAgfxBM7kKvqAXlEk+GE"},"origin":"maunium.net","origin_server_ts":1665402609039,"prev_events":["$R9FUDgNAp9ms7b6ASunZOIkpqmsIRq_ROrNEznu62fs"],"room_id":"!UzZHbJYcgggctGnlzr:envs.net","sender":"@tulir:maunium.net","state_key":"@tulir:maunium.net","type":"m.room.member","signatures":{"maunium.net":{"ed25519:a_xxeS":"lkOW0FSJ8MJ0wZpdwLH1Uf6FSl2q9/u6KthRIlM0CwHDJG4sIZ9DrMA8BdU8L/PWoDS/CoDUlLanDh99SplgBw"}},"unsigned":{"age_ts":1665402609039,"replaces_state":"$R9FUDgNAp9ms7b6ASunZOIkpqmsIRq_ROrNEznu62fs"}}`, + eventID: "$--ilpwnsHaEdHrwiMrZNu5xHP6TthWG0FIXMHnlHCcs", + roomVersion: id.RoomV10, + serverDetails: mauniumNet, +}, { + name: "m.room.member of creator in v12 room", + pdu: `{"auth_events":[],"content":{"avatar_url":"mxc://maunium.net/jdlSfvudiMSmcRrleeiYjjFO","displayname":"tulir","membership":"join"},"depth":2,"hashes":{"sha256":"IebdOBYaaWYIx2zq/lkVCnjWIXTLk1g+vgFpJMgd2/E"},"origin_server_ts":1754939139117,"prev_events":["$mauT12AzsoqxV7Abvy_ApA-HNPK1LcT4GbP70_AOPyQ"],"room_id":"!mauT12AzsoqxV7Abvy_ApA-HNPK1LcT4GbP70_AOPyQ","sender":"@tulir:maunium.net","state_key":"@tulir:maunium.net","type":"m.room.member","signatures":{"maunium.net":{"ed25519:a_xxeS":"rFCgF2hmavdm6+P6/f7rmuOdoSOmELFaH3JdWjgBLZXS2z51Ma7fa2v2+BkAH1FvBo9FLhvEoFVM4WbNQLXtAA"}},"unsigned":{"age_ts":1754939139117}}`, + eventID: "$accqGxfvhBvMP4Sf6P7t3WgnaJK6UbonO2ZmwqSE5Sg", + roomVersion: id.RoomV12, + serverDetails: mauniumNet, +}, { + name: "custom message event in v4 room", + pdu: `{"auth_events":["$VtuCNOfAWGow-cxy0ajeK3fvONcC8QzF2yWa43g0Gwo","$ay_9_nPilrTpb3UxIwHHBBfFjTJb6hBAE_JzQwSjqeY","$Gau_XwziYsr-rt3SouhbKN14twgmbKjcZZc_hz-nOgU"],"content":{"\ud83d\udc08\ufe0f":true,"\ud83d\udc15\ufe0f":false},"depth":69645,"hashes":{"sha256":"VHtWyCt+15ZesNnStU3FOkxrjzHJYZfd3JUgO9JWe0s"},"origin_server_ts":1755423939146,"prev_events":["$exmp4cj0OKOFSxuqBYiOYwQi5j_0XRc78d6EavAkhy0"],"room_id":"!jxlRxnrZCsjpjDubDX:matrix.org","sender":"@tulir:maunium.net","type":"\ud83d\udc08\ufe0f","signatures":{"maunium.net":{"ed25519:a_xxeS":"wfmP1XN4JBkKVkqrQnwysyEUslXt8hQRFwN9NC9vJaIeDMd0OJ6uqCas75808DuG71p23fzqbzhRnHckst6FCQ"}},"unsigned":{"age_ts":1755423939164}}`, + eventID: "$kAagtZAIEeZaLVCUSl74tAxQbdKbE22GU7FM-iAJBc0", + roomVersion: id.RoomV4, + serverDetails: mauniumNet, +}, { + name: "redacted m.room.member event in v11 room with 2 signatures", + pdu: `{"auth_events":["$9f12-_stoY07BOTmyguE1QlqvghLBh9Rk6PWRLoZn_M","$IP8hyjBkIDREVadyv0fPCGAW9IXGNllaZyxqQwiY_tA","$7dN5J8EveliaPkX6_QSejl4GQtem4oieavgALMeWZyE"],"content":{"membership":"join"},"depth":96978,"hashes":{"sha256":"APYA/aj3u+P0EwNaEofuSIlfqY3cK3lBz6RkwHX+Zak"},"origin_server_ts":1755664164485,"prev_events":["$XBN9W5Ll8VEH3eYqJaemxCBTDdy0hZB0sWpmyoUp93c"],"room_id":"!main-1:continuwuity.org","sender":"@6a19abdd4766:nova.astraltech.org","state_key":"@6a19abdd4766:nova.astraltech.org","type":"m.room.member","signatures":{"continuwuity.org":{"ed25519:PwHlNsFu":"+b/Fp2vWnC+Z2lI3GnCu7ZHdo3iWNDZ2AJqMoU9owMtLBPMxs4dVIsJXvaFq0ryawsgwDwKZ7f4xaFUNARJSDg"},"nova.astraltech.org":{"ed25519:a_afpo":"pXIngyxKukCPR7WOIIy8FTZxQ5L2dLiou5Oc8XS4WyY4YzJuckQzOaToigLLZxamfbN/jXbO+XUizpRpYccDAA"}},"unsigned":{}}`, + eventID: "$r6d9m125YWG28-Tln47bWtm6Jlv4mcSUWJTHijBlXLQ", + roomVersion: id.RoomV11, + serverDetails: novaAstraltechOrg, + redacted: true, +}} + +func parsePDU(pdu string) (out *pdu.PDU) { + exerrors.PanicIfNotNil(json.Unmarshal([]byte(pdu), &out)) + return +} diff --git a/federation/pdu/redact.go b/federation/pdu/redact.go new file mode 100644 index 00000000..d7ee0c15 --- /dev/null +++ b/federation/pdu/redact.go @@ -0,0 +1,111 @@ +// Copyright (c) 2025 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +//go:build goexperiment.jsonv2 + +package pdu + +import ( + "encoding/json/jsontext" + + "github.com/tidwall/gjson" + "github.com/tidwall/sjson" + "go.mau.fi/util/exgjson" + "go.mau.fi/util/ptr" + + "maunium.net/go/mautrix/id" +) + +func filteredObject(object jsontext.Value, allowedPaths ...string) jsontext.Value { + filtered := jsontext.Value("{}") + var err error + for _, path := range allowedPaths { + res := gjson.GetBytes(object, path) + if res.Exists() { + var raw jsontext.Value + if res.Index > 0 { + raw = object[res.Index : res.Index+len(res.Raw)] + } else { + raw = jsontext.Value(res.Raw) + } + filtered, err = sjson.SetRawBytes(filtered, path, raw) + if err != nil { + panic(err) + } + } + } + return filtered +} + +func (pdu *PDU) Clone() *PDU { + return ptr.Clone(pdu) +} + +func (pdu *PDU) RedactForSignature(roomVersion id.RoomVersion) *PDU { + pdu.Signatures = nil + return pdu.Redact(roomVersion) +} + +var emptyObject = jsontext.Value("{}") + +func RedactContent(eventType string, content jsontext.Value, roomVersion id.RoomVersion) jsontext.Value { + switch eventType { + case "m.room.member": + allowedPaths := []string{"membership"} + if roomVersion.RestrictedJoinsFix() { + allowedPaths = append(allowedPaths, "join_authorised_via_users_server") + } + if roomVersion.UpdatedRedactionRules() { + allowedPaths = append(allowedPaths, exgjson.Path("third_party_invite", "signed")) + } + return filteredObject(content, allowedPaths...) + case "m.room.create": + if !roomVersion.UpdatedRedactionRules() { + return filteredObject(content, "creator") + } + return content + case "m.room.join_rules": + if roomVersion.RestrictedJoins() { + return filteredObject(content, "join_rule", "allow") + } + return filteredObject(content, "join_rule") + case "m.room.power_levels": + allowedKeys := []string{"ban", "events", "events_default", "kick", "redact", "state_default", "users", "users_default"} + if roomVersion.UpdatedRedactionRules() { + allowedKeys = append(allowedKeys, "invite") + } + return filteredObject(content, allowedKeys...) + case "m.room.history_visibility": + return filteredObject(content, "history_visibility") + case "m.room.redaction": + if roomVersion.RedactsInContent() { + return filteredObject(content, "redacts") + } + return emptyObject + case "m.room.aliases": + if roomVersion.SpecialCasedAliasesAuth() { + return filteredObject(content, "aliases") + } + return emptyObject + default: + return emptyObject + } +} + +func (pdu *PDU) Redact(roomVersion id.RoomVersion) *PDU { + pdu.Unknown = nil + pdu.Unsigned = nil + if roomVersion.UpdatedRedactionRules() { + pdu.DeprecatedPrevState = nil + pdu.DeprecatedOrigin = nil + pdu.DeprecatedMembership = nil + } + if pdu.Type != "m.room.redaction" || roomVersion.RedactsInContent() { + pdu.Redacts = nil + } + pdu.Content = RedactContent(pdu.Type, pdu.Content, roomVersion) + return pdu +} diff --git a/federation/pdu/signature.go b/federation/pdu/signature.go new file mode 100644 index 00000000..04e7c5ef --- /dev/null +++ b/federation/pdu/signature.go @@ -0,0 +1,60 @@ +// Copyright (c) 2025 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +//go:build goexperiment.jsonv2 + +package pdu + +import ( + "crypto/ed25519" + "encoding/base64" + "fmt" + "time" + + "maunium.net/go/mautrix/federation/signutil" + "maunium.net/go/mautrix/id" +) + +func (pdu *PDU) Sign(roomVersion id.RoomVersion, serverName string, keyID id.KeyID, privateKey ed25519.PrivateKey) error { + err := pdu.FillContentHash() + if err != nil { + return err + } + rawJSON, err := marshalCanonical(pdu.Clone().RedactForSignature(roomVersion)) + if err != nil { + return fmt.Errorf("failed to marshal redacted PDU to sign: %w", err) + } + signature := ed25519.Sign(privateKey, rawJSON) + pdu.AddSignature(serverName, keyID, base64.RawStdEncoding.EncodeToString(signature)) + return nil +} + +func (pdu *PDU) VerifySignature(roomVersion id.RoomVersion, serverName string, getKey GetKeyFunc) error { + rawJSON, err := marshalCanonical(pdu.Clone().RedactForSignature(roomVersion)) + if err != nil { + return fmt.Errorf("failed to marshal redacted PDU to verify signature: %w", err) + } + verified := false + for keyID, sig := range pdu.Signatures[serverName] { + originServerTS := time.UnixMilli(pdu.OriginServerTS) + key, validUntil, err := getKey(serverName, keyID, originServerTS) + if err != nil { + return fmt.Errorf("failed to get key %s for %s: %w", keyID, serverName, err) + } else if key == "" { + return fmt.Errorf("key %s not found for %s", keyID, serverName) + } else if validUntil.Before(originServerTS) && roomVersion.EnforceSigningKeyValidity() { + return fmt.Errorf("key %s for %s is only valid until %s, but event is from %s", keyID, serverName, validUntil, originServerTS) + } else if err = signutil.VerifyJSONRaw(key, sig, rawJSON); err != nil { + return fmt.Errorf("failed to verify signature from key %s: %w", keyID, err) + } else { + verified = true + } + } + if !verified { + return fmt.Errorf("no verifiable signatures found for server %s", serverName) + } + return nil +} diff --git a/federation/pdu/signature_test.go b/federation/pdu/signature_test.go new file mode 100644 index 00000000..01df5076 --- /dev/null +++ b/federation/pdu/signature_test.go @@ -0,0 +1,102 @@ +// Copyright (c) 2025 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +//go:build goexperiment.jsonv2 + +package pdu_test + +import ( + "crypto/ed25519" + "encoding/base64" + "encoding/json/jsontext" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "go.mau.fi/util/exerrors" + + "maunium.net/go/mautrix/federation/pdu" + "maunium.net/go/mautrix/id" +) + +func TestPDU_VerifySignature(t *testing.T) { + for _, test := range testPDUs { + t.Run(test.name, func(t *testing.T) { + parsed := parsePDU(test.pdu) + err := parsed.VerifySignature(test.roomVersion, test.serverName, test.getKey) + assert.NoError(t, err) + }) + } +} + +func TestPDU_VerifySignature_Fail_NoKey(t *testing.T) { + test := roomV12MessageTestPDU + parsed := parsePDU(test.pdu) + err := parsed.VerifySignature(test.roomVersion, test.serverName, func(serverName string, keyID id.KeyID, minValidUntil time.Time) (key id.SigningKey, validUntil time.Time, err error) { + return + }) + assert.Error(t, err) +} + +func TestPDU_VerifySignature_V4ExpiredKey(t *testing.T) { + test := roomV4MessageTestPDU + parsed := parsePDU(test.pdu) + err := parsed.VerifySignature(test.roomVersion, test.serverName, func(serverName string, keyID id.KeyID, minValidUntil time.Time) (key id.SigningKey, validUntil time.Time, err error) { + key = test.keys[keyID].key + validUntil = time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC) + return + }) + assert.NoError(t, err) +} + +func TestPDU_VerifySignature_V12ExpiredKey(t *testing.T) { + test := roomV12MessageTestPDU + parsed := parsePDU(test.pdu) + err := parsed.VerifySignature(test.roomVersion, test.serverName, func(serverName string, keyID id.KeyID, minValidUntil time.Time) (key id.SigningKey, validUntil time.Time, err error) { + key = test.keys[keyID].key + validUntil = time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC) + return + }) + assert.Error(t, err) +} + +func TestPDU_VerifySignature_V12InvalidSignature(t *testing.T) { + test := roomV12MessageTestPDU + parsed := parsePDU(test.pdu) + for _, sigs := range parsed.Signatures { + for key := range sigs { + sigs[key] = sigs[key][:len(sigs[key])-3] + "ABC" + } + } + err := parsed.VerifySignature(test.roomVersion, test.serverName, test.getKey) + assert.Error(t, err) +} + +func TestPDU_Sign(t *testing.T) { + pubKey, privKey := exerrors.Must2(ed25519.GenerateKey(nil)) + evt := &pdu.PDU{ + AuthEvents: []id.EventID{"$gCzdJUVV93Qory0x7p_PLG5UUiDjPJNe1H12qbHTuFA", "$hyeL_nU_L3tsZ2dtZZpAHk0Skv-PqFQIipuII_By584"}, + Content: jsontext.Value(`{"msgtype":"m.text","body":"Hello, world!"}`), + Depth: 123, + OriginServerTS: 1755384351627, + PrevEvents: []id.EventID{"$gCzdJUVV93Qory0x7p_PLG5UUiDjPJNe1H12qbHTuFA"}, + RoomID: "!mauT12AzsoqxV7Abvy_ApA-HNPK1LcT4GbP70_AOPyQ", + Sender: "@tulir:example.com", + Type: "m.room.message", + } + err := evt.Sign(id.RoomV12, "example.com", "ed25519:rand", privKey) + require.NoError(t, err) + err = evt.VerifySignature(id.RoomV11, "example.com", func(serverName string, keyID id.KeyID, minValidUntil time.Time) (key id.SigningKey, validUntil time.Time, err error) { + if serverName == "example.com" && keyID == "ed25519:rand" { + key = id.SigningKey(base64.RawStdEncoding.EncodeToString(pubKey)) + validUntil = time.Now() + } + return + }) + require.NoError(t, err) + +} diff --git a/federation/pdu/v1.go b/federation/pdu/v1.go new file mode 100644 index 00000000..9557f8ab --- /dev/null +++ b/federation/pdu/v1.go @@ -0,0 +1,277 @@ +// Copyright (c) 2025 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +//go:build goexperiment.jsonv2 + +package pdu + +import ( + "crypto/ed25519" + "crypto/hmac" + "crypto/sha256" + "encoding/base64" + "encoding/json/jsontext" + "encoding/json/v2" + "fmt" + "time" + + "github.com/tidwall/gjson" + "go.mau.fi/util/ptr" + + "maunium.net/go/mautrix/event" + "maunium.net/go/mautrix/federation/signutil" + "maunium.net/go/mautrix/id" +) + +type V1EventReference struct { + ID id.EventID + Hashes Hashes +} + +var ( + _ json.UnmarshalerFrom = (*V1EventReference)(nil) + _ json.MarshalerTo = (*V1EventReference)(nil) +) + +func (er *V1EventReference) MarshalJSONTo(enc *jsontext.Encoder) error { + return json.MarshalEncode(enc, []any{er.ID, er.Hashes}) +} + +func (er *V1EventReference) UnmarshalJSONFrom(dec *jsontext.Decoder) error { + var ref V1EventReference + var data []jsontext.Value + if err := json.UnmarshalDecode(dec, &data); err != nil { + return err + } else if len(data) != 2 { + return fmt.Errorf("V1EventReference.UnmarshalJSONFrom: expected array with 2 elements, got %d", len(data)) + } else if err = json.Unmarshal(data[0], &ref.ID); err != nil { + return fmt.Errorf("V1EventReference.UnmarshalJSONFrom: failed to unmarshal event ID: %w", err) + } else if err = json.Unmarshal(data[1], &ref.Hashes); err != nil { + return fmt.Errorf("V1EventReference.UnmarshalJSONFrom: failed to unmarshal hashes: %w", err) + } + *er = ref + return nil +} + +type RoomV1PDU struct { + AuthEvents []V1EventReference `json:"auth_events"` + Content jsontext.Value `json:"content"` + Depth int64 `json:"depth"` + EventID id.EventID `json:"event_id"` + Hashes *Hashes `json:"hashes,omitzero"` + OriginServerTS int64 `json:"origin_server_ts"` + PrevEvents []V1EventReference `json:"prev_events"` + Redacts *id.EventID `json:"redacts,omitzero"` + RoomID id.RoomID `json:"room_id"` + Sender id.UserID `json:"sender"` + Signatures map[string]map[id.KeyID]string `json:"signatures,omitzero"` + StateKey *string `json:"state_key,omitzero"` + Type string `json:"type"` + Unsigned jsontext.Value `json:"unsigned,omitzero"` + + Unknown jsontext.Value `json:",unknown"` + + // Deprecated legacy fields + DeprecatedPrevState jsontext.Value `json:"prev_state,omitzero"` + DeprecatedOrigin jsontext.Value `json:"origin,omitzero"` + DeprecatedMembership jsontext.Value `json:"membership,omitzero"` +} + +func (pdu *RoomV1PDU) GetRoomID() (id.RoomID, error) { + return pdu.RoomID, nil +} + +func (pdu *RoomV1PDU) GetEventID(roomVersion id.RoomVersion) (id.EventID, error) { + if !pdu.SupportsRoomVersion(roomVersion) { + return "", fmt.Errorf("RoomV1PDU.GetEventID: unsupported room version %s", roomVersion) + } + return pdu.EventID, nil +} + +func (pdu *RoomV1PDU) RedactForSignature(roomVersion id.RoomVersion) *RoomV1PDU { + pdu.Signatures = nil + return pdu.Redact(roomVersion) +} + +func (pdu *RoomV1PDU) Redact(roomVersion id.RoomVersion) *RoomV1PDU { + pdu.Unknown = nil + pdu.Unsigned = nil + if pdu.Type != "m.room.redaction" { + pdu.Redacts = nil + } + pdu.Content = RedactContent(pdu.Type, pdu.Content, roomVersion) + return pdu +} + +func (pdu *RoomV1PDU) GetReferenceHash(roomVersion id.RoomVersion) ([32]byte, error) { + if !pdu.SupportsRoomVersion(roomVersion) { + return [32]byte{}, fmt.Errorf("RoomV1PDU.GetReferenceHash: unsupported room version %s", roomVersion) + } + if pdu == nil { + return [32]byte{}, ErrPDUIsNil + } + if pdu.Hashes == nil || pdu.Hashes.SHA256 == nil { + if err := pdu.FillContentHash(); err != nil { + return [32]byte{}, err + } + } + rawJSON, err := marshalCanonical(pdu.Clone().RedactForSignature(roomVersion)) + if err != nil { + return [32]byte{}, fmt.Errorf("failed to marshal redacted PDU to calculate event ID: %w", err) + } + return sha256.Sum256(rawJSON), nil +} + +func (pdu *RoomV1PDU) CalculateContentHash() ([32]byte, error) { + if pdu == nil { + return [32]byte{}, ErrPDUIsNil + } + pduClone := pdu.Clone() + pduClone.Signatures = nil + pduClone.Unsigned = nil + pduClone.Hashes = nil + rawJSON, err := marshalCanonical(pduClone) + if err != nil { + return [32]byte{}, fmt.Errorf("failed to marshal PDU to calculate content hash: %w", err) + } + return sha256.Sum256(rawJSON), nil +} + +func (pdu *RoomV1PDU) FillContentHash() error { + if pdu == nil { + return ErrPDUIsNil + } else if pdu.Hashes != nil { + return nil + } else if hash, err := pdu.CalculateContentHash(); err != nil { + return err + } else { + pdu.Hashes = &Hashes{SHA256: hash[:]} + return nil + } +} + +func (pdu *RoomV1PDU) VerifyContentHash() bool { + if pdu == nil || pdu.Hashes == nil { + return false + } + calculatedHash, err := pdu.CalculateContentHash() + if err != nil { + return false + } + return hmac.Equal(calculatedHash[:], pdu.Hashes.SHA256) +} + +func (pdu *RoomV1PDU) Clone() *RoomV1PDU { + return ptr.Clone(pdu) +} + +func (pdu *RoomV1PDU) Sign(roomVersion id.RoomVersion, serverName string, keyID id.KeyID, privateKey ed25519.PrivateKey) error { + if !pdu.SupportsRoomVersion(roomVersion) { + return fmt.Errorf("RoomV1PDU.Sign: unsupported room version %s", roomVersion) + } + err := pdu.FillContentHash() + if err != nil { + return err + } + rawJSON, err := marshalCanonical(pdu.Clone().RedactForSignature(roomVersion)) + if err != nil { + return fmt.Errorf("failed to marshal redacted PDU to sign: %w", err) + } + signature := ed25519.Sign(privateKey, rawJSON) + if pdu.Signatures == nil { + pdu.Signatures = make(map[string]map[id.KeyID]string) + } + if _, ok := pdu.Signatures[serverName]; !ok { + pdu.Signatures[serverName] = make(map[id.KeyID]string) + } + pdu.Signatures[serverName][keyID] = base64.RawStdEncoding.EncodeToString(signature) + return nil +} + +func (pdu *RoomV1PDU) VerifySignature(roomVersion id.RoomVersion, serverName string, getKey GetKeyFunc) error { + if !pdu.SupportsRoomVersion(roomVersion) { + return fmt.Errorf("RoomV1PDU.VerifySignature: unsupported room version %s", roomVersion) + } + rawJSON, err := marshalCanonical(pdu.Clone().RedactForSignature(roomVersion)) + if err != nil { + return fmt.Errorf("failed to marshal redacted PDU to verify signature: %w", err) + } + verified := false + for keyID, sig := range pdu.Signatures[serverName] { + originServerTS := time.UnixMilli(pdu.OriginServerTS) + key, _, err := getKey(serverName, keyID, originServerTS) + if err != nil { + return fmt.Errorf("failed to get key %s for %s: %w", keyID, serverName, err) + } else if key == "" { + return fmt.Errorf("key %s not found for %s", keyID, serverName) + } else if err = signutil.VerifyJSONRaw(key, sig, rawJSON); err != nil { + return fmt.Errorf("failed to verify signature from key %s: %w", keyID, err) + } else { + verified = true + } + } + if !verified { + return fmt.Errorf("no verifiable signatures found for server %s", serverName) + } + return nil +} + +func (pdu *RoomV1PDU) SupportsRoomVersion(roomVersion id.RoomVersion) bool { + switch roomVersion { + case id.RoomV0, id.RoomV1, id.RoomV2: + return true + default: + return false + } +} + +func (pdu *RoomV1PDU) ToClientEvent(roomVersion id.RoomVersion) (*event.Event, error) { + if !pdu.SupportsRoomVersion(roomVersion) { + return nil, fmt.Errorf("RoomV1PDU.ToClientEvent: unsupported room version %s", roomVersion) + } + evtType := event.Type{Type: pdu.Type, Class: event.MessageEventType} + if pdu.StateKey != nil { + evtType.Class = event.StateEventType + } + evt := &event.Event{ + StateKey: pdu.StateKey, + Sender: pdu.Sender, + Type: evtType, + Timestamp: pdu.OriginServerTS, + ID: pdu.EventID, + RoomID: pdu.RoomID, + Redacts: ptr.Val(pdu.Redacts), + } + err := json.Unmarshal(pdu.Content, &evt.Content) + if err != nil { + return nil, fmt.Errorf("failed to unmarshal content: %w", err) + } + return evt, nil +} + +func (pdu *RoomV1PDU) AuthEventSelection(_ id.RoomVersion) (keys AuthEventSelection) { + if pdu.Type == event.StateCreate.Type && pdu.StateKey != nil { + return AuthEventSelection{} + } + keys = make(AuthEventSelection, 0, 3) + keys.Add(event.StateCreate.Type, "") + keys.Add(event.StatePowerLevels.Type, "") + keys.Add(event.StateMember.Type, pdu.Sender.String()) + if pdu.Type == event.StateMember.Type && pdu.StateKey != nil { + keys.Add(event.StateMember.Type, *pdu.StateKey) + membership := event.Membership(gjson.GetBytes(pdu.Content, "membership").Str) + if membership == event.MembershipJoin || membership == event.MembershipInvite || membership == event.MembershipKnock { + keys.Add(event.StateJoinRules.Type, "") + } + if membership == event.MembershipInvite { + thirdPartyInviteToken := gjson.GetBytes(pdu.Content, thirdPartyInviteTokenPath).Str + if thirdPartyInviteToken != "" { + keys.Add(event.StateThirdPartyInvite.Type, thirdPartyInviteToken) + } + } + } + return +} diff --git a/federation/pdu/v1_test.go b/federation/pdu/v1_test.go new file mode 100644 index 00000000..ecf2dbd2 --- /dev/null +++ b/federation/pdu/v1_test.go @@ -0,0 +1,86 @@ +// Copyright (c) 2025 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +//go:build goexperiment.jsonv2 + +package pdu_test + +import ( + "encoding/base64" + "encoding/json/v2" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "go.mau.fi/util/exerrors" + + "maunium.net/go/mautrix/federation/pdu" + "maunium.net/go/mautrix/id" +) + +var testV1PDUs = []testPDU{{ + name: "m.room.message in v1 room", + pdu: `{"auth_events":[["$159234730483190eXavq:matrix.org",{"sha256":"VprZrhMqOQyKbfF3UE26JXE8D27ih4R/FGGc8GZ0Whs"}],["$143454825711DhCxH:matrix.org",{"sha256":"3sJh/5GOB094OKuhbjL634Gt69YIcge9GD55ciJa9ok"}],["$156837651426789wiPdh:maunium.net",{"sha256":"FGyR3sxJ/VxYabDkO/5qtwrPR3hLwGknJ0KX0w3GUHE"}]],"content":{"body":"photo-1526336024174-e58f5cdd8e13.jpg","info":{"h":1620,"mimetype":"image/jpeg","size":208053,"w":1080},"msgtype":"m.image","url":"mxc://maunium.net/aEqEghIjFPAerIhCxJCYpQeC"},"depth":16669,"event_id":"$16738169022163bokdi:maunium.net","hashes":{"sha256":"XYB47Gf2vAci3BTguIJaC75ZYGMuVY65jcvoUVgpcLA"},"origin":"maunium.net","origin_server_ts":1673816902100,"prev_events":[["$1673816901121325UMCjA:matrix.org",{"sha256":"t7e0IYHLI3ydIPoIU8a8E/pIWXH9cNLlQBEtGyGtHwc"}]],"room_id":"!jhpZBTbckszblMYjMK:matrix.org","sender":"@cat:maunium.net","type":"m.room.message","signatures":{"maunium.net":{"ed25519:a_xxeS":"uRZbEm+P+Y1ZVgwBn5I6SlaUZdzlH1bB4nv81yt5EIQ0b1fZ8YgM4UWMijrrXp3+NmqRFl0cakSM3MneJOtFCw"}},"unsigned":{"age_ts":1673816902100}}`, + eventID: "$16738169022163bokdi:maunium.net", + roomVersion: id.RoomV1, + serverDetails: mauniumNet, +}, { + name: "m.room.create in v1 room", + pdu: `{"origin": "matrix.org", "signatures": {"matrix.org": {"ed25519:auto": "XTejpXn5REoHrZWgCpJglGX7MfOWS2zUjYwJRLrwW2PQPbFdqtL+JnprBXwIP2C1NmgWSKG+am1QdApu0KoHCQ"}}, "origin_server_ts": 1434548257426, "sender": "@appservice-irc:matrix.org", "event_id": "$143454825711DhCxH:matrix.org", "prev_events": [], "unsigned": {"age": 12872287834}, "state_key": "", "content": {"creator": "@appservice-irc:matrix.org"}, "depth": 1, "prev_state": [], "room_id": "!jhpZBTbckszblMYjMK:matrix.org", "auth_events": [], "hashes": {"sha256": "+SSdmeeoKI/6yK6sY4XAFljWFiugSlCiXQf0QMCZjTs"}, "type": "m.room.create"}`, + eventID: "$143454825711DhCxH:matrix.org", + roomVersion: id.RoomV1, + serverDetails: matrixOrg, +}, { + name: "m.room.member in v1 room", + pdu: `{"auth_events": [["$1536447669931522zlyWe:matrix.org", {"sha256": "UkzPGd7cPAGvC0FVx3Yy2/Q0GZhA2kcgj8MGp5pjYV8"}], ["$143454825711DhCxH:matrix.org", {"sha256": "3sJh/5GOB094OKuhbjL634Gt69YIcge9GD55ciJa9ok"}], ["$143454825714nUEqZ:matrix.org", {"sha256": "NjuZXu8EDMfIfejPcNlC/IdnKQAGpPIcQjHaf0BZaHk"}]], "prev_events": [["$15660585503271JRRMm:maunium.net", {"sha256": "/Sm7uSLkYMHapp6I3NuEVJlk2JucW2HqjsQy9vzhciA"}]], "type": "m.room.member", "room_id": "!jhpZBTbckszblMYjMK:matrix.org", "sender": "@tulir:maunium.net", "content": {"membership": "join", "avatar_url": "mxc://maunium.net/jdlSfvudiMSmcRrleeiYjjFO", "displayname": "tulir"}, "depth": 10485, "prev_state": [], "state_key": "@tulir:maunium.net", "event_id": "$15660585693272iEryv:maunium.net", "origin": "maunium.net", "origin_server_ts": 1566058569201, "hashes": {"sha256": "1D6fdDzKsMGCxSqlXPA7I9wGQNTutVuJke1enGHoWK8"}, "signatures": {"maunium.net": {"ed25519:a_xxeS": "Lj/zDK6ozr4vgsxyL8jY56wTGWoA4jnlvkTs5paCX1w3nNKHnQnSMi+wuaqI6yv5vYh9usGWco2LLMuMzYXcBg"}}, "unsigned": {"age_ts": 1566058569201, "replaces_state": "$15660585383268liyBc:maunium.net"}}`, + eventID: "$15660585693272iEryv:maunium.net", + roomVersion: id.RoomV1, + serverDetails: mauniumNet, +}} + +func parseV1PDU(pdu string) (out *pdu.RoomV1PDU) { + exerrors.PanicIfNotNil(json.Unmarshal([]byte(pdu), &out)) + return +} + +func TestRoomV1PDU_CalculateContentHash(t *testing.T) { + for _, test := range testV1PDUs { + t.Run(test.name, func(t *testing.T) { + parsed := parseV1PDU(test.pdu) + contentHash := exerrors.Must(parsed.CalculateContentHash()) + assert.Equal( + t, + base64.RawStdEncoding.EncodeToString(parsed.Hashes.SHA256), + base64.RawStdEncoding.EncodeToString(contentHash[:]), + ) + }) + } +} + +func TestRoomV1PDU_VerifyContentHash(t *testing.T) { + for _, test := range testV1PDUs { + t.Run(test.name, func(t *testing.T) { + parsed := parseV1PDU(test.pdu) + assert.True(t, parsed.VerifyContentHash()) + }) + } +} + +func TestRoomV1PDU_VerifySignature(t *testing.T) { + for _, test := range testV1PDUs { + t.Run(test.name, func(t *testing.T) { + parsed := parseV1PDU(test.pdu) + err := parsed.VerifySignature(test.roomVersion, test.serverName, func(serverName string, keyID id.KeyID, _ time.Time) (id.SigningKey, time.Time, error) { + key, ok := test.keys[keyID] + if ok { + return key.key, key.validUntilTS, nil + } + return "", time.Time{}, nil + }) + assert.NoError(t, err) + }) + } +} diff --git a/federation/resolution.go b/federation/resolution.go index 24085282..a3188266 100644 --- a/federation/resolution.go +++ b/federation/resolution.go @@ -20,6 +20,8 @@ import ( "time" "github.com/rs/zerolog" + + "maunium.net/go/mautrix" ) type ResolvedServerName struct { @@ -78,7 +80,10 @@ func ResolveServerName(ctx context.Context, serverName string, opts ...*ResolveS } else if wellKnown != nil { output.Expires = expiry output.HostHeader = wellKnown.Server - hostname, port, ok = ParseServerName(wellKnown.Server) + wkHost, wkPort, ok := ParseServerName(wellKnown.Server) + if ok { + hostname, port = wkHost, wkPort + } // Step 3.1 and 3.2: IP literals and hostnames with port inside .well-known if net.ParseIP(hostname) != nil || port != 0 { if port == 0 { @@ -120,6 +125,38 @@ func RequestSRV(ctx context.Context, cli *net.Resolver, hostname string) ([]*net return target, err } +func parseCacheControl(resp *http.Response) time.Duration { + cc := resp.Header.Get("Cache-Control") + if cc == "" { + return 0 + } + parts := strings.Split(cc, ",") + for _, part := range parts { + kv := strings.SplitN(strings.TrimSpace(part), "=", 1) + switch kv[0] { + case "no-cache", "no-store": + return 0 + case "max-age": + if len(kv) < 2 { + continue + } + maxAge, err := strconv.Atoi(kv[1]) + if err != nil || maxAge < 0 { + continue + } + age, _ := strconv.Atoi(resp.Header.Get("Age")) + return time.Duration(maxAge-age) * time.Second + } + } + return 0 +} + +const ( + MinCacheDuration = 1 * time.Hour + MaxCacheDuration = 72 * time.Hour + DefaultCacheDuration = 24 * time.Hour +) + // RequestWellKnown sends a request to the well-known endpoint of a server and returns the response, // plus the time when the cache should expire. func RequestWellKnown(ctx context.Context, cli *http.Client, hostname string) (*RespWellKnown, time.Time, error) { @@ -139,14 +176,23 @@ func RequestWellKnown(ctx context.Context, cli *http.Client, hostname string) (* defer resp.Body.Close() if resp.StatusCode != http.StatusOK { return nil, time.Time{}, fmt.Errorf("unexpected status code %d", resp.StatusCode) + } else if resp.ContentLength > mautrix.WellKnownMaxSize { + return nil, time.Time{}, fmt.Errorf("response too large: %d bytes", resp.ContentLength) } var respData RespWellKnown - err = json.NewDecoder(io.LimitReader(resp.Body, 50*1024)).Decode(&respData) + err = json.NewDecoder(io.LimitReader(resp.Body, mautrix.WellKnownMaxSize)).Decode(&respData) if err != nil { return nil, time.Time{}, fmt.Errorf("failed to decode response: %w", err) } else if respData.Server == "" { return nil, time.Time{}, errors.New("server name not found in response") } - // TODO parse cache-control header + cacheDuration := parseCacheControl(resp) + if cacheDuration <= 0 { + cacheDuration = DefaultCacheDuration + } else if cacheDuration < MinCacheDuration { + cacheDuration = MinCacheDuration + } else if cacheDuration > MaxCacheDuration { + cacheDuration = MaxCacheDuration + } return &respData, time.Now().Add(24 * time.Hour), nil } diff --git a/federation/serverauth.go b/federation/serverauth.go new file mode 100644 index 00000000..cd300341 --- /dev/null +++ b/federation/serverauth.go @@ -0,0 +1,264 @@ +// Copyright (c) 2025 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package federation + +import ( + "bytes" + "context" + "encoding/json" + "errors" + "fmt" + "io" + "maps" + "net/http" + "slices" + "strings" + "sync" + + "github.com/rs/zerolog" + "go.mau.fi/util/ptr" + + "maunium.net/go/mautrix" + "maunium.net/go/mautrix/id" +) + +type ServerAuth struct { + Keys KeyCache + Client *Client + GetDestination func(XMatrixAuth) string + MaxBodySize int64 + + keyFetchLocks map[string]*sync.Mutex + keyFetchLocksLock sync.Mutex +} + +func NewServerAuth(client *Client, keyCache KeyCache, getDestination func(auth XMatrixAuth) string) *ServerAuth { + return &ServerAuth{ + Keys: keyCache, + Client: client, + GetDestination: getDestination, + MaxBodySize: 50 * 1024 * 1024, + keyFetchLocks: make(map[string]*sync.Mutex), + } +} + +var MUnauthorized = mautrix.RespError{ErrCode: "M_UNAUTHORIZED", StatusCode: http.StatusUnauthorized} + +var ( + errMissingAuthHeader = MUnauthorized.WithMessage("Missing Authorization header") + errInvalidAuthHeader = MUnauthorized.WithMessage("Authorization header does not start with X-Matrix") + errMalformedAuthHeader = MUnauthorized.WithMessage("X-Matrix value is missing required components") + errInvalidDestination = MUnauthorized.WithMessage("Invalid destination in X-Matrix header") + errFailedToQueryKeys = MUnauthorized.WithMessage("Failed to query server keys") + errInvalidSelfSignatures = MUnauthorized.WithMessage("Server keys don't have valid self-signatures") + errRequestBodyTooLarge = mautrix.MTooLarge.WithMessage("Request body too large") + errInvalidJSONBody = mautrix.MBadJSON.WithMessage("Request body is not valid JSON") + errBodyReadFailed = mautrix.MUnknown.WithMessage("Failed to read request body") + errInvalidRequestSignature = MUnauthorized.WithMessage("Failed to verify request signature") +) + +type XMatrixAuth struct { + Origin string + Destination string + KeyID id.KeyID + Signature string +} + +func (xma XMatrixAuth) String() string { + return fmt.Sprintf( + `X-Matrix origin="%s",destination="%s",key="%s",sig="%s"`, + xma.Origin, + xma.Destination, + xma.KeyID, + xma.Signature, + ) +} + +func ParseXMatrixAuth(auth string) (xma XMatrixAuth) { + auth = strings.TrimPrefix(auth, "X-Matrix ") + // TODO upgrade to strings.SplitSeq after Go 1.24 is the minimum + for _, part := range strings.Split(auth, ",") { + part = strings.TrimSpace(part) + eqIdx := strings.Index(part, "=") + if eqIdx == -1 || strings.Count(part, "=") > 1 { + continue + } + val := strings.Trim(part[eqIdx+1:], "\"") + switch strings.ToLower(part[:eqIdx]) { + case "origin": + xma.Origin = val + case "destination": + xma.Destination = val + case "key": + xma.KeyID = id.KeyID(val) + case "sig": + xma.Signature = val + } + } + return +} + +func (sa *ServerAuth) GetKeysWithCache(ctx context.Context, serverName string, keyID id.KeyID) (*ServerKeyResponse, error) { + res, err := sa.Keys.LoadKeys(serverName) + if err != nil { + return nil, fmt.Errorf("failed to read cache: %w", err) + } else if res.HasKey(keyID) { + return res, nil + } + + sa.keyFetchLocksLock.Lock() + lock, ok := sa.keyFetchLocks[serverName] + if !ok { + lock = &sync.Mutex{} + sa.keyFetchLocks[serverName] = lock + } + sa.keyFetchLocksLock.Unlock() + + lock.Lock() + defer lock.Unlock() + res, err = sa.Keys.LoadKeys(serverName) + if err != nil { + return nil, fmt.Errorf("failed to read cache: %w", err) + } else if res != nil { + if res.HasKey(keyID) { + return res, nil + } else if !sa.Keys.ShouldReQuery(serverName) { + zerolog.Ctx(ctx).Trace(). + Str("server_name", serverName). + Stringer("key_id", keyID). + Msg("Not sending key request for missing key ID, last query was too recent") + return res, nil + } + } + res, err = sa.Client.ServerKeys(ctx, serverName) + if err != nil { + sa.Keys.StoreFetchError(serverName, err) + return nil, err + } + sa.Keys.StoreKeys(res) + return res, nil +} + +type fixedLimitedReader struct { + R io.Reader + N int64 + Err error +} + +func (l *fixedLimitedReader) Read(p []byte) (n int, err error) { + if l.N <= 0 { + return 0, l.Err + } + if int64(len(p)) > l.N { + p = p[0:l.N] + } + n, err = l.R.Read(p) + l.N -= int64(n) + return +} + +func (sa *ServerAuth) Authenticate(r *http.Request) (*http.Request, *mautrix.RespError) { + defer func() { + _ = r.Body.Close() + }() + log := zerolog.Ctx(r.Context()) + if r.ContentLength > sa.MaxBodySize { + return nil, &errRequestBodyTooLarge + } + auth := r.Header.Get("Authorization") + if auth == "" { + return nil, &errMissingAuthHeader + } else if !strings.HasPrefix(auth, "X-Matrix ") { + return nil, &errInvalidAuthHeader + } + parsed := ParseXMatrixAuth(auth) + if parsed.Origin == "" || parsed.KeyID == "" || parsed.Signature == "" { + log.Trace().Str("auth_header", auth).Msg("Malformed X-Matrix header") + return nil, &errMalformedAuthHeader + } + destination := sa.GetDestination(parsed) + if destination == "" || (parsed.Destination != "" && parsed.Destination != destination) { + log.Trace(). + Str("got_destination", parsed.Destination). + Str("expected_destination", destination). + Msg("Invalid destination in X-Matrix header") + return nil, &errInvalidDestination + } + resp, err := sa.GetKeysWithCache(r.Context(), parsed.Origin, parsed.KeyID) + if err != nil { + if !errors.Is(err, ErrRecentKeyQueryFailed) { + log.Err(err). + Str("server_name", parsed.Origin). + Msg("Failed to query keys to authenticate request") + } else { + log.Trace().Err(err). + Str("server_name", parsed.Origin). + Msg("Failed to query keys to authenticate request (cached error)") + } + return nil, &errFailedToQueryKeys + } else if err := resp.VerifySelfSignature(); err != nil { + log.Trace().Err(err). + Str("server_name", parsed.Origin). + Msg("Failed to validate self-signatures of server keys") + return nil, &errInvalidSelfSignatures + } + key, ok := resp.VerifyKeys[parsed.KeyID] + if !ok { + keys := slices.Collect(maps.Keys(resp.VerifyKeys)) + log.Trace(). + Stringer("expected_key_id", parsed.KeyID). + Any("found_key_ids", keys). + Msg("Didn't find expected key ID to verify request") + return nil, ptr.Ptr(MUnauthorized.WithMessage("Key ID %q not found (got %v)", parsed.KeyID, keys)) + } + var reqBody []byte + if r.ContentLength != 0 && r.Method != http.MethodGet && r.Method != http.MethodHead { + reqBody, err = io.ReadAll(&fixedLimitedReader{R: r.Body, N: sa.MaxBodySize, Err: errRequestBodyTooLarge}) + if errors.Is(err, errRequestBodyTooLarge) { + return nil, &errRequestBodyTooLarge + } else if err != nil { + log.Err(err). + Str("server_name", parsed.Origin). + Msg("Failed to read request body to authenticate") + return nil, &errBodyReadFailed + } else if !json.Valid(reqBody) { + return nil, &errInvalidJSONBody + } + } + err = (&signableRequest{ + Method: r.Method, + URI: r.URL.RequestURI(), + Origin: parsed.Origin, + Destination: destination, + Content: reqBody, + }).Verify(key.Key, parsed.Signature) + if err != nil { + log.Trace().Err(err).Msg("Request has invalid signature") + return nil, &errInvalidRequestSignature + } + ctx := context.WithValue(r.Context(), contextKeyDestinationServer, destination) + ctx = context.WithValue(ctx, contextKeyOriginServer, parsed.Origin) + ctx = log.With(). + Str("origin_server_name", parsed.Origin). + Str("destination_server_name", destination). + Logger().WithContext(ctx) + modifiedReq := r.WithContext(ctx) + if reqBody != nil { + modifiedReq.Body = io.NopCloser(bytes.NewReader(reqBody)) + } + return modifiedReq, nil +} + +func (sa *ServerAuth) AuthenticateMiddleware(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if modifiedReq, err := sa.Authenticate(r); err != nil { + err.Write(w) + } else { + next.ServeHTTP(w, modifiedReq) + } + }) +} diff --git a/federation/serverauth_test.go b/federation/serverauth_test.go new file mode 100644 index 00000000..f99fc6cf --- /dev/null +++ b/federation/serverauth_test.go @@ -0,0 +1,29 @@ +// Copyright (c) 2025 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package federation_test + +import ( + "context" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "maunium.net/go/mautrix/federation" +) + +func TestServerKeyResponse_VerifySelfSignature(t *testing.T) { + cli := federation.NewClient("", nil, nil) + ctx := context.Background() + for _, name := range []string{"matrix.org", "maunium.net", "cd.mau.dev", "uwu.mau.dev"} { + t.Run(name, func(t *testing.T) { + resp, err := cli.ServerKeys(ctx, name) + require.NoError(t, err) + assert.NoError(t, resp.VerifySelfSignature()) + }) + } +} diff --git a/federation/signingkey.go b/federation/signingkey.go index 67751b48..a4ad9679 100644 --- a/federation/signingkey.go +++ b/federation/signingkey.go @@ -14,9 +14,11 @@ import ( "strings" "time" + "github.com/tidwall/sjson" "go.mau.fi/util/jsontime" "maunium.net/go/mautrix/crypto/canonicaljson" + "maunium.net/go/mautrix/federation/signutil" "maunium.net/go/mautrix/id" ) @@ -31,8 +33,8 @@ type SigningKey struct { // // The output of this function can be parsed back into a [SigningKey] using the [ParseSynapseKey] function. func (sk *SigningKey) SynapseString() string { - alg, id := sk.ID.Parse() - return fmt.Sprintf("%s %s %s", alg, id, base64.RawStdEncoding.EncodeToString(sk.Priv.Seed())) + alg, keyID := sk.ID.Parse() + return fmt.Sprintf("%s %s %s", alg, keyID, base64.RawStdEncoding.EncodeToString(sk.Priv.Seed())) } // ParseSynapseKey parses a Synapse-compatible private key string into a SigningKey. @@ -77,6 +79,37 @@ type ServerKeyResponse struct { OldVerifyKeys map[id.KeyID]OldVerifyKey `json:"old_verify_keys,omitempty"` Signatures map[string]map[id.KeyID]string `json:"signatures,omitempty"` ValidUntilTS jsontime.UnixMilli `json:"valid_until_ts"` + + Raw json.RawMessage `json:"-"` +} + +type QueryKeysResponse struct { + ServerKeys []*ServerKeyResponse `json:"server_keys"` +} + +func (skr *ServerKeyResponse) HasKey(keyID id.KeyID) bool { + if skr == nil { + return false + } else if _, ok := skr.VerifyKeys[keyID]; ok { + return true + } + return false +} + +func (skr *ServerKeyResponse) VerifySelfSignature() error { + for keyID, key := range skr.VerifyKeys { + if err := signutil.VerifyJSON(skr.ServerName, keyID, key.Key, skr.Raw); err != nil { + return fmt.Errorf("failed to verify self signature for key %s: %w", keyID, err) + } + } + return nil +} + +type marshalableSKR ServerKeyResponse + +func (skr *ServerKeyResponse) UnmarshalJSON(data []byte) error { + skr.Raw = data + return json.Unmarshal(data, (*marshalableSKR)(skr)) } type ServerVerifyKey struct { @@ -92,12 +125,16 @@ type OldVerifyKey struct { ExpiredTS jsontime.UnixMilli `json:"expired_ts"` } -func (sk *SigningKey) SignJSON(data any) ([]byte, error) { +func (sk *SigningKey) SignJSON(data any) (string, error) { marshaled, err := json.Marshal(data) if err != nil { - return nil, err + return "", err } - return sk.SignRawJSON(marshaled), nil + marshaled, err = sjson.DeleteBytes(marshaled, "signatures") + if err != nil { + return "", err + } + return base64.RawStdEncoding.EncodeToString(sk.SignRawJSON(marshaled)), nil } func (sk *SigningKey) SignRawJSON(data json.RawMessage) []byte { @@ -120,7 +157,7 @@ func (sk *SigningKey) GenerateKeyResponse(serverName string, oldVerifyKeys map[i } skr.Signatures = map[string]map[id.KeyID]string{ serverName: { - sk.ID: base64.RawURLEncoding.EncodeToString(signature), + sk.ID: signature, }, } return skr diff --git a/federation/signutil/verify.go b/federation/signutil/verify.go new file mode 100644 index 00000000..ea0e7886 --- /dev/null +++ b/federation/signutil/verify.go @@ -0,0 +1,106 @@ +// Copyright (c) 2025 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package signutil + +import ( + "crypto/ed25519" + "encoding/base64" + "encoding/json" + "errors" + "fmt" + + "github.com/tidwall/gjson" + "github.com/tidwall/sjson" + "go.mau.fi/util/exgjson" + + "maunium.net/go/mautrix/crypto/canonicaljson" + "maunium.net/go/mautrix/id" +) + +var ErrSignatureNotFound = errors.New("signature not found") +var ErrInvalidSignature = errors.New("invalid signature") + +func VerifyJSON(serverName string, keyID id.KeyID, key id.SigningKey, data any) error { + var err error + message, ok := data.(json.RawMessage) + if !ok { + message, err = json.Marshal(data) + if err != nil { + return fmt.Errorf("failed to marshal data: %w", err) + } + } + sigVal := gjson.GetBytes(message, exgjson.Path("signatures", serverName, string(keyID))) + if sigVal.Type != gjson.String { + return ErrSignatureNotFound + } + message, err = sjson.DeleteBytes(message, "signatures") + if err != nil { + return fmt.Errorf("failed to delete signatures: %w", err) + } + message, err = sjson.DeleteBytes(message, "unsigned") + if err != nil { + return fmt.Errorf("failed to delete unsigned: %w", err) + } + return VerifyJSONRaw(key, sigVal.Str, message) +} + +func VerifyJSONAny(key id.SigningKey, data any) error { + var err error + message, ok := data.(json.RawMessage) + if !ok { + message, err = json.Marshal(data) + if err != nil { + return fmt.Errorf("failed to marshal data: %w", err) + } + } + sigs := gjson.GetBytes(message, "signatures") + if !sigs.IsObject() { + return ErrSignatureNotFound + } + message, err = sjson.DeleteBytes(message, "signatures") + if err != nil { + return fmt.Errorf("failed to delete signatures: %w", err) + } + message, err = sjson.DeleteBytes(message, "unsigned") + if err != nil { + return fmt.Errorf("failed to delete unsigned: %w", err) + } + var validated bool + sigs.ForEach(func(_, value gjson.Result) bool { + if !value.IsObject() { + return true + } + value.ForEach(func(_, value gjson.Result) bool { + if value.Type != gjson.String { + return true + } + validated = VerifyJSONRaw(key, value.Str, message) == nil + return !validated + }) + return !validated + }) + if !validated { + return ErrInvalidSignature + } + return nil +} + +func VerifyJSONRaw(key id.SigningKey, sig string, message json.RawMessage) error { + sigBytes, err := base64.RawStdEncoding.DecodeString(sig) + if err != nil { + return fmt.Errorf("failed to decode signature: %w", err) + } + keyBytes, err := base64.RawStdEncoding.DecodeString(string(key)) + if err != nil { + return fmt.Errorf("failed to decode key: %w", err) + } + message = canonicaljson.CanonicalJSONAssumeValid(message) + if !ed25519.Verify(keyBytes, message, sigBytes) { + return ErrInvalidSignature + } + return nil +} diff --git a/filter.go b/filter.go index 2603bfb9..54973dab 100644 --- a/filter.go +++ b/filter.go @@ -19,45 +19,45 @@ const ( // Filter is used by clients to specify how the server should filter responses to e.g. sync requests // Specified by: https://spec.matrix.org/v1.2/client-server-api/#filtering type Filter struct { - AccountData FilterPart `json:"account_data,omitempty"` + AccountData *FilterPart `json:"account_data,omitempty"` EventFields []string `json:"event_fields,omitempty"` EventFormat EventFormat `json:"event_format,omitempty"` - Presence FilterPart `json:"presence,omitempty"` - Room RoomFilter `json:"room,omitempty"` + Presence *FilterPart `json:"presence,omitempty"` + Room *RoomFilter `json:"room,omitempty"` BeeperToDevice *FilterPart `json:"com.beeper.to_device,omitempty"` } // RoomFilter is used to define filtering rules for room events type RoomFilter struct { - AccountData FilterPart `json:"account_data,omitempty"` - Ephemeral FilterPart `json:"ephemeral,omitempty"` + AccountData *FilterPart `json:"account_data,omitempty"` + Ephemeral *FilterPart `json:"ephemeral,omitempty"` IncludeLeave bool `json:"include_leave,omitempty"` NotRooms []id.RoomID `json:"not_rooms,omitempty"` Rooms []id.RoomID `json:"rooms,omitempty"` - State FilterPart `json:"state,omitempty"` - Timeline FilterPart `json:"timeline,omitempty"` + State *FilterPart `json:"state,omitempty"` + Timeline *FilterPart `json:"timeline,omitempty"` } // FilterPart is used to define filtering rules for specific categories of events type FilterPart struct { - NotRooms []id.RoomID `json:"not_rooms,omitempty"` - Rooms []id.RoomID `json:"rooms,omitempty"` - Limit int `json:"limit,omitempty"` - NotSenders []id.UserID `json:"not_senders,omitempty"` - NotTypes []event.Type `json:"not_types,omitempty"` - Senders []id.UserID `json:"senders,omitempty"` - Types []event.Type `json:"types,omitempty"` - ContainsURL *bool `json:"contains_url,omitempty"` - - LazyLoadMembers bool `json:"lazy_load_members,omitempty"` - IncludeRedundantMembers bool `json:"include_redundant_members,omitempty"` + NotRooms []id.RoomID `json:"not_rooms,omitempty"` + Rooms []id.RoomID `json:"rooms,omitempty"` + Limit int `json:"limit,omitempty"` + NotSenders []id.UserID `json:"not_senders,omitempty"` + NotTypes []event.Type `json:"not_types,omitempty"` + Senders []id.UserID `json:"senders,omitempty"` + Types []event.Type `json:"types,omitempty"` + ContainsURL *bool `json:"contains_url,omitempty"` + LazyLoadMembers bool `json:"lazy_load_members,omitempty"` + IncludeRedundantMembers bool `json:"include_redundant_members,omitempty"` + UnreadThreadNotifications bool `json:"unread_thread_notifications,omitempty"` } // Validate checks if the filter contains valid property values func (filter *Filter) Validate() error { if filter.EventFormat != EventFormatClient && filter.EventFormat != EventFormatFederation { - return errors.New("Bad event_format value. Must be one of [\"client\", \"federation\"]") + return errors.New("bad event_format value") } return nil } @@ -69,7 +69,7 @@ func DefaultFilter() Filter { EventFields: nil, EventFormat: "client", Presence: DefaultFilterPart(), - Room: RoomFilter{ + Room: &RoomFilter{ AccountData: DefaultFilterPart(), Ephemeral: DefaultFilterPart(), IncludeLeave: false, @@ -82,8 +82,8 @@ func DefaultFilter() Filter { } // DefaultFilterPart returns the default filter part used by the Matrix server if no filter is provided in the request -func DefaultFilterPart() FilterPart { - return FilterPart{ +func DefaultFilterPart() *FilterPart { + return &FilterPart{ NotRooms: nil, Rooms: nil, Limit: 20, diff --git a/format/htmlparser.go b/format/htmlparser.go index 7c3b3c88..e0507d93 100644 --- a/format/htmlparser.go +++ b/format/htmlparser.go @@ -13,6 +13,7 @@ import ( "strconv" "strings" + "go.mau.fi/util/exstrings" "golang.org/x/net/html" "maunium.net/go/mautrix/event" @@ -92,6 +93,30 @@ func DefaultPillConverter(displayname, mxid, eventID string, ctx Context) string } } +func onlyBacktickCount(line string) (count int) { + for i := 0; i < len(line); i++ { + if line[i] != '`' { + return -1 + } + count++ + } + return +} + +func DefaultMonospaceBlockConverter(code, language string, ctx Context) string { + if len(code) == 0 || code[len(code)-1] != '\n' { + code += "\n" + } + fence := "```" + for line := range strings.SplitSeq(code, "\n") { + count := onlyBacktickCount(strings.TrimSpace(line)) + if count >= len(fence) { + fence = strings.Repeat("`", count+1) + } + } + return fmt.Sprintf("%s%s\n%s%s", fence, language, code, fence) +} + // HTMLParser is a somewhat customizable Matrix HTML parser. type HTMLParser struct { PillConverter PillConverter @@ -187,25 +212,6 @@ func (parser *HTMLParser) listToString(node *html.Node, ctx Context) string { return strings.Join(children, "\n") } -func LongestSequence(in string, of rune) int { - currentSeq := 0 - maxSeq := 0 - for _, chr := range in { - if chr == of { - currentSeq++ - } else { - if currentSeq > maxSeq { - maxSeq = currentSeq - } - currentSeq = 0 - } - } - if currentSeq > maxSeq { - maxSeq = currentSeq - } - return maxSeq -} - func (parser *HTMLParser) basicFormatToString(node *html.Node, ctx Context) string { str := parser.nodeToTagAwareString(node.FirstChild, ctx) switch node.Data { @@ -232,8 +238,7 @@ func (parser *HTMLParser) basicFormatToString(node *html.Node, ctx Context) stri if parser.MonospaceConverter != nil { return parser.MonospaceConverter(str, ctx) } - surround := strings.Repeat("`", LongestSequence(str, '`')+1) - return fmt.Sprintf("%s%s%s", surround, str, surround) + return SafeMarkdownCode(str) } return str } @@ -306,7 +311,10 @@ func (parser *HTMLParser) linkToString(node *html.Node, ctx Context) string { } if parser.LinkConverter != nil { return parser.LinkConverter(str, href, ctx) - } else if str == href { + } else if str == href || + str == strings.TrimPrefix(href, "mailto:") || + str == strings.TrimPrefix(href, "http://") || + str == strings.TrimPrefix(href, "https://") { return str } return fmt.Sprintf("%s (%s)", str, href) @@ -348,6 +356,8 @@ func (parser *HTMLParser) tagToString(node *html.Node, ctx Context) string { return parser.imgToString(node, ctx) case "hr": return parser.HorizontalLine + case "input": + return parser.inputToString(node, ctx) case "pre": var preStr, language string if node.FirstChild != nil && node.FirstChild.Type == html.ElementNode && node.FirstChild.Data == "code" { @@ -362,20 +372,28 @@ func (parser *HTMLParser) tagToString(node *html.Node, ctx Context) string { if parser.MonospaceBlockConverter != nil { return parser.MonospaceBlockConverter(preStr, language, ctx) } - if len(preStr) == 0 || preStr[len(preStr)-1] != '\n' { - preStr += "\n" - } - return fmt.Sprintf("```%s\n%s```", language, preStr) + return DefaultMonospaceBlockConverter(preStr, language, ctx) default: return parser.nodeToTagAwareString(node.FirstChild, ctx) } } +func (parser *HTMLParser) inputToString(node *html.Node, ctx Context) string { + if len(ctx.TagStack) > 1 && ctx.TagStack[len(ctx.TagStack)-2] == "li" { + _, checked := parser.maybeGetAttribute(node, "checked") + if checked { + return "[x]" + } + return "[ ]" + } + return parser.nodeToTagAwareString(node.FirstChild, ctx) +} + func (parser *HTMLParser) singleNodeToString(node *html.Node, ctx Context) TaggedString { switch node.Type { case html.TextNode: if !ctx.PreserveWhitespace { - node.Data = strings.Replace(node.Data, "\n", "", -1) + node.Data = exstrings.CollapseSpaces(strings.ReplaceAll(node.Data, "\n", "")) } if parser.TextConverter != nil { node.Data = parser.TextConverter(node.Data, ctx) @@ -455,7 +473,7 @@ var MarkdownHTMLParser = &HTMLParser{ PillConverter: DefaultPillConverter, LinkConverter: func(text, href string, ctx Context) string { if text == href { - return text + return fmt.Sprintf("<%s>", href) } return fmt.Sprintf("[%s](%s)", text, href) }, diff --git a/format/markdown.go b/format/markdown.go index d099ba00..77ced0dc 100644 --- a/format/markdown.go +++ b/format/markdown.go @@ -8,14 +8,17 @@ package format import ( "fmt" + "regexp" "strings" "github.com/yuin/goldmark" "github.com/yuin/goldmark/extension" "github.com/yuin/goldmark/renderer/html" + "go.mau.fi/util/exstrings" "maunium.net/go/mautrix/event" "maunium.net/go/mautrix/format/mdext" + "maunium.net/go/mautrix/id" ) const paragraphStart = "

      " @@ -39,6 +42,55 @@ func UnwrapSingleParagraph(html string) string { return html } +var mdEscapeRegex = regexp.MustCompile("([\\\\`*_[\\]()])") + +func EscapeMarkdown(text string) string { + text = mdEscapeRegex.ReplaceAllString(text, "\\$1") + text = strings.ReplaceAll(text, ">", ">") + text = strings.ReplaceAll(text, "<", "<") + return text +} + +type uriAble interface { + String() string + URI() *id.MatrixURI +} + +func MarkdownMention(id uriAble) string { + return MarkdownMentionWithName(id.String(), id) +} + +func MarkdownMentionWithName(name string, id uriAble) string { + return MarkdownLink(name, id.URI().MatrixToURL()) +} + +func MarkdownMentionRoomID(name string, id id.RoomID, via ...string) string { + if name == "" { + name = id.String() + } + return MarkdownLink(name, id.URI(via...).MatrixToURL()) +} + +func MarkdownLink(name string, url string) string { + return fmt.Sprintf("[%s](%s)", EscapeMarkdown(name), EscapeMarkdown(url)) +} + +func SafeMarkdownCode[T ~string](textInput T) string { + if textInput == "" { + return "` `" + } + text := strings.ReplaceAll(string(textInput), "\n", " ") + backtickCount := exstrings.LongestSequenceOf(text, '`') + if backtickCount == 0 { + return fmt.Sprintf("`%s`", text) + } + quotes := strings.Repeat("`", backtickCount+1) + if text[0] == '`' || text[len(text)-1] == '`' { + return fmt.Sprintf("%s %s %s", quotes, text, quotes) + } + return fmt.Sprintf("%s%s%s", quotes, text, quotes) +} + func RenderMarkdownCustom(text string, renderer goldmark.Markdown) event.MessageEventContent { var buf strings.Builder err := renderer.Convert([]byte(text), &buf) diff --git a/format/markdown_test.go b/format/markdown_test.go index d4e7d716..46ea4886 100644 --- a/format/markdown_test.go +++ b/format/markdown_test.go @@ -196,3 +196,18 @@ func TestRenderMarkdown_CustomEmoji(t *testing.T) { assert.Equal(t, html, rendered, "with input %q", markdown) } } + +var codeTests = map[string]string{ + "meow": "`meow`", + "me`ow": "``me`ow``", + "`me`ow": "`` `me`ow ``", + "me`ow`": "`` me`ow` ``", + "`meow`": "`` `meow` ``", + "`````````": "`````````` ````````` ``````````", +} + +func TestSafeMarkdownCode(t *testing.T) { + for input, expected := range codeTests { + assert.Equal(t, expected, format.SafeMarkdownCode(input), "with input %q", input) + } +} diff --git a/go.mod b/go.mod index c686489a..49a1d4e4 100644 --- a/go.mod +++ b/go.mod @@ -1,43 +1,42 @@ module maunium.net/go/mautrix -go 1.22.0 +go 1.25.0 -toolchain go1.23.4 +toolchain go1.26.0 require ( - filippo.io/edwards25519 v1.1.0 + filippo.io/edwards25519 v1.2.0 github.com/chzyer/readline v1.5.1 - github.com/gorilla/mux v1.8.0 - github.com/gorilla/websocket v1.5.0 - github.com/lib/pq v1.10.9 - github.com/mattn/go-sqlite3 v1.14.24 + github.com/coder/websocket v1.8.14 + github.com/lib/pq v1.11.2 + github.com/mattn/go-sqlite3 v1.14.34 github.com/rs/xid v1.6.0 - github.com/rs/zerolog v1.33.0 + github.com/rs/zerolog v1.34.0 github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e - github.com/stretchr/testify v1.10.0 + github.com/stretchr/testify v1.11.1 github.com/tidwall/gjson v1.18.0 github.com/tidwall/sjson v1.2.5 - github.com/yuin/goldmark v1.7.8 - go.mau.fi/util v0.8.3 - go.mau.fi/zeroconfig v0.1.3 - golang.org/x/crypto v0.31.0 - golang.org/x/exp v0.0.0-20241215155358-4a5509556b9e - golang.org/x/net v0.32.0 - golang.org/x/sync v0.10.0 + github.com/yuin/goldmark v1.7.16 + go.mau.fi/util v0.9.6 + go.mau.fi/zeroconfig v0.2.0 + golang.org/x/crypto v0.48.0 + golang.org/x/exp v0.0.0-20260218203240-3dfff04db8fa + golang.org/x/net v0.50.0 + golang.org/x/sync v0.19.0 gopkg.in/yaml.v3 v3.0.1 maunium.net/go/mauflag v1.0.0 ) require ( - github.com/coreos/go-systemd/v22 v22.5.0 // indirect + github.com/coreos/go-systemd/v22 v22.6.0 // indirect github.com/davecgh/go-spew v1.1.1 // indirect - github.com/mattn/go-colorable v0.1.13 // indirect - github.com/mattn/go-isatty v0.0.19 // indirect - github.com/petermattis/goid v0.0.0-20241211131331-93ee7e083c43 // indirect + github.com/mattn/go-colorable v0.1.14 // indirect + github.com/mattn/go-isatty v0.0.20 // indirect + github.com/petermattis/goid v0.0.0-20260113132338-7c7de50cc741 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect github.com/tidwall/match v1.1.1 // indirect - github.com/tidwall/pretty v1.2.0 // indirect - golang.org/x/sys v0.28.0 // indirect - golang.org/x/text v0.21.0 // indirect + github.com/tidwall/pretty v1.2.1 // indirect + golang.org/x/sys v0.41.0 // indirect + golang.org/x/text v0.34.0 // indirect gopkg.in/natefinch/lumberjack.v2 v2.2.1 // indirect ) diff --git a/go.sum b/go.sum index 47ac74ef..871a5156 100644 --- a/go.sum +++ b/go.sum @@ -1,5 +1,5 @@ -filippo.io/edwards25519 v1.1.0 h1:FNf4tywRC1HmFuKW5xopWpigGjJKiJSV0Cqo0cJWDaA= -filippo.io/edwards25519 v1.1.0/go.mod h1:BxyFTGdWcka3PhytdK4V28tE5sGfRvvvRV7EaN4VDT4= +filippo.io/edwards25519 v1.2.0 h1:crnVqOiS4jqYleHd9vaKZ+HKtHfllngJIiOpNpoJsjo= +filippo.io/edwards25519 v1.2.0/go.mod h1:xzAOLCNug/yB62zG1bQ8uziwrIqIuxhctzJT18Q77mc= github.com/DATA-DOG/go-sqlmock v1.5.2 h1:OcvFkGmslmlZibjAjaHm3L//6LiuBgolP7OputlJIzU= github.com/DATA-DOG/go-sqlmock v1.5.2/go.mod h1:88MAG/4G7SMwSE3CeA0ZKzrT5CiOU3OJ+JlNzwDqpNU= github.com/chzyer/logex v1.2.1 h1:XHDu3E6q+gdHgsdTPH6ImJMIp436vR6MPtH8gP05QzM= @@ -8,69 +8,70 @@ github.com/chzyer/readline v1.5.1 h1:upd/6fQk4src78LMRzh5vItIt361/o4uq553V8B5sGI github.com/chzyer/readline v1.5.1/go.mod h1:Eh+b79XXUwfKfcPLepksvw2tcLE/Ct21YObkaSkeBlk= github.com/chzyer/test v1.0.0 h1:p3BQDXSxOhOG0P9z6/hGnII4LGiEPOYBhs8asl/fC04= github.com/chzyer/test v1.0.0/go.mod h1:2JlltgoNkt4TW/z9V/IzDdFaMTM2JPIi26O1pF38GC8= -github.com/coreos/go-systemd/v22 v22.5.0 h1:RrqgGjYQKalulkV8NGVIfkXQf6YYmOyiJKk8iXXhfZs= +github.com/coder/websocket v1.8.14 h1:9L0p0iKiNOibykf283eHkKUHHrpG7f65OE3BhhO7v9g= +github.com/coder/websocket v1.8.14/go.mod h1:NX3SzP+inril6yawo5CQXx8+fk145lPDC6pumgx0mVg= github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc= +github.com/coreos/go-systemd/v22 v22.6.0 h1:aGVa/v8B7hpb0TKl0MWoAavPDmHvobFe5R5zn0bCJWo= +github.com/coreos/go-systemd/v22 v22.6.0/go.mod h1:iG+pp635Fo7ZmV/j14KUcmEyWF+0X7Lua8rrTWzYgWU= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA= -github.com/gorilla/mux v1.8.0 h1:i40aqfkR1h2SlN9hojwV5ZA91wcXFOvkdNIeFDP5koI= -github.com/gorilla/mux v1.8.0/go.mod h1:DVbg23sWSpFRCP0SfiEN6jmj59UnW/n46BH5rLB71So= -github.com/gorilla/websocket v1.5.0 h1:PPwGk2jz7EePpoHN/+ClbZu8SPxiqlu12wZP/3sWmnc= -github.com/gorilla/websocket v1.5.0/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= -github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw= -github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= -github.com/mattn/go-colorable v0.1.13 h1:fFA4WZxdEF4tXPZVKMLwD8oUnCTTo08duU7wxecdEvA= +github.com/lib/pq v1.11.2 h1:x6gxUeu39V0BHZiugWe8LXZYZ+Utk7hSJGThs8sdzfs= +github.com/lib/pq v1.11.2/go.mod h1:/p+8NSbOcwzAEI7wiMXFlgydTwcgTr3OSKMsD2BitpA= github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg= +github.com/mattn/go-colorable v0.1.14 h1:9A9LHSqF/7dyVVX6g0U9cwm9pG3kP9gSzcuIPHPsaIE= +github.com/mattn/go-colorable v0.1.14/go.mod h1:6LmQG8QLFO4G5z1gPvYEzlUgJ2wF+stgPZH1UqBm1s8= github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM= -github.com/mattn/go-isatty v0.0.19 h1:JITubQf0MOLdlGRuRq+jtsDlekdYPia9ZFsB8h/APPA= github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= -github.com/mattn/go-sqlite3 v1.14.24 h1:tpSp2G2KyMnnQu99ngJ47EIkWVmliIizyZBfPrBWDRM= -github.com/mattn/go-sqlite3 v1.14.24/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y= -github.com/petermattis/goid v0.0.0-20241211131331-93ee7e083c43 h1:ah1dvbqPMN5+ocrg/ZSgZ6k8bOk+kcZQ7fnyx6UvOm4= -github.com/petermattis/goid v0.0.0-20241211131331-93ee7e083c43/go.mod h1:pxMtw7cyUw6B2bRH0ZBANSPg+AoSud1I1iyJHI69jH4= +github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= +github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= +github.com/mattn/go-sqlite3 v1.14.34 h1:3NtcvcUnFBPsuRcno8pUtupspG/GM+9nZ88zgJcp6Zk= +github.com/mattn/go-sqlite3 v1.14.34/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y= +github.com/petermattis/goid v0.0.0-20260113132338-7c7de50cc741 h1:KPpdlQLZcHfTMQRi6bFQ7ogNO0ltFT4PmtwTLW4W+14= +github.com/petermattis/goid v0.0.0-20260113132338-7c7de50cc741/go.mod h1:pxMtw7cyUw6B2bRH0ZBANSPg+AoSud1I1iyJHI69jH4= github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= -github.com/rs/xid v1.5.0/go.mod h1:trrq9SKmegXys3aeAKXMUTdJsYXVwGY3RLcfgqegfbg= github.com/rs/xid v1.6.0 h1:fV591PaemRlL6JfRxGDEPl69wICngIQ3shQtzfy2gxU= github.com/rs/xid v1.6.0/go.mod h1:7XoLgs4eV+QndskICGsho+ADou8ySMSjJKDIan90Nz0= -github.com/rs/zerolog v1.33.0 h1:1cU2KZkvPxNyfgEmhHAz/1A9Bz+llsdYzklWFzgp0r8= -github.com/rs/zerolog v1.33.0/go.mod h1:/7mN4D5sKwJLZQ2b/znpjC3/GQWY/xaDXUM0kKWRHss= +github.com/rs/zerolog v1.34.0 h1:k43nTLIwcTVQAncfCw4KZ2VY6ukYoZaBPNOE8txlOeY= +github.com/rs/zerolog v1.34.0/go.mod h1:bJsvje4Z08ROH4Nhs5iH600c3IkWhwp44iRc54W6wYQ= github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e h1:MRM5ITcdelLK2j1vwZ3Je0FKVCfqOLp5zO6trqMLYs0= github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e/go.mod h1:XV66xRDqSt+GTGFMVlhk3ULuV0y9ZmzeVGR4mloJI3M= -github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA= -github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= +github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= +github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= github.com/tidwall/gjson v1.14.2/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= github.com/tidwall/gjson v1.18.0 h1:FIDeeyB800efLX89e5a8Y0BNH+LOngJyGrIWxG2FKQY= github.com/tidwall/gjson v1.18.0/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA= github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM= -github.com/tidwall/pretty v1.2.0 h1:RWIZEg2iJ8/g6fDDYzMpobmaoGh5OLl4AXtGUGPcqCs= github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU= +github.com/tidwall/pretty v1.2.1 h1:qjsOFOWWQl+N3RsoF5/ssm1pHmJJwhjlSbZ51I6wMl4= +github.com/tidwall/pretty v1.2.1/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU= github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY= github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28= -github.com/yuin/goldmark v1.7.8 h1:iERMLn0/QJeHFhxSt3p6PeN9mGnvIKSpG9YYorDMnic= -github.com/yuin/goldmark v1.7.8/go.mod h1:uzxRWxtg69N339t3louHJ7+O03ezfj6PlliRlaOzY1E= -go.mau.fi/util v0.8.3 h1:sulhXtfquMrQjsOP67x9CzWVBYUwhYeoo8hNQIpCWZ4= -go.mau.fi/util v0.8.3/go.mod h1:c00Db8xog70JeIsEvhdHooylTkTkakgnAOsZ04hplQY= -go.mau.fi/zeroconfig v0.1.3 h1:As9wYDKmktjmNZW5i1vn8zvJlmGKHeVxHVIBMXsm4kM= -go.mau.fi/zeroconfig v0.1.3/go.mod h1:NcSJkf180JT+1IId76PcMuLTNa1CzsFFZ0nBygIQM70= -golang.org/x/crypto v0.31.0 h1:ihbySMvVjLAeSH1IbfcRTkD/iNscyz8rGzjF/E5hV6U= -golang.org/x/crypto v0.31.0/go.mod h1:kDsLvtWBEx7MV9tJOj9bnXsPbxwJQ6csT/x4KIN4Ssk= -golang.org/x/exp v0.0.0-20241215155358-4a5509556b9e h1:4qufH0hlUYs6AO6XmZC3GqfDPGSXHVXUFR6OND+iJX4= -golang.org/x/exp v0.0.0-20241215155358-4a5509556b9e/go.mod h1:qj5a5QZpwLU2NLQudwIN5koi3beDhSAlJwa67PuM98c= -golang.org/x/net v0.32.0 h1:ZqPmj8Kzc+Y6e0+skZsuACbx+wzMgo5MQsJh9Qd6aYI= -golang.org/x/net v0.32.0/go.mod h1:CwU0IoeOlnQQWJ6ioyFrfRuomB8GKF6KbYXZVyeXNfs= -golang.org/x/sync v0.10.0 h1:3NQrjDixjgGwUOCaF8w2+VYHv0Ve/vGYSbdkTa98gmQ= -golang.org/x/sync v0.10.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= +github.com/yuin/goldmark v1.7.16 h1:n+CJdUxaFMiDUNnWC3dMWCIQJSkxH4uz3ZwQBkAlVNE= +github.com/yuin/goldmark v1.7.16/go.mod h1:ip/1k0VRfGynBgxOz0yCqHrbZXhcjxyuS66Brc7iBKg= +go.mau.fi/util v0.9.6 h1:2nsvxm49KhI3wrFltr0+wSUBlnQ4CMtykuELjpIU+ts= +go.mau.fi/util v0.9.6/go.mod h1:sIJpRH7Iy5Ad1SBuxQoatxtIeErgzxCtjd/2hCMkYMI= +go.mau.fi/zeroconfig v0.2.0 h1:e/OGEERqVRRKlgaro7E6bh8xXiKFSXB3eNNIud7FUjU= +go.mau.fi/zeroconfig v0.2.0/go.mod h1:J0Vn0prHNOm493oZoQ84kq83ZaNCYZnq+noI1b1eN8w= +golang.org/x/crypto v0.48.0 h1:/VRzVqiRSggnhY7gNRxPauEQ5Drw9haKdM0jqfcCFts= +golang.org/x/crypto v0.48.0/go.mod h1:r0kV5h3qnFPlQnBSrULhlsRfryS2pmewsg+XfMgkVos= +golang.org/x/exp v0.0.0-20260218203240-3dfff04db8fa h1:Zt3DZoOFFYkKhDT3v7Lm9FDMEV06GpzjG2jrqW+QTE0= +golang.org/x/exp v0.0.0-20260218203240-3dfff04db8fa/go.mod h1:K79w1Vqn7PoiZn+TkNpx3BUWUQksGO3JcVX6qIjytmA= +golang.org/x/net v0.50.0 h1:ucWh9eiCGyDR3vtzso0WMQinm2Dnt8cFMuQa9K33J60= +golang.org/x/net v0.50.0/go.mod h1:UgoSli3F/pBgdJBHCTc+tp3gmrU4XswgGRgtnwWTfyM= +golang.org/x/sync v0.19.0 h1:vV+1eWNmZ5geRlYjzm2adRgW2/mcpevXNg50YZtPCE4= +golang.org/x/sync v0.19.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI= golang.org/x/sys v0.0.0-20220310020820-b874c991c1a5/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.28.0 h1:Fksou7UEQUWlKvIdsqzJmUmCX3cZuD2+P3XyyzwMhlA= -golang.org/x/sys v0.28.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= -golang.org/x/text v0.21.0 h1:zyQAAkrwaneQ066sspRyJaG9VNi/YJ1NfzcGB3hZ/qo= -golang.org/x/text v0.21.0/go.mod h1:4IBbMaMmOPCJ8SecivzSH54+73PCFmPWxNTLm+vZkEQ= +golang.org/x/sys v0.41.0 h1:Ivj+2Cp/ylzLiEU89QhWblYnOE9zerudt9Ftecq2C6k= +golang.org/x/sys v0.41.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= +golang.org/x/text v0.34.0 h1:oL/Qq0Kdaqxa1KbNeMKwQq0reLCCaFtqu2eNuSeNHbk= +golang.org/x/text v0.34.0/go.mod h1:homfLqTYRFyVYemLBFl5GgL/DWEiH5wcsQ5gSh1yziA= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/natefinch/lumberjack.v2 v2.2.1 h1:bBRl1b0OH9s/DuPhuXpNl+VtCaJXFZ5/uEFST95x9zc= diff --git a/id/contenturi.go b/id/contenturi.go index e6a313f5..67127b6c 100644 --- a/id/contenturi.go +++ b/id/contenturi.go @@ -17,8 +17,14 @@ import ( ) var ( - InvalidContentURI = errors.New("invalid Matrix content URI") - InputNotJSONString = errors.New("input doesn't look like a JSON string") + ErrInvalidContentURI = errors.New("invalid Matrix content URI") + ErrInputNotJSONString = errors.New("input doesn't look like a JSON string") +) + +// Deprecated: use variables prefixed with Err +var ( + InvalidContentURI = ErrInvalidContentURI + InputNotJSONString = ErrInputNotJSONString ) // ContentURIString is a string that's expected to be a Matrix content URI. @@ -55,9 +61,9 @@ func ParseContentURI(uri string) (parsed ContentURI, err error) { if len(uri) == 0 { return } else if !strings.HasPrefix(uri, "mxc://") { - err = InvalidContentURI + err = ErrInvalidContentURI } else if index := strings.IndexRune(uri[6:], '/'); index == -1 || index == len(uri)-7 { - err = InvalidContentURI + err = ErrInvalidContentURI } else { parsed.Homeserver = uri[6 : 6+index] parsed.FileID = uri[6+index+1:] @@ -71,9 +77,9 @@ func ParseContentURIBytes(uri []byte) (parsed ContentURI, err error) { if len(uri) == 0 { return } else if !bytes.HasPrefix(uri, mxcBytes) { - err = InvalidContentURI + err = ErrInvalidContentURI } else if index := bytes.IndexRune(uri[6:], '/'); index == -1 || index == len(uri)-7 { - err = InvalidContentURI + err = ErrInvalidContentURI } else { parsed.Homeserver = string(uri[6 : 6+index]) parsed.FileID = string(uri[6+index+1:]) @@ -86,7 +92,7 @@ func (uri *ContentURI) UnmarshalJSON(raw []byte) (err error) { *uri = ContentURI{} return nil } else if len(raw) < 2 || raw[0] != '"' || raw[len(raw)-1] != '"' { - return InputNotJSONString + return fmt.Errorf("ContentURI: %w", ErrInputNotJSONString) } parsed, err := ParseContentURIBytes(raw[1 : len(raw)-1]) if err != nil { diff --git a/id/crypto.go b/id/crypto.go index 355a84a8..ee857f78 100644 --- a/id/crypto.go +++ b/id/crypto.go @@ -53,6 +53,34 @@ const ( KeyBackupAlgorithmMegolmBackupV1 KeyBackupAlgorithm = "m.megolm_backup.v1.curve25519-aes-sha2" ) +type KeySource string + +func (source KeySource) String() string { + return string(source) +} + +func (source KeySource) Int() int { + switch source { + case KeySourceDirect: + return 100 + case KeySourceBackup: + return 90 + case KeySourceImport: + return 80 + case KeySourceForward: + return 50 + default: + return 0 + } +} + +const ( + KeySourceDirect KeySource = "direct" + KeySourceBackup KeySource = "backup" + KeySourceImport KeySource = "import" + KeySourceForward KeySource = "forward" +) + // BackupVersion is an arbitrary string that identifies a server side key backup. type KeyBackupVersion string diff --git a/id/matrixuri.go b/id/matrixuri.go index 2637d876..d5c78bc7 100644 --- a/id/matrixuri.go +++ b/id/matrixuri.go @@ -54,7 +54,7 @@ var SigilToPathSegment = map[rune]string{ func (uri *MatrixURI) getQuery() url.Values { q := make(url.Values) - if uri.Via != nil && len(uri.Via) > 0 { + if len(uri.Via) > 0 { q["via"] = uri.Via } if len(uri.Action) > 0 { @@ -210,7 +210,11 @@ func ProcessMatrixURI(uri *url.URL) (*MatrixURI, error) { if len(parts[1]) == 0 { return nil, ErrEmptySecondSegment } - parsed.MXID1 = parts[1] + var err error + parsed.MXID1, err = url.PathUnescape(parts[1]) + if err != nil { + return nil, fmt.Errorf("failed to url decode second segment %q: %w", parts[1], err) + } // Step 6: if the first part is a room and the URI has 4 segments, construct a second level identifier if parsed.Sigil1 == '!' && len(parts) == 4 { @@ -226,7 +230,10 @@ func ProcessMatrixURI(uri *url.URL) (*MatrixURI, error) { if len(parts[3]) == 0 { return nil, ErrEmptyFourthSegment } - parsed.MXID2 = parts[3] + parsed.MXID2, err = url.PathUnescape(parts[3]) + if err != nil { + return nil, fmt.Errorf("failed to url decode fourth segment %q: %w", parts[3], err) + } } // Step 7: parse the query and extract via and action items diff --git a/id/matrixuri_test.go b/id/matrixuri_test.go index 8b1096cb..90a0754d 100644 --- a/id/matrixuri_test.go +++ b/id/matrixuri_test.go @@ -77,8 +77,12 @@ func TestParseMatrixURI_RoomID(t *testing.T) { parsedVia, err := id.ParseMatrixURI("matrix:roomid/7NdBVvkd4aLSbgKt9RXl:example.org?via=maunium.net&via=matrix.org") require.NoError(t, err) require.NotNil(t, parsedVia) + parsedEncoded, err := id.ParseMatrixURI("matrix:roomid/7NdBVvkd4aLSbgKt9RXl%3Aexample.org") + require.NoError(t, err) + require.NotNil(t, parsedEncoded) assert.Equal(t, roomIDLink, *parsed) + assert.Equal(t, roomIDLink, *parsedEncoded) assert.Equal(t, roomIDViaLink, *parsedVia) } diff --git a/id/opaque.go b/id/opaque.go index 1d9f0dcf..c1ad4988 100644 --- a/id/opaque.go +++ b/id/opaque.go @@ -32,6 +32,9 @@ type EventID string // https://github.com/matrix-org/matrix-doc/pull/2716 type BatchID string +// A DelayID is a string identifying a delayed event. +type DelayID string + func (roomID RoomID) String() string { return string(roomID) } diff --git a/id/roomversion.go b/id/roomversion.go new file mode 100644 index 00000000..578c10bd --- /dev/null +++ b/id/roomversion.go @@ -0,0 +1,265 @@ +// Copyright (c) 2025 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package id + +import ( + "errors" + "fmt" + "slices" +) + +type RoomVersion string + +const ( + RoomV0 RoomVersion = "" // No room version, used for rooms created before room versions were introduced, equivalent to v1 + RoomV1 RoomVersion = "1" + RoomV2 RoomVersion = "2" + RoomV3 RoomVersion = "3" + RoomV4 RoomVersion = "4" + RoomV5 RoomVersion = "5" + RoomV6 RoomVersion = "6" + RoomV7 RoomVersion = "7" + RoomV8 RoomVersion = "8" + RoomV9 RoomVersion = "9" + RoomV10 RoomVersion = "10" + RoomV11 RoomVersion = "11" + RoomV12 RoomVersion = "12" +) + +func (rv RoomVersion) Equals(versions ...RoomVersion) bool { + return slices.Contains(versions, rv) +} + +func (rv RoomVersion) NotEquals(versions ...RoomVersion) bool { + return !rv.Equals(versions...) +} + +var ErrUnknownRoomVersion = errors.New("unknown room version") + +func (rv RoomVersion) unknownVersionError() error { + return fmt.Errorf("%w %s", ErrUnknownRoomVersion, rv) +} + +func (rv RoomVersion) IsKnown() bool { + switch rv { + case RoomV0, RoomV1, RoomV2, RoomV3, RoomV4, RoomV5, RoomV6, RoomV7, RoomV8, RoomV9, RoomV10, RoomV11, RoomV12: + return true + default: + return false + } +} + +type StateResVersion int + +const ( + // StateResV1 is the original state resolution algorithm. + StateResV1 StateResVersion = 0 + // StateResV2 is state resolution v2 introduced by https://github.com/matrix-org/matrix-spec-proposals/pull/1759 + StateResV2 StateResVersion = 1 + // StateResV2_1 is state resolution v2.1 introduced by https://github.com/matrix-org/matrix-spec-proposals/pull/4297 + StateResV2_1 StateResVersion = 2 +) + +// StateResVersion returns the version of the state resolution algorithm used by this room version. +func (rv RoomVersion) StateResVersion() StateResVersion { + switch rv { + case RoomV0, RoomV1: + return StateResV1 + case RoomV2, RoomV3, RoomV4, RoomV5, RoomV6, RoomV7, RoomV8, RoomV9, RoomV10, RoomV11: + return StateResV2 + case RoomV12: + return StateResV2_1 + default: + panic(rv.unknownVersionError()) + } +} + +type EventIDFormat int + +const ( + // EventIDFormatCustom is the original format used by room v1 and v2. + // Event IDs in this format are an arbitrary string followed by a colon and the server name. + EventIDFormatCustom EventIDFormat = 0 + // EventIDFormatBase64 is the format used by room v3 introduced by https://github.com/matrix-org/matrix-spec-proposals/pull/1659. + // Event IDs in this format are the standard unpadded base64-encoded SHA256 reference hash of the event. + EventIDFormatBase64 EventIDFormat = 1 + // EventIDFormatURLSafeBase64 is the format used by room v4 and later introduced by https://github.com/matrix-org/matrix-spec-proposals/pull/2002. + // Event IDs in this format are the url-safe unpadded base64-encoded SHA256 reference hash of the event. + EventIDFormatURLSafeBase64 EventIDFormat = 2 +) + +// EventIDFormat returns the format of event IDs used by this room version. +func (rv RoomVersion) EventIDFormat() EventIDFormat { + switch rv { + case RoomV0, RoomV1, RoomV2: + return EventIDFormatCustom + case RoomV3: + return EventIDFormatBase64 + default: + return EventIDFormatURLSafeBase64 + } +} + +///////////////////// +// Room v5 changes // +///////////////////// +// https://github.com/matrix-org/matrix-spec-proposals/pull/2077 + +// EnforceSigningKeyValidity returns true if the `valid_until_ts` field of federation signing keys +// must be enforced on received events. +// +// See https://github.com/matrix-org/matrix-spec-proposals/pull/2076 +func (rv RoomVersion) EnforceSigningKeyValidity() bool { + return rv.NotEquals(RoomV0, RoomV1, RoomV2, RoomV3, RoomV4) +} + +///////////////////// +// Room v6 changes // +///////////////////// +// https://github.com/matrix-org/matrix-spec-proposals/pull/2240 + +// SpecialCasedAliasesAuth returns true if the `m.room.aliases` event authorization is special cased +// to only always allow servers to modify the state event with their own server name as state key. +// This also implies that the `aliases` field is protected from redactions. +// +// See https://github.com/matrix-org/matrix-spec-proposals/pull/2432 +func (rv RoomVersion) SpecialCasedAliasesAuth() bool { + return rv.Equals(RoomV0, RoomV1, RoomV2, RoomV3, RoomV4, RoomV5) +} + +// ForbidFloatsAndBigInts returns true if floats and integers greater than 2^53-1 or lower than -2^53+1 are forbidden everywhere. +// +// See https://github.com/matrix-org/matrix-spec-proposals/pull/2540 +func (rv RoomVersion) ForbidFloatsAndBigInts() bool { + return rv.NotEquals(RoomV0, RoomV1, RoomV2, RoomV3, RoomV4, RoomV5) +} + +// NotificationsPowerLevels returns true if the `notifications` field in `m.room.power_levels` is validated in event auth. +// However, the field is not protected from redactions. +// +// See https://github.com/matrix-org/matrix-spec-proposals/pull/2209 +func (rv RoomVersion) NotificationsPowerLevels() bool { + return rv.NotEquals(RoomV0, RoomV1, RoomV2, RoomV3, RoomV4, RoomV5) +} + +///////////////////// +// Room v7 changes // +///////////////////// +// https://github.com/matrix-org/matrix-spec-proposals/pull/2998 + +// Knocks returns true if the `knock` join rule is supported. +// +// See https://github.com/matrix-org/matrix-spec-proposals/pull/2403 +func (rv RoomVersion) Knocks() bool { + return rv.NotEquals(RoomV0, RoomV1, RoomV2, RoomV3, RoomV4, RoomV5, RoomV6) +} + +///////////////////// +// Room v8 changes // +///////////////////// +// https://github.com/matrix-org/matrix-spec-proposals/pull/3289 + +// RestrictedJoins returns true if the `restricted` join rule is supported. +// This also implies that the `allow` field in the `m.room.join_rules` event is supported and protected from redactions. +// +// See https://github.com/matrix-org/matrix-spec-proposals/pull/3083 +func (rv RoomVersion) RestrictedJoins() bool { + return rv.NotEquals(RoomV0, RoomV1, RoomV2, RoomV3, RoomV4, RoomV5, RoomV6, RoomV7) +} + +///////////////////// +// Room v9 changes // +///////////////////// +// https://github.com/matrix-org/matrix-spec-proposals/pull/3375 + +// RestrictedJoinsFix returns true if the `join_authorised_via_users_server` field in `m.room.member` events is protected from redactions. +// +// See https://github.com/matrix-org/matrix-spec-proposals/pull/3375 +func (rv RoomVersion) RestrictedJoinsFix() bool { + return rv.RestrictedJoins() && rv != RoomV8 +} + +////////////////////// +// Room v10 changes // +////////////////////// +// https://github.com/matrix-org/matrix-spec-proposals/pull/3604 + +// ValidatePowerLevelInts returns true if the known values in `m.room.power_levels` must be integers (and not strings). +// +// See https://github.com/matrix-org/matrix-spec-proposals/pull/3667 +func (rv RoomVersion) ValidatePowerLevelInts() bool { + return rv.NotEquals(RoomV0, RoomV1, RoomV2, RoomV3, RoomV4, RoomV5, RoomV6, RoomV7, RoomV8, RoomV9) +} + +// KnockRestricted returns true if the `knock_restricted` join rule is supported. +// +// See https://github.com/matrix-org/matrix-spec-proposals/pull/3787 +func (rv RoomVersion) KnockRestricted() bool { + return rv.NotEquals(RoomV0, RoomV1, RoomV2, RoomV3, RoomV4, RoomV5, RoomV6, RoomV7, RoomV8, RoomV9) +} + +////////////////////// +// Room v11 changes // +////////////////////// +// https://github.com/matrix-org/matrix-spec-proposals/pull/3820 + +// CreatorInContent returns true if the `m.room.create` event has a `creator` field in content. +// +// See https://github.com/matrix-org/matrix-spec-proposals/pull/2175 +func (rv RoomVersion) CreatorInContent() bool { + return rv.Equals(RoomV0, RoomV1, RoomV2, RoomV3, RoomV4, RoomV5, RoomV6, RoomV7, RoomV8, RoomV9, RoomV10) +} + +// RedactsInContent returns true if the `m.room.redaction` event has the `redacts` field in content instead of at the top level. +// The redaction protection is also moved from the top level to the content field. +// +// See https://github.com/matrix-org/matrix-spec-proposals/pull/2174 +// (and https://github.com/matrix-org/matrix-spec-proposals/pull/2176 for the redaction protection). +func (rv RoomVersion) RedactsInContent() bool { + return rv.NotEquals(RoomV0, RoomV1, RoomV2, RoomV3, RoomV4, RoomV5, RoomV6, RoomV7, RoomV8, RoomV9, RoomV10) +} + +// UpdatedRedactionRules returns true if various updates to the redaction algorithm are applied. +// +// Specifically: +// +// * the `membership`, `origin`, and `prev_state` fields at the top level of all events are no longer protected. +// * the entire content of `m.room.create` is protected. +// * the `redacts` field in `m.room.redaction` content is protected instead of the top-level field. +// * the `m.room.power_levels` event protects the `invite` field in content. +// * the `signed` field inside the `third_party_invite` field in content of `m.room.member` events is protected. +// +// See https://github.com/matrix-org/matrix-spec-proposals/pull/2176, +// https://github.com/matrix-org/matrix-spec-proposals/pull/3821, and +// https://github.com/matrix-org/matrix-spec-proposals/pull/3989 +func (rv RoomVersion) UpdatedRedactionRules() bool { + return rv.NotEquals(RoomV0, RoomV1, RoomV2, RoomV3, RoomV4, RoomV5, RoomV6, RoomV7, RoomV8, RoomV9, RoomV10) +} + +////////////////////// +// Room v12 changes // +////////////////////// +// https://github.com/matrix-org/matrix-spec-proposals/pull/4304 + +// Return value of StateResVersion was changed to StateResV2_1 + +// PrivilegedRoomCreators returns true if the creator(s) of a room always have infinite power level. +// This also implies that the `m.room.create` event has an `additional_creators` field, +// and that the creators can't be present in the `m.room.power_levels` event. +// +// See https://github.com/matrix-org/matrix-spec-proposals/pull/4289 +func (rv RoomVersion) PrivilegedRoomCreators() bool { + return rv.NotEquals(RoomV0, RoomV1, RoomV2, RoomV3, RoomV4, RoomV5, RoomV6, RoomV7, RoomV8, RoomV9, RoomV10, RoomV11) +} + +// RoomIDIsCreateEventID returns true if the ID of rooms is the same as the ID of the `m.room.create` event. +// This also implies that `m.room.create` events do not have a `room_id` field. +// +// See https://github.com/matrix-org/matrix-spec-proposals/pull/4291 +func (rv RoomVersion) RoomIDIsCreateEventID() bool { + return rv.NotEquals(RoomV0, RoomV1, RoomV2, RoomV3, RoomV4, RoomV5, RoomV6, RoomV7, RoomV8, RoomV9, RoomV10, RoomV11) +} diff --git a/id/servername.go b/id/servername.go new file mode 100644 index 00000000..923705b6 --- /dev/null +++ b/id/servername.go @@ -0,0 +1,58 @@ +// Copyright (c) 2025 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package id + +import ( + "regexp" + "strconv" +) + +type ParsedServerNameType int + +const ( + ServerNameDNS ParsedServerNameType = iota + ServerNameIPv4 + ServerNameIPv6 +) + +type ParsedServerName struct { + Type ParsedServerNameType + Host string + Port int +} + +var ServerNameRegex = regexp.MustCompile(`^(?:\[([0-9A-Fa-f:.]{2,45})]|(\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3})|([0-9A-Za-z.-]{1,255}))(?::(\d{1,5}))?$`) + +func ValidateServerName(serverName string) bool { + return len(serverName) <= 255 && len(serverName) > 0 && ServerNameRegex.MatchString(serverName) +} + +func ParseServerName(serverName string) *ParsedServerName { + if len(serverName) > 255 || len(serverName) < 1 { + return nil + } + match := ServerNameRegex.FindStringSubmatch(serverName) + if len(match) != 5 { + return nil + } + port, _ := strconv.Atoi(match[4]) + parsed := &ParsedServerName{ + Port: port, + } + switch { + case match[1] != "": + parsed.Type = ServerNameIPv6 + parsed.Host = match[1] + case match[2] != "": + parsed.Type = ServerNameIPv4 + parsed.Host = match[2] + case match[3] != "": + parsed.Type = ServerNameDNS + parsed.Host = match[3] + } + return parsed +} diff --git a/id/trust.go b/id/trust.go index 04f6e36b..6255093e 100644 --- a/id/trust.go +++ b/id/trust.go @@ -16,6 +16,7 @@ type TrustState int const ( TrustStateBlacklisted TrustState = -100 + TrustStateDeviceKeyMismatch TrustState = -5 TrustStateUnset TrustState = 0 TrustStateUnknownDevice TrustState = 10 TrustStateForwarded TrustState = 20 @@ -23,7 +24,7 @@ const ( TrustStateCrossSignedTOFU TrustState = 100 TrustStateCrossSignedVerified TrustState = 200 TrustStateVerified TrustState = 300 - TrustStateInvalid TrustState = (1 << 31) - 1 + TrustStateInvalid TrustState = -2147483647 ) func (ts *TrustState) UnmarshalText(data []byte) error { @@ -44,6 +45,8 @@ func ParseTrustState(val string) TrustState { switch strings.ToLower(val) { case "blacklisted": return TrustStateBlacklisted + case "device-key-mismatch": + return TrustStateDeviceKeyMismatch case "unverified": return TrustStateUnset case "cross-signed-untrusted": @@ -67,6 +70,8 @@ func (ts TrustState) String() string { switch ts { case TrustStateBlacklisted: return "blacklisted" + case TrustStateDeviceKeyMismatch: + return "device-key-mismatch" case TrustStateUnset: return "unverified" case TrustStateCrossSignedUntrusted: diff --git a/id/userid.go b/id/userid.go index 1e1f3b29..726a0d58 100644 --- a/id/userid.go +++ b/id/userid.go @@ -30,10 +30,11 @@ func NewEncodedUserID(localpart, homeserver string) UserID { } var ( - ErrInvalidUserID = errors.New("is not a valid user ID") - ErrNoncompliantLocalpart = errors.New("contains characters that are not allowed") - ErrUserIDTooLong = errors.New("the given user ID is longer than 255 characters") - ErrEmptyLocalpart = errors.New("empty localparts are not allowed") + ErrInvalidUserID = errors.New("is not a valid user ID") + ErrNoncompliantLocalpart = errors.New("contains characters that are not allowed") + ErrUserIDTooLong = errors.New("the given user ID is longer than 255 characters") + ErrEmptyLocalpart = errors.New("empty localparts are not allowed") + ErrNoncompliantServerPart = errors.New("is not a valid server name") ) // ParseCommonIdentifier parses a common identifier according to https://spec.matrix.org/v1.9/appendices/#common-identifier-format @@ -43,10 +44,10 @@ func ParseCommonIdentifier[Stringish ~string](identifier Stringish) (sigil byte, } sigil = identifier[0] strIdentifier := string(identifier) - if strings.ContainsRune(strIdentifier, ':') { - parts := strings.SplitN(strIdentifier, ":", 2) - localpart = parts[0][1:] - homeserver = parts[1] + colonIdx := strings.IndexByte(strIdentifier, ':') + if colonIdx > 0 { + localpart = strIdentifier[1:colonIdx] + homeserver = strIdentifier[colonIdx+1:] } else { localpart = strIdentifier[1:] } @@ -103,21 +104,32 @@ func ValidateUserLocalpart(localpart string) error { return nil } -// ParseAndValidate parses the user ID into the localpart and server name like Parse, -// and also validates that the localpart is allowed according to the user identifiers spec. -func (userID UserID) ParseAndValidate() (localpart, homeserver string, err error) { - localpart, homeserver, err = userID.Parse() +// ParseAndValidateStrict is a stricter version of ParseAndValidateRelaxed that checks the localpart to only allow non-historical localparts. +// This should be used with care: there are real users still using historical localparts. +func (userID UserID) ParseAndValidateStrict() (localpart, homeserver string, err error) { + localpart, homeserver, err = userID.ParseAndValidateRelaxed() if err == nil { err = ValidateUserLocalpart(localpart) } - if err == nil && len(userID) > UserIDMaxLength { + return +} + +// ParseAndValidateRelaxed parses the user ID into the localpart and server name like Parse, +// and also validates that the user ID is not too long and that the server name is valid. +func (userID UserID) ParseAndValidateRelaxed() (localpart, homeserver string, err error) { + if len(userID) > UserIDMaxLength { err = ErrUserIDTooLong + return + } + localpart, homeserver, err = userID.Parse() + if err == nil && !ValidateServerName(homeserver) { + err = fmt.Errorf("%q %q", homeserver, ErrNoncompliantServerPart) } return } func (userID UserID) ParseAndDecode() (localpart, homeserver string, err error) { - localpart, homeserver, err = userID.ParseAndValidate() + localpart, homeserver, err = userID.ParseAndValidateStrict() if err == nil { localpart, err = DecodeUserLocalpart(localpart) } @@ -207,15 +219,15 @@ func DecodeUserLocalpart(str string) (string, error) { for i := 0; i < len(strBytes); i++ { b := strBytes[i] if !isValidByte(b) { - return "", fmt.Errorf("Byte pos %d: Invalid byte", i) + return "", fmt.Errorf("invalid encoded byte at position %d: %c", i, b) } if b == '_' { // next byte is a-z and should be upper-case or is another _ and should be a literal _ if i+1 >= len(strBytes) { - return "", fmt.Errorf("Byte pos %d: expected _[a-z_] encoding but ran out of string", i) + return "", fmt.Errorf("unexpected end of string after underscore at %d", i) } if !isValidEscapedChar(strBytes[i+1]) { // invalid escaping - return "", fmt.Errorf("Byte pos %d: expected _[a-z_] encoding", i) + return "", fmt.Errorf("unexpected byte %c after underscore at %d", strBytes[i+1], i) } if strBytes[i+1] == '_' { outputBuffer.WriteByte('_') @@ -225,7 +237,7 @@ func DecodeUserLocalpart(str string) (string, error) { i++ // skip next byte since we just handled it } else if b == '=' { // next 2 bytes are hex and should be buffered ready to be read as utf8 if i+2 >= len(strBytes) { - return "", fmt.Errorf("Byte pos: %d: expected quote-printable encoding but ran out of string", i) + return "", fmt.Errorf("unexpected end of string after equals sign at %d", i) } dst := make([]byte, 1) _, err := hex.Decode(dst, strBytes[i+1:i+3]) diff --git a/id/userid_test.go b/id/userid_test.go index 359bc687..57a88066 100644 --- a/id/userid_test.go +++ b/id/userid_test.go @@ -38,30 +38,30 @@ func TestUserID_Parse_Invalid(t *testing.T) { assert.True(t, errors.Is(err, id.ErrInvalidUserID)) } -func TestUserID_ParseAndValidate_Invalid(t *testing.T) { +func TestUserID_ParseAndValidateStrict_Invalid(t *testing.T) { const inputUserID = "@s p a c e:maunium.net" - _, _, err := id.UserID(inputUserID).ParseAndValidate() + _, _, err := id.UserID(inputUserID).ParseAndValidateStrict() assert.Error(t, err) assert.True(t, errors.Is(err, id.ErrNoncompliantLocalpart)) } -func TestUserID_ParseAndValidate_Empty(t *testing.T) { +func TestUserID_ParseAndValidateStrict_Empty(t *testing.T) { const inputUserID = "@:ponies.im" - _, _, err := id.UserID(inputUserID).ParseAndValidate() + _, _, err := id.UserID(inputUserID).ParseAndValidateStrict() assert.Error(t, err) assert.True(t, errors.Is(err, id.ErrEmptyLocalpart)) } -func TestUserID_ParseAndValidate_Long(t *testing.T) { +func TestUserID_ParseAndValidateStrict_Long(t *testing.T) { const inputUserID = "@aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa:example.com" - _, _, err := id.UserID(inputUserID).ParseAndValidate() + _, _, err := id.UserID(inputUserID).ParseAndValidateStrict() assert.Error(t, err) assert.True(t, errors.Is(err, id.ErrUserIDTooLong)) } -func TestUserID_ParseAndValidate_NotLong(t *testing.T) { +func TestUserID_ParseAndValidateStrict_NotLong(t *testing.T) { const inputUserID = "@aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa:example.com" - _, _, err := id.UserID(inputUserID).ParseAndValidate() + _, _, err := id.UserID(inputUserID).ParseAndValidateStrict() assert.NoError(t, err) } @@ -70,7 +70,7 @@ func TestUserIDEncoding(t *testing.T) { const encodedLocalpart = "_this=20local+part=20contains=20_il_le_ga_l=20ch=c3=a4racters=20=f0=9f=9a=a8" const inputServerName = "example.com" userID := id.NewEncodedUserID(inputLocalpart, inputServerName) - parsedLocalpart, parsedServerName, err := userID.ParseAndValidate() + parsedLocalpart, parsedServerName, err := userID.ParseAndValidateStrict() assert.NoError(t, err) assert.Equal(t, encodedLocalpart, parsedLocalpart) assert.Equal(t, inputServerName, parsedServerName) diff --git a/mediaproxy/mediaproxy.go b/mediaproxy/mediaproxy.go index ff8b2157..4d2bc7cf 100644 --- a/mediaproxy/mediaproxy.go +++ b/mediaproxy/mediaproxy.go @@ -1,4 +1,4 @@ -// Copyright (c) 2024 Tulir Asokan +// Copyright (c) 2025 Tulir Asokan // // This Source Code Form is subject to the terms of the Mozilla Public // License, v. 2.0. If a copy of the MPL was not distributed with this @@ -8,7 +8,6 @@ package mediaproxy import ( "context" - "encoding/json" "errors" "fmt" "io" @@ -22,11 +21,16 @@ import ( "strings" "time" - "github.com/gorilla/mux" "github.com/rs/zerolog" + "github.com/rs/zerolog/hlog" + "go.mau.fi/util/exerrors" + "go.mau.fi/util/exhttp" + "go.mau.fi/util/ptr" + "go.mau.fi/util/requestlog" "maunium.net/go/mautrix" "maunium.net/go/mautrix/federation" + "maunium.net/go/mautrix/id" ) type GetMediaResponse interface { @@ -91,17 +95,20 @@ func (d *GetMediaResponseCallback) GetContentType() string { return d.ContentType } +type FileMeta struct { + ContentType string + ReplacementFile string +} + type GetMediaResponseFile struct { - Callback func(w *os.File) error - ContentType string + Callback func(w *os.File) (*FileMeta, error) } type GetMediaFunc = func(ctx context.Context, mediaID string, params map[string]string) (response GetMediaResponse, err error) type MediaProxy struct { - KeyServer *federation.KeyServer - - ForceProxyLegacyFederation bool + KeyServer *federation.KeyServer + ServerAuth *federation.ServerAuth GetMedia GetMediaFunc PrepareProxyRequest func(*http.Request) @@ -109,8 +116,8 @@ type MediaProxy struct { serverName string serverKey *federation.SigningKey - FederationRouter *mux.Router - ClientMediaRouter *mux.Router + FederationRouter *http.ServeMux + ClientMediaRouter *http.ServeMux } func New(serverName string, serverKey string, getMedia GetMediaFunc) (*MediaProxy, error) { @@ -118,7 +125,7 @@ func New(serverName string, serverKey string, getMedia GetMediaFunc) (*MediaProx if err != nil { return nil, err } - return &MediaProxy{ + mp := &MediaProxy{ serverName: serverName, serverKey: parsed, GetMedia: getMedia, @@ -133,12 +140,27 @@ func New(serverName string, serverKey string, getMedia GetMediaFunc) (*MediaProx Version: strings.TrimPrefix(mautrix.VersionWithCommit, "v"), }, }, - }, nil + } + mp.FederationRouter = http.NewServeMux() + mp.FederationRouter.HandleFunc("GET /v1/media/download/{mediaID}", mp.DownloadMediaFederation) + mp.FederationRouter.HandleFunc("GET /v1/media/thumbnail/{mediaID}", mp.DownloadMediaFederation) + mp.FederationRouter.HandleFunc("GET /v1/version", mp.KeyServer.GetServerVersion) + mp.ClientMediaRouter = http.NewServeMux() + mp.ClientMediaRouter.HandleFunc("GET /download/{serverName}/{mediaID}", mp.DownloadMedia) + mp.ClientMediaRouter.HandleFunc("GET /download/{serverName}/{mediaID}/{fileName}", mp.DownloadMedia) + mp.ClientMediaRouter.HandleFunc("GET /thumbnail/{serverName}/{mediaID}", mp.DownloadMedia) + mp.ClientMediaRouter.HandleFunc("PUT /upload/{serverName}/{mediaID}", mp.UploadNotSupported) + mp.ClientMediaRouter.HandleFunc("POST /upload", mp.UploadNotSupported) + mp.ClientMediaRouter.HandleFunc("POST /create", mp.UploadNotSupported) + mp.ClientMediaRouter.HandleFunc("GET /config", mp.UploadNotSupported) + mp.ClientMediaRouter.HandleFunc("GET /preview_url", mp.PreviewURLNotSupported) + return mp, nil } type BasicConfig struct { ServerName string `yaml:"server_name" json:"server_name"` ServerKey string `yaml:"server_key" json:"server_key"` + FederationAuth bool `yaml:"federation_auth" json:"federation_auth"` WellKnownResponse string `yaml:"well_known_response" json:"well_known_response"` } @@ -150,6 +172,9 @@ func NewFromConfig(cfg BasicConfig, getMedia GetMediaFunc) (*MediaProxy, error) if cfg.WellKnownResponse != "" { mp.KeyServer.WellKnownTarget = cfg.WellKnownResponse } + if cfg.FederationAuth { + mp.EnableServerAuth(nil, nil) + } return mp, nil } @@ -159,8 +184,8 @@ type ServerConfig struct { } func (mp *MediaProxy) Listen(cfg ServerConfig) error { - router := mux.NewRouter() - mp.RegisterRoutes(router) + router := http.NewServeMux() + mp.RegisterRoutes(router, zerolog.Nop()) return http.ListenAndServe(fmt.Sprintf("%s:%d", cfg.Hostname, cfg.Port), router) } @@ -172,49 +197,42 @@ func (mp *MediaProxy) GetServerKey() *federation.SigningKey { return mp.serverKey } -func (mp *MediaProxy) RegisterRoutes(router *mux.Router) { - if mp.FederationRouter == nil { - mp.FederationRouter = router.PathPrefix("/_matrix/federation").Subrouter() +func (mp *MediaProxy) EnableServerAuth(client *federation.Client, keyCache federation.KeyCache) { + if keyCache == nil { + keyCache = federation.NewInMemoryCache() } - if mp.ClientMediaRouter == nil { - mp.ClientMediaRouter = router.PathPrefix("/_matrix/client/v1/media").Subrouter() + if client == nil { + resCache, _ := keyCache.(federation.ResolutionCache) + client = federation.NewClient(mp.serverName, mp.serverKey, resCache) } - - mp.FederationRouter.HandleFunc("/v1/media/download/{mediaID}", mp.DownloadMediaFederation).Methods(http.MethodGet) - mp.FederationRouter.HandleFunc("/v1/version", mp.KeyServer.GetServerVersion).Methods(http.MethodGet) - mp.ClientMediaRouter.HandleFunc("/download/{serverName}/{mediaID}", mp.DownloadMedia).Methods(http.MethodGet) - mp.ClientMediaRouter.HandleFunc("/download/{serverName}/{mediaID}/{fileName}", mp.DownloadMedia).Methods(http.MethodGet) - mp.ClientMediaRouter.HandleFunc("/thumbnail/{serverName}/{mediaID}", mp.DownloadMedia).Methods(http.MethodGet) - mp.ClientMediaRouter.HandleFunc("/upload/{serverName}/{mediaID}", mp.UploadNotSupported).Methods(http.MethodPut) - mp.ClientMediaRouter.HandleFunc("/upload", mp.UploadNotSupported).Methods(http.MethodPost) - mp.ClientMediaRouter.HandleFunc("/create", mp.UploadNotSupported).Methods(http.MethodPost) - mp.ClientMediaRouter.HandleFunc("/config", mp.UploadNotSupported).Methods(http.MethodGet) - mp.ClientMediaRouter.HandleFunc("/preview_url", mp.PreviewURLNotSupported).Methods(http.MethodGet) - mp.FederationRouter.NotFoundHandler = http.HandlerFunc(mp.UnknownEndpoint) - mp.FederationRouter.MethodNotAllowedHandler = http.HandlerFunc(mp.UnsupportedMethod) - mp.ClientMediaRouter.NotFoundHandler = http.HandlerFunc(mp.UnknownEndpoint) - mp.ClientMediaRouter.MethodNotAllowedHandler = http.HandlerFunc(mp.UnsupportedMethod) - corsMiddleware := func(next http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Access-Control-Allow-Origin", "*") - w.Header().Set("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS") - w.Header().Set("Access-Control-Allow-Headers", "X-Requested-With, Content-Type, Authorization") - w.Header().Set("Content-Security-Policy", "sandbox; default-src 'none'; script-src 'none'; plugin-types application/pdf; style-src 'unsafe-inline'; object-src 'self';") - next.ServeHTTP(w, r) - }) - } - mp.ClientMediaRouter.Use(corsMiddleware) - mp.KeyServer.Register(router) + mp.ServerAuth = federation.NewServerAuth(client, keyCache, func(auth federation.XMatrixAuth) string { + return mp.GetServerName() + }) } -// Deprecated: use mautrix.RespError instead -type ResponseError struct { - Status int - Data any -} - -func (err *ResponseError) Error() string { - return fmt.Sprintf("HTTP %d: %v", err.Status, err.Data) +func (mp *MediaProxy) RegisterRoutes(router *http.ServeMux, log zerolog.Logger) { + errorBodies := exhttp.ErrorBodies{ + NotFound: exerrors.Must(ptr.Ptr(mautrix.MUnrecognized.WithMessage("Unrecognized endpoint")).MarshalJSON()), + MethodNotAllowed: exerrors.Must(ptr.Ptr(mautrix.MUnrecognized.WithMessage("Invalid method for endpoint")).MarshalJSON()), + } + router.Handle("/_matrix/federation/", exhttp.ApplyMiddleware( + mp.FederationRouter, + exhttp.StripPrefix("/_matrix/federation"), + hlog.NewHandler(log), + hlog.RequestIDHandler("request_id", "Request-Id"), + requestlog.AccessLogger(requestlog.Options{TrustXForwardedFor: true}), + exhttp.HandleErrors(errorBodies), + )) + router.Handle("/_matrix/client/v1/media/", exhttp.ApplyMiddleware( + mp.ClientMediaRouter, + exhttp.StripPrefix("/_matrix/client/v1/media"), + hlog.NewHandler(log), + hlog.RequestIDHandler("request_id", "Request-Id"), + exhttp.CORSMiddleware, + requestlog.AccessLogger(requestlog.Options{TrustXForwardedFor: true}), + exhttp.HandleErrors(errorBodies), + )) + mp.KeyServer.Register(router, log) } var ErrInvalidMediaIDSyntax = errors.New("invalid media ID syntax") @@ -228,20 +246,18 @@ func queryToMap(vals url.Values) map[string]string { } func (mp *MediaProxy) getMedia(w http.ResponseWriter, r *http.Request) GetMediaResponse { - mediaID := mux.Vars(r)["mediaID"] + mediaID := r.PathValue("mediaID") + if !id.IsValidMediaID(mediaID) { + mautrix.MNotFound.WithMessage("Media ID %q is not valid", mediaID).Write(w) + return nil + } resp, err := mp.GetMedia(r.Context(), mediaID, queryToMap(r.URL.Query())) if err != nil { - //lint:ignore SA1019 deprecated types need to be supported until they're removed - var respError *ResponseError var mautrixRespError mautrix.RespError if errors.Is(err, ErrInvalidMediaIDSyntax) { mautrix.MNotFound.WithMessage("This is a media proxy at %q, other media downloads are not available here", mp.serverName).Write(w) } else if errors.As(err, &mautrixRespError) { mautrixRespError.Write(w) - } else if errors.As(err, &respError) { - w.Header().Add("Content-Type", "application/json") - w.WriteHeader(respError.Status) - _ = json.NewEncoder(w).Encode(respError.Data) } else { zerolog.Ctx(r.Context()).Err(err).Str("media_id", mediaID).Msg("Failed to get media URL") mautrix.MNotFound.WithMessage("Media not found").Write(w) @@ -271,9 +287,16 @@ func startMultipart(ctx context.Context, w http.ResponseWriter) *multipart.Write } func (mp *MediaProxy) DownloadMediaFederation(w http.ResponseWriter, r *http.Request) { + if mp.ServerAuth != nil { + var err *mautrix.RespError + r, err = mp.ServerAuth.Authenticate(r) + if err != nil { + err.Write(w) + return + } + } ctx := r.Context() log := zerolog.Ctx(ctx) - // TODO check destination header in X-Matrix auth resp := mp.getMedia(w, r) if resp == nil { @@ -369,8 +392,7 @@ func (mp *MediaProxy) addHeaders(w http.ResponseWriter, mimeType, fileName strin func (mp *MediaProxy) DownloadMedia(w http.ResponseWriter, r *http.Request) { ctx := r.Context() log := zerolog.Ctx(ctx) - vars := mux.Vars(r) - if vars["serverName"] != mp.serverName { + if r.PathValue("serverName") != mp.serverName { mautrix.MNotFound.WithMessage("This is a media proxy at %q, other media downloads are not available here", mp.serverName).Write(w) return } @@ -393,7 +415,7 @@ func (mp *MediaProxy) DownloadMedia(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusTemporaryRedirect) } else if fileResp, ok := resp.(*GetMediaResponseFile); ok { responseStarted, err := doTempFileDownload(fileResp, func(wt io.WriterTo, size int64, mimeType string) error { - mp.addHeaders(w, mimeType, vars["fileName"]) + mp.addHeaders(w, mimeType, r.PathValue("fileName")) w.Header().Set("Content-Length", strconv.FormatInt(size, 10)) w.WriteHeader(http.StatusOK) _, err := wt.WriteTo(w) @@ -410,13 +432,16 @@ func (mp *MediaProxy) DownloadMedia(w http.ResponseWriter, r *http.Request) { } } } - } else if dataResp, ok := resp.(GetMediaResponseWriter); ok { - mp.addHeaders(w, dataResp.GetContentType(), vars["fileName"]) - if dataResp.GetContentLength() != 0 { - w.Header().Set("Content-Length", strconv.FormatInt(dataResp.GetContentLength(), 10)) + } else if writerResp, ok := resp.(GetMediaResponseWriter); ok { + if dataResp, ok := writerResp.(*GetMediaResponseData); ok { + defer dataResp.Reader.Close() + } + mp.addHeaders(w, writerResp.GetContentType(), r.PathValue("fileName")) + if writerResp.GetContentLength() != 0 { + w.Header().Set("Content-Length", strconv.FormatInt(writerResp.GetContentLength(), 10)) } w.WriteHeader(http.StatusOK) - _, err := dataResp.WriteTo(w) + _, err := writerResp.WriteTo(w) if err != nil { log.Err(err).Msg("Failed to write media data") } @@ -433,23 +458,35 @@ func doTempFileDownload( if err != nil { return false, fmt.Errorf("failed to create temp file: %w", err) } + origTempFile := tempFile defer func() { - _ = tempFile.Close() - _ = os.Remove(tempFile.Name()) + _ = origTempFile.Close() + _ = os.Remove(origTempFile.Name()) }() - err = data.Callback(tempFile) + meta, err := data.Callback(tempFile) if err != nil { return false, err } - _, err = tempFile.Seek(0, io.SeekStart) - if err != nil { - return false, fmt.Errorf("failed to seek to start of temp file: %w", err) + if meta.ReplacementFile != "" { + tempFile, err = os.Open(meta.ReplacementFile) + if err != nil { + return false, fmt.Errorf("failed to open replacement file: %w", err) + } + defer func() { + _ = tempFile.Close() + _ = os.Remove(origTempFile.Name()) + }() + } else { + _, err = tempFile.Seek(0, io.SeekStart) + if err != nil { + return false, fmt.Errorf("failed to seek to start of temp file: %w", err) + } } fileInfo, err := tempFile.Stat() if err != nil { return false, fmt.Errorf("failed to stat temp file: %w", err) } - mimeType := data.ContentType + mimeType := meta.ContentType if mimeType == "" { buf := make([]byte, 512) n, err := tempFile.Read(buf) @@ -477,11 +514,6 @@ var ( ErrPreviewURLNotSupported = mautrix.MUnrecognized. WithMessage("This is a media proxy and does not support URL previews."). WithStatus(http.StatusNotImplemented) - ErrUnknownEndpoint = mautrix.MUnrecognized. - WithMessage("Unrecognized endpoint") - ErrUnsupportedMethod = mautrix.MUnrecognized. - WithMessage("Invalid method for endpoint"). - WithStatus(http.StatusMethodNotAllowed) ) func (mp *MediaProxy) UploadNotSupported(w http.ResponseWriter, r *http.Request) { @@ -491,11 +523,3 @@ func (mp *MediaProxy) UploadNotSupported(w http.ResponseWriter, r *http.Request) func (mp *MediaProxy) PreviewURLNotSupported(w http.ResponseWriter, r *http.Request) { ErrPreviewURLNotSupported.Write(w) } - -func (mp *MediaProxy) UnknownEndpoint(w http.ResponseWriter, r *http.Request) { - ErrUnknownEndpoint.Write(w) -} - -func (mp *MediaProxy) UnsupportedMethod(w http.ResponseWriter, r *http.Request) { - ErrUnsupportedMethod.Write(w) -} diff --git a/mockserver/mockserver.go b/mockserver/mockserver.go new file mode 100644 index 00000000..507c24a5 --- /dev/null +++ b/mockserver/mockserver.go @@ -0,0 +1,307 @@ +// Copyright (c) 2025 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package mockserver + +import ( + "context" + "encoding/json" + "fmt" + "io" + "maps" + "net/http" + "net/http/httptest" + "strings" + "testing" + + globallog "github.com/rs/zerolog/log" // zerolog-allow-global-log + "github.com/stretchr/testify/require" + "go.mau.fi/util/dbutil" + "go.mau.fi/util/exerrors" + "go.mau.fi/util/exhttp" + "go.mau.fi/util/random" + + "maunium.net/go/mautrix" + "maunium.net/go/mautrix/crypto" + "maunium.net/go/mautrix/crypto/cryptohelper" + "maunium.net/go/mautrix/event" + "maunium.net/go/mautrix/id" +) + +func mustDecode(r *http.Request, data any) { + exerrors.PanicIfNotNil(json.NewDecoder(r.Body).Decode(data)) +} + +type userAndDeviceID struct { + UserID id.UserID + DeviceID id.DeviceID +} + +type MockServer struct { + Router *http.ServeMux + Server *httptest.Server + + AccessTokenToUserID map[string]userAndDeviceID + DeviceInbox map[id.UserID]map[id.DeviceID][]event.Event + AccountData map[id.UserID]map[event.Type]json.RawMessage + DeviceKeys map[id.UserID]map[id.DeviceID]mautrix.DeviceKeys + OneTimeKeys map[id.UserID]map[id.DeviceID]map[id.KeyID]mautrix.OneTimeKey + MasterKeys map[id.UserID]mautrix.CrossSigningKeys + SelfSigningKeys map[id.UserID]mautrix.CrossSigningKeys + UserSigningKeys map[id.UserID]mautrix.CrossSigningKeys + + PopOTKs bool + MemoryStore bool +} + +func Create(t testing.TB) *MockServer { + t.Helper() + + server := MockServer{ + AccessTokenToUserID: map[string]userAndDeviceID{}, + DeviceInbox: map[id.UserID]map[id.DeviceID][]event.Event{}, + AccountData: map[id.UserID]map[event.Type]json.RawMessage{}, + DeviceKeys: map[id.UserID]map[id.DeviceID]mautrix.DeviceKeys{}, + OneTimeKeys: map[id.UserID]map[id.DeviceID]map[id.KeyID]mautrix.OneTimeKey{}, + MasterKeys: map[id.UserID]mautrix.CrossSigningKeys{}, + SelfSigningKeys: map[id.UserID]mautrix.CrossSigningKeys{}, + UserSigningKeys: map[id.UserID]mautrix.CrossSigningKeys{}, + PopOTKs: true, + MemoryStore: true, + } + + router := http.NewServeMux() + router.HandleFunc("POST /_matrix/client/v3/login", server.postLogin) + router.HandleFunc("POST /_matrix/client/v3/keys/query", server.postKeysQuery) + router.HandleFunc("POST /_matrix/client/v3/keys/claim", server.postKeysClaim) + router.HandleFunc("PUT /_matrix/client/v3/sendToDevice/{type}/{txn}", server.putSendToDevice) + router.HandleFunc("PUT /_matrix/client/v3/user/{userID}/account_data/{type}", server.putAccountData) + router.HandleFunc("POST /_matrix/client/v3/keys/device_signing/upload", server.postDeviceSigningUpload) + router.HandleFunc("POST /_matrix/client/v3/keys/signatures/upload", server.emptyResp) + router.HandleFunc("POST /_matrix/client/v3/keys/upload", server.postKeysUpload) + server.Router = router + server.Server = httptest.NewServer(router) + t.Cleanup(server.Server.Close) + return &server +} + +func (ms *MockServer) getUserID(r *http.Request) userAndDeviceID { + authHeader := r.Header.Get("Authorization") + authHeader = strings.TrimPrefix(authHeader, "Bearer ") + userID, ok := ms.AccessTokenToUserID[authHeader] + if !ok { + panic("no user ID found for access token " + authHeader) + } + return userID +} + +func (ms *MockServer) emptyResp(w http.ResponseWriter, _ *http.Request) { + exhttp.WriteEmptyJSONResponse(w, http.StatusOK) +} + +func (ms *MockServer) postLogin(w http.ResponseWriter, r *http.Request) { + var loginReq mautrix.ReqLogin + mustDecode(r, &loginReq) + + deviceID := loginReq.DeviceID + if deviceID == "" { + deviceID = id.DeviceID(random.String(10)) + } + + accessToken := random.String(30) + userID := id.UserID(loginReq.Identifier.User) + ms.AccessTokenToUserID[accessToken] = userAndDeviceID{ + UserID: userID, + DeviceID: deviceID, + } + + exhttp.WriteJSONResponse(w, http.StatusOK, &mautrix.RespLogin{ + AccessToken: accessToken, + DeviceID: deviceID, + UserID: userID, + }) +} + +func (ms *MockServer) putSendToDevice(w http.ResponseWriter, r *http.Request) { + var req mautrix.ReqSendToDevice + mustDecode(r, &req) + evtType := event.Type{Type: r.PathValue("type"), Class: event.ToDeviceEventType} + + for user, devices := range req.Messages { + for device, content := range devices { + if _, ok := ms.DeviceInbox[user]; !ok { + ms.DeviceInbox[user] = map[id.DeviceID][]event.Event{} + } + content.ParseRaw(evtType) + ms.DeviceInbox[user][device] = append(ms.DeviceInbox[user][device], event.Event{ + Sender: ms.getUserID(r).UserID, + Type: evtType, + Content: *content, + }) + } + } + ms.emptyResp(w, r) +} + +func (ms *MockServer) putAccountData(w http.ResponseWriter, r *http.Request) { + userID := id.UserID(r.PathValue("userID")) + eventType := event.Type{Type: r.PathValue("type"), Class: event.AccountDataEventType} + + jsonData, _ := io.ReadAll(r.Body) + if _, ok := ms.AccountData[userID]; !ok { + ms.AccountData[userID] = map[event.Type]json.RawMessage{} + } + ms.AccountData[userID][eventType] = json.RawMessage(jsonData) + ms.emptyResp(w, r) +} + +func (ms *MockServer) postKeysQuery(w http.ResponseWriter, r *http.Request) { + var req mautrix.ReqQueryKeys + mustDecode(r, &req) + resp := mautrix.RespQueryKeys{ + MasterKeys: map[id.UserID]mautrix.CrossSigningKeys{}, + UserSigningKeys: map[id.UserID]mautrix.CrossSigningKeys{}, + SelfSigningKeys: map[id.UserID]mautrix.CrossSigningKeys{}, + DeviceKeys: map[id.UserID]map[id.DeviceID]mautrix.DeviceKeys{}, + } + for user := range req.DeviceKeys { + resp.MasterKeys[user] = ms.MasterKeys[user] + resp.UserSigningKeys[user] = ms.UserSigningKeys[user] + resp.SelfSigningKeys[user] = ms.SelfSigningKeys[user] + resp.DeviceKeys[user] = ms.DeviceKeys[user] + } + exhttp.WriteJSONResponse(w, http.StatusOK, &resp) +} + +func (ms *MockServer) postKeysClaim(w http.ResponseWriter, r *http.Request) { + var req mautrix.ReqClaimKeys + mustDecode(r, &req) + resp := mautrix.RespClaimKeys{ + OneTimeKeys: map[id.UserID]map[id.DeviceID]map[id.KeyID]mautrix.OneTimeKey{}, + } + for user, devices := range req.OneTimeKeys { + resp.OneTimeKeys[user] = map[id.DeviceID]map[id.KeyID]mautrix.OneTimeKey{} + for device := range devices { + keys := ms.OneTimeKeys[user][device] + for keyID, key := range keys { + if ms.PopOTKs { + delete(keys, keyID) + } + resp.OneTimeKeys[user][device] = map[id.KeyID]mautrix.OneTimeKey{ + keyID: key, + } + break + } + } + } + exhttp.WriteJSONResponse(w, http.StatusOK, &resp) +} + +func (ms *MockServer) postKeysUpload(w http.ResponseWriter, r *http.Request) { + var req mautrix.ReqUploadKeys + mustDecode(r, &req) + + uid := ms.getUserID(r) + userID := uid.UserID + if _, ok := ms.DeviceKeys[userID]; !ok { + ms.DeviceKeys[userID] = map[id.DeviceID]mautrix.DeviceKeys{} + } + if _, ok := ms.OneTimeKeys[userID]; !ok { + ms.OneTimeKeys[userID] = map[id.DeviceID]map[id.KeyID]mautrix.OneTimeKey{} + } + + if req.DeviceKeys != nil { + ms.DeviceKeys[userID][uid.DeviceID] = *req.DeviceKeys + } + otks, ok := ms.OneTimeKeys[userID][uid.DeviceID] + if !ok { + otks = map[id.KeyID]mautrix.OneTimeKey{} + ms.OneTimeKeys[userID][uid.DeviceID] = otks + } + if req.OneTimeKeys != nil { + maps.Copy(otks, req.OneTimeKeys) + } + + exhttp.WriteJSONResponse(w, http.StatusOK, &mautrix.RespUploadKeys{ + OneTimeKeyCounts: mautrix.OTKCount{SignedCurve25519: len(otks)}, + }) +} + +func (ms *MockServer) postDeviceSigningUpload(w http.ResponseWriter, r *http.Request) { + var req mautrix.UploadCrossSigningKeysReq[any] + mustDecode(r, &req) + + userID := ms.getUserID(r).UserID + ms.MasterKeys[userID] = req.Master + ms.SelfSigningKeys[userID] = req.SelfSigning + ms.UserSigningKeys[userID] = req.UserSigning + + ms.emptyResp(w, r) +} + +func (ms *MockServer) Login(t testing.TB, ctx context.Context, userID id.UserID, deviceID id.DeviceID) (*mautrix.Client, crypto.Store) { + t.Helper() + if ctx == nil { + ctx = context.TODO() + } + client, err := mautrix.NewClient(ms.Server.URL, "", "") + require.NoError(t, err) + client.Client = ms.Server.Client() + + _, err = client.Login(ctx, &mautrix.ReqLogin{ + Type: mautrix.AuthTypePassword, + Identifier: mautrix.UserIdentifier{ + Type: mautrix.IdentifierTypeUser, + User: userID.String(), + }, + DeviceID: deviceID, + Password: "password", + StoreCredentials: true, + }) + require.NoError(t, err) + + var store any + if ms.MemoryStore { + store = crypto.NewMemoryStore(nil) + client.StateStore = mautrix.NewMemoryStateStore() + } else { + store, err = dbutil.NewFromConfig("", dbutil.Config{ + PoolConfig: dbutil.PoolConfig{ + Type: "sqlite3-fk-wal", + URI: fmt.Sprintf("file:%s?mode=memory&cache=shared&_txlock=immediate", random.String(10)), + MaxOpenConns: 5, + MaxIdleConns: 1, + }, + }, nil) + require.NoError(t, err) + } + cryptoHelper, err := cryptohelper.NewCryptoHelper(client, []byte("test"), store) + require.NoError(t, err) + client.Crypto = cryptoHelper + + err = cryptoHelper.Init(ctx) + require.NoError(t, err) + + machineLog := globallog.Logger.With(). + Stringer("my_user_id", userID). + Stringer("my_device_id", deviceID). + Logger() + cryptoHelper.Machine().Log = &machineLog + + err = cryptoHelper.Machine().ShareKeys(ctx, 50) + require.NoError(t, err) + + return client, cryptoHelper.Machine().CryptoStore +} + +func (ms *MockServer) DispatchToDevice(t testing.TB, ctx context.Context, client *mautrix.Client) { + t.Helper() + + for _, evt := range ms.DeviceInbox[client.UserID][client.DeviceID] { + client.Syncer.(*mautrix.DefaultSyncer).Dispatch(ctx, &evt) + ms.DeviceInbox[client.UserID][client.DeviceID] = ms.DeviceInbox[client.UserID][client.DeviceID][1:] + } +} diff --git a/pushrules/action.go b/pushrules/action.go index 9838e88b..b5a884b2 100644 --- a/pushrules/action.go +++ b/pushrules/action.go @@ -105,7 +105,7 @@ func (action *PushAction) UnmarshalJSON(raw []byte) error { if ok { action.Action = ActionSetTweak action.Tweak = PushActionTweak(tweak) - action.Value, _ = val["value"] + action.Value = val["value"] } } return nil diff --git a/pushrules/action_test.go b/pushrules/action_test.go index a8f68415..3c0aa168 100644 --- a/pushrules/action_test.go +++ b/pushrules/action_test.go @@ -139,9 +139,9 @@ func TestPushAction_UnmarshalJSON_InvalidTypeDoesNothing(t *testing.T) { } err := pa.UnmarshalJSON([]byte(`{"foo": "bar"}`)) - assert.Nil(t, err) + assert.NoError(t, err) err = pa.UnmarshalJSON([]byte(`9001`)) - assert.Nil(t, err) + assert.NoError(t, err) assert.Equal(t, pushrules.PushActionType("unchanged"), pa.Action) assert.Equal(t, pushrules.PushActionTweak("unchanged"), pa.Tweak) @@ -156,7 +156,7 @@ func TestPushAction_UnmarshalJSON_StringChangesActionType(t *testing.T) { } err := pa.UnmarshalJSON([]byte(`"foo"`)) - assert.Nil(t, err) + assert.NoError(t, err) assert.Equal(t, pushrules.PushActionType("foo"), pa.Action) assert.Equal(t, pushrules.PushActionTweak("unchanged"), pa.Tweak) @@ -171,7 +171,7 @@ func TestPushAction_UnmarshalJSON_SetTweakChangesTweak(t *testing.T) { } err := pa.UnmarshalJSON([]byte(`{"set_tweak": "foo", "value": 123.0}`)) - assert.Nil(t, err) + assert.NoError(t, err) assert.Equal(t, pushrules.ActionSetTweak, pa.Action) assert.Equal(t, pushrules.PushActionTweak("foo"), pa.Tweak) @@ -185,7 +185,7 @@ func TestPushAction_MarshalJSON_TweakOutputWorks(t *testing.T) { Value: "bar", } data, err := pa.MarshalJSON() - assert.Nil(t, err) + assert.NoError(t, err) assert.Equal(t, []byte(`{"set_tweak":"foo","value":"bar"}`), data) } @@ -196,6 +196,6 @@ func TestPushAction_MarshalJSON_OtherOutputWorks(t *testing.T) { Value: "bar", } data, err := pa.MarshalJSON() - assert.Nil(t, err) + assert.NoError(t, err) assert.Equal(t, []byte(`"something else"`), data) } diff --git a/pushrules/condition_test.go b/pushrules/condition_test.go index 0d3eaf7a..37af3e34 100644 --- a/pushrules/condition_test.go +++ b/pushrules/condition_test.go @@ -102,14 +102,6 @@ func newEventPropertyIsPushCondition(key string, value any) *pushrules.PushCondi } } -func newEventPropertyContainsPushCondition(key string, value any) *pushrules.PushCondition { - return &pushrules.PushCondition{ - Kind: pushrules.KindEventPropertyContains, - Key: key, - Value: value, - } -} - func TestPushCondition_Match_InvalidKind(t *testing.T) { condition := &pushrules.PushCondition{ Kind: pushrules.PushCondKind("invalid"), diff --git a/pushrules/pushrules_test.go b/pushrules/pushrules_test.go index a531ca28..a5a0f5e7 100644 --- a/pushrules/pushrules_test.go +++ b/pushrules/pushrules_test.go @@ -25,7 +25,7 @@ func TestEventToPushRules(t *testing.T) { }, } pushRuleset, err := pushrules.EventToPushRules(evt) - assert.Nil(t, err) + assert.NoError(t, err) assert.NotNil(t, pushRuleset) assert.IsType(t, pushRuleset.Override, pushrules.PushRuleArray{}) diff --git a/pushrules/rule.go b/pushrules/rule.go index ee6d33c4..cf659695 100644 --- a/pushrules/rule.go +++ b/pushrules/rule.go @@ -8,7 +8,10 @@ package pushrules import ( "encoding/gob" + "regexp" + "strings" + "go.mau.fi/util/exerrors" "go.mau.fi/util/glob" "maunium.net/go/mautrix/event" @@ -165,13 +168,20 @@ func (rule *PushRule) matchConditions(room Room, evt *event.Event) bool { } func (rule *PushRule) matchPattern(room Room, evt *event.Event) bool { - pattern := glob.CompileWithImplicitContains(rule.Pattern) - if pattern == nil { - return false - } msg, ok := evt.Content.Raw["body"].(string) if !ok { return false } - return pattern.Match(msg) + var buf strings.Builder + // As per https://spec.matrix.org/unstable/client-server-api/#push-rules, content rules are case-insensitive + // and must match whole words, so wrap the converted glob in (?i) and \b. + buf.WriteString(`(?i)\b`) + // strings.Builder will never return errors + exerrors.PanicIfNotNil(glob.ToRegexPattern(rule.Pattern, &buf)) + buf.WriteString(`\b`) + pattern, err := regexp.Compile(buf.String()) + if err != nil { + return false + } + return pattern.MatchString(msg) } diff --git a/pushrules/rule_test.go b/pushrules/rule_test.go index 803c721e..7ff839a7 100644 --- a/pushrules/rule_test.go +++ b/pushrules/rule_test.go @@ -186,6 +186,34 @@ func TestPushRule_Match_Content(t *testing.T) { assert.True(t, rule.Match(blankTestRoom, evt)) } +func TestPushRule_Match_WordBoundary(t *testing.T) { + rule := &pushrules.PushRule{ + Type: pushrules.ContentRule, + Enabled: true, + Pattern: "test", + } + + evt := newFakeEvent(event.EventMessage, &event.MessageEventContent{ + MsgType: event.MsgEmote, + Body: "is testing pushrules", + }) + assert.False(t, rule.Match(blankTestRoom, evt)) +} + +func TestPushRule_Match_CaseInsensitive(t *testing.T) { + rule := &pushrules.PushRule{ + Type: pushrules.ContentRule, + Enabled: true, + Pattern: "test", + } + + evt := newFakeEvent(event.EventMessage, &event.MessageEventContent{ + MsgType: event.MsgEmote, + Body: "is TeSt-InG pushrules", + }) + assert.True(t, rule.Match(blankTestRoom, evt)) +} + func TestPushRule_Match_Content_Fail(t *testing.T) { rule := &pushrules.PushRule{ Type: pushrules.ContentRule, diff --git a/requests.go b/requests.go index 9e7eb0bd..cc8b7266 100644 --- a/requests.go +++ b/requests.go @@ -2,7 +2,9 @@ package mautrix import ( "encoding/json" + "fmt" "strconv" + "time" "maunium.net/go/mautrix/crypto/signatures" "maunium.net/go/mautrix/event" @@ -38,20 +40,40 @@ const ( type Direction rune +func (d Direction) MarshalJSON() ([]byte, error) { + return json.Marshal(string(d)) +} + +func (d *Direction) UnmarshalJSON(data []byte) error { + var str string + if err := json.Unmarshal(data, &str); err != nil { + return err + } + switch str { + case "f": + *d = DirectionForward + case "b": + *d = DirectionBackward + default: + return fmt.Errorf("invalid direction %q, must be 'f' or 'b'", str) + } + return nil +} + const ( DirectionForward Direction = 'f' DirectionBackward Direction = 'b' ) // ReqRegister is the JSON request for https://spec.matrix.org/v1.2/client-server-api/#post_matrixclientv3register -type ReqRegister struct { +type ReqRegister[UIAType any] struct { Username string `json:"username,omitempty"` Password string `json:"password,omitempty"` DeviceID id.DeviceID `json:"device_id,omitempty"` InitialDeviceDisplayName string `json:"initial_device_display_name,omitempty"` InhibitLogin bool `json:"inhibit_login,omitempty"` RefreshToken bool `json:"refresh_token,omitempty"` - Auth interface{} `json:"auth,omitempty"` + Auth UIAType `json:"auth,omitempty"` // Type for registration, only used for appservice user registrations // https://spec.matrix.org/v1.2/application-service-api/#server-admin-style-permissions @@ -119,11 +141,12 @@ type ReqCreateRoom struct { InitialState []*event.Event `json:"initial_state,omitempty"` Preset string `json:"preset,omitempty"` IsDirect bool `json:"is_direct,omitempty"` - RoomVersion string `json:"room_version,omitempty"` + RoomVersion id.RoomVersion `json:"room_version,omitempty"` PowerLevelOverride *event.PowerLevelsEventContent `json:"power_level_content_override,omitempty"` MeowRoomID id.RoomID `json:"fi.mau.room_id,omitempty"` + MeowCreateTS int64 `json:"fi.mau.origin_server_ts,omitempty"` BeeperInitialMembers []id.UserID `json:"com.beeper.initial_members,omitempty"` BeeperAutoJoinInvites bool `json:"com.beeper.auto_join_invites,omitempty"` BeeperLocalRoomID id.RoomID `json:"com.beeper.local_room_id,omitempty"` @@ -138,12 +161,33 @@ type ReqRedact struct { Extra map[string]interface{} } +type ReqRedactUser struct { + Reason string `json:"reason"` + Limit int `json:"-"` +} + type ReqMembers struct { At string `json:"at"` Membership event.Membership `json:"membership,omitempty"` NotMembership event.Membership `json:"not_membership,omitempty"` } +type ReqJoinRoom struct { + Via []string `json:"-"` + Reason string `json:"reason,omitempty"` + ThirdPartySigned any `json:"third_party_signed,omitempty"` +} + +type ReqKnockRoom struct { + Via []string `json:"-"` + Reason string `json:"reason,omitempty"` +} + +type ReqSearchUserDirectory struct { + SearchTerm string `json:"search_term"` + Limit int `json:"limit,omitempty"` +} + type ReqMutualRooms struct { From string `json:"-"` } @@ -176,6 +220,8 @@ type ReqKickUser struct { type ReqBanUser struct { Reason string `json:"reason,omitempty"` UserID id.UserID `json:"user_id"` + + MSC4293RedactEvents bool `json:"org.matrix.msc4293.redact_events,omitempty"` } // ReqUnbanUser is the JSON request for https://spec.matrix.org/v1.2/client-server-api/#post_matrixclientv3roomsroomidunban @@ -191,7 +237,8 @@ type ReqTyping struct { } type ReqPresence struct { - Presence event.Presence `json:"presence"` + Presence event.Presence `json:"presence"` + StatusMsg string `json:"status_msg,omitempty"` } type ReqAliasCreate struct { @@ -273,11 +320,11 @@ func (csk *CrossSigningKeys) FirstKey() id.Ed25519 { return "" } -type UploadCrossSigningKeysReq struct { +type UploadCrossSigningKeysReq[UIAType any] struct { Master CrossSigningKeys `json:"master_key"` SelfSigning CrossSigningKeys `json:"self_signing_key"` UserSigning CrossSigningKeys `json:"user_signing_key"` - Auth interface{} `json:"auth,omitempty"` + Auth UIAType `json:"auth,omitempty"` } type KeyMap map[id.DeviceKeyID]string @@ -319,20 +366,40 @@ type ReqSendToDevice struct { Messages map[id.UserID]map[id.DeviceID]*event.Content `json:"messages"` } +type ReqSendEvent struct { + Timestamp int64 + TransactionID string + UnstableDelay time.Duration + UnstableStickyDuration time.Duration + DontEncrypt bool + MeowEventID id.EventID +} + +type ReqDelayedEvents struct { + DelayID id.DelayID `json:"-"` + Status event.DelayStatus `json:"-"` + NextBatch string `json:"-"` +} + +type ReqUpdateDelayedEvent struct { + DelayID id.DelayID `json:"-"` + Action event.DelayAction `json:"action"` +} + // ReqDeviceInfo is the JSON request for https://spec.matrix.org/v1.2/client-server-api/#put_matrixclientv3devicesdeviceid type ReqDeviceInfo struct { DisplayName string `json:"display_name,omitempty"` } // ReqDeleteDevice is the JSON request for https://spec.matrix.org/v1.2/client-server-api/#delete_matrixclientv3devicesdeviceid -type ReqDeleteDevice struct { - Auth interface{} `json:"auth,omitempty"` +type ReqDeleteDevice[UIAType any] struct { + Auth UIAType `json:"auth,omitempty"` } // ReqDeleteDevices is the JSON request for https://spec.matrix.org/v1.2/client-server-api/#post_matrixclientv3delete_devices -type ReqDeleteDevices struct { +type ReqDeleteDevices[UIAType any] struct { Devices []id.DeviceID `json:"devices"` - Auth interface{} `json:"auth,omitempty"` + Auth UIAType `json:"auth,omitempty"` } type ReqPutPushRule struct { @@ -344,18 +411,6 @@ type ReqPutPushRule struct { Pattern string `json:"pattern"` } -// Deprecated: MSC2716 was abandoned -type ReqBatchSend struct { - PrevEventID id.EventID `json:"-"` - BatchID id.BatchID `json:"-"` - - BeeperNewMessages bool `json:"-"` - BeeperMarkReadBy id.UserID `json:"-"` - - StateEventsAtStart []*event.Event `json:"state_events_at_start"` - Events []*event.Event `json:"events"` -} - type ReqBeeperBatchSend struct { // ForwardIfNoMessages should be set to true if the batch should be forward // backfilled if there are no messages currently in the room. @@ -391,6 +446,33 @@ type ReqSendReceipt struct { ThreadID string `json:"thread_id,omitempty"` } +type ReqPublicRooms struct { + IncludeAllNetworks bool + Limit int + Since string + ThirdPartyInstanceID string +} + +func (req *ReqPublicRooms) Query() map[string]string { + query := map[string]string{} + if req == nil { + return query + } + if req.IncludeAllNetworks { + query["include_all_networks"] = "true" + } + if req.Limit > 0 { + query["limit"] = strconv.Itoa(req.Limit) + } + if req.Since != "" { + query["since"] = req.Since + } + if req.ThirdPartyInstanceID != "" { + query["third_party_instance_id"] = req.ThirdPartyInstanceID + } + return query +} + // ReqHierarchy contains the parameters for https://spec.matrix.org/v1.4/client-server-api/#get_matrixclientv1roomsroomidhierarchy // // As it's a GET method, there is no JSON body, so this is only query parameters. @@ -483,3 +565,54 @@ type ReqReport struct { Reason string `json:"reason,omitempty"` Score int `json:"score,omitempty"` } + +type ReqGetRelations struct { + RelationType event.RelationType + EventType event.Type + + Dir Direction + From string + To string + Limit int + Recurse bool +} + +func (rgr *ReqGetRelations) PathSuffix() ClientURLPath { + if rgr.RelationType != "" { + if rgr.EventType.Type != "" { + return ClientURLPath{rgr.RelationType, rgr.EventType.Type} + } + return ClientURLPath{rgr.RelationType} + } + return ClientURLPath{} +} + +func (rgr *ReqGetRelations) Query() map[string]string { + query := map[string]string{} + if rgr.Dir != 0 { + query["dir"] = string(rgr.Dir) + } + if rgr.From != "" { + query["from"] = rgr.From + } + if rgr.To != "" { + query["to"] = rgr.To + } + if rgr.Limit > 0 { + query["limit"] = strconv.Itoa(rgr.Limit) + } + if rgr.Recurse { + query["recurse"] = "true" + } + return query +} + +// ReqSuspend is the request body for https://github.com/matrix-org/matrix-spec-proposals/pull/4323 +type ReqSuspend struct { + Suspended bool `json:"suspended"` +} + +// ReqLocked is the request body for https://github.com/matrix-org/matrix-spec-proposals/pull/4323 +type ReqLocked struct { + Locked bool `json:"locked"` +} diff --git a/responses.go b/responses.go index 6ead355e..4fbe1fbc 100644 --- a/responses.go +++ b/responses.go @@ -4,13 +4,16 @@ import ( "bytes" "encoding/json" "fmt" + "maps" "reflect" + "slices" "strconv" "strings" "github.com/tidwall/gjson" "github.com/tidwall/sjson" "go.mau.fi/util/jsontime" + "go.mau.fi/util/ptr" "maunium.net/go/mautrix/event" "maunium.net/go/mautrix/id" @@ -32,6 +35,11 @@ type RespJoinRoom struct { RoomID id.RoomID `json:"room_id"` } +// RespKnockRoom is the JSON response for https://spec.matrix.org/v1.13/client-server-api/#post_matrixclientv3knockroomidoralias +type RespKnockRoom struct { + RoomID id.RoomID `json:"room_id"` +} + // RespLeaveRoom is the JSON response for https://spec.matrix.org/v1.2/client-server-api/#post_matrixclientv3roomsroomidleave type RespLeaveRoom struct{} @@ -97,6 +105,29 @@ type RespContext struct { // RespSendEvent is the JSON response for https://spec.matrix.org/v1.2/client-server-api/#put_matrixclientv3roomsroomidsendeventtypetxnid type RespSendEvent struct { EventID id.EventID `json:"event_id"` + + UnstableDelayID id.DelayID `json:"delay_id,omitempty"` +} + +type RespUpdateDelayedEvent struct{} + +type RespDelayedEvents struct { + Scheduled []*event.ScheduledDelayedEvent `json:"scheduled,omitempty"` + Finalised []*event.FinalisedDelayedEvent `json:"finalised,omitempty"` + NextBatch string `json:"next_batch,omitempty"` + + // Deprecated: Synapse implementation still returns this + DelayedEvents []*event.ScheduledDelayedEvent `json:"delayed_events,omitempty"` + // Deprecated: Synapse implementation still returns this + FinalisedEvents []*event.FinalisedDelayedEvent `json:"finalised_events,omitempty"` +} + +type RespRedactUserEvents struct { + IsMoreEvents bool `json:"is_more_events"` + RedactedEvents struct { + Total int `json:"total"` + SoftFailed int `json:"soft_failed"` + } `json:"redacted_events"` } // RespMediaConfig is the JSON response for https://spec.matrix.org/v1.4/client-server-api/#get_matrixmediav3config @@ -155,13 +186,89 @@ type RespUserDisplayName struct { } type RespUserProfile struct { - DisplayName string `json:"displayname"` - AvatarURL id.ContentURI `json:"avatar_url"` + DisplayName string `json:"displayname,omitempty"` + AvatarURL id.ContentURI `json:"avatar_url,omitempty"` + Extra map[string]any `json:"-"` +} + +type marshalableUserProfile RespUserProfile + +func (r *RespUserProfile) UnmarshalJSON(data []byte) error { + err := json.Unmarshal(data, &r.Extra) + if err != nil { + return err + } + r.DisplayName, _ = r.Extra["displayname"].(string) + avatarURL, _ := r.Extra["avatar_url"].(string) + if avatarURL != "" { + r.AvatarURL, _ = id.ParseContentURI(avatarURL) + } + delete(r.Extra, "displayname") + delete(r.Extra, "avatar_url") + return nil +} + +func (r *RespUserProfile) MarshalJSON() ([]byte, error) { + if len(r.Extra) == 0 { + return json.Marshal((*marshalableUserProfile)(r)) + } + marshalMap := maps.Clone(r.Extra) + if r.DisplayName != "" { + marshalMap["displayname"] = r.DisplayName + } else { + delete(marshalMap, "displayname") + } + if !r.AvatarURL.IsEmpty() { + marshalMap["avatar_url"] = r.AvatarURL.String() + } else { + delete(marshalMap, "avatar_url") + } + return json.Marshal(marshalMap) +} + +type RespSearchUserDirectory struct { + Limited bool `json:"limited"` + Results []*UserDirectoryEntry `json:"results"` +} + +type UserDirectoryEntry struct { + RespUserProfile + UserID id.UserID `json:"user_id"` +} + +func (r *UserDirectoryEntry) UnmarshalJSON(data []byte) error { + err := r.RespUserProfile.UnmarshalJSON(data) + if err != nil { + return err + } + userIDStr, _ := r.Extra["user_id"].(string) + r.UserID = id.UserID(userIDStr) + delete(r.Extra, "user_id") + return nil +} + +func (r *UserDirectoryEntry) MarshalJSON() ([]byte, error) { + if r.Extra == nil { + r.Extra = make(map[string]any) + } + r.Extra["user_id"] = r.UserID.String() + return r.RespUserProfile.MarshalJSON() } type RespMutualRooms struct { Joined []id.RoomID `json:"joined"` NextBatch string `json:"next_batch,omitempty"` + Count int `json:"count,omitempty"` +} + +type RespRoomSummary struct { + PublicRoomInfo + + Membership event.Membership `json:"membership,omitempty"` + + UnstableRoomVersion id.RoomVersion `json:"im.nheko.summary.room_version,omitempty"` + UnstableRoomVersionOld id.RoomVersion `json:"im.nheko.summary.version,omitempty"` + UnstableEncryption id.Algorithm `json:"im.nheko.summary.encryption,omitempty"` } // RespRegisterAvailable is the JSON response for https://spec.matrix.org/v1.4/client-server-api/#get_matrixclientv3registeravailable @@ -212,6 +319,9 @@ type RespLogin struct { DeviceID id.DeviceID `json:"device_id"` UserID id.UserID `json:"user_id"` WellKnown *ClientWellKnown `json:"well_known,omitempty"` + + RefreshToken string `json:"refresh_token,omitempty"` + ExpiresInMS int64 `json:"expires_in_ms,omitempty"` } // RespLogout is the JSON response for https://spec.matrix.org/v1.2/client-server-api/#post_matrixclientv3logout @@ -232,6 +342,24 @@ type LazyLoadSummary struct { InvitedMemberCount *int `json:"m.invited_member_count,omitempty"` } +func (lls *LazyLoadSummary) MemberCount() int { + if lls == nil { + return 0 + } + return ptr.Val(lls.JoinedMemberCount) + ptr.Val(lls.InvitedMemberCount) +} + +func (lls *LazyLoadSummary) Equal(other *LazyLoadSummary) bool { + if lls == other { + return true + } else if lls == nil || other == nil { + return false + } + return ptr.Val(lls.JoinedMemberCount) == ptr.Val(other.JoinedMemberCount) && + ptr.Val(lls.InvitedMemberCount) == ptr.Val(other.InvitedMemberCount) && + slices.Equal(lls.Heroes, other.Heroes) +} + type SyncEventsList struct { Events []*event.Event `json:"events,omitempty"` } @@ -327,6 +455,7 @@ type BeeperInboxPreviewEvent struct { type SyncJoinedRoom struct { Summary LazyLoadSummary `json:"summary"` State SyncEventsList `json:"state"` + StateAfter *SyncEventsList `json:"state_after,omitempty"` Timeline SyncTimeline `json:"timeline"` Ephemeral SyncEventsList `json:"ephemeral"` AccountData SyncEventsList `json:"account_data"` @@ -352,16 +481,7 @@ func (sjr SyncJoinedRoom) MarshalJSON() ([]byte, error) { } type SyncInvitedRoom struct { - Summary LazyLoadSummary `json:"summary"` - State SyncEventsList `json:"invite_state"` -} - -type marshalableSyncInvitedRoom SyncInvitedRoom - -var syncInvitedRoomPathsToDelete = []string{"summary"} - -func (sir SyncInvitedRoom) MarshalJSON() ([]byte, error) { - return marshalAndDeleteEmpty((marshalableSyncInvitedRoom)(sir), syncInvitedRoomPathsToDelete) + State SyncEventsList `json:"invite_state"` } type SyncKnockedRoom struct { @@ -426,29 +546,19 @@ type RespDeviceInfo struct { LastSeenTS int64 `json:"last_seen_ts"` } -// Deprecated: MSC2716 was abandoned -type RespBatchSend struct { - StateEventIDs []id.EventID `json:"state_event_ids"` - EventIDs []id.EventID `json:"event_ids"` - - InsertionEventID id.EventID `json:"insertion_event_id"` - BatchEventID id.EventID `json:"batch_event_id"` - BaseInsertionEventID id.EventID `json:"base_insertion_event_id"` - - NextBatchID id.BatchID `json:"next_batch_id"` -} - type RespBeeperBatchSend struct { EventIDs []id.EventID `json:"event_ids"` } // RespCapabilities is the JSON response for https://spec.matrix.org/v1.3/client-server-api/#get_matrixclientv3capabilities type RespCapabilities struct { - RoomVersions *CapRoomVersions `json:"m.room_versions,omitempty"` - ChangePassword *CapBooleanTrue `json:"m.change_password,omitempty"` - SetDisplayname *CapBooleanTrue `json:"m.set_displayname,omitempty"` - SetAvatarURL *CapBooleanTrue `json:"m.set_avatar_url,omitempty"` - ThreePIDChanges *CapBooleanTrue `json:"m.3pid_changes,omitempty"` + RoomVersions *CapRoomVersions `json:"m.room_versions,omitempty"` + ChangePassword *CapBooleanTrue `json:"m.change_password,omitempty"` + SetDisplayname *CapBooleanTrue `json:"m.set_displayname,omitempty"` + SetAvatarURL *CapBooleanTrue `json:"m.set_avatar_url,omitempty"` + ThreePIDChanges *CapBooleanTrue `json:"m.3pid_changes,omitempty"` + GetLoginToken *CapBooleanTrue `json:"m.get_login_token,omitempty"` + UnstableAccountModeration *CapUnstableAccountModeration `json:"uk.timedout.msc4323,omitempty"` Custom map[string]interface{} `json:"-"` } @@ -557,29 +667,44 @@ func (vers *CapRoomVersions) IsAvailable(version string) bool { return available } +type CapUnstableAccountModeration struct { + Suspend bool `json:"suspend"` + Lock bool `json:"lock"` +} + +type RespPublicRooms struct { + Chunk []*PublicRoomInfo `json:"chunk"` + NextBatch string `json:"next_batch,omitempty"` + PrevBatch string `json:"prev_batch,omitempty"` + TotalRoomCountEstimate int `json:"total_room_count_estimate"` +} + +type PublicRoomInfo struct { + RoomID id.RoomID `json:"room_id"` + AvatarURL id.ContentURIString `json:"avatar_url,omitempty"` + CanonicalAlias id.RoomAlias `json:"canonical_alias,omitempty"` + GuestCanJoin bool `json:"guest_can_join"` + JoinRule event.JoinRule `json:"join_rule,omitempty"` + Name string `json:"name,omitempty"` + NumJoinedMembers int `json:"num_joined_members"` + RoomType event.RoomType `json:"room_type"` + Topic string `json:"topic,omitempty"` + WorldReadable bool `json:"world_readable"` + + RoomVersion id.RoomVersion `json:"room_version,omitempty"` + Encryption id.Algorithm `json:"encryption,omitempty"` + AllowedRoomIDs []id.RoomID `json:"allowed_room_ids,omitempty"` +} + // RespHierarchy is the JSON response for https://spec.matrix.org/v1.4/client-server-api/#get_matrixclientv1roomsroomidhierarchy type RespHierarchy struct { - NextBatch string `json:"next_batch,omitempty"` - Rooms []ChildRoomsChunk `json:"rooms"` + NextBatch string `json:"next_batch,omitempty"` + Rooms []*ChildRoomsChunk `json:"rooms"` } type ChildRoomsChunk struct { - AvatarURL id.ContentURI `json:"avatar_url,omitempty"` - CanonicalAlias id.RoomAlias `json:"canonical_alias,omitempty"` - ChildrenState []StrippedStateWithTime `json:"children_state"` - GuestCanJoin bool `json:"guest_can_join"` - JoinRule event.JoinRule `json:"join_rule,omitempty"` - Name string `json:"name,omitempty"` - NumJoinedMembers int `json:"num_joined_members"` - RoomID id.RoomID `json:"room_id"` - RoomType event.RoomType `json:"room_type"` - Topic string `json:"topic,omitempty"` - WorldReadble bool `json:"world_readable"` -} - -type StrippedStateWithTime struct { - event.StrippedState - Timestamp jsontime.UnixMilli `json:"origin_server_ts"` + PublicRoomInfo + ChildrenState []*event.Event `json:"children_state"` } type RespAppservicePing struct { @@ -628,3 +753,47 @@ type RespRoomKeysUpdate struct { Count int `json:"count"` ETag string `json:"etag"` } + +type RespOpenIDToken struct { + AccessToken string `json:"access_token"` + ExpiresIn int `json:"expires_in"` + MatrixServerName string `json:"matrix_server_name"` + TokenType string `json:"token_type"` // Always "Bearer" +} + +type RespGetRelations struct { + Chunk []*event.Event `json:"chunk"` + NextBatch string `json:"next_batch,omitempty"` + PrevBatch string `json:"prev_batch,omitempty"` + RecursionDepth int `json:"recursion_depth,omitempty"` +} + +// RespSuspended is the response body for https://github.com/matrix-org/matrix-spec-proposals/pull/4323 +type RespSuspended struct { + Suspended bool `json:"suspended"` +} + +// RespLocked is the response body for https://github.com/matrix-org/matrix-spec-proposals/pull/4323 +type RespLocked struct { + Locked bool `json:"locked"` +} + +type ConnectionInfo struct { + IP string `json:"ip,omitempty"` + LastSeen jsontime.UnixMilli `json:"last_seen,omitempty"` + UserAgent string `json:"user_agent,omitempty"` +} + +type SessionInfo struct { + Connections []ConnectionInfo `json:"connections,omitempty"` +} + +type DeviceInfo struct { + Sessions []SessionInfo `json:"sessions,omitempty"` +} + +// RespWhoIs is the response body for https://spec.matrix.org/v1.15/client-server-api/#get_matrixclientv3adminwhoisuserid +type RespWhoIs struct { + UserID id.UserID `json:"user_id,omitempty"` + Devices map[id.DeviceID]DeviceInfo `json:"devices,omitempty"` +} diff --git a/responses_test.go b/responses_test.go index b23d85ad..73d82635 100644 --- a/responses_test.go +++ b/responses_test.go @@ -8,7 +8,6 @@ package mautrix_test import ( "encoding/json" - "fmt" "testing" "github.com/stretchr/testify/assert" @@ -86,7 +85,6 @@ func TestRespCapabilities_UnmarshalJSON(t *testing.T) { var caps mautrix.RespCapabilities err := json.Unmarshal([]byte(sampleData), &caps) require.NoError(t, err) - fmt.Println(caps) require.NotNil(t, caps.RoomVersions) assert.Equal(t, "9", caps.RoomVersions.Default) diff --git a/room.go b/room.go index c3ddb7e6..4292bff5 100644 --- a/room.go +++ b/room.go @@ -5,8 +5,6 @@ import ( "maunium.net/go/mautrix/id" ) -type RoomStateMap = map[event.Type]map[string]*event.Event - // Room represents a single Matrix room. type Room struct { ID id.RoomID @@ -25,8 +23,8 @@ func (room Room) UpdateState(evt *event.Event) { // GetStateEvent returns the state event for the given type/state_key combo, or nil. func (room Room) GetStateEvent(eventType event.Type, stateKey string) *event.Event { - stateEventMap, _ := room.State[eventType] - evt, _ := stateEventMap[stateKey] + stateEventMap := room.State[eventType] + evt := stateEventMap[stateKey] return evt } diff --git a/sqlstatestore/statestore.go b/sqlstatestore/statestore.go index 33c10c4c..11957dfa 100644 --- a/sqlstatestore/statestore.go +++ b/sqlstatestore/statestore.go @@ -62,6 +62,9 @@ func (store *SQLStateStore) IsRegistered(ctx context.Context, userID id.UserID) } func (store *SQLStateStore) MarkRegistered(ctx context.Context, userID id.UserID) error { + if userID == "" { + return fmt.Errorf("user ID is empty") + } _, err := store.Exec(ctx, "INSERT INTO mx_registrations (user_id) VALUES ($1) ON CONFLICT (user_id) DO NOTHING", userID) return err } @@ -85,14 +88,11 @@ func (store *SQLStateStore) GetRoomMembers(ctx context.Context, roomID id.RoomID query = fmt.Sprintf("%s AND membership IN (%s)", query, strings.Join(placeholders, ",")) } rows, err := store.Query(ctx, query, args...) - if err != nil { - return nil, err - } members := make(map[id.UserID]*event.MemberEventContent) - return members, dbutil.NewRowIter(rows, func(row dbutil.Scannable) (ret Member, err error) { + return members, dbutil.NewRowIterWithError(rows, func(row dbutil.Scannable) (ret Member, err error) { err = row.Scan(&ret.UserID, &ret.Membership, &ret.Displayname, &ret.AvatarURL) return - }).Iter(func(m Member) (bool, error) { + }, err).Iter(func(m Member) (bool, error) { members[m.UserID] = &m.MemberEventContent return true, nil }) @@ -159,10 +159,7 @@ func (store *SQLStateStore) FindSharedRooms(ctx context.Context, userID id.UserI ` } rows, err := store.Query(ctx, query, userID) - if err != nil { - return nil, err - } - return dbutil.NewRowIter(rows, dbutil.ScanSingleColumn[id.RoomID]).AsList() + return dbutil.NewRowIterWithError(rows, dbutil.ScanSingleColumn[id.RoomID], err).AsList() } func (store *SQLStateStore) IsInRoom(ctx context.Context, roomID id.RoomID, userID id.UserID) bool { @@ -188,6 +185,11 @@ func (store *SQLStateStore) IsMembership(ctx context.Context, roomID id.RoomID, } func (store *SQLStateStore) SetMembership(ctx context.Context, roomID id.RoomID, userID id.UserID, membership event.Membership) error { + if roomID == "" { + return fmt.Errorf("room ID is empty") + } else if userID == "" { + return fmt.Errorf("user ID is empty") + } _, err := store.Exec(ctx, ` INSERT INTO mx_user_profile (room_id, user_id, membership, displayname, avatar_url) VALUES ($1, $2, $3, '', '') ON CONFLICT (room_id, user_id) DO UPDATE SET membership=excluded.membership @@ -220,6 +222,11 @@ func (u *userProfileRow) GetMassInsertValues() [5]any { var userProfileMassInserter = dbutil.NewMassInsertBuilder[*userProfileRow, [1]any](insertUserProfileQuery, "($1, $%d, $%d, $%d, $%d, $%d)") func (store *SQLStateStore) SetMember(ctx context.Context, roomID id.RoomID, userID id.UserID, member *event.MemberEventContent) error { + if roomID == "" { + return fmt.Errorf("room ID is empty") + } else if userID == "" { + return fmt.Errorf("user ID is empty") + } var nameSkeleton []byte if !store.DisableNameDisambiguation && len(member.Displayname) > 0 { nameSkeletonArr := confusable.SkeletonHash(member.Displayname) @@ -241,6 +248,9 @@ func (store *SQLStateStore) IsConfusableName(ctx context.Context, roomID id.Room const userProfileMassInsertBatchSize = 500 func (store *SQLStateStore) ReplaceCachedMembers(ctx context.Context, roomID id.RoomID, evts []*event.Event, onlyMemberships ...event.Membership) error { + if roomID == "" { + return fmt.Errorf("room ID is empty") + } return store.DoTxn(ctx, nil, func(ctx context.Context) error { err := store.ClearCachedMembers(ctx, roomID, onlyMemberships...) if err != nil { @@ -311,6 +321,9 @@ func (store *SQLStateStore) HasFetchedMembers(ctx context.Context, roomID id.Roo } func (store *SQLStateStore) MarkMembersFetched(ctx context.Context, roomID id.RoomID) error { + if roomID == "" { + return fmt.Errorf("room ID is empty") + } _, err := store.Exec(ctx, ` INSERT INTO mx_room_state (room_id, members_fetched) VALUES ($1, true) ON CONFLICT (room_id) DO UPDATE SET members_fetched=true @@ -340,6 +353,9 @@ func (store *SQLStateStore) GetAllMembers(ctx context.Context, roomID id.RoomID) } func (store *SQLStateStore) SetEncryptionEvent(ctx context.Context, roomID id.RoomID, content *event.EncryptionEventContent) error { + if roomID == "" { + return fmt.Errorf("room ID is empty") + } contentBytes, err := json.Marshal(content) if err != nil { return fmt.Errorf("failed to marshal content JSON: %w", err) @@ -354,7 +370,7 @@ func (store *SQLStateStore) SetEncryptionEvent(ctx context.Context, roomID id.Ro func (store *SQLStateStore) GetEncryptionEvent(ctx context.Context, roomID id.RoomID) (*event.EncryptionEventContent, error) { var data []byte err := store. - QueryRow(ctx, "SELECT encryption FROM mx_room_state WHERE room_id=$1", roomID). + QueryRow(ctx, "SELECT encryption FROM mx_room_state WHERE room_id=$1 AND encryption IS NOT NULL", roomID). Scan(&data) if errors.Is(err, sql.ErrNoRows) { return nil, nil @@ -377,6 +393,9 @@ func (store *SQLStateStore) IsEncrypted(ctx context.Context, roomID id.RoomID) ( } func (store *SQLStateStore) SetPowerLevels(ctx context.Context, roomID id.RoomID, levels *event.PowerLevelsEventContent) error { + if roomID == "" { + return fmt.Errorf("room ID is empty") + } _, err := store.Exec(ctx, ` INSERT INTO mx_room_state (room_id, power_levels) VALUES ($1, $2) ON CONFLICT (room_id) DO UPDATE SET power_levels=excluded.power_levels @@ -385,89 +404,92 @@ func (store *SQLStateStore) SetPowerLevels(ctx context.Context, roomID id.RoomID } func (store *SQLStateStore) GetPowerLevels(ctx context.Context, roomID id.RoomID) (levels *event.PowerLevelsEventContent, err error) { + levels = &event.PowerLevelsEventContent{} err = store. - QueryRow(ctx, "SELECT power_levels FROM mx_room_state WHERE room_id=$1", roomID). - Scan(&dbutil.JSON{Data: &levels}) + QueryRow(ctx, "SELECT power_levels, create_event FROM mx_room_state WHERE room_id=$1 AND power_levels IS NOT NULL", roomID). + Scan(&dbutil.JSON{Data: &levels}, &dbutil.JSON{Data: &levels.CreateEvent}) if errors.Is(err, sql.ErrNoRows) { - err = nil + return nil, nil + } else if err != nil { + return nil, err + } + if levels.CreateEvent != nil { + err = levels.CreateEvent.Content.ParseRaw(event.StateCreate) } return } func (store *SQLStateStore) GetPowerLevel(ctx context.Context, roomID id.RoomID, userID id.UserID) (int, error) { - if store.Dialect == dbutil.Postgres { - var powerLevel int - err := store. - QueryRow(ctx, ` - SELECT COALESCE((power_levels->'users'->$2)::int, (power_levels->'users_default')::int, 0) - FROM mx_room_state WHERE room_id=$1 - `, roomID, userID). - Scan(&powerLevel) - return powerLevel, err - } else { - levels, err := store.GetPowerLevels(ctx, roomID) - if err != nil { - return 0, err - } - return levels.GetUserLevel(userID), nil + levels, err := store.GetPowerLevels(ctx, roomID) + if err != nil { + return 0, err } + return levels.GetUserLevel(userID), nil } func (store *SQLStateStore) GetPowerLevelRequirement(ctx context.Context, roomID id.RoomID, eventType event.Type) (int, error) { - if store.Dialect == dbutil.Postgres { - defaultType := "events_default" - defaultValue := 0 - if eventType.IsState() { - defaultType = "state_default" - defaultValue = 50 - } - var powerLevel int - err := store. - QueryRow(ctx, ` - SELECT COALESCE((power_levels->'events'->$2)::int, (power_levels->'$3')::int, $4) - FROM mx_room_state WHERE room_id=$1 - `, roomID, eventType.Type, defaultType, defaultValue). - Scan(&powerLevel) - if errors.Is(err, sql.ErrNoRows) { - err = nil - powerLevel = defaultValue - } - return powerLevel, err - } else { - levels, err := store.GetPowerLevels(ctx, roomID) - if err != nil { - return 0, err - } - return levels.GetEventLevel(eventType), nil + levels, err := store.GetPowerLevels(ctx, roomID) + if err != nil { + return 0, err } + return levels.GetEventLevel(eventType), nil } func (store *SQLStateStore) HasPowerLevel(ctx context.Context, roomID id.RoomID, userID id.UserID, eventType event.Type) (bool, error) { - if store.Dialect == dbutil.Postgres { - defaultType := "events_default" - defaultValue := 0 - if eventType.IsState() { - defaultType = "state_default" - defaultValue = 50 - } - var hasPower bool - err := store. - QueryRow(ctx, `SELECT - COALESCE((power_levels->'users'->$2)::int, (power_levels->'users_default')::int, 0) - >= - COALESCE((power_levels->'events'->$3)::int, (power_levels->'$4')::int, $5) - FROM mx_room_state WHERE room_id=$1`, roomID, userID, eventType.Type, defaultType, defaultValue). - Scan(&hasPower) - if errors.Is(err, sql.ErrNoRows) { - err = nil - hasPower = defaultValue == 0 - } - return hasPower, err - } else { - levels, err := store.GetPowerLevels(ctx, roomID) - if err != nil { - return false, err - } - return levels.GetUserLevel(userID) >= levels.GetEventLevel(eventType), nil + levels, err := store.GetPowerLevels(ctx, roomID) + if err != nil { + return false, err } + return levels.GetUserLevel(userID) >= levels.GetEventLevel(eventType), nil +} + +func (store *SQLStateStore) SetCreate(ctx context.Context, evt *event.Event) error { + if evt.Type != event.StateCreate { + return fmt.Errorf("invalid event type for create event: %s", evt.Type) + } else if evt.RoomID == "" { + return fmt.Errorf("room ID is empty") + } + _, err := store.Exec(ctx, ` + INSERT INTO mx_room_state (room_id, create_event) VALUES ($1, $2) + ON CONFLICT (room_id) DO UPDATE SET create_event=excluded.create_event + `, evt.RoomID, dbutil.JSON{Data: evt}) + return err +} + +func (store *SQLStateStore) GetCreate(ctx context.Context, roomID id.RoomID) (evt *event.Event, err error) { + err = store. + QueryRow(ctx, "SELECT create_event FROM mx_room_state WHERE room_id=$1 AND create_event IS NOT NULL", roomID). + Scan(&dbutil.JSON{Data: &evt}) + if errors.Is(err, sql.ErrNoRows) { + return nil, nil + } else if err != nil { + return nil, err + } + if evt != nil { + err = evt.Content.ParseRaw(event.StateCreate) + } + return +} + +func (store *SQLStateStore) SetJoinRules(ctx context.Context, roomID id.RoomID, rules *event.JoinRulesEventContent) error { + if roomID == "" { + return fmt.Errorf("room ID is empty") + } + _, err := store.Exec(ctx, ` + INSERT INTO mx_room_state (room_id, join_rules) VALUES ($1, $2) + ON CONFLICT (room_id) DO UPDATE SET join_rules=excluded.join_rules + `, roomID, dbutil.JSON{Data: rules}) + return err +} + +func (store *SQLStateStore) GetJoinRules(ctx context.Context, roomID id.RoomID) (levels *event.JoinRulesEventContent, err error) { + levels = &event.JoinRulesEventContent{} + err = store. + QueryRow(ctx, "SELECT join_rules FROM mx_room_state WHERE room_id=$1 AND join_rules IS NOT NULL", roomID). + Scan(&dbutil.JSON{Data: &levels}) + if errors.Is(err, sql.ErrNoRows) { + levels = nil + err = nil + } + return } diff --git a/sqlstatestore/v00-latest-revision.sql b/sqlstatestore/v00-latest-revision.sql index a58cc56a..4679f1c6 100644 --- a/sqlstatestore/v00-latest-revision.sql +++ b/sqlstatestore/v00-latest-revision.sql @@ -1,4 +1,4 @@ --- v0 -> v7 (compatible with v3+): Latest revision +-- v0 -> v10 (compatible with v3+): Latest revision CREATE TABLE mx_registrations ( user_id TEXT PRIMARY KEY @@ -26,5 +26,7 @@ CREATE TABLE mx_room_state ( room_id TEXT PRIMARY KEY, power_levels jsonb, encryption jsonb, + create_event jsonb, + join_rules jsonb, members_fetched BOOLEAN NOT NULL DEFAULT false ); diff --git a/sqlstatestore/v08-create-event.sql b/sqlstatestore/v08-create-event.sql new file mode 100644 index 00000000..9f1b55c9 --- /dev/null +++ b/sqlstatestore/v08-create-event.sql @@ -0,0 +1,2 @@ +-- v8 (compatible with v3+): Add create event to room state table +ALTER TABLE mx_room_state ADD COLUMN create_event jsonb; diff --git a/sqlstatestore/v09-clear-empty-room-ids.sql b/sqlstatestore/v09-clear-empty-room-ids.sql new file mode 100644 index 00000000..ca951068 --- /dev/null +++ b/sqlstatestore/v09-clear-empty-room-ids.sql @@ -0,0 +1,3 @@ +-- v9 (compatible with v3+): Clear invalid rows +DELETE FROM mx_room_state WHERE room_id=''; +DELETE FROM mx_user_profile WHERE room_id='' OR user_id=''; diff --git a/sqlstatestore/v10-join-rules.sql b/sqlstatestore/v10-join-rules.sql new file mode 100644 index 00000000..3074c46a --- /dev/null +++ b/sqlstatestore/v10-join-rules.sql @@ -0,0 +1,2 @@ +-- v10 (compatible with v3+): Add join rules to room state table +ALTER TABLE mx_room_state ADD COLUMN join_rules jsonb; diff --git a/statestore.go b/statestore.go index e728b885..2bd498dd 100644 --- a/statestore.go +++ b/statestore.go @@ -34,6 +34,12 @@ type StateStore interface { SetPowerLevels(ctx context.Context, roomID id.RoomID, levels *event.PowerLevelsEventContent) error GetPowerLevels(ctx context.Context, roomID id.RoomID) (*event.PowerLevelsEventContent, error) + SetCreate(ctx context.Context, evt *event.Event) error + GetCreate(ctx context.Context, roomID id.RoomID) (*event.Event, error) + + GetJoinRules(ctx context.Context, roomID id.RoomID) (*event.JoinRulesEventContent, error) + SetJoinRules(ctx context.Context, roomID id.RoomID, content *event.JoinRulesEventContent) error + HasFetchedMembers(ctx context.Context, roomID id.RoomID) (bool, error) MarkMembersFetched(ctx context.Context, roomID id.RoomID) error GetAllMembers(ctx context.Context, roomID id.RoomID) (map[id.UserID]*event.MemberEventContent, error) @@ -68,9 +74,13 @@ func UpdateStateStore(ctx context.Context, store StateStore, evt *event.Event) { err = store.SetPowerLevels(ctx, evt.RoomID, content) case *event.EncryptionEventContent: err = store.SetEncryptionEvent(ctx, evt.RoomID, content) + case *event.CreateEventContent: + err = store.SetCreate(ctx, evt) + case *event.JoinRulesEventContent: + err = store.SetJoinRules(ctx, evt.RoomID, content) default: switch evt.Type { - case event.StateMember, event.StatePowerLevels, event.StateEncryption: + case event.StateMember, event.StatePowerLevels, event.StateEncryption, event.StateCreate: zerolog.Ctx(ctx).Warn(). Stringer("event_id", evt.ID). Str("event_type", evt.Type.Type). @@ -101,11 +111,14 @@ type MemoryStateStore struct { MembersFetched map[id.RoomID]bool `json:"members_fetched"` PowerLevels map[id.RoomID]*event.PowerLevelsEventContent `json:"power_levels"` Encryption map[id.RoomID]*event.EncryptionEventContent `json:"encryption"` + Create map[id.RoomID]*event.Event `json:"create"` + JoinRules map[id.RoomID]*event.JoinRulesEventContent `json:"join_rules"` registrationsLock sync.RWMutex membersLock sync.RWMutex powerLevelsLock sync.RWMutex encryptionLock sync.RWMutex + joinRulesLock sync.RWMutex } func NewMemoryStateStore() StateStore { @@ -115,6 +128,8 @@ func NewMemoryStateStore() StateStore { MembersFetched: make(map[id.RoomID]bool), PowerLevels: make(map[id.RoomID]*event.PowerLevelsEventContent), Encryption: make(map[id.RoomID]*event.EncryptionEventContent), + Create: make(map[id.RoomID]*event.Event), + JoinRules: make(map[id.RoomID]*event.JoinRulesEventContent), } } @@ -298,6 +313,9 @@ func (store *MemoryStateStore) SetPowerLevels(_ context.Context, roomID id.RoomI func (store *MemoryStateStore) GetPowerLevels(_ context.Context, roomID id.RoomID) (levels *event.PowerLevelsEventContent, err error) { store.powerLevelsLock.RLock() levels = store.PowerLevels[roomID] + if levels != nil && levels.CreateEvent == nil { + levels.CreateEvent = store.Create[roomID] + } store.powerLevelsLock.RUnlock() return } @@ -314,6 +332,23 @@ func (store *MemoryStateStore) HasPowerLevel(ctx context.Context, roomID id.Room return exerrors.Must(store.GetPowerLevel(ctx, roomID, userID)) >= exerrors.Must(store.GetPowerLevelRequirement(ctx, roomID, eventType)), nil } +func (store *MemoryStateStore) SetCreate(ctx context.Context, evt *event.Event) error { + store.powerLevelsLock.Lock() + store.Create[evt.RoomID] = evt + if pls, ok := store.PowerLevels[evt.RoomID]; ok && pls.CreateEvent == nil { + pls.CreateEvent = evt + } + store.powerLevelsLock.Unlock() + return nil +} + +func (store *MemoryStateStore) GetCreate(ctx context.Context, roomID id.RoomID) (*event.Event, error) { + store.powerLevelsLock.RLock() + evt := store.Create[roomID] + store.powerLevelsLock.RUnlock() + return evt, nil +} + func (store *MemoryStateStore) SetEncryptionEvent(_ context.Context, roomID id.RoomID, content *event.EncryptionEventContent) error { store.encryptionLock.Lock() store.Encryption[roomID] = content @@ -327,6 +362,19 @@ func (store *MemoryStateStore) GetEncryptionEvent(_ context.Context, roomID id.R return store.Encryption[roomID], nil } +func (store *MemoryStateStore) SetJoinRules(ctx context.Context, roomID id.RoomID, content *event.JoinRulesEventContent) error { + store.joinRulesLock.Lock() + store.JoinRules[roomID] = content + store.joinRulesLock.Unlock() + return nil +} + +func (store *MemoryStateStore) GetJoinRules(ctx context.Context, roomID id.RoomID) (*event.JoinRulesEventContent, error) { + store.joinRulesLock.RLock() + defer store.joinRulesLock.RUnlock() + return store.JoinRules[roomID], nil +} + func (store *MemoryStateStore) IsEncrypted(ctx context.Context, roomID id.RoomID) (bool, error) { cfg, err := store.GetEncryptionEvent(ctx, roomID) return cfg != nil && cfg.Algorithm == id.AlgorithmMegolmV1, err diff --git a/synapseadmin/client.go b/synapseadmin/client.go index 775b4b13..6925ca7d 100644 --- a/synapseadmin/client.go +++ b/synapseadmin/client.go @@ -14,9 +14,9 @@ import ( // // https://matrix-org.github.io/synapse/latest/usage/administration/admin_api/index.html type Client struct { - *mautrix.Client + Client *mautrix.Client } func (cli *Client) BuildAdminURL(path ...any) string { - return cli.BuildURL(mautrix.SynapseAdminURLPath(path)) + return cli.Client.BuildURL(mautrix.SynapseAdminURLPath(path)) } diff --git a/synapseadmin/register.go b/synapseadmin/register.go index 641f9b56..05e0729a 100644 --- a/synapseadmin/register.go +++ b/synapseadmin/register.go @@ -73,7 +73,7 @@ func (req *ReqSharedSecretRegister) Sign(secret string) string { // This does not need to be called manually as SharedSecretRegister will automatically call this if no nonce is provided. func (cli *Client) GetRegisterNonce(ctx context.Context) (string, error) { var resp respGetRegisterNonce - _, err := cli.MakeRequest(ctx, http.MethodGet, cli.BuildURL(mautrix.SynapseAdminURLPath{"v1", "register"}), nil, &resp) + _, err := cli.Client.MakeRequest(ctx, http.MethodGet, cli.BuildAdminURL("v1", "register"), nil, &resp) if err != nil { return "", err } @@ -93,7 +93,7 @@ func (cli *Client) SharedSecretRegister(ctx context.Context, sharedSecret string } req.SHA1Checksum = req.Sign(sharedSecret) var resp mautrix.RespRegister - _, err = cli.MakeRequest(ctx, http.MethodPost, cli.BuildURL(mautrix.SynapseAdminURLPath{"v1", "register"}), &req, &resp) + _, err = cli.Client.MakeRequest(ctx, http.MethodPost, cli.BuildAdminURL("v1", "register"), &req, &resp) if err != nil { return nil, err } diff --git a/synapseadmin/roomapi.go b/synapseadmin/roomapi.go index 6c072e23..0925b748 100644 --- a/synapseadmin/roomapi.go +++ b/synapseadmin/roomapi.go @@ -75,12 +75,17 @@ type RespListRooms struct { // https://matrix-org.github.io/synapse/latest/admin_api/rooms.html#list-room-api func (cli *Client) ListRooms(ctx context.Context, req ReqListRoom) (RespListRooms, error) { var resp RespListRooms - var reqURL string - reqURL = cli.BuildURLWithQuery(mautrix.SynapseAdminURLPath{"v1", "rooms"}, req.BuildQuery()) - _, err := cli.MakeRequest(ctx, http.MethodGet, reqURL, nil, &resp) + reqURL := cli.Client.BuildURLWithQuery(mautrix.SynapseAdminURLPath{"v1", "rooms"}, req.BuildQuery()) + _, err := cli.Client.MakeRequest(ctx, http.MethodGet, reqURL, nil, &resp) return resp, err } +func (cli *Client) RoomInfo(ctx context.Context, roomID id.RoomID) (resp *RoomInfo, err error) { + reqURL := cli.BuildAdminURL("v1", "rooms", roomID) + _, err = cli.Client.MakeRequest(ctx, http.MethodGet, reqURL, nil, &resp) + return +} + type RespRoomMessages = mautrix.RespMessages // RoomMessages returns a list of messages in a room. @@ -104,13 +109,14 @@ func (cli *Client) RoomMessages(ctx context.Context, roomID id.RoomID, from, to if limit != 0 { query["limit"] = strconv.Itoa(limit) } - urlPath := cli.BuildURLWithQuery(mautrix.SynapseAdminURLPath{"v1", "rooms", roomID, "messages"}, query) - _, err = cli.MakeRequest(ctx, http.MethodGet, urlPath, nil, &resp) + urlPath := cli.Client.BuildURLWithQuery(mautrix.SynapseAdminURLPath{"v1", "rooms", roomID, "messages"}, query) + _, err = cli.Client.MakeRequest(ctx, http.MethodGet, urlPath, nil, &resp) return resp, err } type ReqDeleteRoom struct { Purge bool `json:"purge,omitempty"` + ForcePurge bool `json:"force_purge,omitempty"` Block bool `json:"block,omitempty"` Message string `json:"message,omitempty"` RoomName string `json:"room_name,omitempty"` @@ -121,6 +127,19 @@ type RespDeleteRoom struct { DeleteID string `json:"delete_id"` } +type RespDeleteRoomResult struct { + KickedUsers []id.UserID `json:"kicked_users,omitempty"` + FailedToKickUsers []id.UserID `json:"failed_to_kick_users,omitempty"` + LocalAliases []id.RoomAlias `json:"local_aliases,omitempty"` + NewRoomID id.RoomID `json:"new_room_id,omitempty"` +} + +type RespDeleteRoomStatus struct { + Status string `json:"status,omitempty"` + Error string `json:"error,omitempty"` + ShutdownRoom RespDeleteRoomResult `json:"shutdown_room,omitempty"` +} + // DeleteRoom deletes a room from the server, optionally blocking it and/or purging all data from the database. // // This calls the async version of the endpoint, which will return immediately and delete the room in the background. @@ -129,10 +148,37 @@ type RespDeleteRoom struct { func (cli *Client) DeleteRoom(ctx context.Context, roomID id.RoomID, req ReqDeleteRoom) (RespDeleteRoom, error) { reqURL := cli.BuildAdminURL("v2", "rooms", roomID) var resp RespDeleteRoom - _, err := cli.MakeRequest(ctx, http.MethodDelete, reqURL, &req, &resp) + _, err := cli.Client.MakeRequest(ctx, http.MethodDelete, reqURL, &req, &resp) return resp, err } +func (cli *Client) DeleteRoomStatus(ctx context.Context, deleteID string) (resp RespDeleteRoomStatus, err error) { + reqURL := cli.BuildAdminURL("v2", "rooms", "delete_status", deleteID) + _, err = cli.Client.MakeRequest(ctx, http.MethodGet, reqURL, nil, &resp) + return +} + +// DeleteRoomSync deletes a room from the server, optionally blocking it and/or purging all data from the database. +// +// This calls the synchronous version of the endpoint, which will block until the room is deleted. +// +// https://element-hq.github.io/synapse/latest/admin_api/rooms.html#version-1-old-version +func (cli *Client) DeleteRoomSync(ctx context.Context, roomID id.RoomID, req ReqDeleteRoom) (resp RespDeleteRoomResult, err error) { + reqURL := cli.BuildAdminURL("v1", "rooms", roomID) + httpClient := &http.Client{} + _, err = cli.Client.MakeFullRequest(ctx, mautrix.FullRequest{ + Method: http.MethodDelete, + URL: reqURL, + RequestJSON: &req, + ResponseJSON: &resp, + MaxAttempts: 1, + // Use a fresh HTTP client without timeouts + Client: httpClient, + }) + httpClient.CloseIdleConnections() + return +} + type RespRoomsMembers struct { Members []id.UserID `json:"members"` Total int `json:"total"` @@ -144,7 +190,7 @@ type RespRoomsMembers struct { func (cli *Client) RoomMembers(ctx context.Context, roomID id.RoomID) (RespRoomsMembers, error) { reqURL := cli.BuildAdminURL("v1", "rooms", roomID, "members") var resp RespRoomsMembers - _, err := cli.MakeRequest(ctx, http.MethodGet, reqURL, nil, &resp) + _, err := cli.Client.MakeRequest(ctx, http.MethodGet, reqURL, nil, &resp) return resp, err } @@ -157,7 +203,7 @@ type ReqMakeRoomAdmin struct { // https://matrix-org.github.io/synapse/latest/admin_api/rooms.html#make-room-admin-api func (cli *Client) MakeRoomAdmin(ctx context.Context, roomIDOrAlias string, req ReqMakeRoomAdmin) error { reqURL := cli.BuildAdminURL("v1", "rooms", roomIDOrAlias, "make_room_admin") - _, err := cli.MakeRequest(ctx, http.MethodPost, reqURL, &req, nil) + _, err := cli.Client.MakeRequest(ctx, http.MethodPost, reqURL, &req, nil) return err } @@ -170,7 +216,7 @@ type ReqJoinUserToRoom struct { // https://matrix-org.github.io/synapse/latest/admin_api/room_membership.html func (cli *Client) JoinUserToRoom(ctx context.Context, roomID id.RoomID, req ReqJoinUserToRoom) error { reqURL := cli.BuildAdminURL("v1", "join", roomID) - _, err := cli.MakeRequest(ctx, http.MethodPost, reqURL, &req, nil) + _, err := cli.Client.MakeRequest(ctx, http.MethodPost, reqURL, &req, nil) return err } @@ -183,7 +229,7 @@ type ReqBlockRoom struct { // https://matrix-org.github.io/synapse/latest/admin_api/rooms.html#block-room-api func (cli *Client) BlockRoom(ctx context.Context, roomID id.RoomID, req ReqBlockRoom) error { reqURL := cli.BuildAdminURL("v1", "rooms", roomID, "block") - _, err := cli.MakeRequest(ctx, http.MethodPut, reqURL, &req, nil) + _, err := cli.Client.MakeRequest(ctx, http.MethodPut, reqURL, &req, nil) return err } @@ -199,6 +245,6 @@ type RoomsBlockResponse struct { func (cli *Client) GetRoomBlockStatus(ctx context.Context, roomID id.RoomID) (RoomsBlockResponse, error) { var resp RoomsBlockResponse reqURL := cli.BuildAdminURL("v1", "rooms", roomID, "block") - _, err := cli.MakeRequest(ctx, http.MethodGet, reqURL, nil, &resp) + _, err := cli.Client.MakeRequest(ctx, http.MethodGet, reqURL, nil, &resp) return resp, err } diff --git a/synapseadmin/userapi.go b/synapseadmin/userapi.go index 9cbb17e4..b1de55b6 100644 --- a/synapseadmin/userapi.go +++ b/synapseadmin/userapi.go @@ -32,7 +32,7 @@ type ReqResetPassword struct { // https://matrix-org.github.io/synapse/latest/admin_api/user_admin_api.html#reset-password func (cli *Client) ResetPassword(ctx context.Context, req ReqResetPassword) error { reqURL := cli.BuildAdminURL("v1", "reset_password", req.UserID) - _, err := cli.MakeRequest(ctx, http.MethodPost, reqURL, &req, nil) + _, err := cli.Client.MakeRequest(ctx, http.MethodPost, reqURL, &req, nil) return err } @@ -43,8 +43,8 @@ func (cli *Client) ResetPassword(ctx context.Context, req ReqResetPassword) erro // // https://matrix-org.github.io/synapse/latest/admin_api/user_admin_api.html#check-username-availability func (cli *Client) UsernameAvailable(ctx context.Context, username string) (resp *mautrix.RespRegisterAvailable, err error) { - u := cli.BuildURLWithQuery(mautrix.SynapseAdminURLPath{"v1", "username_available"}, map[string]string{"username": username}) - _, err = cli.MakeRequest(ctx, http.MethodGet, u, nil, &resp) + u := cli.Client.BuildURLWithQuery(mautrix.SynapseAdminURLPath{"v1", "username_available"}, map[string]string{"username": username}) + _, err = cli.Client.MakeRequest(ctx, http.MethodGet, u, nil, &resp) if err == nil && !resp.Available { err = fmt.Errorf(`request returned OK status without "available": true`) } @@ -65,7 +65,7 @@ type RespListDevices struct { // // https://matrix-org.github.io/synapse/latest/admin_api/user_admin_api.html#list-all-devices func (cli *Client) ListDevices(ctx context.Context, userID id.UserID) (resp *RespListDevices, err error) { - _, err = cli.MakeRequest(ctx, http.MethodGet, cli.BuildAdminURL("v2", "users", userID, "devices"), nil, &resp) + _, err = cli.Client.MakeRequest(ctx, http.MethodGet, cli.BuildAdminURL("v2", "users", userID, "devices"), nil, &resp) return } @@ -89,7 +89,7 @@ type RespUserInfo struct { // // https://matrix-org.github.io/synapse/latest/admin_api/user_admin_api.html#query-user-account func (cli *Client) GetUserInfo(ctx context.Context, userID id.UserID) (resp *RespUserInfo, err error) { - _, err = cli.MakeRequest(ctx, http.MethodGet, cli.BuildAdminURL("v2", "users", userID), nil, &resp) + _, err = cli.Client.MakeRequest(ctx, http.MethodGet, cli.BuildAdminURL("v2", "users", userID), nil, &resp) return } @@ -102,7 +102,20 @@ type ReqDeleteUser struct { // https://matrix-org.github.io/synapse/latest/admin_api/user_admin_api.html#deactivate-account func (cli *Client) DeactivateAccount(ctx context.Context, userID id.UserID, req ReqDeleteUser) error { reqURL := cli.BuildAdminURL("v1", "deactivate", userID) - _, err := cli.MakeRequest(ctx, http.MethodPost, reqURL, &req, nil) + _, err := cli.Client.MakeRequest(ctx, http.MethodPost, reqURL, &req, nil) + return err +} + +type ReqSuspendUser struct { + Suspend bool `json:"suspend"` +} + +// SuspendAccount suspends or unsuspends a specific local user account. +// +// https://element-hq.github.io/synapse/latest/admin_api/user_admin_api.html#suspendunsuspend-account +func (cli *Client) SuspendAccount(ctx context.Context, userID id.UserID, req ReqSuspendUser) error { + reqURL := cli.BuildAdminURL("v1", "suspend", userID) + _, err := cli.Client.MakeRequest(ctx, http.MethodPut, reqURL, &req, nil) return err } @@ -124,7 +137,7 @@ type ReqCreateOrModifyAccount struct { // https://matrix-org.github.io/synapse/latest/admin_api/user_admin_api.html#create-or-modify-account func (cli *Client) CreateOrModifyAccount(ctx context.Context, userID id.UserID, req ReqCreateOrModifyAccount) error { reqURL := cli.BuildAdminURL("v2", "users", userID) - _, err := cli.MakeRequest(ctx, http.MethodPut, reqURL, &req, nil) + _, err := cli.Client.MakeRequest(ctx, http.MethodPut, reqURL, &req, nil) return err } @@ -140,7 +153,7 @@ type ReqSetRatelimit = RatelimitOverride // https://matrix-org.github.io/synapse/latest/admin_api/user_admin_api.html#set-ratelimit func (cli *Client) SetUserRatelimit(ctx context.Context, userID id.UserID, req ReqSetRatelimit) error { reqURL := cli.BuildAdminURL("v1", "users", userID, "override_ratelimit") - _, err := cli.MakeRequest(ctx, http.MethodPost, reqURL, &req, nil) + _, err := cli.Client.MakeRequest(ctx, http.MethodPost, reqURL, &req, nil) return err } @@ -150,7 +163,7 @@ type RespUserRatelimit = RatelimitOverride // // https://matrix-org.github.io/synapse/latest/admin_api/user_admin_api.html#get-status-of-ratelimit func (cli *Client) GetUserRatelimit(ctx context.Context, userID id.UserID) (resp RespUserRatelimit, err error) { - _, err = cli.MakeRequest(ctx, http.MethodGet, cli.BuildAdminURL("v1", "users", userID, "override_ratelimit"), nil, &resp) + _, err = cli.Client.MakeRequest(ctx, http.MethodGet, cli.BuildAdminURL("v1", "users", userID, "override_ratelimit"), nil, &resp) return } @@ -158,6 +171,6 @@ func (cli *Client) GetUserRatelimit(ctx context.Context, userID id.UserID) (resp // // https://matrix-org.github.io/synapse/latest/admin_api/user_admin_api.html#delete-ratelimit func (cli *Client) DeleteUserRatelimit(ctx context.Context, userID id.UserID) (err error) { - _, err = cli.MakeRequest(ctx, http.MethodDelete, cli.BuildAdminURL("v1", "users", userID, "override_ratelimit"), nil, nil) + _, err = cli.Client.MakeRequest(ctx, http.MethodDelete, cli.BuildAdminURL("v1", "users", userID, "override_ratelimit"), nil, nil) return } diff --git a/sync.go b/sync.go index d4208404..598df8e0 100644 --- a/sync.go +++ b/sync.go @@ -90,6 +90,7 @@ func (s *DefaultSyncer) ProcessResponse(ctx context.Context, res *RespSync, sinc err = fmt.Errorf("ProcessResponse panicked! since=%s panic=%s\n%s", since, r, debug.Stack()) } }() + ctx = context.WithValue(ctx, SyncTokenContextKey, since) for _, listener := range s.syncListeners { if !listener(ctx, res, since) { @@ -97,33 +98,38 @@ func (s *DefaultSyncer) ProcessResponse(ctx context.Context, res *RespSync, sinc } } - s.processSyncEvents(ctx, "", res.ToDevice.Events, event.SourceToDevice) - s.processSyncEvents(ctx, "", res.Presence.Events, event.SourcePresence) - s.processSyncEvents(ctx, "", res.AccountData.Events, event.SourceAccountData) + s.processSyncEvents(ctx, "", res.ToDevice.Events, event.SourceToDevice, false) + s.processSyncEvents(ctx, "", res.Presence.Events, event.SourcePresence, false) + s.processSyncEvents(ctx, "", res.AccountData.Events, event.SourceAccountData, false) for roomID, roomData := range res.Rooms.Join { - s.processSyncEvents(ctx, roomID, roomData.State.Events, event.SourceJoin|event.SourceState) - s.processSyncEvents(ctx, roomID, roomData.Timeline.Events, event.SourceJoin|event.SourceTimeline) - s.processSyncEvents(ctx, roomID, roomData.Ephemeral.Events, event.SourceJoin|event.SourceEphemeral) - s.processSyncEvents(ctx, roomID, roomData.AccountData.Events, event.SourceJoin|event.SourceAccountData) + if roomData.StateAfter == nil { + s.processSyncEvents(ctx, roomID, roomData.State.Events, event.SourceJoin|event.SourceState, false) + s.processSyncEvents(ctx, roomID, roomData.Timeline.Events, event.SourceJoin|event.SourceTimeline, false) + } else { + s.processSyncEvents(ctx, roomID, roomData.Timeline.Events, event.SourceJoin|event.SourceTimeline, true) + s.processSyncEvents(ctx, roomID, roomData.StateAfter.Events, event.SourceJoin|event.SourceState, false) + } + s.processSyncEvents(ctx, roomID, roomData.Ephemeral.Events, event.SourceJoin|event.SourceEphemeral, false) + s.processSyncEvents(ctx, roomID, roomData.AccountData.Events, event.SourceJoin|event.SourceAccountData, false) } for roomID, roomData := range res.Rooms.Invite { - s.processSyncEvents(ctx, roomID, roomData.State.Events, event.SourceInvite|event.SourceState) + s.processSyncEvents(ctx, roomID, roomData.State.Events, event.SourceInvite|event.SourceState, false) } for roomID, roomData := range res.Rooms.Leave { - s.processSyncEvents(ctx, roomID, roomData.State.Events, event.SourceLeave|event.SourceState) - s.processSyncEvents(ctx, roomID, roomData.Timeline.Events, event.SourceLeave|event.SourceTimeline) + s.processSyncEvents(ctx, roomID, roomData.State.Events, event.SourceLeave|event.SourceState, false) + s.processSyncEvents(ctx, roomID, roomData.Timeline.Events, event.SourceLeave|event.SourceTimeline, false) } return } -func (s *DefaultSyncer) processSyncEvents(ctx context.Context, roomID id.RoomID, events []*event.Event, source event.Source) { +func (s *DefaultSyncer) processSyncEvents(ctx context.Context, roomID id.RoomID, events []*event.Event, source event.Source, ignoreState bool) { for _, evt := range events { - s.processSyncEvent(ctx, roomID, evt, source) + s.processSyncEvent(ctx, roomID, evt, source, ignoreState) } } -func (s *DefaultSyncer) processSyncEvent(ctx context.Context, roomID id.RoomID, evt *event.Event, source event.Source) { +func (s *DefaultSyncer) processSyncEvent(ctx context.Context, roomID id.RoomID, evt *event.Event, source event.Source, ignoreState bool) { evt.RoomID = roomID // Ensure the type class is correct. It's safe to mutate the class since the event type is not a pointer. @@ -149,6 +155,7 @@ func (s *DefaultSyncer) processSyncEvent(ctx context.Context, roomID id.RoomID, } evt.Mautrix.EventSource = source + evt.Mautrix.IgnoreState = ignoreState s.Dispatch(ctx, evt) } @@ -191,8 +198,8 @@ func (s *DefaultSyncer) OnFailedSync(res *RespSync, err error) (time.Duration, e } var defaultFilter = Filter{ - Room: RoomFilter{ - Timeline: FilterPart{ + Room: &RoomFilter{ + Timeline: &FilterPart{ Limit: 50, }, }, @@ -257,7 +264,7 @@ func dontProcessOldEvents(userID id.UserID, resp *RespSync, since string) bool { // cli.Syncer.(mautrix.ExtensibleSyncer).OnSync(cli.MoveInviteState) func (cli *Client) MoveInviteState(ctx context.Context, resp *RespSync, _ string) bool { for _, meta := range resp.Rooms.Invite { - var inviteState []event.StrippedState + var inviteState []*event.Event var inviteEvt *event.Event for _, evt := range meta.State.Events { if evt.Type == event.StateMember && evt.GetStateKey() == cli.UserID.String() { @@ -265,12 +272,7 @@ func (cli *Client) MoveInviteState(ctx context.Context, resp *RespSync, _ string } else { evt.Type.Class = event.StateEventType _ = evt.Content.ParseRaw(evt.Type) - inviteState = append(inviteState, event.StrippedState{ - Content: evt.Content, - Type: evt.Type, - StateKey: evt.GetStateKey(), - Sender: evt.Sender, - }) + inviteState = append(inviteState, evt) } } if inviteEvt != nil { diff --git a/url.go b/url.go index f35ae5e2..91b3d49d 100644 --- a/url.go +++ b/url.go @@ -57,13 +57,13 @@ func BuildURL(baseURL *url.URL, path ...any) *url.URL { // BuildURL builds a URL with the Client's homeserver and appservice user ID set already. func (cli *Client) BuildURL(urlPath PrefixableURLPath) string { - return cli.BuildURLWithQuery(urlPath, nil) + return cli.BuildURLWithFullQuery(urlPath, nil) } // BuildClientURL builds a URL with the Client's homeserver and appservice user ID set already. // This method also automatically prepends the client API prefix (/_matrix/client). func (cli *Client) BuildClientURL(urlPath ...any) string { - return cli.BuildURLWithQuery(ClientURLPath(urlPath), nil) + return cli.BuildURLWithFullQuery(ClientURLPath(urlPath), nil) } type PrefixableURLPath interface { @@ -97,6 +97,19 @@ func (saup SynapseAdminURLPath) FullPath() []any { // BuildURLWithQuery builds a URL with query parameters in addition to the Client's homeserver // and appservice user ID set already. func (cli *Client) BuildURLWithQuery(urlPath PrefixableURLPath, urlQuery map[string]string) string { + return cli.BuildURLWithFullQuery(urlPath, func(q url.Values) { + for k, v := range urlQuery { + q.Set(k, v) + } + }) +} + +// BuildURLWithQuery builds a URL with query parameters in addition to the Client's homeserver +// and appservice user ID set already. +func (cli *Client) BuildURLWithFullQuery(urlPath PrefixableURLPath, fn func(q url.Values)) string { + if cli == nil { + return "client is nil" + } hsURL := *BuildURL(cli.HomeserverURL, urlPath.FullPath()...) query := hsURL.Query() if cli.SetAppServiceUserID { @@ -106,10 +119,8 @@ func (cli *Client) BuildURLWithQuery(urlPath PrefixableURLPath, urlQuery map[str query.Set("device_id", string(cli.DeviceID)) query.Set("org.matrix.msc3202.device_id", string(cli.DeviceID)) } - if urlQuery != nil { - for k, v := range urlQuery { - query.Set(k, v) - } + if fn != nil { + fn(query) } hsURL.RawQuery = query.Encode() return hsURL.String() diff --git a/version.go b/version.go index 362a684b..f00bbf39 100644 --- a/version.go +++ b/version.go @@ -4,10 +4,11 @@ import ( "fmt" "regexp" "runtime" + "runtime/debug" "strings" ) -const Version = "v0.22.1" +const Version = "v0.26.3" var GoModVersion = "" var Commit = "" @@ -15,11 +16,20 @@ var VersionWithCommit = Version var DefaultUserAgent = "mautrix-go/" + Version + " go/" + strings.TrimPrefix(runtime.Version(), "go") -var goModVersionRegex = regexp.MustCompile(`v.+\d{14}-([0-9a-f]{12})`) - func init() { + if GoModVersion == "" { + info, _ := debug.ReadBuildInfo() + if info != nil { + for _, mod := range info.Deps { + if mod.Path == "maunium.net/go/mautrix" { + GoModVersion = mod.Version + break + } + } + } + } if GoModVersion != "" { - match := goModVersionRegex.FindStringSubmatch(GoModVersion) + match := regexp.MustCompile(`v.+\d{14}-([0-9a-f]{12})`).FindStringSubmatch(GoModVersion) if match != nil { Commit = match[1] } diff --git a/versions.go b/versions.go index 5c0d6eaa..61b2e4ea 100644 --- a/versions.go +++ b/versions.go @@ -60,18 +60,28 @@ type UnstableFeature struct { } var ( - FeatureAsyncUploads = UnstableFeature{UnstableFlag: "fi.mau.msc2246.stable", SpecVersion: SpecV17} - FeatureAppservicePing = UnstableFeature{UnstableFlag: "fi.mau.msc2659.stable", SpecVersion: SpecV17} - FeatureAuthenticatedMedia = UnstableFeature{UnstableFlag: "org.matrix.msc3916.stable", SpecVersion: SpecV111} - FeatureMutualRooms = UnstableFeature{UnstableFlag: "uk.half-shot.msc2666.query_mutual_rooms"} + FeatureAsyncUploads = UnstableFeature{UnstableFlag: "fi.mau.msc2246.stable", SpecVersion: SpecV17} + FeatureAppservicePing = UnstableFeature{UnstableFlag: "fi.mau.msc2659.stable", SpecVersion: SpecV17} + FeatureAuthenticatedMedia = UnstableFeature{UnstableFlag: "org.matrix.msc3916.stable", SpecVersion: SpecV111} + FeatureUnstableMutualRooms = UnstableFeature{UnstableFlag: "uk.half-shot.msc2666.query_mutual_rooms"} + FeatureStableMutualRooms = UnstableFeature{UnstableFlag: "uk.half-shot.msc2666.query_mutual_rooms.stable" /*, SpecVersion: SpecV118*/} + FeatureUserRedaction = UnstableFeature{UnstableFlag: "org.matrix.msc4194"} + FeatureViewRedactedContent = UnstableFeature{UnstableFlag: "fi.mau.msc2815"} + FeatureUnstableAccountModeration = UnstableFeature{UnstableFlag: "uk.timedout.msc4323"} + FeatureStableAccountModeration = UnstableFeature{UnstableFlag: "uk.timedout.msc4323.stable" /*, SpecVersion: SpecV118*/} + FeatureUnstableProfileFields = UnstableFeature{UnstableFlag: "uk.tcpip.msc4133"} + FeatureArbitraryProfileFields = UnstableFeature{UnstableFlag: "uk.tcpip.msc4133.stable", SpecVersion: SpecV116} + FeatureRedactSendAsEvent = UnstableFeature{UnstableFlag: "com.beeper.msc4169"} - BeeperFeatureHungry = UnstableFeature{UnstableFlag: "com.beeper.hungry"} - BeeperFeatureBatchSending = UnstableFeature{UnstableFlag: "com.beeper.batch_sending"} - BeeperFeatureRoomYeeting = UnstableFeature{UnstableFlag: "com.beeper.room_yeeting"} - BeeperFeatureAutojoinInvites = UnstableFeature{UnstableFlag: "com.beeper.room_create_autojoin_invites"} - BeeperFeatureArbitraryProfileMeta = UnstableFeature{UnstableFlag: "com.beeper.arbitrary_profile_meta"} - BeeperFeatureAccountDataMute = UnstableFeature{UnstableFlag: "com.beeper.account_data_mute"} - BeeperFeatureInboxState = UnstableFeature{UnstableFlag: "com.beeper.inbox_state"} + BeeperFeatureHungry = UnstableFeature{UnstableFlag: "com.beeper.hungry"} + BeeperFeatureBatchSending = UnstableFeature{UnstableFlag: "com.beeper.batch_sending"} + BeeperFeatureRoomYeeting = UnstableFeature{UnstableFlag: "com.beeper.room_yeeting"} + BeeperFeatureAutojoinInvites = UnstableFeature{UnstableFlag: "com.beeper.room_create_autojoin_invites"} + BeeperFeatureArbitraryProfileMeta = UnstableFeature{UnstableFlag: "com.beeper.arbitrary_profile_meta"} + BeeperFeatureAccountDataMute = UnstableFeature{UnstableFlag: "com.beeper.account_data_mute"} + BeeperFeatureInboxState = UnstableFeature{UnstableFlag: "com.beeper.inbox_state"} + BeeperFeatureArbitraryMemberChange = UnstableFeature{UnstableFlag: "com.beeper.arbitrary_member_change"} + BeeperFeatureEphemeralEvents = UnstableFeature{UnstableFlag: "com.beeper.ephemeral"} ) func (versions *RespVersions) Supports(feature UnstableFeature) bool { @@ -111,6 +121,12 @@ var ( SpecV19 = MustParseSpecVersion("v1.9") SpecV110 = MustParseSpecVersion("v1.10") SpecV111 = MustParseSpecVersion("v1.11") + SpecV112 = MustParseSpecVersion("v1.12") + SpecV113 = MustParseSpecVersion("v1.13") + SpecV114 = MustParseSpecVersion("v1.14") + SpecV115 = MustParseSpecVersion("v1.15") + SpecV116 = MustParseSpecVersion("v1.16") + SpecV117 = MustParseSpecVersion("v1.17") ) func (svf SpecVersionFormat) String() string {