diff --git a/.editorconfig b/.editorconfig index 21d312a1..1a167e7e 100644 --- a/.editorconfig +++ b/.editorconfig @@ -10,3 +10,6 @@ insert_final_newline = true [*.{yaml,yml}] indent_style = space + +[provisioning.yaml] +indent_size = 2 diff --git a/.github/workflows/go.yml b/.github/workflows/go.yml index 66f6aee1..c0add220 100644 --- a/.github/workflows/go.yml +++ b/.github/workflows/go.yml @@ -2,17 +2,20 @@ name: Go on: [push, pull_request] +env: + GOTOOLCHAIN: local + jobs: lint: runs-on: ubuntu-latest name: Lint (latest) steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@v6 - name: Set up Go - uses: actions/setup-go@v5 + uses: actions/setup-go@v6 with: - go-version: "1.22" + go-version: "1.26" cache: true - name: Install libolm @@ -21,27 +24,25 @@ jobs: - name: Install goimports run: | go install golang.org/x/tools/cmd/goimports@latest + go install honnef.co/go/tools/cmd/staticcheck@latest export PATH="$HOME/go/bin:$PATH" - - name: Install pre-commit - run: pip install pre-commit - - - name: Lint - run: pre-commit run -a + - name: Run pre-commit + uses: pre-commit/action@v3.0.1 build: runs-on: ubuntu-latest strategy: fail-fast: false matrix: - go-version: ["1.21", "1.22"] - name: Build (${{ matrix.go-version == '1.22' && 'latest' || 'old' }}, libolm) + go-version: ["1.25", "1.26"] + name: Build (${{ matrix.go-version == '1.26' && 'latest' || 'old' }}, libolm) steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@v6 - name: Set up Go ${{ matrix.go-version }} - uses: actions/setup-go@v5 + uses: actions/setup-go@v6 with: go-version: ${{ matrix.go-version }} cache: true @@ -60,30 +61,29 @@ jobs: - name: Test run: go test -json -v ./... 2>&1 | gotestfmt + - name: Test (jsonv2) + env: + GOEXPERIMENT: jsonv2 + run: go test -json -v ./... 2>&1 | gotestfmt + build-goolm: runs-on: ubuntu-latest strategy: fail-fast: false matrix: - go-version: ["1.21", "1.22"] - name: Build (${{ matrix.go-version == '1.22' && 'latest' || 'old' }}, goolm) + go-version: ["1.25", "1.26"] + name: Build (${{ matrix.go-version == '1.26' && 'latest' || 'old' }}, goolm) steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@v6 - name: Set up Go ${{ matrix.go-version }} - uses: actions/setup-go@v5 + uses: actions/setup-go@v6 with: go-version: ${{ matrix.go-version }} cache: true - - name: Set up gotestfmt - uses: GoTestTools/gotestfmt-action@v2 - with: - token: ${{ secrets.GITHUB_TOKEN }} - - name: Build - run: go build -tags=goolm -v ./... - - - name: Test - run: go test -tags=goolm -json -v ./... 2>&1 | gotestfmt + run: | + rm -rf crypto/libolm + go build -tags=goolm -v ./... diff --git a/.github/workflows/stale.yml b/.github/workflows/stale.yml new file mode 100644 index 00000000..9a9e7375 --- /dev/null +++ b/.github/workflows/stale.yml @@ -0,0 +1,29 @@ +name: 'Lock old issues' + +on: + schedule: + - cron: '0 6 * * *' + workflow_dispatch: + +permissions: + issues: write +# pull-requests: write +# discussions: write + +concurrency: + group: lock-threads + +jobs: + lock-stale: + runs-on: ubuntu-latest + steps: + - uses: dessant/lock-threads@v6 + id: lock + with: + issue-inactive-days: 90 + process-only: issues + - name: Log processed threads + run: | + if [ '${{ steps.lock.outputs.issues }}' ]; then + echo "Issues:" && echo '${{ steps.lock.outputs.issues }}' | jq -r '.[] | "https://github.com/\(.owner)/\(.repo)/issues/\(.issue_number)"' + fi diff --git a/.gitignore b/.gitignore index f37a7d0c..c01f2f30 100644 --- a/.gitignore +++ b/.gitignore @@ -1,4 +1,4 @@ .idea/ .vscode/ -*.db +*.db* *.log diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 5fffa9fb..616fccb2 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,6 +1,6 @@ repos: - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v4.4.0 + rev: v6.0.0 hooks: - id: trailing-whitespace exclude_types: [markdown] @@ -9,7 +9,7 @@ repos: - id: check-added-large-files - repo: https://github.com/tekwizely/pre-commit-golang - rev: v1.0.0-rc.1 + rev: v1.0.0-rc.4 hooks: - id: go-imports-repo args: @@ -17,8 +17,13 @@ repos: - "maunium.net/go/mautrix" - "-w" - id: go-vet-repo-mod + - id: go-mod-tidy + - id: go-staticcheck-repo-mod - repo: https://github.com/beeper/pre-commit-go - rev: v0.3.1 + rev: v0.4.2 hooks: - id: prevent-literal-http-methods + - id: zerolog-ban-global-log + - id: zerolog-ban-msgf + - id: zerolog-use-stringer diff --git a/CHANGELOG.md b/CHANGELOG.md index cece9947..f2829199 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,602 @@ +## v0.26.3 (2026-02-16) + +* Bumped minimum Go version to 1.25. +* *(client)* Added fields for sending [MSC4354] sticky events. +* *(bridgev2)* Added automatic message request accepting when sending message. +* *(mediaproxy)* Added support for federation thumbnail endpoint. +* *(crypto/ssss)* Improved support for recovery keys with slightly broken + metadata. +* *(crypto)* Changed key import to call session received callback even for + sessions that already exist in the database. +* *(appservice)* Fixed building websocket URL accidentally using file path + separators instead of always `/`. +* *(crypto)* Fixed key exports not including the `sender_claimed_keys` field. +* *(client)* Fixed incorrect context usage in async uploads. +* *(crypto)* Fixed panic when passing invalid input to megolm message index + parser used for debugging. +* *(bridgev2/provisioning)* Fixed completed or failed logins not being cleaned + up properly. + +[MSC4354]: https://github.com/matrix-org/matrix-spec-proposals/pull/4354 + +## v0.26.2 (2026-01-16) + +* *(bridgev2)* Added chunked portal deletion to avoid database locks when + deleting large portals. +* *(crypto,bridgev2)* Added option to encrypt reaction and reply metadata + as per [MSC4392]. +* *(bridgev2/login)* Added `default_value` for user input fields. +* *(bridgev2)* Added interfaces to let the Matrix connector provide suggested + HTTP client settings and to reset active connections of the network connector. +* *(bridgev2)* Added interface to let network connectors get the provisioning + API HTTP router and add new endpoints. +* *(event)* Added blurhash field to Beeper link preview objects. +* *(event)* Added [MSC4391] support for bot commands. +* *(event)* Dropped [MSC4332] support for bot commands. +* *(client)* Changed media download methods to return an error if the provided + MXC URI is empty. +* *(client)* Stabilized support for [MSC4323]. +* *(bridgev2/matrix)* Fixed `GetEvent` panicking when trying to decrypt events. +* *(bridgev2)* Fixed some deadlocks when room creation happens in parallel with + a portal re-ID call. + +[MSC4391]: https://github.com/matrix-org/matrix-spec-proposals/pull/4391 +[MSC4392]: https://github.com/matrix-org/matrix-spec-proposals/pull/4392 + +## v0.26.1 (2025-12-16) + +* **Breaking change *(mediaproxy)*** Changed `GetMediaResponseFile` to return + the mime type from the callback rather than in the return get media return + value. The callback can now also redirect the caller to a different file. +* *(federation)* Added join/knock/leave functions + (thanks to [@nexy7574] in [#422]). +* *(federation/eventauth)* Fixed various incorrect checks. +* *(client)* Added backoff for retrying media uploads to external URLs + (with MSC3870). +* *(bridgev2/config)* Added support for overriding config fields using + environment variables. +* *(bridgev2/commands)* Added command to mute chat on remote network. +* *(bridgev2)* Added interface for network connectors to redirect to a different + user ID when handling an invite from Matrix. +* *(bridgev2)* Added interface for signaling message request status of portals. +* *(bridgev2)* Changed portal creation to not backfill unless `CanBackfill` flag + is set in chat info. +* *(bridgev2)* Changed Matrix reaction handling to only delete old reaction if + bridging the new one is successful. +* *(bridgev2/mxmain)* Improved error message when trying to run bridge with + pre-megabridge database when no database migration exists. +* *(bridgev2)* Improved reliability of database migration when enabling split + portals. +* *(bridgev2)* Improved detection of orphaned DM rooms when starting new chats. +* *(bridgev2)* Stopped sending redundant invites when joining ghosts to public + portal rooms. +* *(bridgev2)* Stopped hardcoding room versions in favor of checking + server capabilities to determine appropriate `/createRoom` parameters. + +[#422]: https://github.com/mautrix/go/pull/422 + +## v0.26.0 (2025-11-16) + +* *(client,appservice)* Deprecated `SendMassagedStateEvent` as `SendStateEvent` + has been able to do the same for a while now. +* *(client,federation)* Added size limits for responses to make it safer to send + requests to untrusted servers. +* *(client)* Added wrapper for `/admin/whois` client API + (thanks to [@nexy7574] in [#411]). +* *(synapseadmin)* Added `force_purge` option to DeleteRoom + (thanks to [@nexy7574] in [#420]). +* *(statestore)* Added saving join rules for rooms. +* *(bridgev2)* Added optional automatic rollback of room state if bridging the + change to the remote network fails. +* *(bridgev2)* Added management room notices if transient disconnect state + doesn't resolve within 3 minutes. +* *(bridgev2)* Added interface to signal that certain participants couldn't be + invited when creating a group. +* *(bridgev2)* Added `select` type for user input fields in login. +* *(bridgev2)* Added interface to let network connector customize personal + filtering space. +* *(bridgev2/matrix)* Added checks to avoid sending error messages in reply to + other bots. +* *(bridgev2/matrix)* Switched to using [MSC4169] to send redactions whenever + possible. +* *(bridgev2/publicmedia)* Added support for custom path prefixes, file names, + and encrypted files. +* *(bridgev2/commands)* Added command to resync a single portal. +* *(bridgev2/commands)* Added create group command. +* *(bridgev2/config)* Added option to limit maximum number of logins. +* *(bridgev2)* Changed ghost joining to skip unnecessary invite if portal room + is public. +* *(bridgev2/disappear)* Changed read receipt handling to only start + disappearing timers for messages up to the read message (note: may not work in + all cases if the read receipt points at an unknown event). +* *(event/reply)* Changed plaintext reply fallback removal to only happen when + an HTML reply fallback is removed successfully. +* *(bridgev2/matrix)* Fixed unnecessary sleep after registering bot on first run. +* *(crypto/goolm)* Fixed panic when processing certain malformed Olm messages. +* *(federation)* Fixed HTTP method for sending transactions + (thanks to [@nexy7574] in [#426]). +* *(federation)* Fixed response body being closed even when using `DontReadBody` + parameter. +* *(federation)* Fixed validating auth for requests with query params. +* *(federation/eventauth)* Fixed typo causing restricted joins to not work. + +[MSC4169]: https://github.com/matrix-org/matrix-spec-proposals/pull/4169 +[#411]: github.com/mautrix/go/pull/411 +[#420]: github.com/mautrix/go/pull/420 +[#426]: github.com/mautrix/go/pull/426 + +## v0.25.2 (2025-10-16) + +* **Breaking change *(id)*** Split `UserID.ParseAndValidate` into + `ParseAndValidateRelaxed` and `ParseAndValidateStrict`. Strict is the old + behavior, but most users likely want the relaxed version, as there are real + users whose user IDs aren't valid under the strict rules. +* *(crypto)* Added helper methods for generating and verifying with recovery + keys. +* *(bridgev2/matrix)* Added config option to automatically generate a recovery + key for the bridge bot and self-sign the bridge's device. +* *(bridgev2/matrix)* Added initial support for using appservice/MSC3202 mode + for encryption with standard servers like Synapse. +* *(bridgev2)* Added optional support for implicit read receipts. +* *(bridgev2)* Added interface for deleting chats on remote network. +* *(bridgev2)* Added local enforcement of media duration and size limits. +* *(bridgev2)* Extended event duration logging to log any event taking too long. +* *(bridgev2)* Improved validation in group creation provisioning API. +* *(event)* Added event type constant for poll end events. +* *(client)* Added wrapper for searching user directory. +* *(client)* Improved support for managing [MSC4140] delayed events. +* *(crypto/helper)* Changed default sync handling to not block on waiting for + decryption keys. On initial sync, keys won't be requested at all by default. +* *(crypto)* Fixed olm unwedging not working (regressed in v0.25.1). +* *(bridgev2)* Fixed various bugs with migrating to split portals. +* *(event)* Fixed poll start events having incorrect null `m.relates_to`. +* *(client)* Fixed `RespUserProfile` losing standard fields when re-marshaling. +* *(federation)* Fixed various bugs in event auth. + +## v0.25.1 (2025-09-16) + +* *(client)* Fixed HTTP method of delete devices API call + (thanks to [@fmseals] in [#393]). +* *(client)* Added wrappers for [MSC4323]: User suspension & locking endpoints + (thanks to [@nexy7574] in [#407]). +* *(client)* Stabilized support for extensible profiles. +* *(client)* Stabilized support for `state_after` in sync. +* *(client)* Removed deprecated MSC2716 requests. +* *(crypto)* Added fallback to ensure `m.relates_to` is always copied even if + the content struct doesn't implement `Relatable`. +* *(crypto)* Changed olm unwedging to ignore newly created sessions if they + haven't been used successfully in either direction. +* *(federation)* Added utilities for generating, parsing, validating and + authorizing PDUs. + * Note: the new PDU code depends on `GOEXPERIMENT=jsonv2` +* *(event)* Added `is_animated` flag from [MSC4230] to file info. +* *(event)* Added types for [MSC4332]: In-room bot commands. +* *(event)* Added missing poll end event type for [MSC3381]. +* *(appservice)* Fixed URLs not being escaped properly when using unix socket + for homeserver connections. +* *(format)* Added more helpers for forming markdown links. +* *(event,bridgev2)* Added support for Beeper's disappearing message state event. +* *(bridgev2)* Redesigned group creation interface and added support in commands + and provisioning API. +* *(bridgev2)* Added GetEvent to Matrix interface to allow network connectors to + get an old event. The method is best effort only, as some configurations don't + allow fetching old events. +* *(bridgev2)* Added shared logic for provisioning that can be reused by the + API, commands and other sources. +* *(bridgev2)* Fixed mentions and URL previews not being copied over when + caption and media are merged. +* *(bridgev2)* Removed config option to change provisioning API prefix, which + had already broken in the previous release. + +[@fmseals]: https://github.com/fmseals +[#393]: https://github.com/mautrix/go/pull/393 +[#407]: https://github.com/mautrix/go/pull/407 +[MSC3381]: https://github.com/matrix-org/matrix-spec-proposals/pull/3381 +[MSC4230]: https://github.com/matrix-org/matrix-spec-proposals/pull/4230 +[MSC4323]: https://github.com/matrix-org/matrix-spec-proposals/pull/4323 +[MSC4332]: https://github.com/matrix-org/matrix-spec-proposals/pull/4332 + +## v0.25.0 (2025-08-16) + +* Bumped minimum Go version to 1.24. +* **Breaking change *(appservice,bridgev2,federation)*** Replaced gorilla/mux + with standard library ServeMux. +* *(client,bridgev2)* Added support for creator power in room v12. +* *(client)* Added option to not set `User-Agent` header for improved Wasm + compatibility. +* *(bridgev2)* Added support for following tombstones. +* *(bridgev2)* Added interface for getting arbitrary state event from Matrix. +* *(bridgev2)* Added batching to disappearing message queue to ensure it doesn't + use too many resources even if there are a large number of messages. +* *(bridgev2/commands)* Added support for canceling QR login with `cancel` + command. +* *(client)* Added option to override HTTP client used for .well-known + resolution. +* *(crypto/backup)* Added method for encrypting key backup session without + private keys. +* *(event->id)* Moved room version type and constants to id package. +* *(bridgev2)* Bots in DM portals will now be added to the functional members + state event to hide them from the room name calculation. +* *(bridgev2)* Changed message delete handling to ignore "delete for me" events + if there are multiple Matrix users in the room. +* *(format/htmlparser)* Changed text processing to collapse multiple spaces into + one when outside `pre`/`code` tags. +* *(format/htmlparser)* Removed link suffix in plaintext output when link text + is only missing protocol part of href. + * e.g. `example.com` will turn into + `example.com` rather than `example.com (https://example.com)` +* *(appservice)* Switched appservice websockets from gorilla/websocket to + coder/websocket. +* *(bridgev2/matrix)* Fixed encryption key sharing not ignoring ghosts properly. +* *(crypto/attachments)* Fixed hash check when decrypting file streams. +* *(crypto)* Removed unnecessary `AlreadyShared` error in `ShareGroupSession`. + The function will now act as if it was successful instead. + +## v0.24.2 (2025-07-16) + +* *(bridgev2)* Added support for return values from portal event handlers. Note + that the return value will always be "queued" unless the event buffer is + disabled. +* *(bridgev2)* Added support for [MSC4144] per-message profile passthrough in + relay mode. +* *(bridgev2)* Added option to auto-reconnect logins after a certain period if + they hit an `UNKNOWN_ERROR` state. +* *(bridgev2)* Added analytics for event handler panics. +* *(bridgev2)* Changed new room creation to hardcode room v11 to avoid v12 rooms + being created before proper support for them can be added. +* *(bridgev2)* Changed queuing events to block instead of dropping events if the + buffer is full. +* *(bridgev2)* Fixed assumption that replies to unknown messages are cross-room. +* *(id)* Fixed server name validation not including ports correctly + (thanks to [@krombel] in [#392]). +* *(federation)* Fixed base64 algorithm in signature generation. +* *(event)* Fixed [MSC4144] fallbacks not being removed from edits. + +[@krombel]: https://github.com/krombel +[#392]: https://github.com/mautrix/go/pull/392 + +## v0.24.1 (2025-06-16) + +* *(commands)* Added framework for using reactions as buttons that execute + command handlers. +* *(client)* Added wrapper for `/relations` endpoints. +* *(client)* Added support for stable version of room summary endpoint. +* *(client)* Fixed parsing URL preview responses where width/height are strings. +* *(federation)* Fixed bugs in server auth. +* *(id)* Added utilities for validating server names. +* *(event)* Fixed incorrect empty `entity` field when sending hashed moderation + policy events. +* *(event)* Added [MSC4293] redact events field to member events. +* *(event)* Added support for fallbacks in [MSC4144] per-message profiles. +* *(format)* Added `MarkdownLink` and `MarkdownMention` utility functions for + generating properly escaped markdown. +* *(synapseadmin)* Added support for synchronous (v1) room delete endpoint. +* *(synapseadmin)* Changed `Client` struct to not embed the `mautrix.Client`. + This is a breaking change if you were relying on accessing non-admin functions + from the admin client. +* *(bridgev2/provisioning)* Fixed `/display_and_wait` not passing through errors + from the network connector properly. +* *(bridgev2/crypto)* Fixed encryption not working if the user's ID had the same + prefix as the bridge ghosts (e.g. `@whatsappbridgeuser:example.com` with a + `@whatsapp_` prefix). +* *(bridgev2)* Fixed portals not being saved after creating a DM portal from a + Matrix DM invite. +* *(bridgev2)* Added config option to determine whether cross-room replies + should be bridged. +* *(appservice)* Fixed `EnsureRegistered` not being called when sending a custom + member event for the controlled user. + +[MSC4293]: https://github.com/matrix-org/matrix-spec-proposals/pull/4293 + +## v0.24.0 (2025-05-16) + +* *(commands)* Added generic framework for implementing bot commands. +* *(client)* Added support for specifying maximum number of HTTP retries using + a context value instead of having to call `MakeFullRequest` manually. +* *(client,federation)* Added methods for fetching room directories. +* *(federation)* Added support for server side of request authentication. +* *(synapseadmin)* Added wrapper for the account suspension endpoint. +* *(format)* Added method for safely wrapping a string in markdown inline code. +* *(crypto)* Added method to import key backup without persisting to database, + to allow the client more control over the process. +* *(bridgev2)* Added viewing chat interface to signal when the user is viewing + a given chat. +* *(bridgev2)* Added option to pass through transaction ID from client when + sending messages to remote network. +* *(crypto)* Fixed unnecessary error log when decrypting dummy events used for + unwedging Olm sessions. +* *(crypto)* Fixed `forwarding_curve25519_key_chain` not being set consistently + when backing up keys. +* *(event)* Fixed marshaling legacy VoIP events with no version field. +* *(bridgev2)* Fixed disappearing message references not being deleted when the + portal is deleted. +* *(bridgev2)* Fixed read receipt bridging not ignoring fake message entries + and causing unnecessary error logs. + +## v0.23.3 (2025-04-16) + +* *(commands)* Added generic command processing framework for bots. +* *(client)* Added `allowed_room_ids` field to room summary responses + (thanks to [@nexy7574] in [#367]). +* *(bridgev2)* Added support for custom timeouts on outgoing messages which have + to wait for a remote echo. +* *(bridgev2)* Added automatic typing stop event if the ghost user had sent a + typing event before a message. +* *(bridgev2)* The saved management room is now cleared if the user leaves the + room, allowing the next DM to be automatically marked as a management room. +* *(bridge)* Removed deprecated fallback package for bridge statuses. + The status package is now only available under bridgev2. + +[#367]: https://github.com/mautrix/go/pull/367 + +## v0.23.2 (2025-03-16) + +* **Breaking change *(bridge)*** Removed legacy bridge module. +* **Breaking change *(event)*** Changed `m.federate` field in room create event + content to a pointer to allow detecting omitted values. +* *(bridgev2/commands)* Added `set-management-room` command to set a new + management room. +* *(bridgev2/portal)* Changed edit bridging to ignore remote edits if the + original sender on Matrix can't be puppeted. +* *(bridgv2)* Added config option to disable bridging `m.notice` messages. +* *(appservice/http)* Switched access token validation to use constant time + comparisons. +* *(event)* Added support for [MSC3765] rich text topics. +* *(event)* Added fields to policy list event contents for [MSC4204] and + [MSC4205]. +* *(client)* Added method for getting the content of a redacted event using + [MSC2815]. +* *(client)* Added methods for sending and updating [MSC4140] delayed events. +* *(client)* Added support for [MSC4222] in sync payloads. +* *(crypto/cryptohelper)* Switched to using `sqlite3-fk-wal` instead of plain + `sqlite3` by default. +* *(crypto/encryptolm)* Added generic method for encrypting to-device events. +* *(crypto/ssss)* Fixed panic if server-side key metadata is corrupted. +* *(crypto/sqlstore)* Fixed error when marking over 32 thousand device lists + as outdated on SQLite. + +[MSC2815]: https://github.com/matrix-org/matrix-spec-proposals/pull/2815 +[MSC3765]: https://github.com/matrix-org/matrix-spec-proposals/pull/3765 +[MSC4140]: https://github.com/matrix-org/matrix-spec-proposals/pull/4140 +[MSC4204]: https://github.com/matrix-org/matrix-spec-proposals/pull/4204 +[MSC4205]: https://github.com/matrix-org/matrix-spec-proposals/pull/4205 +[MSC4222]: https://github.com/matrix-org/matrix-spec-proposals/pull/4222 + +## v0.23.1 (2025-02-16) + +* *(client)* Added `FullStateEvent` method to get a state event including + metadata (using the `?format=event` query parameter). +* *(client)* Added wrapper method for [MSC4194]'s redact endpoint. +* *(pushrules)* Fixed content rules not considering word boundaries and being + case-sensitive. +* *(crypto)* Fixed bugs that would cause key exports to fail for no reason. +* *(crypto)* Deprecated `ResolveTrust` in favor of `ResolveTrustContext`. +* *(crypto)* Stopped accepting secret shares from unverified devices. +* **Breaking change *(crypto)*** Changed `GetAndVerifyLatestKeyBackupVersion` + to take an optional private key parameter. The method will now trust the + public key if it matches the provided private key even if there are no valid + signatures. +* **Breaking change *(crypto)*** Added context parameter to `IsDeviceTrusted`. + +[MSC4194]: https://github.com/matrix-org/matrix-spec-proposals/pull/4194 + +## v0.23.0 (2025-01-16) + +* **Breaking change *(client)*** Changed `JoinRoom` parameters to allow multiple + `via`s. +* **Breaking change *(bridgev2)*** Updated capability system. + * The return type of `NetworkAPI.GetCapabilities` is now different. + * Media type capabilities are enforced automatically by bridgev2. + * Capabilities are now sent to Matrix rooms using the + `com.beeper.room_features` state event. +* *(client)* Added `GetRoomSummary` to implement [MSC3266]. +* *(client)* Added support for arbitrary profile fields to implement [MSC4133] + (thanks to [@nexy7574] in [#337]). +* *(crypto)* Started storing olm message hashes to prevent decryption errors + if messages are repeated (e.g. if the app crashes right after decrypting). +* *(crypto)* Improved olm session unwedging to check when the last session was + created instead of only relying on an in-memory map. +* *(crypto/verificationhelper)* Fixed emoji verification not doing cross-signing + properly after a successful verification. +* *(bridgev2/config)* Moved MSC4190 flag from `appservice` to `encryption`. +* *(bridgev2/space)* Fixed failing to add rooms to spaces if the room create + call was made with a temporary context. +* *(bridgev2/commands)* Changed `help` command to hide commands which require + interfaces that aren't implemented by the network connector. +* *(bridgev2/matrixinterface)* Moved deterministic room ID generation to Matrix + connector. +* *(bridgev2)* Fixed service member state event not being set correctly when + creating a DM by inviting a ghost user. +* *(bridgev2)* Fixed `RemoteReactionSync` events replacing all reactions every + time instead of only changed ones. + +[MSC3266]: https://github.com/matrix-org/matrix-spec-proposals/pull/3266 +[MSC4133]: https://github.com/matrix-org/matrix-spec-proposals/pull/4133 +[@nexy7574]: https://github.com/nexy7574 +[#337]: https://github.com/mautrix/go/pull/337 + +## v0.22.1 (2024-12-16) + +* *(crypto)* Added automatic cleanup when there are too many olm sessions with + a single device. +* *(crypto)* Added helper for getting cached device list with cross-signing + status. +* *(crypto/verificationhelper)* Added interface for persisting the state of + in-progress verifications. +* *(client)* Added `GetMutualRooms` wrapper for [MSC2666]. +* *(client)* Switched `JoinRoom` to use the `via` query param instead of + `server_name` as per [MSC4156]. +* *(bridgev2/commands)* Fixed `pm` command not actually starting the chat. +* *(bridgev2/interface)* Added separate network API interface for starting + chats with a Matrix ghost user. This allows treating internal user IDs + differently than arbitrary user-input strings. +* *(bridgev2/crypto)* Added support for [MSC4190] + (thanks to [@onestacked] in [#288]). + +[MSC2666]: https://github.com/matrix-org/matrix-spec-proposals/pull/2666 +[MSC4156]: https://github.com/matrix-org/matrix-spec-proposals/pull/4156 +[MSC4190]: https://github.com/matrix-org/matrix-spec-proposals/pull/4190 +[#288]: https://github.com/mautrix/go/pull/288 +[@onestacked]: https://github.com/onestacked + +## v0.22.0 (2024-11-16) + +* *(hicli)* Moved package into gomuks repo. +* *(bridgev2/commands)* Fixed cookie unescaping in login commands. +* *(bridgev2/portal)* Added special `DefaultChatName` constant to explicitly + reset portal names to the default (based on members). +* *(bridgev2/config)* Added options to disable room tag bridging. +* *(bridgev2/database)* Fixed reaction queries not including portal receiver. +* *(appservice)* Updated [MSC2409] stable registration field name from + `push_ephemeral` to `receive_ephemeral`. Homeserver admins must update + existing registrations manually. +* *(format)* Added support for `img` tags. +* *(format/mdext)* Added goldmark extensions for Matrix math and custom emojis. +* *(event/reply)* Removed support for generating reply fallbacks ([MSC2781]). +* *(pushrules)* Added support for `sender_notification_permission` condition + kind (used for `@room` mentions). +* *(crypto)* Added support for `json.RawMessage` in `EncryptMegolmEvent`. +* *(mediaproxy)* Added `GetMediaResponseCallback` and `GetMediaResponseFile` + to write proxied data directly to http response or temp file instead of + having to use an `io.Reader`. +* *(mediaproxy)* Dropped support for legacy media download endpoints. +* *(mediaproxy,bridgev2)* Made interface pass through query parameters. + +[MSC2781]: https://github.com/matrix-org/matrix-spec-proposals/pull/2781 + +## v0.21.1 (2024-10-16) + +* *(bridgev2)* Added more features and fixed bugs. +* *(hicli)* Added more features and fixed bugs. +* *(appservice)* Removed TLS support. A reverse proxy should be used if TLS + is needed. +* *(format/mdext)* Added goldmark extension to fix indented paragraphs when + disabling indented code block parser. +* *(event)* Added `Has` method for `Mentions`. +* *(event)* Added basic support for the unstable version of polls. + +## v0.21.0 (2024-09-16) + +* **Breaking change *(client)*** Dropped support for unauthenticated media. + Matrix v1.11 support is now required from the homeserver, although it's not + enforced using `/versions` as some servers don't advertise it. +* *(bridgev2)* Added more features and fixed bugs. +* *(appservice,crypto)* Added support for using MSC3202 for appservice + encryption. +* *(crypto/olm)* Made everything into an interface to allow side-by-side + testing of libolm and goolm, as well as potentially support vodozemac + in the future. +* *(client)* Fixed requests being retried even after context is canceled. +* *(client)* Added option to move `/sync` request logs to trace level. +* *(error)* Added `Write` and `WithMessage` helpers to `RespError` to make it + easier to use on servers. +* *(event)* Fixed `org.matrix.msc1767.audio` field allowing omitting the + duration and waveform. +* *(id)* Changed `MatrixURI` methods to not panic if the receiver is nil. +* *(federation)* Added limit to response size when fetching `.well-known` files. + +## v0.20.0 (2024-08-16) + +* Bumped minimum Go version to 1.22. +* *(bridgev2)* Added more features and fixed bugs. +* *(event)* Added types for [MSC4144]: Per-message profiles. +* *(federation)* Added implementation of server name resolution and a basic + client for making federation requests. +* *(crypto/ssss)* Changed recovery key/passphrase verify functions to take the + key ID as a parameter to ensure it's correctly set even if the key metadata + wasn't fetched via `GetKeyData`. +* *(format/mdext)* Added goldmark extensions for single-character bold, italic + and strikethrough parsing (as in `*foo*` -> **foo**, `_foo_` -> _foo_ and + `~foo~` -> ~~foo~~) +* *(format)* Changed `RenderMarkdown` et al to always include `m.mentions` in + returned content. The mention list is filled with matrix.to URLs from the + input by default. + +[MSC4144]: https://github.com/matrix-org/matrix-spec-proposals/pull/4144 + +## v0.19.0 (2024-07-16) + +* Renamed `master` branch to `main`. +* *(bridgev2)* Added more features. +* *(crypto)* Fixed bug with copying `m.relates_to` from wire content to + decrypted content. +* *(mediaproxy)* Added module for implementing simple media repos that proxy + requests elsewhere. +* *(client)* Changed `Members()` to automatically parse event content for all + returned events. +* *(bridge)* Added `/register` call if `/versions` fails with `M_FORBIDDEN`. +* *(crypto)* Fixed `DecryptMegolmEvent` sometimes calling database without + transaction by using the non-context version of `ResolveTrust`. +* *(crypto/attachment)* Implemented `io.Seeker` in `EncryptStream` to allow + using it in retriable HTTP requests. +* *(event)* Added helper method to add user ID to a `Mentions` object. +* *(event)* Fixed default power level for invites + (thanks to [@rudis] in [#250]). +* *(client)* Fixed incorrect warning log in `State()` when state store returns + no error (thanks to [@rudis] in [#249]). +* *(crypto/verificationhelper)* Fixed deadlock when ignoring unknown + cancellation events (thanks to [@rudis] in [#247]). + +[@rudis]: https://github.com/rudis +[#250]: https://github.com/mautrix/go/pull/250 +[#249]: https://github.com/mautrix/go/pull/249 +[#247]: https://github.com/mautrix/go/pull/247 + +### beta.1 (2024-06-16) + +* *(bridgev2)* Added experimental high-level bridge framework. +* *(hicli)* Added experimental high-level client framework. +* **Slightly breaking changes** + * *(crypto)* Added room ID and first known index parameters to + `SessionReceived` callback. + * *(crypto)* Changed `ImportRoomKeyFromBackup` to return the imported + session. + * *(client)* Added `error` parameter to `ResponseHook`. + * *(client)* Changed `Download` to return entire response instead of just an + `io.Reader`. +* *(crypto)* Changed initial olm device sharing to save keys before sharing to + ensure keys aren't accidentally regenerated in case the request fails. +* *(crypto)* Changed `EncryptMegolmEvent` and `ShareGroupSession` to return + more errors instead of only logging and ignoring them. +* *(crypto)* Added option to completely disable megolm ratchet tracking. + * The tracking is meant for bots and bridges which may want to delete old + keys, but for normal clients it's just unnecessary overhead. +* *(crypto)* Changed Megolm session storage methods in `Store` to not take + sender key as parameter. + * This causes a breaking change to the layout of the `MemoryStore` struct. + Using MemoryStore in production is not recommended. +* *(crypto)* Changed `DecryptMegolmEvent` to copy `m.relates_to` in the raw + content too instead of only in the parsed struct. +* *(crypto)* Exported function to parse megolm message index from raw + ciphertext bytes. +* *(crypto/sqlstore)* Fixed schema of `crypto_secrets` table to include + account ID. +* *(crypto/verificationhelper)* Fixed more bugs. +* *(client)* Added `UpdateRequestOnRetry` hook which is called immediately + before retrying a normal HTTP request. +* *(client)* Added support for MSC3916 media download endpoint. + * Support is automatically detected from spec versions. The `SpecVersions` + property can either be filled manually, or `Versions` can be called to + automatically populate the field with the response. +* *(event)* Added constants for known room versions. + +## v0.18.1 (2024-04-16) + +* *(format)* Added a `context.Context` field to HTMLParser's Context struct. +* *(bridge)* Added support for handling join rules, knocks, invites and bans + (thanks to [@maltee1] in [#193] and [#204]). +* *(crypto)* Changed forwarded room key handling to only accept keys with a + lower first known index than the existing session if there is one. +* *(crypto)* Changed key backup restore to assume own device list is up to date + to avoid re-requesting device list for every deleted device that has signed + key backup. +* *(crypto)* Fixed memory cache not being invalidated when storing own + cross-signing keys + +[@maltee1]: https://github.com/maltee1 +[#193]: https://github.com/mautrix/go/pull/193 +[#204]: https://github.com/mautrix/go/pull/204 + ## v0.18.0 (2024-03-16) * **Breaking change *(client, bridge, appservice)*** Dropped support for diff --git a/README.md b/README.md index d45860e7..b1a2edf8 100644 --- a/README.md +++ b/README.md @@ -1,11 +1,12 @@ # mautrix-go [![GoDoc](https://pkg.go.dev/badge/maunium.net/go/mautrix)](https://pkg.go.dev/maunium.net/go/mautrix) -A Golang Matrix framework. Used by [gomuks](https://matrix.org/docs/projects/client/gomuks), -[go-neb](https://github.com/matrix-org/go-neb), [mautrix-whatsapp](https://github.com/mautrix/whatsapp) +A Golang Matrix framework. Used by [gomuks](https://gomuks.app), +[go-neb](https://github.com/matrix-org/go-neb), +[mautrix-whatsapp](https://github.com/mautrix/whatsapp) and others. -Matrix room: [`#maunium:maunium.net`](https://matrix.to/#/#maunium:maunium.net) +Matrix room: [`#go:maunium.net`](https://matrix.to/#/#go:maunium.net) This project is based on [matrix-org/gomatrix](https://github.com/matrix-org/gomatrix). The original project is licensed under [Apache 2.0](https://github.com/matrix-org/gomatrix/blob/master/LICENSE). @@ -13,7 +14,11 @@ The original project is licensed under [Apache 2.0](https://github.com/matrix-or In addition to the basic client API features the original project has, this framework also has: * Appservice support (Intent API like mautrix-python, room state storage, etc) -* End-to-end encryption support (incl. interactive SAS verification) +* End-to-end encryption support (incl. key backup, cross-signing, interactive verification, etc) +* High-level module for building puppeting bridges +* Partial federation module (making requests, PDU processing and event authorization) +* A media proxy server which can be used to expose anything as a Matrix media repo +* Wrapper functions for the Synapse admin API * Structs for parsing event content * Helpers for parsing and generating Matrix HTML * Helpers for handling push rules diff --git a/appservice/appservice.go b/appservice/appservice.go index ef9c6236..d7037ef6 100644 --- a/appservice/appservice.go +++ b/appservice/appservice.go @@ -1,4 +1,4 @@ -// Copyright (c) 2023 Tulir Asokan +// Copyright (c) 2025 Tulir Asokan // // This Source Code Form is subject to the terms of the Mozilla Public // License, v. 2.0. If a copy of the MPL was not distributed with this @@ -19,8 +19,7 @@ import ( "syscall" "time" - "github.com/gorilla/mux" - "github.com/gorilla/websocket" + "github.com/coder/websocket" "github.com/rs/zerolog" "golang.org/x/net/publicsuffix" "gopkg.in/yaml.v3" @@ -32,7 +31,7 @@ import ( // EventChannelSize is the size for the Events channel in Appservice instances. var EventChannelSize = 64 -var OTKChannelSize = 4 +var OTKChannelSize = 64 // Create creates a blank appservice instance. func Create() *AppService { @@ -43,7 +42,7 @@ func Create() *AppService { intents: make(map[id.UserID]*IntentAPI), HTTPClient: &http.Client{Timeout: 180 * time.Second, Jar: jar}, StateStore: mautrix.NewMemoryStateStore().(StateStore), - Router: mux.NewRouter(), + Router: http.NewServeMux(), UserAgent: mautrix.DefaultUserAgent, txnIDC: NewTransactionIDCache(128), Live: true, @@ -56,15 +55,17 @@ func Create() *AppService { DeviceLists: make(chan *mautrix.DeviceLists, EventChannelSize), QueryHandler: &QueryHandlerStub{}, + SpecVersions: &mautrix.RespVersions{}, + DefaultHTTPRetries: 4, } - as.Router.HandleFunc("/_matrix/app/v1/transactions/{txnID}", as.PutTransaction).Methods(http.MethodPut) - as.Router.HandleFunc("/_matrix/app/v1/rooms/{roomAlias}", as.GetRoom).Methods(http.MethodGet) - as.Router.HandleFunc("/_matrix/app/v1/users/{userID}", as.GetUser).Methods(http.MethodGet) - as.Router.HandleFunc("/_matrix/app/v1/ping", as.PostPing).Methods(http.MethodPost) - as.Router.HandleFunc("/_matrix/mau/live", as.GetLive).Methods(http.MethodGet) - as.Router.HandleFunc("/_matrix/mau/ready", as.GetReady).Methods(http.MethodGet) + as.Router.HandleFunc("PUT /_matrix/app/v1/transactions/{txnID}", as.PutTransaction) + as.Router.HandleFunc("GET /_matrix/app/v1/rooms/{roomAlias}", as.GetRoom) + as.Router.HandleFunc("GET /_matrix/app/v1/users/{userID}", as.GetUser) + as.Router.HandleFunc("POST /_matrix/app/v1/ping", as.PostPing) + as.Router.HandleFunc("GET /_matrix/mau/live", as.GetLive) + as.Router.HandleFunc("GET /_matrix/mau/ready", as.GetReady) return as } @@ -112,13 +113,13 @@ var _ StateStore = (*mautrix.MemoryStateStore)(nil) // QueryHandler handles room alias and user ID queries from the homeserver. type QueryHandler interface { - QueryAlias(alias string) bool + QueryAlias(alias id.RoomAlias) bool QueryUser(userID id.UserID) bool } type QueryHandlerStub struct{} -func (qh *QueryHandlerStub) QueryAlias(alias string) bool { +func (qh *QueryHandlerStub) QueryAlias(alias id.RoomAlias) bool { return false } @@ -126,7 +127,7 @@ func (qh *QueryHandlerStub) QueryUser(userID id.UserID) bool { return false } -type WebsocketHandler func(WebsocketCommand) (ok bool, data interface{}) +type WebsocketHandler func(WebsocketCommand) (ok bool, data any) type StateStore interface { mautrix.StateStore @@ -158,12 +159,13 @@ type AppService struct { QueryHandler QueryHandler StateStore StateStore - Router *mux.Router - UserAgent string - server *http.Server - HTTPClient *http.Client - botClient *mautrix.Client - botIntent *IntentAPI + Router *http.ServeMux + UserAgent string + server *http.Server + HTTPClient *http.Client + botClient *mautrix.Client + botIntent *IntentAPI + SpecVersions *mautrix.RespVersions DefaultHTTPRetries int @@ -176,7 +178,6 @@ type AppService struct { intentsLock sync.RWMutex ws *websocket.Conn - wsWriteLock sync.Mutex StopWebsocket func(error) websocketHandlers map[string]WebsocketHandler websocketHandlersLock sync.RWMutex @@ -193,6 +194,7 @@ type AppService struct { } const DoublePuppetKey = "fi.mau.double_puppet_source" +const DoublePuppetTSKey = "fi.mau.double_puppet_ts" func getDefaultProcessID() string { pid := syscall.Getpid() @@ -220,9 +222,6 @@ type HostConfig struct { Hostname string `yaml:"hostname"` // Port is required when Hostname is an IP address, optional for unix sockets Port uint16 `yaml:"port"` - - TLSKey string `yaml:"tls_key,omitempty"` - TLSCert string `yaml:"tls_cert,omitempty"` } // Address gets the whole address of the Appservice. @@ -335,7 +334,7 @@ func (as *AppService) SetHomeserverURL(homeserverURL string) error { } else if as.hsURLForClient.Scheme == "" { as.hsURLForClient.Scheme = "https" } - as.hsURLForClient.RawPath = parsedURL.EscapedPath() + as.hsURLForClient.RawPath = as.hsURLForClient.EscapedPath() jar, _ := cookiejar.New(&cookiejar.Options{PublicSuffixList: publicsuffix.List}) as.HTTPClient = &http.Client{Timeout: 180 * time.Second, Jar: jar} @@ -361,9 +360,10 @@ func (as *AppService) NewMautrixClient(userID id.UserID) *mautrix.Client { AccessToken: as.Registration.AppToken, UserAgent: as.UserAgent, StateStore: as.StateStore, - Log: as.Log.With().Str("as_user_id", userID.String()).Logger(), + Log: as.Log.With().Stringer("as_user_id", userID).Logger(), Client: as.HTTPClient, DefaultHTTPRetries: as.DefaultHTTPRetries, + SpecVersions: as.SpecVersions, } } diff --git a/appservice/http.go b/appservice/http.go index 38bcecf8..27ce6288 100644 --- a/appservice/http.go +++ b/appservice/http.go @@ -1,4 +1,4 @@ -// Copyright (c) 2023 Tulir Asokan +// Copyright (c) 2025 Tulir Asokan // // This Source Code Form is subject to the terms of the Mozilla Public // License, v. 2.0. If a copy of the MPL was not distributed with this @@ -17,8 +17,9 @@ import ( "syscall" "time" - "github.com/gorilla/mux" "github.com/rs/zerolog" + "go.mau.fi/util/exhttp" + "go.mau.fi/util/exstrings" "maunium.net/go/mautrix" "maunium.net/go/mautrix/event" @@ -59,13 +60,8 @@ func (as *AppService) listenUnix() error { } func (as *AppService) listenTCP() error { - if len(as.Host.TLSCert) == 0 || len(as.Host.TLSKey) == 0 { - as.Log.Info().Str("address", as.server.Addr).Msg("Starting HTTP listener") - return as.server.ListenAndServe() - } else { - as.Log.Info().Str("address", as.server.Addr).Msg("Starting HTTP listener with TLS") - return as.server.ListenAndServeTLS(as.Host.TLSCert, as.Host.TLSKey) - } + as.Log.Info().Str("address", as.server.Addr).Msg("Starting HTTP listener") + return as.server.ListenAndServe() } func (as *AppService) Stop() { @@ -83,17 +79,9 @@ func (as *AppService) Stop() { func (as *AppService) CheckServerToken(w http.ResponseWriter, r *http.Request) (isValid bool) { authHeader := r.Header.Get("Authorization") if !strings.HasPrefix(authHeader, "Bearer ") { - Error{ - ErrorCode: ErrUnknownToken, - HTTPStatus: http.StatusForbidden, - Message: "Missing access token", - }.Write(w) - } else if authHeader[len("Bearer "):] != as.Registration.ServerToken { - Error{ - ErrorCode: ErrUnknownToken, - HTTPStatus: http.StatusForbidden, - Message: "Incorrect access token", - }.Write(w) + mautrix.MMissingToken.WithMessage("Missing access token").Write(w) + } else if !exstrings.ConstantTimeEqual(authHeader[len("Bearer "):], as.Registration.ServerToken) { + mautrix.MUnknownToken.WithMessage("Invalid access token").Write(w) } else { isValid = true } @@ -106,24 +94,15 @@ func (as *AppService) PutTransaction(w http.ResponseWriter, r *http.Request) { return } - vars := mux.Vars(r) - txnID := vars["txnID"] + txnID := r.PathValue("txnID") if len(txnID) == 0 { - Error{ - ErrorCode: ErrNoTransactionID, - HTTPStatus: http.StatusBadRequest, - Message: "Missing transaction ID", - }.Write(w) + mautrix.MInvalidParam.WithMessage("Missing transaction ID").Write(w) return } defer r.Body.Close() body, err := io.ReadAll(r.Body) if err != nil || len(body) == 0 { - Error{ - ErrorCode: ErrNotJSON, - HTTPStatus: http.StatusBadRequest, - Message: "Missing request body", - }.Write(w) + mautrix.MNotJSON.WithMessage("Failed to read response body").Write(w) return } log := as.Log.With().Str("transaction_id", txnID).Logger() @@ -132,7 +111,7 @@ func (as *AppService) PutTransaction(w http.ResponseWriter, r *http.Request) { ctx = log.WithContext(ctx) if as.txnIDC.IsProcessed(txnID) { // Duplicate transaction ID: no-op - WriteBlankOK(w) + exhttp.WriteEmptyJSONResponse(w, http.StatusOK) log.Debug().Msg("Ignoring duplicate transaction") return } @@ -141,14 +120,10 @@ func (as *AppService) PutTransaction(w http.ResponseWriter, r *http.Request) { err = json.Unmarshal(body, &txn) if err != nil { log.Error().Err(err).Msg("Failed to parse transaction content") - Error{ - ErrorCode: ErrBadJSON, - HTTPStatus: http.StatusBadRequest, - Message: "Failed to parse body JSON", - }.Write(w) + mautrix.MBadJSON.WithMessage("Failed to parse transaction content").Write(w) } else { as.handleTransaction(ctx, txnID, &txn) - WriteBlankOK(w) + exhttp.WriteEmptyJSONResponse(w, http.StatusOK) } } @@ -211,15 +186,22 @@ func (as *AppService) handleEvents(ctx context.Context, evts []*event.Event, def for _, evt := range evts { evt.Mautrix.ReceivedAt = time.Now() if defaultTypeClass != event.UnknownEventType { + if defaultTypeClass == event.EphemeralEventType { + evt.Mautrix.EventSource = event.SourceEphemeral + } else if defaultTypeClass == event.ToDeviceEventType { + evt.Mautrix.EventSource = event.SourceToDevice + } evt.Type.Class = defaultTypeClass } else if evt.StateKey != nil { + evt.Mautrix.EventSource = event.SourceTimeline & event.SourceJoin evt.Type.Class = event.StateEventType } else { + evt.Mautrix.EventSource = event.SourceTimeline evt.Type.Class = event.MessageEventType } err := evt.Content.ParseRaw(evt.Type) if errors.Is(err, event.ErrUnsupportedContentType) { - log.Debug().Str("event_id", evt.ID.String()).Msg("Not parsing content of unsupported event") + log.Debug().Stringer("event_id", evt.ID).Msg("Not parsing content of unsupported event") } else if err != nil { log.Warn().Err(err). Str("event_id", evt.ID.String()). @@ -256,16 +238,12 @@ func (as *AppService) GetRoom(w http.ResponseWriter, r *http.Request) { return } - vars := mux.Vars(r) - roomAlias := vars["roomAlias"] + roomAlias := id.RoomAlias(r.PathValue("roomAlias")) ok := as.QueryHandler.QueryAlias(roomAlias) if ok { - WriteBlankOK(w) + exhttp.WriteEmptyJSONResponse(w, http.StatusOK) } else { - Error{ - ErrorCode: ErrUnknown, - HTTPStatus: http.StatusNotFound, - }.Write(w) + mautrix.MNotFound.WithMessage("Alias not found").Write(w) } } @@ -275,16 +253,12 @@ func (as *AppService) GetUser(w http.ResponseWriter, r *http.Request) { return } - vars := mux.Vars(r) - userID := id.UserID(vars["userID"]) + userID := id.UserID(r.PathValue("userID")) ok := as.QueryHandler.QueryUser(userID) if ok { - WriteBlankOK(w) + exhttp.WriteEmptyJSONResponse(w, http.StatusOK) } else { - Error{ - ErrorCode: ErrUnknown, - HTTPStatus: http.StatusNotFound, - }.Write(w) + mautrix.MNotFound.WithMessage("User not found").Write(w) } } @@ -294,11 +268,7 @@ func (as *AppService) PostPing(w http.ResponseWriter, r *http.Request) { } body, err := io.ReadAll(r.Body) if err != nil || len(body) == 0 || !json.Valid(body) { - Error{ - ErrorCode: ErrNotJSON, - HTTPStatus: http.StatusBadRequest, - Message: "Missing request body", - }.Write(w) + mautrix.MNotJSON.WithMessage("Invalid or missing request body").Write(w) return } @@ -306,27 +276,21 @@ func (as *AppService) PostPing(w http.ResponseWriter, r *http.Request) { _ = json.Unmarshal(body, &txn) as.Log.Debug().Str("txn_id", txn.TxnID).Msg("Received ping from homeserver") - w.Header().Add("Content-Type", "application/json") - w.WriteHeader(http.StatusOK) - w.Write([]byte("{}")) + exhttp.WriteEmptyJSONResponse(w, http.StatusOK) } func (as *AppService) GetLive(w http.ResponseWriter, r *http.Request) { - w.Header().Add("Content-Type", "application/json") if as.Live { - w.WriteHeader(http.StatusOK) + exhttp.WriteEmptyJSONResponse(w, http.StatusOK) } else { - w.WriteHeader(http.StatusInternalServerError) + exhttp.WriteEmptyJSONResponse(w, http.StatusInternalServerError) } - w.Write([]byte("{}")) } func (as *AppService) GetReady(w http.ResponseWriter, r *http.Request) { - w.Header().Add("Content-Type", "application/json") if as.Ready { - w.WriteHeader(http.StatusOK) + exhttp.WriteEmptyJSONResponse(w, http.StatusOK) } else { - w.WriteHeader(http.StatusInternalServerError) + exhttp.WriteEmptyJSONResponse(w, http.StatusInternalServerError) } - w.Write([]byte("{}")) } diff --git a/appservice/intent.go b/appservice/intent.go index e091582a..5d43f190 100644 --- a/appservice/intent.go +++ b/appservice/intent.go @@ -51,7 +51,7 @@ func (as *AppService) NewIntentAPI(localpart string) *IntentAPI { } func (intent *IntentAPI) Register(ctx context.Context) error { - _, err := intent.Client.MakeRequest(ctx, http.MethodPost, intent.BuildClientURL("v3", "register"), &mautrix.ReqRegister{ + _, err := intent.Client.MakeRequest(ctx, http.MethodPost, intent.BuildClientURL("v3", "register"), &mautrix.ReqRegister[any]{ Username: intent.Localpart, Type: mautrix.AuthTypeAppservice, InhibitLogin: true, @@ -86,6 +86,7 @@ func (intent *IntentAPI) EnsureRegistered(ctx context.Context) error { type EnsureJoinedParams struct { IgnoreCache bool BotOverride *mautrix.Client + Via []string } func (intent *IntentAPI) EnsureJoined(ctx context.Context, roomID id.RoomID, extra ...EnsureJoinedParams) error { @@ -99,11 +100,17 @@ func (intent *IntentAPI) EnsureJoined(ctx context.Context, roomID id.RoomID, ext return nil } - if err := intent.EnsureRegistered(ctx); err != nil { + err := intent.EnsureRegistered(ctx) + if err != nil { return fmt.Errorf("failed to ensure joined: %w", err) } - resp, err := intent.JoinRoomByID(ctx, roomID) + var resp *mautrix.RespJoinRoom + if len(params.Via) > 0 { + resp, err = intent.JoinRoom(ctx, roomID.String(), &mautrix.ReqJoinRoom{Via: params.Via}) + } else { + resp, err = intent.JoinRoomByID(ctx, roomID) + } if err != nil { bot := intent.bot if params.BotOverride != nil { @@ -112,9 +119,21 @@ func (intent *IntentAPI) EnsureJoined(ctx context.Context, roomID id.RoomID, ext if !errors.Is(err, mautrix.MForbidden) || bot == nil { return fmt.Errorf("failed to ensure joined: %w", err) } - _, inviteErr := bot.InviteUser(ctx, roomID, &mautrix.ReqInviteUser{ - UserID: intent.UserID, - }) + var inviteErr error + if intent.IsCustomPuppet { + _, inviteErr = bot.SendStateEvent(ctx, roomID, event.StateMember, intent.UserID.String(), &event.Content{ + Raw: map[string]any{ + "fi.mau.will_auto_accept": true, + }, + Parsed: &event.MemberEventContent{ + Membership: event.MembershipInvite, + }, + }) + } else { + _, inviteErr = bot.InviteUser(ctx, roomID, &mautrix.ReqInviteUser{ + UserID: intent.UserID, + }) + } if inviteErr != nil { return fmt.Errorf("failed to invite in ensure joined: %w", inviteErr) } @@ -130,75 +149,110 @@ func (intent *IntentAPI) EnsureJoined(ctx context.Context, roomID id.RoomID, ext return nil } -func (intent *IntentAPI) AddDoublePuppetValue(into interface{}) interface{} { - if !intent.IsCustomPuppet || intent.as.DoublePuppetValue == "" { +func (intent *IntentAPI) IsDoublePuppet() bool { + return intent.IsCustomPuppet && intent.as.DoublePuppetValue != "" +} + +func (intent *IntentAPI) AddDoublePuppetValue(into any) any { + return intent.AddDoublePuppetValueWithTS(into, 0) +} + +func (intent *IntentAPI) AddDoublePuppetValueWithTS(into any, ts int64) any { + if !intent.IsDoublePuppet() { return into } + // Only use ts deduplication feature with appservice double puppeting + if !intent.SetAppServiceUserID { + ts = 0 + } switch val := into.(type) { - case *map[string]interface{}: + case *map[string]any: if *val == nil { - valNonPtr := make(map[string]interface{}) + valNonPtr := make(map[string]any) *val = valNonPtr } (*val)[DoublePuppetKey] = intent.as.DoublePuppetValue + if ts != 0 { + (*val)[DoublePuppetTSKey] = ts + } return val - case map[string]interface{}: + case map[string]any: val[DoublePuppetKey] = intent.as.DoublePuppetValue + if ts != 0 { + val[DoublePuppetTSKey] = ts + } return val case *event.Content: if val.Raw == nil { - val.Raw = make(map[string]interface{}) + val.Raw = make(map[string]any) } val.Raw[DoublePuppetKey] = intent.as.DoublePuppetValue + if ts != 0 { + val.Raw[DoublePuppetTSKey] = ts + } return val case event.Content: if val.Raw == nil { - val.Raw = make(map[string]interface{}) + val.Raw = make(map[string]any) } val.Raw[DoublePuppetKey] = intent.as.DoublePuppetValue + if ts != 0 { + val.Raw[DoublePuppetTSKey] = ts + } return val default: - return &event.Content{ - Raw: map[string]interface{}{ + content := &event.Content{ + Raw: map[string]any{ DoublePuppetKey: intent.as.DoublePuppetValue, }, Parsed: val, } + if ts != 0 { + content.Raw[DoublePuppetTSKey] = ts + } + return content } } -func (intent *IntentAPI) SendMessageEvent(ctx context.Context, roomID id.RoomID, eventType event.Type, contentJSON interface{}) (*mautrix.RespSendEvent, error) { +func (intent *IntentAPI) SendMessageEvent(ctx context.Context, roomID id.RoomID, eventType event.Type, contentJSON any, extra ...mautrix.ReqSendEvent) (*mautrix.RespSendEvent, error) { if err := intent.EnsureJoined(ctx, roomID); err != nil { return nil, err } contentJSON = intent.AddDoublePuppetValue(contentJSON) - return intent.Client.SendMessageEvent(ctx, roomID, eventType, contentJSON) + return intent.Client.SendMessageEvent(ctx, roomID, eventType, contentJSON, extra...) } +func (intent *IntentAPI) BeeperSendEphemeralEvent(ctx context.Context, roomID id.RoomID, eventType event.Type, contentJSON any, extra ...mautrix.ReqSendEvent) (*mautrix.RespSendEvent, error) { + if err := intent.EnsureJoined(ctx, roomID); err != nil { + return nil, err + } + if !intent.SpecVersions.Supports(mautrix.BeeperFeatureEphemeralEvents) { + return nil, mautrix.MUnrecognized.WithMessage("Homeserver does not advertise com.beeper.ephemeral support") + } + contentJSON = intent.AddDoublePuppetValue(contentJSON) + return intent.Client.BeeperSendEphemeralEvent(ctx, roomID, eventType, contentJSON, extra...) +} + +// Deprecated: use SendMessageEvent with mautrix.ReqSendEvent.Timestamp instead func (intent *IntentAPI) SendMassagedMessageEvent(ctx context.Context, roomID id.RoomID, eventType event.Type, contentJSON interface{}, ts int64) (*mautrix.RespSendEvent, error) { - if err := intent.EnsureJoined(ctx, roomID); err != nil { - return nil, err - } - contentJSON = intent.AddDoublePuppetValue(contentJSON) - return intent.Client.SendMessageEvent(ctx, roomID, eventType, contentJSON, mautrix.ReqSendEvent{Timestamp: ts}) + return intent.SendMessageEvent(ctx, roomID, eventType, contentJSON, mautrix.ReqSendEvent{Timestamp: ts}) } -func (intent *IntentAPI) SendStateEvent(ctx context.Context, roomID id.RoomID, eventType event.Type, stateKey string, contentJSON interface{}) (*mautrix.RespSendEvent, error) { +func (intent *IntentAPI) SendStateEvent(ctx context.Context, roomID id.RoomID, eventType event.Type, stateKey string, contentJSON any, extra ...mautrix.ReqSendEvent) (*mautrix.RespSendEvent, error) { if eventType != event.StateMember || stateKey != string(intent.UserID) { if err := intent.EnsureJoined(ctx, roomID); err != nil { return nil, err } - } - contentJSON = intent.AddDoublePuppetValue(contentJSON) - return intent.Client.SendStateEvent(ctx, roomID, eventType, stateKey, contentJSON) -} - -func (intent *IntentAPI) SendMassagedStateEvent(ctx context.Context, roomID id.RoomID, eventType event.Type, stateKey string, contentJSON interface{}, ts int64) (*mautrix.RespSendEvent, error) { - if err := intent.EnsureJoined(ctx, roomID); err != nil { + } else if err := intent.EnsureRegistered(ctx); err != nil { return nil, err } contentJSON = intent.AddDoublePuppetValue(contentJSON) - return intent.Client.SendMassagedStateEvent(ctx, roomID, eventType, stateKey, contentJSON, ts) + return intent.Client.SendStateEvent(ctx, roomID, eventType, stateKey, contentJSON, extra...) +} + +// Deprecated: use SendStateEvent with mautrix.ReqSendEvent.Timestamp instead +func (intent *IntentAPI) SendMassagedStateEvent(ctx context.Context, roomID id.RoomID, eventType event.Type, stateKey string, contentJSON interface{}, ts int64) (*mautrix.RespSendEvent, error) { + return intent.SendStateEvent(ctx, roomID, eventType, stateKey, contentJSON, mautrix.ReqSendEvent{Timestamp: ts}) } func (intent *IntentAPI) StateEvent(ctx context.Context, roomID id.RoomID, eventType event.Type, stateKey string, outContent interface{}) error { @@ -257,7 +311,7 @@ func (intent *IntentAPI) SendCustomMembershipEvent(ctx context.Context, roomID i func (intent *IntentAPI) JoinRoomByID(ctx context.Context, roomID id.RoomID, extraContent ...map[string]interface{}) (resp *mautrix.RespJoinRoom, err error) { if intent.IsCustomPuppet || len(extraContent) > 0 { _, err = intent.SendCustomMembershipEvent(ctx, roomID, intent.UserID, event.MembershipJoin, "", extraContent...) - return &mautrix.RespJoinRoom{}, err + return &mautrix.RespJoinRoom{RoomID: roomID}, err } return intent.Client.JoinRoomByID(ctx, roomID) } @@ -326,6 +380,24 @@ func (intent *IntentAPI) Member(ctx context.Context, roomID id.RoomID, userID id return member } +func (intent *IntentAPI) FillPowerLevelCreateEvent(ctx context.Context, roomID id.RoomID, pl *event.PowerLevelsEventContent) error { + if pl.CreateEvent != nil { + return nil + } + var err error + pl.CreateEvent, err = intent.StateStore.GetCreate(ctx, roomID) + if err != nil { + return fmt.Errorf("failed to get create event from cache: %w", err) + } else if pl.CreateEvent != nil { + return nil + } + pl.CreateEvent, err = intent.FullStateEvent(ctx, roomID, event.StateCreate, "") + if err != nil { + return fmt.Errorf("failed to get create event from server: %w", err) + } + return nil +} + func (intent *IntentAPI) PowerLevels(ctx context.Context, roomID id.RoomID) (pl *event.PowerLevelsEventContent, err error) { pl, err = intent.as.StateStore.GetPowerLevels(ctx, roomID) if err != nil { @@ -335,6 +407,12 @@ func (intent *IntentAPI) PowerLevels(ctx context.Context, roomID id.RoomID) (pl if pl == nil { pl = &event.PowerLevelsEventContent{} err = intent.StateEvent(ctx, roomID, event.StatePowerLevels, "", pl) + if err != nil { + return + } + } + if pl.CreateEvent == nil { + pl.CreateEvent, err = intent.FullStateEvent(ctx, roomID, event.StateCreate, "") } return } @@ -349,8 +427,7 @@ func (intent *IntentAPI) SetPowerLevel(ctx context.Context, roomID id.RoomID, us return nil, err } - if pl.GetUserLevel(userID) != level { - pl.SetUserLevel(userID, level) + if pl.EnsureUserLevelAs(intent.UserID, userID, level) { return intent.SendStateEvent(ctx, roomID, event.StatePowerLevels, "", &pl) } return nil, nil @@ -400,6 +477,20 @@ func (intent *IntentAPI) SetRoomTopic(ctx context.Context, roomID id.RoomID, top }) } +func (intent *IntentAPI) UploadMedia(ctx context.Context, data mautrix.ReqUploadMedia) (*mautrix.RespMediaUpload, error) { + if err := intent.EnsureRegistered(ctx); err != nil { + return nil, err + } + return intent.Client.UploadMedia(ctx, data) +} + +func (intent *IntentAPI) UploadAsync(ctx context.Context, data mautrix.ReqUploadMedia) (*mautrix.RespCreateMXC, error) { + if err := intent.EnsureRegistered(ctx); err != nil { + return nil, err + } + return intent.Client.UploadAsync(ctx, data) +} + func (intent *IntentAPI) SetDisplayName(ctx context.Context, displayName string) error { if err := intent.EnsureRegistered(ctx); err != nil { return err @@ -425,11 +516,11 @@ func (intent *IntentAPI) SetAvatarURL(ctx context.Context, avatarURL id.ContentU // No need to update return nil } - if !avatarURL.IsEmpty() { + if !avatarURL.IsEmpty() && !intent.SpecVersions.Supports(mautrix.BeeperFeatureHungry) { // Some homeservers require the avatar to be downloaded before setting it - body, _ := intent.Client.Download(ctx, avatarURL) - if body != nil { - _ = body.Close() + resp, _ := intent.Download(ctx, avatarURL) + if resp != nil { + _ = resp.Body.Close() } } return intent.Client.SetAvatarURL(ctx, avatarURL) diff --git a/appservice/ping.go b/appservice/ping.go new file mode 100644 index 00000000..774ec423 --- /dev/null +++ b/appservice/ping.go @@ -0,0 +1,68 @@ +// Copyright (c) 2025 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package appservice + +import ( + "context" + "encoding/json" + "errors" + "os" + "strings" + "time" + + "github.com/rs/zerolog" + + "maunium.net/go/mautrix" +) + +func (intent *IntentAPI) EnsureAppserviceConnection(ctx context.Context) { + var pingResp *mautrix.RespAppservicePing + var txnID string + var retryCount int + var err error + const maxRetries = 6 + for { + txnID = intent.TxnID() + pingResp, err = intent.AppservicePing(ctx, intent.as.Registration.ID, txnID) + if err == nil { + break + } + var httpErr mautrix.HTTPError + var pingErrBody string + if errors.As(err, &httpErr) && httpErr.RespError != nil { + if val, ok := httpErr.RespError.ExtraData["body"].(string); ok { + pingErrBody = strings.TrimSpace(val) + } + } + outOfRetries := retryCount >= maxRetries + level := zerolog.ErrorLevel + if outOfRetries { + level = zerolog.FatalLevel + } + evt := zerolog.Ctx(ctx).WithLevel(level).Err(err).Str("txn_id", txnID) + if pingErrBody != "" { + bodyBytes := []byte(pingErrBody) + if json.Valid(bodyBytes) { + evt.RawJSON("body", bodyBytes) + } else { + evt.Str("body", pingErrBody) + } + } + if outOfRetries { + evt.Msg("Homeserver -> appservice connection is not working") + zerolog.Ctx(ctx).Info().Msg("See https://docs.mau.fi/faq/as-ping for more info") + os.Exit(13) + } + evt.Msg("Homeserver -> appservice connection is not working, retrying in 5 seconds...") + time.Sleep(5 * time.Second) + retryCount++ + } + zerolog.Ctx(ctx).Debug(). + Str("txn_id", txnID). + Int64("duration_ms", pingResp.DurationMS). + Msg("Homeserver -> appservice connection works") +} diff --git a/appservice/protocol.go b/appservice/protocol.go index 7a9891ef..7c493bcb 100644 --- a/appservice/protocol.go +++ b/appservice/protocol.go @@ -1,4 +1,4 @@ -// Copyright (c) 2023 Tulir Asokan +// Copyright (c) 2025 Tulir Asokan // // This Source Code Form is subject to the terms of the Mozilla Public // License, v. 2.0. If a copy of the MPL was not distributed with this @@ -7,9 +7,7 @@ package appservice import ( - "encoding/json" "fmt" - "net/http" "strings" "github.com/rs/zerolog" @@ -103,50 +101,3 @@ func (txn *Transaction) ContentString() string { // EventListener is a function that receives events. type EventListener func(evt *event.Event) - -// WriteBlankOK writes a blank OK message as a reply to a HTTP request. -func WriteBlankOK(w http.ResponseWriter) { - w.Header().Add("Content-Type", "application/json") - w.WriteHeader(http.StatusOK) - _, _ = w.Write([]byte("{}")) -} - -// Respond responds to a HTTP request with a JSON object. -func Respond(w http.ResponseWriter, data interface{}) error { - w.Header().Add("Content-Type", "application/json") - dataStr, err := json.Marshal(data) - if err != nil { - return err - } - _, err = w.Write(dataStr) - return err -} - -// Error represents a Matrix protocol error. -type Error struct { - HTTPStatus int `json:"-"` - ErrorCode ErrorCode `json:"errcode"` - Message string `json:"error"` -} - -func (err Error) Write(w http.ResponseWriter) { - w.Header().Add("Content-Type", "application/json") - w.WriteHeader(err.HTTPStatus) - _ = Respond(w, &err) -} - -// ErrorCode is the machine-readable code in an Error. -type ErrorCode string - -// Native ErrorCodes -const ( - ErrUnknownToken ErrorCode = "M_UNKNOWN_TOKEN" - ErrBadJSON ErrorCode = "M_BAD_JSON" - ErrNotJSON ErrorCode = "M_NOT_JSON" - ErrUnknown ErrorCode = "M_UNKNOWN" -) - -// Custom ErrorCodes -const ( - ErrNoTransactionID ErrorCode = "NET.MAUNIUM.NO_TRANSACTION_ID" -) diff --git a/appservice/registration.go b/appservice/registration.go index b11bd84b..54eff716 100644 --- a/appservice/registration.go +++ b/appservice/registration.go @@ -27,7 +27,9 @@ type Registration struct { Protocols []string `yaml:"protocols,omitempty" json:"protocols,omitempty"` SoruEphemeralEvents bool `yaml:"de.sorunome.msc2409.push_ephemeral,omitempty" json:"de.sorunome.msc2409.push_ephemeral,omitempty"` - EphemeralEvents bool `yaml:"push_ephemeral,omitempty" json:"push_ephemeral,omitempty"` + EphemeralEvents bool `yaml:"receive_ephemeral,omitempty" json:"receive_ephemeral,omitempty"` + MSC3202 bool `yaml:"org.matrix.msc3202,omitempty" json:"org.matrix.msc3202,omitempty"` + MSC4190 bool `yaml:"io.element.msc4190,omitempty" json:"io.element.msc4190,omitempty"` } // CreateRegistration creates a Registration with random appservice and homeserver tokens. diff --git a/appservice/websocket.go b/appservice/websocket.go index 598d70d1..ef65e65a 100644 --- a/appservice/websocket.go +++ b/appservice/websocket.go @@ -1,4 +1,4 @@ -// Copyright (c) 2023 Tulir Asokan +// Copyright (c) 2025 Tulir Asokan // // This Source Code Form is subject to the terms of the Mozilla Public // License, v. 2.0. If a copy of the MPL was not distributed with this @@ -11,26 +11,26 @@ import ( "encoding/json" "errors" "fmt" + "io" "net/http" "net/url" - "path/filepath" + "path" "strings" "sync" "sync/atomic" - "time" - "github.com/gorilla/websocket" + "github.com/coder/websocket" "github.com/rs/zerolog" "github.com/tidwall/gjson" "github.com/tidwall/sjson" + + "maunium.net/go/mautrix" ) type WebsocketRequest struct { - ReqID int `json:"id,omitempty"` - Command string `json:"command"` - Data interface{} `json:"data"` - - Deadline time.Duration `json:"-"` + ReqID int `json:"id,omitempty"` + Command string `json:"command"` + Data any `json:"data"` } type WebsocketCommand struct { @@ -41,7 +41,7 @@ type WebsocketCommand struct { Ctx context.Context `json:"-"` } -func (wsc *WebsocketCommand) MakeResponse(ok bool, data interface{}) *WebsocketRequest { +func (wsc *WebsocketCommand) MakeResponse(ok bool, data any) *WebsocketRequest { if wsc.ReqID == 0 || wsc.Command == "response" || wsc.Command == "error" { return nil } @@ -56,7 +56,7 @@ func (wsc *WebsocketCommand) MakeResponse(ok bool, data interface{}) *WebsocketR var prefixMessage string for unwrappedErr != nil { errorData, jsonErr = json.Marshal(unwrappedErr) - if errorData != nil && len(errorData) > 2 && jsonErr == nil { + if len(errorData) > 2 && jsonErr == nil { prefixMessage = strings.Replace(err.Error(), unwrappedErr.Error(), "", 1) prefixMessage = strings.TrimRight(prefixMessage, ": ") break @@ -98,8 +98,8 @@ type WebsocketMessage struct { } const ( - WebsocketCloseConnReplaced = 4001 - WebsocketCloseTxnNotAcknowledged = 4002 + WebsocketCloseConnReplaced websocket.StatusCode = 4001 + WebsocketCloseTxnNotAcknowledged websocket.StatusCode = 4002 ) type MeowWebsocketCloseCode string @@ -133,7 +133,7 @@ func (mwcc MeowWebsocketCloseCode) String() string { } type CloseCommand struct { - Code int `json:"-"` + Code websocket.StatusCode `json:"-"` Command string `json:"command"` Status MeowWebsocketCloseCode `json:"status"` } @@ -143,15 +143,15 @@ func (cc CloseCommand) Error() string { } func parseCloseError(err error) error { - closeError := &websocket.CloseError{} + var closeError websocket.CloseError if !errors.As(err, &closeError) { return err } var closeCommand CloseCommand closeCommand.Code = closeError.Code closeCommand.Command = "disconnect" - if len(closeError.Text) > 0 { - jsonErr := json.Unmarshal([]byte(closeError.Text), &closeCommand) + if len(closeError.Reason) > 0 { + jsonErr := json.Unmarshal([]byte(closeError.Reason), &closeCommand) if jsonErr != nil { return err } @@ -159,7 +159,7 @@ func parseCloseError(err error) error { if len(closeCommand.Status) == 0 { if closeCommand.Code == WebsocketCloseConnReplaced { closeCommand.Status = MeowConnectionReplaced - } else if closeCommand.Code == websocket.CloseServiceRestart { + } else if closeCommand.Code == websocket.StatusServiceRestart { closeCommand.Status = MeowServerShuttingDown } } @@ -170,20 +170,23 @@ func (as *AppService) HasWebsocket() bool { return as.ws != nil } -func (as *AppService) SendWebsocket(cmd *WebsocketRequest) error { +func (as *AppService) SendWebsocket(ctx context.Context, cmd *WebsocketRequest) error { ws := as.ws if cmd == nil { return nil } else if ws == nil { return ErrWebsocketNotConnected } - as.wsWriteLock.Lock() - defer as.wsWriteLock.Unlock() - if cmd.Deadline == 0 { - cmd.Deadline = 3 * time.Minute + wr, err := ws.Writer(ctx, websocket.MessageText) + if err != nil { + return err } - _ = ws.SetWriteDeadline(time.Now().Add(cmd.Deadline)) - return ws.WriteJSON(cmd) + err = json.NewEncoder(wr).Encode(cmd) + if err != nil { + _ = wr.Close() + return err + } + return wr.Close() } func (as *AppService) clearWebsocketResponseWaiters() { @@ -220,12 +223,12 @@ func (er *ErrorResponse) Error() string { return fmt.Sprintf("%s: %s", er.Code, er.Message) } -func (as *AppService) RequestWebsocket(ctx context.Context, cmd *WebsocketRequest, response interface{}) error { +func (as *AppService) RequestWebsocket(ctx context.Context, cmd *WebsocketRequest, response any) error { cmd.ReqID = int(atomic.AddInt32(&as.websocketRequestID, 1)) respChan := make(chan *WebsocketCommand, 1) as.addWebsocketResponseWaiter(cmd.ReqID, respChan) defer as.removeWebsocketResponseWaiter(cmd.ReqID, respChan) - err := as.SendWebsocket(cmd) + err := as.SendWebsocket(ctx, cmd) if err != nil { return err } @@ -254,7 +257,7 @@ func (as *AppService) RequestWebsocket(ctx context.Context, cmd *WebsocketReques } } -func (as *AppService) unknownCommandHandler(cmd WebsocketCommand) (bool, interface{}) { +func (as *AppService) unknownCommandHandler(cmd WebsocketCommand) (bool, any) { zerolog.Ctx(cmd.Ctx).Warn().Msg("No handler for websocket command") return false, fmt.Errorf("unknown request type") } @@ -278,14 +281,28 @@ func (as *AppService) defaultHandleWebsocketTransaction(ctx context.Context, msg return true, &WebsocketTransactionResponse{TxnID: msg.TxnID} } -func (as *AppService) consumeWebsocket(stopFunc func(error), ws *websocket.Conn) { +func (as *AppService) consumeWebsocket(ctx context.Context, stopFunc func(error), ws *websocket.Conn) { defer stopFunc(ErrWebsocketUnknownError) - ctx := context.Background() for { - var msg WebsocketMessage - err := ws.ReadJSON(&msg) + msgType, reader, err := ws.Reader(ctx) if err != nil { - as.Log.Debug().Err(err).Msg("Error reading from websocket") + as.Log.Debug().Err(err).Msg("Error getting reader from websocket") + stopFunc(parseCloseError(err)) + return + } else if msgType != websocket.MessageText { + as.Log.Debug().Msg("Ignoring non-text message from websocket") + continue + } + data, err := io.ReadAll(reader) + if err != nil { + as.Log.Debug().Err(err).Msg("Error reading data from websocket") + stopFunc(parseCloseError(err)) + return + } + var msg WebsocketMessage + err = json.Unmarshal(data, &msg) + if err != nil { + as.Log.Debug().Err(err).Msg("Error parsing JSON received from websocket") stopFunc(parseCloseError(err)) return } @@ -296,11 +313,11 @@ func (as *AppService) consumeWebsocket(stopFunc func(error), ws *websocket.Conn) with = with.Str("transaction_id", msg.TxnID) } log := with.Logger() - ctx = log.WithContext(ctx) + ctx := log.WithContext(ctx) if msg.Command == "" || msg.Command == "transaction" { ok, resp := as.WebsocketTransactionHandler(ctx, msg) go func() { - err := as.SendWebsocket(msg.MakeResponse(ok, resp)) + err := as.SendWebsocket(ctx, msg.MakeResponse(ok, resp)) if err != nil { log.Warn().Err(err).Msg("Failed to send response to websocket transaction") } else { @@ -332,7 +349,7 @@ func (as *AppService) consumeWebsocket(stopFunc func(error), ws *websocket.Conn) } go func() { okResp, data := handler(msg.WebsocketCommand) - err := as.SendWebsocket(msg.MakeResponse(okResp, data)) + err := as.SendWebsocket(ctx, msg.MakeResponse(okResp, data)) if err != nil { log.Error().Err(err).Msg("Failed to send response to websocket command") } else if okResp { @@ -345,7 +362,7 @@ func (as *AppService) consumeWebsocket(stopFunc func(error), ws *websocket.Conn) } } -func (as *AppService) StartWebsocket(baseURL string, onConnect func()) error { +func (as *AppService) StartWebsocket(ctx context.Context, baseURL string, onConnect func()) error { var parsed *url.URL if baseURL != "" { var err error @@ -357,26 +374,29 @@ func (as *AppService) StartWebsocket(baseURL string, onConnect func()) error { copiedURL := *as.hsURLForClient parsed = &copiedURL } - parsed.Path = filepath.Join(parsed.Path, "_matrix/client/unstable/fi.mau.as_sync") + parsed.Path = path.Join(parsed.Path, "_matrix/client/unstable/fi.mau.as_sync") if parsed.Scheme == "http" { parsed.Scheme = "ws" } else if parsed.Scheme == "https" { parsed.Scheme = "wss" } - ws, resp, err := websocket.DefaultDialer.Dial(parsed.String(), http.Header{ - "Authorization": []string{fmt.Sprintf("Bearer %s", as.Registration.AppToken)}, - "User-Agent": []string{as.BotClient().UserAgent}, + ws, resp, err := websocket.Dial(ctx, parsed.String(), &websocket.DialOptions{ + HTTPClient: as.HTTPClient, + HTTPHeader: http.Header{ + "Authorization": []string{fmt.Sprintf("Bearer %s", as.Registration.AppToken)}, + "User-Agent": []string{as.BotClient().UserAgent}, - "X-Mautrix-Process-ID": []string{as.ProcessID}, - "X-Mautrix-Websocket-Version": []string{"3"}, + "X-Mautrix-Process-ID": []string{as.ProcessID}, + "X-Mautrix-Websocket-Version": []string{"3"}, + }, }) if resp != nil && resp.StatusCode >= 400 { - var errResp Error + var errResp mautrix.RespError err = json.NewDecoder(resp.Body).Decode(&errResp) if err != nil { return fmt.Errorf("websocket request returned HTTP %d with non-JSON body", resp.StatusCode) } else { - return fmt.Errorf("websocket request returned %s (HTTP %d): %s", errResp.ErrorCode, resp.StatusCode, errResp.Message) + return fmt.Errorf("websocket request returned %s (HTTP %d): %s", errResp.ErrCode, resp.StatusCode, errResp.Err) } } else if err != nil { return fmt.Errorf("failed to open websocket: %w", err) @@ -399,12 +419,13 @@ func (as *AppService) StartWebsocket(baseURL string, onConnect func()) error { } }) } + ws.SetReadLimit(50 * 1024 * 1024) as.ws = ws as.StopWebsocket = stopFunc as.PrepareWebsocket() as.Log.Debug().Msg("Appservice transaction websocket opened") - go as.consumeWebsocket(stopFunc, ws) + go as.consumeWebsocket(ctx, stopFunc, ws) var onConnectDone atomic.Bool if onConnect != nil { @@ -426,12 +447,7 @@ func (as *AppService) StartWebsocket(baseURL string, onConnect func()) error { as.ws = nil } - _ = ws.SetWriteDeadline(time.Now().Add(3 * time.Second)) - err = ws.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseGoingAway, "")) - if err != nil && !errors.Is(err, websocket.ErrCloseSent) { - as.Log.Warn().Err(err).Msg("Error writing close message to websocket") - } - err = ws.Close() + err = ws.Close(websocket.StatusGoingAway, "") if err != nil { as.Log.Warn().Err(err).Msg("Error closing websocket") } diff --git a/appservice/wshttp.go b/appservice/wshttp.go index 40ceda9d..c5f6a672 100644 --- a/appservice/wshttp.go +++ b/appservice/wshttp.go @@ -61,6 +61,11 @@ func (as *AppService) WebsocketHTTPProxy(cmd WebsocketCommand) (bool, interface{ if err != nil { return false, fmt.Errorf("failed to create fake HTTP request: %w", err) } + httpReq.RequestURI = req.Path + if req.Query != "" { + httpReq.RequestURI += "?" + req.Query + } + httpReq.RemoteAddr = "websocket" httpReq.Header = req.Headers var resp HTTPProxyResponse diff --git a/bridge/bridge.go b/bridge/bridge.go deleted file mode 100644 index cfc31044..00000000 --- a/bridge/bridge.go +++ /dev/null @@ -1,897 +0,0 @@ -// Copyright (c) 2023 Tulir Asokan -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. - -package bridge - -import ( - "context" - "encoding/json" - "errors" - "fmt" - "net/url" - "os" - "os/signal" - "runtime" - "strings" - "sync" - "syscall" - "time" - - "github.com/lib/pq" - "github.com/mattn/go-sqlite3" - "github.com/rs/zerolog" - "go.mau.fi/util/configupgrade" - "go.mau.fi/util/dbutil" - _ "go.mau.fi/util/dbutil/litestream" - "go.mau.fi/util/exzerolog" - "gopkg.in/yaml.v3" - flag "maunium.net/go/mauflag" - - "maunium.net/go/mautrix" - "maunium.net/go/mautrix/appservice" - "maunium.net/go/mautrix/bridge/bridgeconfig" - "maunium.net/go/mautrix/bridge/status" - "maunium.net/go/mautrix/event" - "maunium.net/go/mautrix/id" - "maunium.net/go/mautrix/sqlstatestore" -) - -var configPath = flag.MakeFull("c", "config", "The path to your config file.", "config.yaml").String() -var dontSaveConfig = flag.MakeFull("n", "no-update", "Don't save updated config to disk.", "false").Bool() -var registrationPath = flag.MakeFull("r", "registration", "The path where to save the appservice registration.", "registration.yaml").String() -var generateRegistration = flag.MakeFull("g", "generate-registration", "Generate registration and quit.", "false").Bool() -var version = flag.MakeFull("v", "version", "View bridge version and quit.", "false").Bool() -var versionJSON = flag.Make().LongKey("version-json").Usage("Print a JSON object representing the bridge version and quit.").Default("false").Bool() -var ignoreUnsupportedDatabase = flag.Make().LongKey("ignore-unsupported-database").Usage("Run even if the database schema is too new").Default("false").Bool() -var ignoreForeignTables = flag.Make().LongKey("ignore-foreign-tables").Usage("Run even if the database contains tables from other programs (like Synapse)").Default("false").Bool() -var ignoreUnsupportedServer = flag.Make().LongKey("ignore-unsupported-server").Usage("Run even if the Matrix homeserver is outdated").Default("false").Bool() -var wantHelp, _ = flag.MakeHelpFlag() - -var _ appservice.StateStore = (*sqlstatestore.SQLStateStore)(nil) - -type Portal interface { - IsEncrypted() bool - IsPrivateChat() bool - MarkEncrypted() - MainIntent() *appservice.IntentAPI - - ReceiveMatrixEvent(user User, evt *event.Event) - UpdateBridgeInfo(ctx context.Context) -} - -type MembershipHandlingPortal interface { - Portal - HandleMatrixLeave(sender User, evt *event.Event) - HandleMatrixKick(sender User, ghost Ghost, evt *event.Event) - HandleMatrixInvite(sender User, ghost Ghost, evt *event.Event) -} - -type ReadReceiptHandlingPortal interface { - Portal - HandleMatrixReadReceipt(sender User, eventID id.EventID, receipt event.ReadReceipt) -} - -type TypingPortal interface { - Portal - HandleMatrixTyping(userIDs []id.UserID) -} - -type MetaHandlingPortal interface { - Portal - HandleMatrixMeta(sender User, evt *event.Event) -} - -type DisappearingPortal interface { - Portal - ScheduleDisappearing() -} - -type PowerLevelHandlingPortal interface { - Portal - HandleMatrixPowerLevels(sender User, evt *event.Event) -} - -type User interface { - GetPermissionLevel() bridgeconfig.PermissionLevel - IsLoggedIn() bool - GetManagementRoomID() id.RoomID - SetManagementRoom(id.RoomID) - GetMXID() id.UserID - GetIDoublePuppet() DoublePuppet - GetIGhost() Ghost -} - -type DoublePuppet interface { - CustomIntent() *appservice.IntentAPI - SwitchCustomMXID(accessToken string, userID id.UserID) error - ClearCustomMXID() -} - -type Ghost interface { - DoublePuppet - DefaultIntent() *appservice.IntentAPI - GetMXID() id.UserID -} - -type GhostWithProfile interface { - Ghost - GetDisplayname() string - GetAvatarURL() id.ContentURI -} - -type ChildOverride interface { - GetExampleConfig() string - GetConfigPtr() interface{} - - Init() - Start() - Stop() - - GetIPortal(id.RoomID) Portal - GetAllIPortals() []Portal - GetIUser(id id.UserID, create bool) User - IsGhost(id.UserID) bool - GetIGhost(id.UserID) Ghost - CreatePrivatePortal(id.RoomID, User, Ghost) -} - -type ConfigValidatingBridge interface { - ChildOverride - ValidateConfig() error -} - -type FlagHandlingBridge interface { - ChildOverride - HandleFlags() bool -} - -type PreInitableBridge interface { - ChildOverride - PreInit() -} - -type WebsocketStartingBridge interface { - ChildOverride - OnWebsocketConnect() -} - -type CSFeatureRequirer interface { - CheckFeatures(versions *mautrix.RespVersions) (string, bool) -} - -type Bridge struct { - Name string - URL string - Description string - Version string - ProtocolName string - BeeperServiceName string - BeeperNetworkName string - - AdditionalShortFlags string - AdditionalLongFlags string - - VersionDesc string - LinkifiedVersion string - BuildTime string - commit string - baseVersion string - - PublicHSAddress *url.URL - - DoublePuppet *doublePuppetUtil - - AS *appservice.AppService - EventProcessor *appservice.EventProcessor - CommandProcessor CommandProcessor - MatrixHandler *MatrixHandler - Bot *appservice.IntentAPI - Config bridgeconfig.BaseConfig - ConfigPath string - RegistrationPath string - SaveConfig bool - ConfigUpgrader configupgrade.BaseUpgrader - DB *dbutil.Database - StateStore *sqlstatestore.SQLStateStore - Crypto Crypto - CryptoPickleKey string - - ZLog *zerolog.Logger - - MediaConfig mautrix.RespMediaConfig - SpecVersions mautrix.RespVersions - - Child ChildOverride - - manualStop chan int - Stopping bool - - latestState *status.BridgeState - - Websocket bool - wsStopPinger chan struct{} - wsStarted chan struct{} - wsStopped chan struct{} - wsShortCircuitReconnectBackoff chan struct{} - wsStartupWait *sync.WaitGroup -} - -type Crypto interface { - HandleMemberEvent(context.Context, *event.Event) - Decrypt(context.Context, *event.Event) (*event.Event, error) - Encrypt(context.Context, id.RoomID, event.Type, *event.Content) error - WaitForSession(context.Context, id.RoomID, id.SenderKey, id.SessionID, time.Duration) bool - RequestSession(context.Context, id.RoomID, id.SenderKey, id.SessionID, id.UserID, id.DeviceID) - ResetSession(context.Context, id.RoomID) - Init(ctx context.Context) error - Start() - Stop() - Reset(ctx context.Context, startAfterReset bool) - Client() *mautrix.Client - ShareKeys(context.Context) error -} - -func (br *Bridge) GenerateRegistration() { - if !br.SaveConfig { - // We need to save the generated as_token and hs_token in the config - _, _ = fmt.Fprintln(os.Stderr, "--no-update is not compatible with --generate-registration") - os.Exit(5) - } else if br.Config.Homeserver.Domain == "example.com" { - _, _ = fmt.Fprintln(os.Stderr, "Homeserver domain is not set") - os.Exit(20) - } - reg := br.Config.GenerateRegistration() - err := reg.Save(br.RegistrationPath) - if err != nil { - _, _ = fmt.Fprintln(os.Stderr, "Failed to save registration:", err) - os.Exit(21) - } - - updateTokens := func(helper *configupgrade.Helper) { - helper.Set(configupgrade.Str, reg.AppToken, "appservice", "as_token") - helper.Set(configupgrade.Str, reg.ServerToken, "appservice", "hs_token") - } - _, _, err = configupgrade.Do(br.ConfigPath, true, br.ConfigUpgrader, configupgrade.SimpleUpgrader(updateTokens)) - if err != nil { - _, _ = fmt.Fprintln(os.Stderr, "Failed to save config:", err) - os.Exit(22) - } - fmt.Println("Registration generated. See https://docs.mau.fi/bridges/general/registering-appservices.html for instructions on installing the registration.") - os.Exit(0) -} - -func (br *Bridge) InitVersion(tag, commit, buildTime string) { - br.baseVersion = br.Version - if len(tag) > 0 && tag[0] == 'v' { - tag = tag[1:] - } - if tag != br.Version { - suffix := "" - if !strings.HasSuffix(br.Version, "+dev") { - suffix = "+dev" - } - if len(commit) > 8 { - br.Version = fmt.Sprintf("%s%s.%s", br.Version, suffix, commit[:8]) - } else { - br.Version = fmt.Sprintf("%s%s.unknown", br.Version, suffix) - } - } - - br.LinkifiedVersion = fmt.Sprintf("v%s", br.Version) - if tag == br.Version { - br.LinkifiedVersion = fmt.Sprintf("[v%s](%s/releases/v%s)", br.Version, br.URL, tag) - } else if len(commit) > 8 { - br.LinkifiedVersion = strings.Replace(br.LinkifiedVersion, commit[:8], fmt.Sprintf("[%s](%s/commit/%s)", commit[:8], br.URL, commit), 1) - } - mautrix.DefaultUserAgent = fmt.Sprintf("%s/%s %s", br.Name, br.Version, mautrix.DefaultUserAgent) - br.VersionDesc = fmt.Sprintf("%s %s (%s with %s)", br.Name, br.Version, buildTime, runtime.Version()) - br.commit = commit - br.BuildTime = buildTime -} - -var MinSpecVersion = mautrix.SpecV14 - -func (br *Bridge) ensureConnection(ctx context.Context) { - for { - versions, err := br.Bot.Versions(ctx) - if err != nil { - br.ZLog.Err(err).Msg("Failed to connect to homeserver, retrying in 10 seconds...") - time.Sleep(10 * time.Second) - } else { - br.SpecVersions = *versions - break - } - } - - unsupportedServerLogLevel := zerolog.FatalLevel - if *ignoreUnsupportedServer { - unsupportedServerLogLevel = zerolog.ErrorLevel - } - if br.Config.Homeserver.Software == bridgeconfig.SoftwareHungry && !br.SpecVersions.Supports(mautrix.BeeperFeatureHungry) { - br.ZLog.WithLevel(zerolog.FatalLevel).Msg("The config claims the homeserver is hungryserv, but the /versions response didn't confirm it") - os.Exit(18) - } else if !br.SpecVersions.ContainsGreaterOrEqual(MinSpecVersion) { - br.ZLog.WithLevel(unsupportedServerLogLevel). - Stringer("server_supports", br.SpecVersions.GetLatest()). - Stringer("bridge_requires", MinSpecVersion). - Msg("The homeserver is outdated (supported spec versions are below minimum required by bridge)") - if !*ignoreUnsupportedServer { - os.Exit(18) - } - } else if fr, ok := br.Child.(CSFeatureRequirer); ok { - if msg, hasFeatures := fr.CheckFeatures(&br.SpecVersions); !hasFeatures { - br.ZLog.WithLevel(unsupportedServerLogLevel).Msg(msg) - if !*ignoreUnsupportedServer { - os.Exit(18) - } - } - } - - resp, err := br.Bot.Whoami(ctx) - if err != nil { - if errors.Is(err, mautrix.MUnknownToken) { - br.ZLog.WithLevel(zerolog.FatalLevel).Msg("The as_token was not accepted. Is the registration file installed in your homeserver correctly?") - br.ZLog.Info().Msg("See https://docs.mau.fi/faq/as-token for more info") - } else if errors.Is(err, mautrix.MExclusive) { - br.ZLog.WithLevel(zerolog.FatalLevel).Msg("The as_token was accepted, but the /register request was not. Are the homeserver domain, bot username and username template in the config correct, and do they match the values in the registration?") - br.ZLog.Info().Msg("See https://docs.mau.fi/faq/as-register for more info") - } else { - br.ZLog.WithLevel(zerolog.FatalLevel).Err(err).Msg("/whoami request failed with unknown error") - } - os.Exit(16) - } else if resp.UserID != br.Bot.UserID { - br.ZLog.WithLevel(zerolog.FatalLevel). - Stringer("got_user_id", resp.UserID). - Stringer("expected_user_id", br.Bot.UserID). - Msg("Unexpected user ID in whoami call") - os.Exit(17) - } - - if br.Websocket { - br.ZLog.Debug().Msg("Websocket mode: no need to check status of homeserver -> bridge connection") - return - } else if !br.SpecVersions.Supports(mautrix.FeatureAppservicePing) { - br.ZLog.Debug().Msg("Homeserver does not support checking status of homeserver -> bridge connection") - return - } - var pingResp *mautrix.RespAppservicePing - var txnID string - var retryCount int - const maxRetries = 6 - for { - txnID = br.Bot.TxnID() - pingResp, err = br.Bot.AppservicePing(ctx, br.Config.AppService.ID, txnID) - if err == nil { - break - } - var httpErr mautrix.HTTPError - var pingErrBody string - if errors.As(err, &httpErr) && httpErr.RespError != nil { - if val, ok := httpErr.RespError.ExtraData["body"].(string); ok { - pingErrBody = strings.TrimSpace(val) - } - } - outOfRetries := retryCount >= maxRetries - level := zerolog.ErrorLevel - if outOfRetries { - level = zerolog.FatalLevel - } - evt := br.ZLog.WithLevel(level).Err(err).Str("txn_id", txnID) - if pingErrBody != "" { - bodyBytes := []byte(pingErrBody) - if json.Valid(bodyBytes) { - evt.RawJSON("body", bodyBytes) - } else { - evt.Str("body", pingErrBody) - } - } - if outOfRetries { - evt.Msg("Homeserver -> bridge connection is not working") - br.ZLog.Info().Msg("See https://docs.mau.fi/faq/as-ping for more info") - os.Exit(13) - } - evt.Msg("Homeserver -> bridge connection is not working, retrying in 5 seconds...") - time.Sleep(5 * time.Second) - retryCount++ - } - br.ZLog.Debug(). - Str("txn_id", txnID). - Int64("duration_ms", pingResp.DurationMS). - Msg("Homeserver -> bridge connection works") -} - -func (br *Bridge) fetchMediaConfig(ctx context.Context) { - cfg, err := br.Bot.GetMediaConfig(ctx) - if err != nil { - br.ZLog.Warn().Err(err).Msg("Failed to fetch media config") - } else { - if cfg.UploadSize == 0 { - cfg.UploadSize = 50 * 1024 * 1024 - } - br.MediaConfig = *cfg - } -} - -func (br *Bridge) UpdateBotProfile(ctx context.Context) { - br.ZLog.Debug().Msg("Updating bot profile") - botConfig := &br.Config.AppService.Bot - - var err error - var mxc id.ContentURI - if botConfig.Avatar == "remove" { - err = br.Bot.SetAvatarURL(ctx, mxc) - } else if !botConfig.ParsedAvatar.IsEmpty() { - err = br.Bot.SetAvatarURL(ctx, botConfig.ParsedAvatar) - } - if err != nil { - br.ZLog.Warn().Err(err).Msg("Failed to update bot avatar") - } - - if botConfig.Displayname == "remove" { - err = br.Bot.SetDisplayName(ctx, "") - } else if len(botConfig.Displayname) > 0 { - err = br.Bot.SetDisplayName(ctx, botConfig.Displayname) - } - if err != nil { - br.ZLog.Warn().Err(err).Msg("Failed to update bot displayname") - } - - if br.SpecVersions.Supports(mautrix.BeeperFeatureArbitraryProfileMeta) && br.BeeperNetworkName != "" { - br.ZLog.Debug().Msg("Setting contact info on the appservice bot") - br.Bot.BeeperUpdateProfile(ctx, map[string]any{ - "com.beeper.bridge.service": br.BeeperServiceName, - "com.beeper.bridge.network": br.BeeperNetworkName, - "com.beeper.bridge.is_bridge_bot": true, - }) - } -} - -func (br *Bridge) loadConfig() { - configData, upgraded, err := configupgrade.Do(br.ConfigPath, br.SaveConfig, br.ConfigUpgrader) - if err != nil { - _, _ = fmt.Fprintln(os.Stderr, "Error updating config:", err) - if configData == nil { - os.Exit(10) - } - } - - target := br.Child.GetConfigPtr() - if !upgraded { - // Fallback: if config upgrading failed, load example config for base values - err = yaml.Unmarshal([]byte(br.Child.GetExampleConfig()), &target) - if err != nil { - _, _ = fmt.Fprintln(os.Stderr, "Failed to unmarshal example config:", err) - os.Exit(10) - } - } - err = yaml.Unmarshal(configData, target) - if err != nil { - _, _ = fmt.Fprintln(os.Stderr, "Failed to parse config:", err) - os.Exit(10) - } -} - -func (br *Bridge) validateConfig() error { - switch { - case br.Config.Homeserver.Address == "https://matrix.example.com": - return errors.New("homeserver.address not configured") - case br.Config.Homeserver.Domain == "example.com": - return errors.New("homeserver.domain not configured") - case !bridgeconfig.AllowedHomeserverSoftware[br.Config.Homeserver.Software]: - return errors.New("invalid value for homeserver.software (use `standard` if you don't know what the field is for)") - case br.Config.AppService.ASToken == "This value is generated when generating the registration": - return errors.New("appservice.as_token not configured. Did you forget to generate the registration? ") - case br.Config.AppService.HSToken == "This value is generated when generating the registration": - return errors.New("appservice.hs_token not configured. Did you forget to generate the registration? ") - case br.Config.AppService.Database.URI == "postgres://user:password@host/database?sslmode=disable": - return errors.New("appservice.database not configured") - default: - err := br.Config.Bridge.Validate() - if err != nil { - return err - } - validator, ok := br.Child.(ConfigValidatingBridge) - if ok { - return validator.ValidateConfig() - } - return nil - } -} - -func (br *Bridge) getProfile(userID id.UserID, roomID id.RoomID) *event.MemberEventContent { - ghost := br.Child.GetIGhost(userID) - if ghost == nil { - return nil - } - profilefulGhost, ok := ghost.(GhostWithProfile) - if ok { - return &event.MemberEventContent{ - Displayname: profilefulGhost.GetDisplayname(), - AvatarURL: profilefulGhost.GetAvatarURL().CUString(), - } - } - return nil -} - -func (br *Bridge) init() { - pib, ok := br.Child.(PreInitableBridge) - if ok { - pib.PreInit() - } - - var err error - - br.MediaConfig.UploadSize = 50 * 1024 * 1024 - - br.ZLog, err = br.Config.Logging.Compile() - if err != nil { - _, _ = fmt.Fprintln(os.Stderr, "Failed to initialize logger:", err) - os.Exit(12) - } - exzerolog.SetupDefaults(br.ZLog) - - br.DoublePuppet = &doublePuppetUtil{br: br, log: br.ZLog.With().Str("component", "double puppet").Logger()} - - err = br.validateConfig() - if err != nil { - br.ZLog.WithLevel(zerolog.FatalLevel).Err(err).Msg("Configuration error") - br.ZLog.Info().Msg("See https://docs.mau.fi/faq/field-unconfigured for more info") - os.Exit(11) - } - - br.ZLog.Info(). - Str("name", br.Name). - Str("version", br.Version). - Str("built_at", br.BuildTime). - Str("go_version", runtime.Version()). - Msg("Initializing bridge") - - br.ZLog.Debug().Msg("Initializing database connection") - dbConfig := br.Config.AppService.Database - if (dbConfig.Type == "sqlite3-fk-wal" || dbConfig.Type == "litestream") && dbConfig.MaxOpenConns != 1 && !strings.Contains(dbConfig.URI, "_txlock=immediate") { - var fixedExampleURI string - if !strings.HasPrefix(dbConfig.URI, "file:") { - fixedExampleURI = fmt.Sprintf("file:%s?_txlock=immediate", dbConfig.URI) - } else if !strings.ContainsRune(dbConfig.URI, '?') { - fixedExampleURI = fmt.Sprintf("%s?_txlock=immediate", dbConfig.URI) - } else { - fixedExampleURI = fmt.Sprintf("%s&_txlock=immediate", dbConfig.URI) - } - br.ZLog.Warn(). - Str("fixed_uri_example", fixedExampleURI). - Msg("Using SQLite without _txlock=immediate is not recommended") - } - br.DB, err = dbutil.NewFromConfig(br.Name, dbConfig, dbutil.ZeroLogger(br.ZLog.With().Str("db_section", "main").Logger())) - if err != nil { - br.ZLog.WithLevel(zerolog.FatalLevel).Err(err).Msg("Failed to initialize database connection") - if sqlError := (&sqlite3.Error{}); errors.As(err, sqlError) && sqlError.Code == sqlite3.ErrCorrupt { - os.Exit(18) - } - os.Exit(14) - } - br.DB.IgnoreUnsupportedDatabase = *ignoreUnsupportedDatabase - br.DB.IgnoreForeignTables = *ignoreForeignTables - - br.ZLog.Debug().Msg("Initializing state store") - br.StateStore = sqlstatestore.NewSQLStateStore(br.DB, dbutil.ZeroLogger(br.ZLog.With().Str("db_section", "matrix_state").Logger()), true) - - br.AS, err = appservice.CreateFull(appservice.CreateOpts{ - Registration: br.Config.AppService.GetRegistration(), - HomeserverDomain: br.Config.Homeserver.Domain, - HomeserverURL: br.Config.Homeserver.Address, - HostConfig: appservice.HostConfig{ - Hostname: br.Config.AppService.Hostname, - Port: br.Config.AppService.Port, - }, - StateStore: br.StateStore, - }) - if err != nil { - br.ZLog.WithLevel(zerolog.FatalLevel).Err(err). - Msg("Failed to initialize appservice") - os.Exit(15) - } - br.AS.Log = *br.ZLog - br.AS.DoublePuppetValue = br.Name - br.AS.GetProfile = br.getProfile - br.Bot = br.AS.BotIntent() - - br.ZLog.Debug().Msg("Initializing Matrix event processor") - br.EventProcessor = appservice.NewEventProcessor(br.AS) - if !br.Config.AppService.AsyncTransactions { - br.EventProcessor.ExecMode = appservice.Sync - } - br.ZLog.Debug().Msg("Initializing Matrix event handler") - br.MatrixHandler = NewMatrixHandler(br) - - br.Crypto = NewCryptoHelper(br) - - hsURL := br.Config.Homeserver.Address - if br.Config.Homeserver.PublicAddress != "" { - hsURL = br.Config.Homeserver.PublicAddress - } - br.PublicHSAddress, err = url.Parse(hsURL) - if err != nil { - br.ZLog.WithLevel(zerolog.FatalLevel).Err(err). - Str("input", hsURL). - Msg("Failed to parse public homeserver URL") - os.Exit(15) - } - - br.Child.Init() -} - -type zerologPQError pq.Error - -func (zpe *zerologPQError) MarshalZerologObject(evt *zerolog.Event) { - maybeStr := func(field, value string) { - if value != "" { - evt.Str(field, value) - } - } - maybeStr("severity", zpe.Severity) - if name := zpe.Code.Name(); name != "" { - evt.Str("code", name) - } else if zpe.Code != "" { - evt.Str("code", string(zpe.Code)) - } - //maybeStr("message", zpe.Message) - maybeStr("detail", zpe.Detail) - maybeStr("hint", zpe.Hint) - maybeStr("position", zpe.Position) - maybeStr("internal_position", zpe.InternalPosition) - maybeStr("internal_query", zpe.InternalQuery) - maybeStr("where", zpe.Where) - maybeStr("schema", zpe.Schema) - maybeStr("table", zpe.Table) - maybeStr("column", zpe.Column) - maybeStr("data_type_name", zpe.DataTypeName) - maybeStr("constraint", zpe.Constraint) - maybeStr("file", zpe.File) - maybeStr("line", zpe.Line) - maybeStr("routine", zpe.Routine) -} - -func (br *Bridge) LogDBUpgradeErrorAndExit(name string, err error) { - logEvt := br.ZLog.WithLevel(zerolog.FatalLevel). - Err(err). - Str("db_section", name) - var errWithLine *dbutil.PQErrorWithLine - if errors.As(err, &errWithLine) { - logEvt.Str("sql_line", errWithLine.Line) - } - var pqe *pq.Error - if errors.As(err, &pqe) { - logEvt.Object("pq_error", (*zerologPQError)(pqe)) - } - logEvt.Msg("Failed to initialize database") - if sqlError := (&sqlite3.Error{}); errors.As(err, sqlError) && sqlError.Code == sqlite3.ErrCorrupt { - os.Exit(18) - } else if errors.Is(err, dbutil.ErrForeignTables) { - br.ZLog.Info().Msg("You can use --ignore-foreign-tables to ignore this error") - br.ZLog.Info().Msg("See https://docs.mau.fi/faq/foreign-tables for more info") - } else if errors.Is(err, dbutil.ErrNotOwned) { - br.ZLog.Info().Msg("Sharing the same database with different programs is not supported") - } else if errors.Is(err, dbutil.ErrUnsupportedDatabaseVersion) { - br.ZLog.Info().Msg("Downgrading the bridge is not supported") - } - os.Exit(15) -} - -func (br *Bridge) WaitWebsocketConnected() { - if br.wsStartupWait != nil { - br.wsStartupWait.Wait() - } -} - -func (br *Bridge) start() { - br.ZLog.Debug().Msg("Running database upgrades") - err := br.DB.Upgrade(br.ZLog.With().Str("db_section", "main").Logger().WithContext(context.TODO())) - if err != nil { - br.LogDBUpgradeErrorAndExit("main", err) - } else if err = br.StateStore.Upgrade(br.ZLog.With().Str("db_section", "matrix_state").Logger().WithContext(context.TODO())); err != nil { - br.LogDBUpgradeErrorAndExit("matrix_state", err) - } - - if br.Config.Homeserver.Websocket || len(br.Config.Homeserver.WSProxy) > 0 { - br.Websocket = true - br.ZLog.Debug().Msg("Starting application service websocket") - var wg sync.WaitGroup - wg.Add(1) - br.wsStartupWait = &wg - br.wsShortCircuitReconnectBackoff = make(chan struct{}) - go br.startWebsocket(&wg) - } else if br.AS.Host.IsConfigured() { - br.ZLog.Debug().Msg("Starting application service HTTP server") - go br.AS.Start() - } else { - br.ZLog.WithLevel(zerolog.FatalLevel).Msg("Neither appservice HTTP listener nor websocket is enabled") - os.Exit(23) - } - br.ZLog.Debug().Msg("Checking connection to homeserver") - - ctx := br.ZLog.WithContext(context.Background()) - br.ensureConnection(ctx) - go br.fetchMediaConfig(ctx) - - if br.Crypto != nil { - err = br.Crypto.Init(ctx) - if err != nil { - br.ZLog.WithLevel(zerolog.FatalLevel).Err(err).Msg("Error initializing end-to-bridge encryption") - os.Exit(19) - } - } - - br.ZLog.Debug().Msg("Starting event processor") - br.EventProcessor.Start(ctx) - - go br.UpdateBotProfile(ctx) - if br.Crypto != nil { - go br.Crypto.Start() - } - - br.Child.Start() - br.WaitWebsocketConnected() - br.AS.Ready = true - - if br.Config.Bridge.GetResendBridgeInfo() { - go br.ResendBridgeInfo() - } - if br.Websocket && br.Config.Homeserver.WSPingInterval > 0 { - br.wsStopPinger = make(chan struct{}, 1) - go br.websocketServerPinger() - } -} - -func (br *Bridge) ResendBridgeInfo() { - if !br.SaveConfig { - br.ZLog.Warn().Msg("Not setting resend_bridge_info to false in config due to --no-update flag") - } else { - _, _, err := configupgrade.Do(br.ConfigPath, true, br.ConfigUpgrader, configupgrade.SimpleUpgrader(func(helper *configupgrade.Helper) { - helper.Set(configupgrade.Bool, "false", "bridge", "resend_bridge_info") - })) - if err != nil { - br.ZLog.Err(err).Msg("Failed to save config after setting resend_bridge_info to false") - } - } - br.ZLog.Info().Msg("Re-sending bridge info state event to all portals") - for _, portal := range br.Child.GetAllIPortals() { - portal.UpdateBridgeInfo(context.TODO()) - } - br.ZLog.Info().Msg("Finished re-sending bridge info state events") -} - -func sendStopSignal(ch chan struct{}) { - if ch != nil { - select { - case ch <- struct{}{}: - default: - } - } -} - -func (br *Bridge) stop() { - br.Stopping = true - if br.Crypto != nil { - br.Crypto.Stop() - } - waitForWS := false - if br.AS.StopWebsocket != nil { - br.ZLog.Debug().Msg("Stopping application service websocket") - br.AS.StopWebsocket(appservice.ErrWebsocketManualStop) - waitForWS = true - } - br.AS.Stop() - sendStopSignal(br.wsStopPinger) - sendStopSignal(br.wsShortCircuitReconnectBackoff) - br.EventProcessor.Stop() - br.Child.Stop() - err := br.DB.Close() - if err != nil { - br.ZLog.Warn().Err(err).Msg("Error closing database") - } - if waitForWS { - select { - case <-br.wsStopped: - case <-time.After(4 * time.Second): - br.ZLog.Warn().Msg("Timed out waiting for websocket to close") - } - } -} - -func (br *Bridge) ManualStop(exitCode int) { - if br.manualStop != nil { - br.manualStop <- exitCode - } else { - os.Exit(exitCode) - } -} - -type VersionJSONOutput struct { - Name string - URL string - - Version string - IsRelease bool - Commit string - FormattedVersion string - BuildTime string - - OS string - Arch string - - Mautrix struct { - Version string - Commit string - } -} - -func (br *Bridge) Main() { - flag.SetHelpTitles( - fmt.Sprintf("%s - %s", br.Name, br.Description), - fmt.Sprintf("%s [-hgvn%s] [-c ] [-r ]%s", br.Name, br.AdditionalShortFlags, br.AdditionalLongFlags)) - err := flag.Parse() - br.ConfigPath = *configPath - br.RegistrationPath = *registrationPath - br.SaveConfig = !*dontSaveConfig - if err != nil { - _, _ = fmt.Fprintln(os.Stderr, err) - flag.PrintHelp() - os.Exit(1) - } else if *wantHelp { - flag.PrintHelp() - os.Exit(0) - } else if *version { - fmt.Println(br.VersionDesc) - return - } else if *versionJSON { - output := VersionJSONOutput{ - URL: br.URL, - Name: br.Name, - - Version: br.baseVersion, - IsRelease: br.Version == br.baseVersion, - Commit: br.commit, - FormattedVersion: br.Version, - BuildTime: br.BuildTime, - - OS: runtime.GOOS, - Arch: runtime.GOARCH, - } - output.Mautrix.Commit = mautrix.Commit - output.Mautrix.Version = mautrix.Version - _ = json.NewEncoder(os.Stdout).Encode(output) - return - } else if flagHandler, ok := br.Child.(FlagHandlingBridge); ok && flagHandler.HandleFlags() { - return - } - - br.loadConfig() - - if *generateRegistration { - br.GenerateRegistration() - return - } - - br.manualStop = make(chan int, 1) - br.init() - br.ZLog.Info().Msg("Bridge initialization complete, starting...") - br.start() - br.ZLog.Info().Msg("Bridge started!") - - c := make(chan os.Signal, 1) - signal.Notify(c, os.Interrupt, syscall.SIGTERM) - var exitCode int - select { - case <-c: - br.ZLog.Info().Msg("Interrupt received, stopping...") - case exitCode = <-br.manualStop: - br.ZLog.Info().Int("exit_code", exitCode).Msg("Manual stop requested") - } - - br.stop() - br.ZLog.Info().Msg("Bridge stopped.") - os.Exit(exitCode) -} diff --git a/bridge/bridgeconfig/config.go b/bridge/bridgeconfig/config.go deleted file mode 100644 index 2e8548b5..00000000 --- a/bridge/bridgeconfig/config.go +++ /dev/null @@ -1,337 +0,0 @@ -// Copyright (c) 2023 Tulir Asokan -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. - -package bridgeconfig - -import ( - "fmt" - "os" - "path/filepath" - "regexp" - "strings" - - "github.com/rs/zerolog" - up "go.mau.fi/util/configupgrade" - "go.mau.fi/util/dbutil" - "go.mau.fi/util/random" - "go.mau.fi/zeroconfig" - "gopkg.in/yaml.v3" - - "maunium.net/go/mautrix/appservice" - "maunium.net/go/mautrix/id" -) - -type HomeserverSoftware string - -const ( - SoftwareStandard HomeserverSoftware = "standard" - SoftwareAsmux HomeserverSoftware = "asmux" - SoftwareHungry HomeserverSoftware = "hungry" -) - -var AllowedHomeserverSoftware = map[HomeserverSoftware]bool{ - SoftwareStandard: true, - SoftwareAsmux: true, - SoftwareHungry: true, -} - -type HomeserverConfig struct { - Address string `yaml:"address"` - Domain string `yaml:"domain"` - AsyncMedia bool `yaml:"async_media"` - - PublicAddress string `yaml:"public_address,omitempty"` - - Software HomeserverSoftware `yaml:"software"` - - StatusEndpoint string `yaml:"status_endpoint"` - MessageSendCheckpointEndpoint string `yaml:"message_send_checkpoint_endpoint"` - - Websocket bool `yaml:"websocket"` - WSProxy string `yaml:"websocket_proxy"` - WSPingInterval int `yaml:"ping_interval_seconds"` -} - -type AppserviceConfig struct { - Address string `yaml:"address"` - Hostname string `yaml:"hostname"` - Port uint16 `yaml:"port"` - - Database dbutil.Config `yaml:"database"` - - ID string `yaml:"id"` - Bot BotUserConfig `yaml:"bot"` - - ASToken string `yaml:"as_token"` - HSToken string `yaml:"hs_token"` - - EphemeralEvents bool `yaml:"ephemeral_events"` - AsyncTransactions bool `yaml:"async_transactions"` -} - -func (config *BaseConfig) MakeUserIDRegex(matcher string) *regexp.Regexp { - usernamePlaceholder := strings.ToLower(random.String(16)) - usernameTemplate := fmt.Sprintf("@%s:%s", - config.Bridge.FormatUsername(usernamePlaceholder), - config.Homeserver.Domain) - usernameTemplate = regexp.QuoteMeta(usernameTemplate) - usernameTemplate = strings.Replace(usernameTemplate, usernamePlaceholder, matcher, 1) - usernameTemplate = fmt.Sprintf("^%s$", usernameTemplate) - return regexp.MustCompile(usernameTemplate) -} - -// GenerateRegistration generates a registration file for the homeserver. -func (config *BaseConfig) GenerateRegistration() *appservice.Registration { - registration := appservice.CreateRegistration() - config.AppService.HSToken = registration.ServerToken - config.AppService.ASToken = registration.AppToken - config.AppService.copyToRegistration(registration) - - registration.SenderLocalpart = random.String(32) - botRegex := regexp.MustCompile(fmt.Sprintf("^@%s:%s$", - regexp.QuoteMeta(config.AppService.Bot.Username), - regexp.QuoteMeta(config.Homeserver.Domain))) - registration.Namespaces.UserIDs.Register(botRegex, true) - registration.Namespaces.UserIDs.Register(config.MakeUserIDRegex(".*"), true) - - return registration -} - -func (config *BaseConfig) MakeAppService() *appservice.AppService { - as := appservice.Create() - as.HomeserverDomain = config.Homeserver.Domain - _ = as.SetHomeserverURL(config.Homeserver.Address) - as.Host.Hostname = config.AppService.Hostname - as.Host.Port = config.AppService.Port - as.Registration = config.AppService.GetRegistration() - return as -} - -// GetRegistration copies the data from the bridge config into an *appservice.Registration struct. -// This can't be used with the homeserver, see GenerateRegistration for generating files for the homeserver. -func (asc *AppserviceConfig) GetRegistration() *appservice.Registration { - reg := &appservice.Registration{} - asc.copyToRegistration(reg) - reg.SenderLocalpart = asc.Bot.Username - reg.ServerToken = asc.HSToken - reg.AppToken = asc.ASToken - return reg -} - -func (asc *AppserviceConfig) copyToRegistration(registration *appservice.Registration) { - registration.ID = asc.ID - registration.URL = asc.Address - falseVal := false - registration.RateLimited = &falseVal - registration.EphemeralEvents = asc.EphemeralEvents - registration.SoruEphemeralEvents = asc.EphemeralEvents -} - -type BotUserConfig struct { - Username string `yaml:"username"` - Displayname string `yaml:"displayname"` - Avatar string `yaml:"avatar"` - - ParsedAvatar id.ContentURI `yaml:"-"` -} - -type serializableBUC BotUserConfig - -func (buc *BotUserConfig) UnmarshalYAML(unmarshal func(interface{}) error) error { - var sbuc serializableBUC - err := unmarshal(&sbuc) - if err != nil { - return err - } - *buc = (BotUserConfig)(sbuc) - if buc.Avatar != "" && buc.Avatar != "remove" { - buc.ParsedAvatar, err = id.ParseContentURI(buc.Avatar) - if err != nil { - return fmt.Errorf("%w in bot avatar", err) - } - } - return nil -} - -type BridgeConfig interface { - FormatUsername(username string) string - GetEncryptionConfig() EncryptionConfig - GetCommandPrefix() string - GetManagementRoomTexts() ManagementRoomTexts - GetDoublePuppetConfig() DoublePuppetConfig - GetResendBridgeInfo() bool - EnableMessageStatusEvents() bool - EnableMessageErrorNotices() bool - Validate() error -} - -type DoublePuppetConfig struct { - ServerMap map[string]string `yaml:"double_puppet_server_map"` - AllowDiscovery bool `yaml:"double_puppet_allow_discovery"` - SharedSecretMap map[string]string `yaml:"login_shared_secret_map"` -} - -type EncryptionConfig struct { - Allow bool `yaml:"allow"` - Default bool `yaml:"default"` - Require bool `yaml:"require"` - Appservice bool `yaml:"appservice"` - - PlaintextMentions bool `yaml:"plaintext_mentions"` - - DeleteKeys struct { - DeleteOutboundOnAck bool `yaml:"delete_outbound_on_ack"` - DontStoreOutbound bool `yaml:"dont_store_outbound"` - RatchetOnDecrypt bool `yaml:"ratchet_on_decrypt"` - DeleteFullyUsedOnDecrypt bool `yaml:"delete_fully_used_on_decrypt"` - DeletePrevOnNewSession bool `yaml:"delete_prev_on_new_session"` - DeleteOnDeviceDelete bool `yaml:"delete_on_device_delete"` - PeriodicallyDeleteExpired bool `yaml:"periodically_delete_expired"` - DeleteOutdatedInbound bool `yaml:"delete_outdated_inbound"` - } `yaml:"delete_keys"` - - VerificationLevels struct { - Receive id.TrustState `yaml:"receive"` - Send id.TrustState `yaml:"send"` - Share id.TrustState `yaml:"share"` - } `yaml:"verification_levels"` - AllowKeySharing bool `yaml:"allow_key_sharing"` - - Rotation struct { - EnableCustom bool `yaml:"enable_custom"` - Milliseconds int64 `yaml:"milliseconds"` - Messages int `yaml:"messages"` - - DisableDeviceChangeKeyRotation bool `yaml:"disable_device_change_key_rotation"` - } `yaml:"rotation"` -} - -type ManagementRoomTexts struct { - Welcome string `yaml:"welcome"` - WelcomeConnected string `yaml:"welcome_connected"` - WelcomeUnconnected string `yaml:"welcome_unconnected"` - AdditionalHelp string `yaml:"additional_help"` -} - -type BaseConfig struct { - Homeserver HomeserverConfig `yaml:"homeserver"` - AppService AppserviceConfig `yaml:"appservice"` - Bridge BridgeConfig `yaml:"-"` - Logging zeroconfig.Config `yaml:"logging"` -} - -func doUpgrade(helper *up.Helper) { - helper.Copy(up.Str, "homeserver", "address") - helper.Copy(up.Str, "homeserver", "domain") - if legacyAsmuxFlag, ok := helper.Get(up.Bool, "homeserver", "asmux"); ok && legacyAsmuxFlag == "true" { - helper.Set(up.Str, string(SoftwareAsmux), "homeserver", "software") - } else { - helper.Copy(up.Str, "homeserver", "software") - } - helper.Copy(up.Str|up.Null, "homeserver", "status_endpoint") - helper.Copy(up.Str|up.Null, "homeserver", "message_send_checkpoint_endpoint") - helper.Copy(up.Bool, "homeserver", "async_media") - helper.Copy(up.Str|up.Null, "homeserver", "websocket_proxy") - helper.Copy(up.Bool, "homeserver", "websocket") - helper.Copy(up.Int, "homeserver", "ping_interval_seconds") - - helper.Copy(up.Str|up.Null, "appservice", "address") - helper.Copy(up.Str|up.Null, "appservice", "hostname") - helper.Copy(up.Int|up.Null, "appservice", "port") - if dbType, ok := helper.Get(up.Str, "appservice", "database", "type"); ok && dbType == "sqlite3" { - helper.Set(up.Str, "sqlite3-fk-wal", "appservice", "database", "type") - } else { - helper.Copy(up.Str, "appservice", "database", "type") - } - helper.Copy(up.Str, "appservice", "database", "uri") - helper.Copy(up.Int, "appservice", "database", "max_open_conns") - helper.Copy(up.Int, "appservice", "database", "max_idle_conns") - helper.Copy(up.Str|up.Null, "appservice", "database", "max_conn_idle_time") - helper.Copy(up.Str|up.Null, "appservice", "database", "max_conn_lifetime") - helper.Copy(up.Str, "appservice", "id") - helper.Copy(up.Str, "appservice", "bot", "username") - helper.Copy(up.Str, "appservice", "bot", "displayname") - helper.Copy(up.Str, "appservice", "bot", "avatar") - helper.Copy(up.Bool, "appservice", "ephemeral_events") - helper.Copy(up.Bool, "appservice", "async_transactions") - helper.Copy(up.Str, "appservice", "as_token") - helper.Copy(up.Str, "appservice", "hs_token") - - if helper.GetNode("logging", "writers") == nil && (helper.GetNode("logging", "print_level") != nil || helper.GetNode("logging", "file_name_format") != nil) { - _, _ = fmt.Fprintln(os.Stderr, "Migrating legacy log config") - migrateLegacyLogConfig(helper) - } else if helper.GetNode("logging", "writers") == nil && (helper.GetNode("logging", "handlers") != nil) { - _, _ = fmt.Fprintln(os.Stderr, "Migrating Python log config is not currently supported") - // TODO implement? - //migratePythonLogConfig(helper) - } else { - helper.Copy(up.Map, "logging") - } -} - -type legacyLogConfig struct { - Directory string `yaml:"directory"` - FileNameFormat string `yaml:"file_name_format"` - FileDateFormat string `yaml:"file_date_format"` - FileMode uint32 `yaml:"file_mode"` - TimestampFormat string `yaml:"timestamp_format"` - RawPrintLevel string `yaml:"print_level"` - JSONStdout bool `yaml:"print_json"` - JSONFile bool `yaml:"file_json"` -} - -func migrateLegacyLogConfig(helper *up.Helper) { - var llc legacyLogConfig - var newConfig zeroconfig.Config - err := helper.GetBaseNode("logging").Decode(&newConfig) - if err != nil { - _, _ = fmt.Fprintln(os.Stderr, "Base config is corrupted: failed to decode example log config:", err) - return - } else if len(newConfig.Writers) != 2 || newConfig.Writers[0].Type != "stdout" || newConfig.Writers[1].Type != "file" { - _, _ = fmt.Fprintln(os.Stderr, "Base log config is not in expected format") - return - } - err = helper.GetNode("logging").Decode(&llc) - if err != nil { - _, _ = fmt.Fprintln(os.Stderr, "Failed to decode legacy log config:", err) - return - } - if llc.RawPrintLevel != "" { - level, err := zerolog.ParseLevel(llc.RawPrintLevel) - if err != nil { - _, _ = fmt.Fprintln(os.Stderr, "Failed to parse minimum stdout log level:", err) - } else { - newConfig.Writers[0].MinLevel = &level - } - } - if llc.Directory != "" && llc.FileNameFormat != "" { - if llc.FileNameFormat == "{{.Date}}-{{.Index}}.log" { - llc.FileNameFormat = "bridge.log" - } else { - llc.FileNameFormat = strings.ReplaceAll(llc.FileNameFormat, "{{.Date}}", "") - llc.FileNameFormat = strings.ReplaceAll(llc.FileNameFormat, "{{.Index}}", "") - } - newConfig.Writers[1].Filename = filepath.Join(llc.Directory, llc.FileNameFormat) - } else if llc.FileNameFormat == "" { - newConfig.Writers = newConfig.Writers[0:1] - } - if llc.JSONStdout { - newConfig.Writers[0].TimeFormat = "" - newConfig.Writers[0].Format = "json" - } else if llc.TimestampFormat != "" { - newConfig.Writers[0].TimeFormat = llc.TimestampFormat - } - var updatedConfig yaml.Node - err = updatedConfig.Encode(&newConfig) - if err != nil { - _, _ = fmt.Fprintln(os.Stderr, "Failed to encode migrated log config:", err) - return - } - *helper.GetBaseNode("logging").Node = updatedConfig -} - -// Upgrader is a config upgrader that copies the default fields in the homeserver, appservice and logging blocks. -var Upgrader = up.SimpleUpgrader(doUpgrade) diff --git a/bridge/bridgeconfig/permissions.go b/bridge/bridgeconfig/permissions.go deleted file mode 100644 index 198e140e..00000000 --- a/bridge/bridgeconfig/permissions.go +++ /dev/null @@ -1,71 +0,0 @@ -// Copyright (c) 2023 Tulir Asokan -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. - -package bridgeconfig - -import ( - "strconv" - "strings" - - "maunium.net/go/mautrix/id" -) - -type PermissionConfig map[string]PermissionLevel - -type PermissionLevel int - -const ( - PermissionLevelBlock PermissionLevel = 0 - PermissionLevelRelay PermissionLevel = 5 - PermissionLevelUser PermissionLevel = 10 - PermissionLevelAdmin PermissionLevel = 100 -) - -var namesToLevels = map[string]PermissionLevel{ - "block": PermissionLevelBlock, - "relay": PermissionLevelRelay, - "user": PermissionLevelUser, - "admin": PermissionLevelAdmin, -} - -func RegisterPermissionLevel(name string, level PermissionLevel) { - namesToLevels[name] = level -} - -func (pc *PermissionConfig) UnmarshalYAML(unmarshal func(interface{}) error) error { - rawPC := make(map[string]string) - err := unmarshal(&rawPC) - if err != nil { - return err - } - - if *pc == nil { - *pc = make(map[string]PermissionLevel) - } - for key, value := range rawPC { - level, ok := namesToLevels[strings.ToLower(value)] - if ok { - (*pc)[key] = level - } else if val, err := strconv.Atoi(value); err == nil { - (*pc)[key] = PermissionLevel(val) - } else { - (*pc)[key] = PermissionLevelBlock - } - } - return nil -} - -func (pc PermissionConfig) Get(userID id.UserID) PermissionLevel { - if level, ok := pc[string(userID)]; ok { - return level - } else if level, ok = pc[userID.Homeserver()]; len(userID.Homeserver()) > 0 && ok { - return level - } else if level, ok = pc["*"]; ok { - return level - } else { - return PermissionLevelBlock - } -} diff --git a/bridge/bridgestate.go b/bridge/bridgestate.go deleted file mode 100644 index f9c3a3c6..00000000 --- a/bridge/bridgestate.go +++ /dev/null @@ -1,156 +0,0 @@ -// Copyright (c) 2023 Tulir Asokan -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. - -package bridge - -import ( - "context" - "runtime/debug" - "time" - - "github.com/rs/zerolog" - - "maunium.net/go/mautrix/appservice" - "maunium.net/go/mautrix/bridge/status" -) - -func (br *Bridge) SendBridgeState(ctx context.Context, state *status.BridgeState) error { - if br.Websocket { - // FIXME this doesn't account for multiple users - br.latestState = state - - return br.AS.SendWebsocket(&appservice.WebsocketRequest{ - Command: "bridge_status", - Data: state, - }) - } else if br.Config.Homeserver.StatusEndpoint != "" { - return state.SendHTTP(ctx, br.Config.Homeserver.StatusEndpoint, br.Config.AppService.ASToken) - } else { - return nil - } -} - -func (br *Bridge) SendGlobalBridgeState(state status.BridgeState) { - if len(br.Config.Homeserver.StatusEndpoint) == 0 && !br.Websocket { - return - } - - for { - ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) - if err := br.SendBridgeState(ctx, &state); err != nil { - br.ZLog.Warn().Err(err).Msg("Failed to update global bridge state") - cancel() - time.Sleep(5 * time.Second) - continue - } else { - br.ZLog.Debug().Interface("bridge_state", state).Msg("Sent new global bridge state") - cancel() - break - } - } -} - -type BridgeStateQueue struct { - prev *status.BridgeState - ch chan status.BridgeState - bridge *Bridge - user status.BridgeStateFiller -} - -func (br *Bridge) NewBridgeStateQueue(user status.BridgeStateFiller) *BridgeStateQueue { - if len(br.Config.Homeserver.StatusEndpoint) == 0 && !br.Websocket { - return nil - } - bsq := &BridgeStateQueue{ - ch: make(chan status.BridgeState, 10), - bridge: br, - user: user, - } - go bsq.loop() - return bsq -} - -func (bsq *BridgeStateQueue) loop() { - defer func() { - err := recover() - if err != nil { - bsq.bridge.ZLog.Error(). - Str(zerolog.ErrorStackFieldName, string(debug.Stack())). - Interface(zerolog.ErrorFieldName, err). - Msg("Panic in bridge state loop") - } - }() - for state := range bsq.ch { - bsq.immediateSendBridgeState(state) - } -} - -func (bsq *BridgeStateQueue) immediateSendBridgeState(state status.BridgeState) { - retryIn := 2 - for { - if bsq.prev != nil && bsq.prev.ShouldDeduplicate(&state) { - bsq.bridge.ZLog.Debug(). - Str("state_event", string(state.StateEvent)). - Msg("Not sending bridge state as it's a duplicate") - return - } - - ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) - err := bsq.bridge.SendBridgeState(ctx, &state) - cancel() - - if err != nil { - bsq.bridge.ZLog.Warn().Err(err). - Int("retry_in_seconds", retryIn). - Msg("Failed to update bridge state") - time.Sleep(time.Duration(retryIn) * time.Second) - retryIn *= 2 - if retryIn > 64 { - retryIn = 64 - } - } else { - bsq.prev = &state - bsq.bridge.ZLog.Debug(). - Interface("bridge_state", state). - Msg("Sent new bridge state") - return - } - } -} - -func (bsq *BridgeStateQueue) Send(state status.BridgeState) { - if bsq == nil { - return - } - - state = state.Fill(bsq.user) - - if len(bsq.ch) >= 8 { - bsq.bridge.ZLog.Warn().Msg("Bridge state queue is nearly full, discarding an item") - select { - case <-bsq.ch: - default: - } - } - select { - case bsq.ch <- state: - default: - bsq.bridge.ZLog.Error().Msg("Bridge state queue is full, dropped new state") - } -} - -func (bsq *BridgeStateQueue) GetPrev() status.BridgeState { - if bsq != nil && bsq.prev != nil { - return *bsq.prev - } - return status.BridgeState{} -} - -func (bsq *BridgeStateQueue) SetPrev(prev status.BridgeState) { - if bsq != nil { - bsq.prev = &prev - } -} diff --git a/bridge/commands/doublepuppet.go b/bridge/commands/doublepuppet.go deleted file mode 100644 index 3f074951..00000000 --- a/bridge/commands/doublepuppet.go +++ /dev/null @@ -1,83 +0,0 @@ -// Copyright (c) 2022 Tulir Asokan -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. - -package commands - -var CommandLoginMatrix = &FullHandler{ - Func: fnLoginMatrix, - Name: "login-matrix", - Help: HelpMeta{ - Section: HelpSectionAuth, - Description: "Enable double puppeting.", - Args: "<_access token_>", - }, - RequiresLogin: true, -} - -func fnLoginMatrix(ce *Event) { - if len(ce.Args) == 0 { - ce.Reply("**Usage:** `login-matrix `") - return - } - puppet := ce.User.GetIDoublePuppet() - if puppet == nil { - puppet = ce.User.GetIGhost() - if puppet == nil { - ce.Reply("Didn't get a ghost :(") - return - } - } - err := puppet.SwitchCustomMXID(ce.Args[0], ce.User.GetMXID()) - if err != nil { - ce.Reply("Failed to enable double puppeting: %v", err) - } else { - ce.Reply("Successfully switched puppet") - } -} - -var CommandPingMatrix = &FullHandler{ - Func: fnPingMatrix, - Name: "ping-matrix", - Help: HelpMeta{ - Section: HelpSectionAuth, - Description: "Ping the Matrix server with the double puppet.", - }, - RequiresLogin: true, -} - -func fnPingMatrix(ce *Event) { - puppet := ce.User.GetIDoublePuppet() - if puppet == nil || puppet.CustomIntent() == nil { - ce.Reply("You are not logged in with your Matrix account.") - return - } - resp, err := puppet.CustomIntent().Whoami(ce.Ctx) - if err != nil { - ce.Reply("Failed to validate Matrix login: %v", err) - } else { - ce.Reply("Confirmed valid access token for %s / %s", resp.UserID, resp.DeviceID) - } -} - -var CommandLogoutMatrix = &FullHandler{ - Func: fnLogoutMatrix, - Name: "logout-matrix", - Help: HelpMeta{ - Section: HelpSectionAuth, - Description: "Disable double puppeting.", - }, - RequiresLogin: true, -} - -func fnLogoutMatrix(ce *Event) { - puppet := ce.User.GetIDoublePuppet() - if puppet == nil || puppet.CustomIntent() == nil { - ce.Reply("You don't have double puppeting enabled.") - return - } - puppet.ClearCustomMXID() - ce.Reply("Successfully disabled double puppeting.") -} diff --git a/bridge/commands/event.go b/bridge/commands/event.go deleted file mode 100644 index 42b49b68..00000000 --- a/bridge/commands/event.go +++ /dev/null @@ -1,95 +0,0 @@ -// Copyright (c) 2023 Tulir Asokan -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. - -package commands - -import ( - "context" - "fmt" - "strings" - - "github.com/rs/zerolog" - - "maunium.net/go/mautrix" - "maunium.net/go/mautrix/appservice" - "maunium.net/go/mautrix/bridge" - "maunium.net/go/mautrix/event" - "maunium.net/go/mautrix/format" - "maunium.net/go/mautrix/id" -) - -// Event stores all data which might be used to handle commands -type Event struct { - Bot *appservice.IntentAPI - Bridge *bridge.Bridge - Portal bridge.Portal - Processor *Processor - Handler MinimalHandler - RoomID id.RoomID - EventID id.EventID - User bridge.User - Command string - Args []string - RawArgs string - ReplyTo id.EventID - Ctx context.Context - ZLog *zerolog.Logger -} - -// MainIntent returns the intent to use when replying to the command. -// -// It prefers the bridge bot, but falls back to the other user in DMs if the bridge bot is not present. -func (ce *Event) MainIntent() *appservice.IntentAPI { - intent := ce.Bot - if ce.Portal != nil && ce.Portal.IsPrivateChat() && !ce.Portal.IsEncrypted() { - intent = ce.Portal.MainIntent() - } - return intent -} - -// Reply sends a reply to command as notice, with optional string formatting and automatic $cmdprefix replacement. -func (ce *Event) Reply(msg string, args ...interface{}) { - msg = strings.ReplaceAll(msg, "$cmdprefix ", ce.Bridge.Config.Bridge.GetCommandPrefix()+" ") - if len(args) > 0 { - msg = fmt.Sprintf(msg, args...) - } - ce.ReplyAdvanced(msg, true, false) -} - -// ReplyAdvanced sends a reply to command as notice. It allows using HTML and disabling markdown, -// but doesn't have built-in string formatting. -func (ce *Event) ReplyAdvanced(msg string, allowMarkdown, allowHTML bool) { - content := format.RenderMarkdown(msg, allowMarkdown, allowHTML) - content.MsgType = event.MsgNotice - _, err := ce.MainIntent().SendMessageEvent(ce.Ctx, ce.RoomID, event.EventMessage, content) - if err != nil { - ce.ZLog.Error().Err(err).Msgf("Failed to reply to command") - } -} - -// React sends a reaction to the command. -func (ce *Event) React(key string) { - _, err := ce.MainIntent().SendReaction(ce.Ctx, ce.RoomID, ce.EventID, key) - if err != nil { - ce.ZLog.Error().Err(err).Msgf("Failed to react to command") - } -} - -// Redact redacts the command. -func (ce *Event) Redact(req ...mautrix.ReqRedact) { - _, err := ce.MainIntent().RedactEvent(ce.Ctx, ce.RoomID, ce.EventID, req...) - if err != nil { - ce.ZLog.Error().Err(err).Msgf("Failed to redact command") - } -} - -// MarkRead marks the command event as read. -func (ce *Event) MarkRead() { - err := ce.MainIntent().SendReceipt(ce.Ctx, ce.RoomID, ce.EventID, event.ReceiptTypeRead, nil) - if err != nil { - ce.ZLog.Error().Err(err).Msgf("Failed to mark command as read") - } -} diff --git a/bridge/commands/handler.go b/bridge/commands/handler.go deleted file mode 100644 index ab6899c0..00000000 --- a/bridge/commands/handler.go +++ /dev/null @@ -1,100 +0,0 @@ -// Copyright (c) 2022 Tulir Asokan -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. - -package commands - -import ( - "maunium.net/go/mautrix/bridge" - "maunium.net/go/mautrix/bridge/bridgeconfig" - "maunium.net/go/mautrix/event" -) - -type MinimalHandler interface { - Run(*Event) -} - -type MinimalHandlerFunc func(*Event) - -func (mhf MinimalHandlerFunc) Run(ce *Event) { - mhf(ce) -} - -type CommandState struct { - Next MinimalHandler - Action string - Meta interface{} -} - -type CommandingUser interface { - bridge.User - GetCommandState() *CommandState - SetCommandState(*CommandState) -} - -type Handler interface { - MinimalHandler - GetName() string -} - -type AliasedHandler interface { - Handler - GetAliases() []string -} - -type FullHandler struct { - Func func(*Event) - - Name string - Aliases []string - Help HelpMeta - - RequiresAdmin bool - RequiresPortal bool - RequiresLogin bool - - RequiresEventLevel event.Type -} - -func (fh *FullHandler) GetHelp() HelpMeta { - fh.Help.Command = fh.Name - return fh.Help -} - -func (fh *FullHandler) GetName() string { - return fh.Name -} - -func (fh *FullHandler) GetAliases() []string { - return fh.Aliases -} - -func (fh *FullHandler) ShowInHelp(ce *Event) bool { - return !fh.RequiresAdmin || ce.User.GetPermissionLevel() >= bridgeconfig.PermissionLevelAdmin -} - -func (fh *FullHandler) userHasRoomPermission(ce *Event) bool { - levels, err := ce.MainIntent().PowerLevels(ce.Ctx, ce.RoomID) - if err != nil { - ce.ZLog.Warn().Err(err).Msg("Failed to check room power levels") - ce.Reply("Failed to get room power levels to see if you're allowed to use that command") - return false - } - return levels.GetUserLevel(ce.User.GetMXID()) >= levels.GetEventLevel(fh.RequiresEventLevel) -} - -func (fh *FullHandler) Run(ce *Event) { - if fh.RequiresAdmin && ce.User.GetPermissionLevel() < bridgeconfig.PermissionLevelAdmin { - ce.Reply("That command is limited to bridge administrators.") - } else if fh.RequiresEventLevel.Type != "" && ce.User.GetPermissionLevel() < bridgeconfig.PermissionLevelAdmin && !fh.userHasRoomPermission(ce) { - ce.Reply("That command requires room admin rights.") - } else if fh.RequiresPortal && ce.Portal == nil { - ce.Reply("That command can only be ran in portal rooms.") - } else if fh.RequiresLogin && !ce.User.IsLoggedIn() { - ce.Reply("That command requires you to be logged in.") - } else { - fh.Func(ce) - } -} diff --git a/bridge/commands/meta.go b/bridge/commands/meta.go deleted file mode 100644 index 615f6a34..00000000 --- a/bridge/commands/meta.go +++ /dev/null @@ -1,56 +0,0 @@ -// Copyright (c) 2022 Tulir Asokan -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. - -package commands - -var CommandHelp = &FullHandler{ - Func: func(ce *Event) { - ce.Reply(FormatHelp(ce)) - }, - Name: "help", - Help: HelpMeta{ - Section: HelpSectionGeneral, - Description: "Show this help message.", - }, -} - -var CommandVersion = &FullHandler{ - Func: func(ce *Event) { - ce.Reply("[%s](%s) %s (%s)", ce.Bridge.Name, ce.Bridge.URL, ce.Bridge.LinkifiedVersion, ce.Bridge.BuildTime) - }, - Name: "version", - Help: HelpMeta{ - Section: HelpSectionGeneral, - Description: "Get the bridge version.", - }, -} - -var CommandCancel = &FullHandler{ - Func: func(ce *Event) { - commandingUser, ok := ce.User.(CommandingUser) - if !ok { - ce.Reply("This bridge does not implement cancelable commands") - return - } - state := commandingUser.GetCommandState() - - if state != nil { - action := state.Action - if action == "" { - action = "Unknown action" - } - commandingUser.SetCommandState(nil) - ce.Reply("%s cancelled.", action) - } else { - ce.Reply("No ongoing command.") - } - }, - Name: "cancel", - Help: HelpMeta{ - Section: HelpSectionGeneral, - Description: "Cancel an ongoing action.", - }, -} diff --git a/bridge/commands/processor.go b/bridge/commands/processor.go deleted file mode 100644 index 6158a7cd..00000000 --- a/bridge/commands/processor.go +++ /dev/null @@ -1,122 +0,0 @@ -// Copyright (c) 2023 Tulir Asokan -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. - -package commands - -import ( - "context" - "runtime/debug" - "strings" - - "github.com/rs/zerolog" - - "maunium.net/go/mautrix/bridge" - "maunium.net/go/mautrix/id" -) - -type Processor struct { - bridge *bridge.Bridge - log *zerolog.Logger - - handlers map[string]Handler - aliases map[string]string -} - -// NewProcessor creates a Processor -func NewProcessor(bridge *bridge.Bridge) *Processor { - proc := &Processor{ - bridge: bridge, - log: bridge.ZLog, - - handlers: make(map[string]Handler), - aliases: make(map[string]string), - } - proc.AddHandlers( - CommandHelp, CommandVersion, CommandCancel, - CommandLoginMatrix, CommandLogoutMatrix, CommandPingMatrix, - CommandDiscardMegolmSession, CommandSetPowerLevel) - return proc -} - -func (proc *Processor) AddHandlers(handlers ...Handler) { - for _, handler := range handlers { - proc.AddHandler(handler) - } -} - -func (proc *Processor) AddHandler(handler Handler) { - proc.handlers[handler.GetName()] = handler - aliased, ok := handler.(AliasedHandler) - if ok { - for _, alias := range aliased.GetAliases() { - proc.aliases[alias] = handler.GetName() - } - } -} - -// Handle handles messages to the bridge -func (proc *Processor) Handle(ctx context.Context, roomID id.RoomID, eventID id.EventID, user bridge.User, message string, replyTo id.EventID) { - defer func() { - err := recover() - if err != nil { - zerolog.Ctx(ctx).Error(). - Str(zerolog.ErrorStackFieldName, string(debug.Stack())). - Interface(zerolog.ErrorFieldName, err). - Msg("Panic in Matrix command handler") - } - }() - args := strings.Fields(message) - if len(args) == 0 { - args = []string{"unknown-command"} - } - command := strings.ToLower(args[0]) - rawArgs := strings.TrimLeft(strings.TrimPrefix(message, command), " ") - log := zerolog.Ctx(ctx).With().Str("mx_command", command).Logger() - ctx = log.WithContext(ctx) - ce := &Event{ - Bot: proc.bridge.Bot, - Bridge: proc.bridge, - Portal: proc.bridge.Child.GetIPortal(roomID), - Processor: proc, - RoomID: roomID, - EventID: eventID, - User: user, - Command: command, - Args: args[1:], - RawArgs: rawArgs, - ReplyTo: replyTo, - Ctx: ctx, - ZLog: &log, - } - log.Debug().Msg("Received command") - - realCommand, ok := proc.aliases[ce.Command] - if !ok { - realCommand = ce.Command - } - commandingUser, ok := ce.User.(CommandingUser) - - var handler MinimalHandler - handler, ok = proc.handlers[realCommand] - if !ok { - var state *CommandState - if commandingUser != nil { - state = commandingUser.GetCommandState() - } - if state != nil && state.Next != nil { - ce.Command = "" - ce.RawArgs = message - ce.Args = args - ce.Handler = state.Next - state.Next.Run(ce) - } else { - ce.Reply("Unknown command, use the `help` command for help.") - } - } else { - ce.Handler = handler - handler.Run(ce) - } -} diff --git a/bridge/doublepuppet.go b/bridge/doublepuppet.go deleted file mode 100644 index 265d3d5c..00000000 --- a/bridge/doublepuppet.go +++ /dev/null @@ -1,173 +0,0 @@ -// Copyright (c) 2023 Tulir Asokan -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. - -package bridge - -import ( - "context" - "crypto/hmac" - "crypto/sha512" - "encoding/hex" - "errors" - "fmt" - "strings" - - "github.com/rs/zerolog" - - "maunium.net/go/mautrix" - "maunium.net/go/mautrix/appservice" - "maunium.net/go/mautrix/id" -) - -type doublePuppetUtil struct { - br *Bridge - log zerolog.Logger -} - -func (dp *doublePuppetUtil) newClient(ctx context.Context, mxid id.UserID, accessToken string) (*mautrix.Client, error) { - _, homeserver, err := mxid.Parse() - if err != nil { - return nil, err - } - homeserverURL, found := dp.br.Config.Bridge.GetDoublePuppetConfig().ServerMap[homeserver] - if !found { - if homeserver == dp.br.AS.HomeserverDomain { - homeserverURL = "" - } else if dp.br.Config.Bridge.GetDoublePuppetConfig().AllowDiscovery { - resp, err := mautrix.DiscoverClientAPI(ctx, homeserver) - if err != nil { - return nil, fmt.Errorf("failed to find homeserver URL for %s: %v", homeserver, err) - } - homeserverURL = resp.Homeserver.BaseURL - dp.log.Debug(). - Str("homeserver", homeserver). - Str("url", homeserverURL). - Str("user_id", mxid.String()). - Msg("Discovered URL to enable double puppeting for user") - } else { - return nil, fmt.Errorf("double puppeting from %s is not allowed", homeserver) - } - } - return dp.br.AS.NewExternalMautrixClient(mxid, accessToken, homeserverURL) -} - -func (dp *doublePuppetUtil) newIntent(ctx context.Context, mxid id.UserID, accessToken string) (*appservice.IntentAPI, error) { - client, err := dp.newClient(ctx, mxid, accessToken) - if err != nil { - return nil, err - } - - ia := dp.br.AS.NewIntentAPI("custom") - ia.Client = client - ia.Localpart, _, _ = mxid.Parse() - ia.UserID = mxid - ia.IsCustomPuppet = true - return ia, nil -} - -func (dp *doublePuppetUtil) autoLogin(ctx context.Context, mxid id.UserID, loginSecret string) (string, error) { - dp.log.Debug().Str("user_id", mxid.String()).Msg("Logging into user account with shared secret") - client, err := dp.newClient(ctx, mxid, "") - if err != nil { - return "", fmt.Errorf("failed to create mautrix client to log in: %v", err) - } - bridgeName := fmt.Sprintf("%s Bridge", dp.br.ProtocolName) - req := mautrix.ReqLogin{ - Identifier: mautrix.UserIdentifier{Type: mautrix.IdentifierTypeUser, User: string(mxid)}, - DeviceID: id.DeviceID(bridgeName), - InitialDeviceDisplayName: bridgeName, - } - if loginSecret == "appservice" { - client.AccessToken = dp.br.AS.Registration.AppToken - req.Type = mautrix.AuthTypeAppservice - } else { - loginFlows, err := client.GetLoginFlows(ctx) - if err != nil { - return "", fmt.Errorf("failed to get supported login flows: %w", err) - } - mac := hmac.New(sha512.New, []byte(loginSecret)) - mac.Write([]byte(mxid)) - token := hex.EncodeToString(mac.Sum(nil)) - switch { - case loginFlows.HasFlow(mautrix.AuthTypeDevtureSharedSecret): - req.Type = mautrix.AuthTypeDevtureSharedSecret - req.Token = token - case loginFlows.HasFlow(mautrix.AuthTypePassword): - req.Type = mautrix.AuthTypePassword - req.Password = token - default: - return "", fmt.Errorf("no supported auth types for shared secret auth found") - } - } - resp, err := client.Login(ctx, &req) - if err != nil { - return "", err - } - return resp.AccessToken, nil -} - -var ( - ErrMismatchingMXID = errors.New("whoami result does not match custom mxid") - ErrNoAccessToken = errors.New("no access token provided") - ErrNoMXID = errors.New("no mxid provided") -) - -const useConfigASToken = "appservice-config" -const asTokenModePrefix = "as_token:" - -func (dp *doublePuppetUtil) Setup(ctx context.Context, mxid id.UserID, savedAccessToken string, reloginOnFail bool) (intent *appservice.IntentAPI, newAccessToken string, err error) { - if len(mxid) == 0 { - err = ErrNoMXID - return - } - _, homeserver, _ := mxid.Parse() - loginSecret, hasSecret := dp.br.Config.Bridge.GetDoublePuppetConfig().SharedSecretMap[homeserver] - // Special case appservice: prefix to not login and use it as an as_token directly. - if hasSecret && strings.HasPrefix(loginSecret, asTokenModePrefix) { - intent, err = dp.newIntent(ctx, mxid, strings.TrimPrefix(loginSecret, asTokenModePrefix)) - if err != nil { - return - } - intent.SetAppServiceUserID = true - if savedAccessToken != useConfigASToken { - var resp *mautrix.RespWhoami - resp, err = intent.Whoami(ctx) - if err == nil && resp.UserID != mxid { - err = ErrMismatchingMXID - } - } - return intent, useConfigASToken, err - } - if savedAccessToken == "" || savedAccessToken == useConfigASToken { - if reloginOnFail && hasSecret { - savedAccessToken, err = dp.autoLogin(ctx, mxid, loginSecret) - } else { - err = ErrNoAccessToken - } - if err != nil { - return - } - } - intent, err = dp.newIntent(ctx, mxid, savedAccessToken) - if err != nil { - return - } - var resp *mautrix.RespWhoami - resp, err = intent.Whoami(ctx) - if err != nil { - if reloginOnFail && hasSecret && errors.Is(err, mautrix.MUnknownToken) { - intent.AccessToken, err = dp.autoLogin(ctx, mxid, loginSecret) - if err == nil { - newAccessToken = intent.AccessToken - } - } - } else if resp.UserID != mxid { - err = ErrMismatchingMXID - } else { - newAccessToken = savedAccessToken - } - return -} diff --git a/bridge/matrix.go b/bridge/matrix.go deleted file mode 100644 index b49279aa..00000000 --- a/bridge/matrix.go +++ /dev/null @@ -1,704 +0,0 @@ -// Copyright (c) 2023 Tulir Asokan -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. - -package bridge - -import ( - "context" - "errors" - "fmt" - "strings" - "time" - - "github.com/rs/zerolog" - - "maunium.net/go/mautrix" - "maunium.net/go/mautrix/appservice" - "maunium.net/go/mautrix/bridge/bridgeconfig" - "maunium.net/go/mautrix/bridge/status" - "maunium.net/go/mautrix/event" - "maunium.net/go/mautrix/format" - "maunium.net/go/mautrix/id" -) - -type CommandProcessor interface { - Handle(ctx context.Context, roomID id.RoomID, eventID id.EventID, user User, message string, replyTo id.EventID) -} - -type MatrixHandler struct { - bridge *Bridge - as *appservice.AppService - log *zerolog.Logger - - TrackEventDuration func(event.Type) func() -} - -func noop() {} - -func noopTrack(_ event.Type) func() { - return noop -} - -func NewMatrixHandler(br *Bridge) *MatrixHandler { - handler := &MatrixHandler{ - bridge: br, - as: br.AS, - log: br.ZLog, - - TrackEventDuration: noopTrack, - } - for evtType := range status.CheckpointTypes { - br.EventProcessor.On(evtType, handler.sendBridgeCheckpoint) - } - br.EventProcessor.On(event.EventMessage, handler.HandleMessage) - br.EventProcessor.On(event.EventEncrypted, handler.HandleEncrypted) - br.EventProcessor.On(event.EventSticker, handler.HandleMessage) - br.EventProcessor.On(event.EventReaction, handler.HandleReaction) - br.EventProcessor.On(event.EventRedaction, handler.HandleRedaction) - br.EventProcessor.On(event.StateMember, handler.HandleMembership) - br.EventProcessor.On(event.StateRoomName, handler.HandleRoomMetadata) - br.EventProcessor.On(event.StateRoomAvatar, handler.HandleRoomMetadata) - br.EventProcessor.On(event.StateTopic, handler.HandleRoomMetadata) - br.EventProcessor.On(event.StateEncryption, handler.HandleEncryption) - br.EventProcessor.On(event.EphemeralEventReceipt, handler.HandleReceipt) - br.EventProcessor.On(event.EphemeralEventTyping, handler.HandleTyping) - br.EventProcessor.On(event.StatePowerLevels, handler.HandlePowerLevels) - return handler -} - -func (mx *MatrixHandler) sendBridgeCheckpoint(_ context.Context, evt *event.Event) { - if !evt.Mautrix.CheckpointSent { - go mx.bridge.SendMessageSuccessCheckpoint(evt, status.MsgStepBridge, 0) - } -} - -func (mx *MatrixHandler) HandleEncryption(ctx context.Context, evt *event.Event) { - defer mx.TrackEventDuration(evt.Type)() - if evt.Content.AsEncryption().Algorithm != id.AlgorithmMegolmV1 { - return - } - portal := mx.bridge.Child.GetIPortal(evt.RoomID) - if portal != nil && !portal.IsEncrypted() { - mx.log.Debug(). - Str("user_id", evt.Sender.String()). - Str("room_id", evt.RoomID.String()). - Msg("Encryption was enabled in room") - portal.MarkEncrypted() - if portal.IsPrivateChat() { - err := mx.as.BotIntent().EnsureJoined(ctx, evt.RoomID, appservice.EnsureJoinedParams{BotOverride: portal.MainIntent().Client}) - if err != nil { - mx.log.Err(err). - Str("room_id", evt.RoomID.String()). - Msg("Failed to join bot to room after encryption was enabled") - } - } - } -} - -func (mx *MatrixHandler) joinAndCheckMembers(ctx context.Context, evt *event.Event, intent *appservice.IntentAPI) *mautrix.RespJoinedMembers { - log := zerolog.Ctx(ctx) - resp, err := intent.JoinRoomByID(ctx, evt.RoomID) - if err != nil { - log.Warn().Err(err).Msg("Failed to join room with invite") - return nil - } - - members, err := intent.JoinedMembers(ctx, resp.RoomID) - if err != nil { - log.Warn().Err(err).Msg("Failed to get members in room after accepting invite, leaving room") - _, _ = intent.LeaveRoom(ctx, resp.RoomID) - return nil - } - - if len(members.Joined) < 2 { - log.Debug().Msg("Leaving empty room after accepting invite") - _, _ = intent.LeaveRoom(ctx, resp.RoomID) - return nil - } - return members -} - -func (mx *MatrixHandler) sendNoticeWithMarkdown(ctx context.Context, roomID id.RoomID, message string) (*mautrix.RespSendEvent, error) { - intent := mx.as.BotIntent() - content := format.RenderMarkdown(message, true, false) - content.MsgType = event.MsgNotice - return intent.SendMessageEvent(ctx, roomID, event.EventMessage, content) -} - -func (mx *MatrixHandler) HandleBotInvite(ctx context.Context, evt *event.Event) { - intent := mx.as.BotIntent() - - user := mx.bridge.Child.GetIUser(evt.Sender, true) - if user == nil { - return - } - - members := mx.joinAndCheckMembers(ctx, evt, intent) - if members == nil { - return - } - - if user.GetPermissionLevel() < bridgeconfig.PermissionLevelUser { - _, _ = intent.SendNotice(ctx, evt.RoomID, "You are not whitelisted to use this bridge.\n"+ - "If you're the owner of this bridge, see the bridge.permissions section in your config file.") - _, _ = intent.LeaveRoom(ctx, evt.RoomID) - return - } - - texts := mx.bridge.Config.Bridge.GetManagementRoomTexts() - _, _ = mx.sendNoticeWithMarkdown(ctx, evt.RoomID, texts.Welcome) - - if len(members.Joined) == 2 && (len(user.GetManagementRoomID()) == 0 || evt.Content.AsMember().IsDirect) { - user.SetManagementRoom(evt.RoomID) - _, _ = intent.SendNotice(ctx, user.GetManagementRoomID(), "This room has been registered as your bridge management/status room.") - zerolog.Ctx(ctx).Debug().Msg("Registered room as management room with inviter") - } - - if evt.RoomID == user.GetManagementRoomID() { - if user.IsLoggedIn() { - _, _ = mx.sendNoticeWithMarkdown(ctx, evt.RoomID, texts.WelcomeConnected) - } else { - _, _ = mx.sendNoticeWithMarkdown(ctx, evt.RoomID, texts.WelcomeUnconnected) - } - - additionalHelp := texts.AdditionalHelp - if len(additionalHelp) > 0 { - _, _ = mx.sendNoticeWithMarkdown(ctx, evt.RoomID, additionalHelp) - } - } -} - -func (mx *MatrixHandler) HandleGhostInvite(ctx context.Context, evt *event.Event, inviter User, ghost Ghost) { - log := zerolog.Ctx(ctx) - intent := ghost.DefaultIntent() - - if inviter.GetPermissionLevel() < bridgeconfig.PermissionLevelUser { - log.Debug().Msg("Rejecting invite: inviter is not whitelisted") - _, err := intent.LeaveRoom(ctx, evt.RoomID, &mautrix.ReqLeave{ - Reason: "You're not whitelisted to use this bridge", - }) - if err != nil { - log.Error().Err(err).Msg("Failed to reject invite") - } - return - } else if !inviter.IsLoggedIn() { - log.Debug().Msg("Rejecting invite: inviter is not logged in") - _, err := intent.LeaveRoom(ctx, evt.RoomID, &mautrix.ReqLeave{ - Reason: "You're not logged into this bridge", - }) - if err != nil { - log.Error().Err(err).Msg("Failed to reject invite") - } - return - } - - members := mx.joinAndCheckMembers(ctx, evt, intent) - if members == nil { - return - } - var createEvent event.CreateEventContent - if err := intent.StateEvent(ctx, evt.RoomID, event.StateCreate, "", &createEvent); err != nil { - log.Warn().Err(err).Msg("Failed to check m.room.create event in room") - } else if createEvent.Type != "" { - log.Warn().Str("room_type", string(createEvent.Type)).Msg("Non-standard room type, leaving room") - _, err = intent.LeaveRoom(ctx, evt.RoomID, &mautrix.ReqLeave{ - Reason: "Unsupported room type", - }) - if err != nil { - log.Error().Err(err).Msg("Failed to leave room") - } - return - } - var hasBridgeBot, hasOtherUsers bool - for mxid, _ := range members.Joined { - if mxid == intent.UserID || mxid == inviter.GetMXID() { - continue - } else if mxid == mx.bridge.Bot.UserID { - hasBridgeBot = true - } else { - hasOtherUsers = true - } - } - if !hasBridgeBot && !hasOtherUsers && evt.Content.AsMember().IsDirect { - mx.bridge.Child.CreatePrivatePortal(evt.RoomID, inviter, ghost) - } else if !hasBridgeBot { - log.Debug().Msg("Leaving multi-user room after accepting invite") - _, _ = intent.SendNotice(ctx, evt.RoomID, "Please invite the bridge bot first if you want to bridge to a remote chat.") - _, _ = intent.LeaveRoom(ctx, evt.RoomID) - } else { - _, _ = intent.SendNotice(ctx, evt.RoomID, "This puppet will remain inactive until this room is bridged to a remote chat.") - } -} - -func (mx *MatrixHandler) HandleMembership(ctx context.Context, evt *event.Event) { - if evt.Sender == mx.bridge.Bot.UserID || mx.bridge.Child.IsGhost(evt.Sender) { - return - } - defer mx.TrackEventDuration(evt.Type)() - - if mx.bridge.Crypto != nil { - mx.bridge.Crypto.HandleMemberEvent(ctx, evt) - } - - log := mx.log.With(). - Str("sender", evt.Sender.String()). - Str("target", evt.GetStateKey()). - Str("room_id", evt.RoomID.String()). - Logger() - ctx = log.WithContext(ctx) - - content := evt.Content.AsMember() - if content.Membership == event.MembershipInvite && id.UserID(evt.GetStateKey()) == mx.as.BotMXID() { - mx.HandleBotInvite(ctx, evt) - return - } - - if mx.shouldIgnoreEvent(evt) { - return - } - - user := mx.bridge.Child.GetIUser(evt.Sender, true) - if user == nil { - return - } - isSelf := id.UserID(evt.GetStateKey()) == evt.Sender - ghost := mx.bridge.Child.GetIGhost(id.UserID(evt.GetStateKey())) - portal := mx.bridge.Child.GetIPortal(evt.RoomID) - if portal == nil { - if ghost != nil && content.Membership == event.MembershipInvite { - mx.HandleGhostInvite(ctx, evt, user, ghost) - } - return - } else if user.GetPermissionLevel() < bridgeconfig.PermissionLevelUser || !user.IsLoggedIn() { - return - } - - mhp, ok := portal.(MembershipHandlingPortal) - if !ok { - return - } - - if content.Membership == event.MembershipLeave { - if evt.Unsigned.PrevContent != nil { - _ = evt.Unsigned.PrevContent.ParseRaw(evt.Type) - prevContent, ok := evt.Unsigned.PrevContent.Parsed.(*event.MemberEventContent) - if ok && prevContent.Membership != "join" { - return - } - } - if isSelf { - mhp.HandleMatrixLeave(user, evt) - } else if ghost != nil { - mhp.HandleMatrixKick(user, ghost, evt) - } - } else if content.Membership == event.MembershipInvite && !isSelf && ghost != nil { - mhp.HandleMatrixInvite(user, ghost, evt) - } - // TODO kicking/inviting non-ghost users users -} - -func (mx *MatrixHandler) HandleRoomMetadata(ctx context.Context, evt *event.Event) { - defer mx.TrackEventDuration(evt.Type)() - if mx.shouldIgnoreEvent(evt) { - return - } - - user := mx.bridge.Child.GetIUser(evt.Sender, true) - if user == nil { - return - } - - portal := mx.bridge.Child.GetIPortal(evt.RoomID) - if portal == nil || portal.IsPrivateChat() { - return - } - - metaPortal, ok := portal.(MetaHandlingPortal) - if !ok { - return - } - - metaPortal.HandleMatrixMeta(user, evt) -} - -func (mx *MatrixHandler) shouldIgnoreEvent(evt *event.Event) bool { - if evt.Sender == mx.bridge.Bot.UserID || mx.bridge.Child.IsGhost(evt.Sender) { - return true - } - user := mx.bridge.Child.GetIUser(evt.Sender, true) - if user == nil || user.GetPermissionLevel() <= 0 { - return true - } else if val, ok := evt.Content.Raw[appservice.DoublePuppetKey]; ok && val == mx.bridge.Name && user.GetIDoublePuppet() != nil { - return true - } - return false -} - -const initialSessionWaitTimeout = 3 * time.Second -const extendedSessionWaitTimeout = 22 * time.Second - -func (mx *MatrixHandler) sendCryptoStatusError(ctx context.Context, evt *event.Event, editEvent id.EventID, err error, retryCount int, isFinal bool) id.EventID { - mx.bridge.SendMessageErrorCheckpoint(evt, status.MsgStepDecrypted, err, isFinal, retryCount) - - if mx.bridge.Config.Bridge.EnableMessageStatusEvents() { - statusEvent := &event.BeeperMessageStatusEventContent{ - // TODO: network - RelatesTo: event.RelatesTo{ - Type: event.RelReference, - EventID: evt.ID, - }, - Status: event.MessageStatusRetriable, - Reason: event.MessageStatusUndecryptable, - Error: err.Error(), - Message: errorToHumanMessage(err), - } - if !isFinal { - statusEvent.Status = event.MessageStatusPending - } - _, sendErr := mx.bridge.Bot.SendMessageEvent(ctx, evt.RoomID, event.BeeperMessageStatus, statusEvent) - if sendErr != nil { - zerolog.Ctx(ctx).Error().Err(err).Msg("Failed to send message status event") - } - } - if mx.bridge.Config.Bridge.EnableMessageErrorNotices() { - update := event.MessageEventContent{ - MsgType: event.MsgNotice, - Body: fmt.Sprintf("\u26a0 Your message was not bridged: %v.", err), - } - if errors.Is(err, errNoCrypto) { - update.Body = "🔒 This bridge has not been configured to support encryption" - } - relatable, ok := evt.Content.Parsed.(event.Relatable) - if editEvent != "" { - update.SetEdit(editEvent) - } else if ok && relatable.OptionalGetRelatesTo().GetThreadParent() != "" { - update.GetRelatesTo().SetThread(relatable.OptionalGetRelatesTo().GetThreadParent(), evt.ID) - } - resp, sendErr := mx.bridge.Bot.SendMessageEvent(ctx, evt.RoomID, event.EventMessage, &update) - if sendErr != nil { - zerolog.Ctx(ctx).Error().Err(sendErr).Msg("Failed to send decryption error notice") - } else if resp != nil { - return resp.EventID - } - } - return "" -} - -var ( - errDeviceNotTrusted = errors.New("your device is not trusted") - errMessageNotEncrypted = errors.New("unencrypted message") - errNoDecryptionKeys = errors.New("the bridge hasn't received the decryption keys") - errNoCrypto = errors.New("this bridge has not been configured to support encryption") -) - -func errorToHumanMessage(err error) string { - var withheld *event.RoomKeyWithheldEventContent - switch { - case errors.Is(err, errDeviceNotTrusted), errors.Is(err, errNoDecryptionKeys): - return err.Error() - case errors.Is(err, UnknownMessageIndex): - return "the keys received by the bridge can't decrypt the message" - case errors.Is(err, DuplicateMessageIndex): - return "your client encrypted multiple messages with the same key" - case errors.As(err, &withheld): - if withheld.Code == event.RoomKeyWithheldBeeperRedacted { - return "your client used an outdated encryption session" - } - return "your client refused to share decryption keys with the bridge" - case errors.Is(err, errMessageNotEncrypted): - return "the message is not encrypted" - default: - return "the bridge failed to decrypt the message" - } -} - -func deviceUnverifiedErrorWithExplanation(trust id.TrustState) error { - var explanation string - switch trust { - case id.TrustStateBlacklisted: - explanation = "device is blacklisted" - case id.TrustStateUnset: - explanation = "unverified" - case id.TrustStateUnknownDevice: - explanation = "device info not found" - case id.TrustStateForwarded: - explanation = "keys were forwarded from an unknown device" - case id.TrustStateCrossSignedUntrusted: - explanation = "cross-signing keys changed after setting up the bridge" - default: - return errDeviceNotTrusted - } - return fmt.Errorf("%w (%s)", errDeviceNotTrusted, explanation) -} - -func copySomeKeys(original, decrypted *event.Event) { - isScheduled, _ := original.Content.Raw["com.beeper.scheduled"].(bool) - _, alreadyExists := decrypted.Content.Raw["com.beeper.scheduled"] - if isScheduled && !alreadyExists { - decrypted.Content.Raw["com.beeper.scheduled"] = true - } -} - -func (mx *MatrixHandler) postDecrypt(ctx context.Context, original, decrypted *event.Event, retryCount int, errorEventID id.EventID, duration time.Duration) { - log := zerolog.Ctx(ctx) - minLevel := mx.bridge.Config.Bridge.GetEncryptionConfig().VerificationLevels.Send - if decrypted.Mautrix.TrustState < minLevel { - logEvt := log.Warn(). - Str("user_id", decrypted.Sender.String()). - Bool("forwarded_keys", decrypted.Mautrix.ForwardedKeys). - Stringer("device_trust", decrypted.Mautrix.TrustState). - Stringer("min_trust", minLevel) - if decrypted.Mautrix.TrustSource != nil { - dev := decrypted.Mautrix.TrustSource - logEvt. - Str("device_id", dev.DeviceID.String()). - Str("device_signing_key", dev.SigningKey.String()) - } else { - logEvt.Str("device_id", "unknown") - } - logEvt.Msg("Dropping event due to insufficient verification level") - err := deviceUnverifiedErrorWithExplanation(decrypted.Mautrix.TrustState) - go mx.sendCryptoStatusError(ctx, decrypted, errorEventID, err, retryCount, true) - return - } - copySomeKeys(original, decrypted) - - mx.bridge.SendMessageSuccessCheckpoint(decrypted, status.MsgStepDecrypted, retryCount) - decrypted.Mautrix.CheckpointSent = true - decrypted.Mautrix.DecryptionDuration = duration - decrypted.Mautrix.EventSource |= event.SourceDecrypted - mx.bridge.EventProcessor.Dispatch(ctx, decrypted) - if errorEventID != "" { - _, _ = mx.bridge.Bot.RedactEvent(ctx, decrypted.RoomID, errorEventID) - } -} - -func (mx *MatrixHandler) HandleEncrypted(ctx context.Context, evt *event.Event) { - defer mx.TrackEventDuration(evt.Type)() - if mx.shouldIgnoreEvent(evt) { - return - } - content := evt.Content.AsEncrypted() - log := zerolog.Ctx(ctx).With(). - Str("event_id", evt.ID.String()). - Str("session_id", content.SessionID.String()). - Logger() - ctx = log.WithContext(ctx) - if mx.bridge.Crypto == nil { - go mx.sendCryptoStatusError(ctx, evt, "", errNoCrypto, 0, true) - return - } - log.Debug().Msg("Decrypting received event") - - decryptionStart := time.Now() - decrypted, err := mx.bridge.Crypto.Decrypt(ctx, evt) - decryptionRetryCount := 0 - if errors.Is(err, NoSessionFound) { - decryptionRetryCount = 1 - log.Debug(). - Int("wait_seconds", int(initialSessionWaitTimeout.Seconds())). - Msg("Couldn't find session, waiting for keys to arrive...") - mx.bridge.SendMessageErrorCheckpoint(evt, status.MsgStepDecrypted, err, false, 0) - if mx.bridge.Crypto.WaitForSession(ctx, evt.RoomID, content.SenderKey, content.SessionID, initialSessionWaitTimeout) { - log.Debug().Msg("Got keys after waiting, trying to decrypt event again") - decrypted, err = mx.bridge.Crypto.Decrypt(ctx, evt) - } else { - go mx.waitLongerForSession(ctx, evt, decryptionStart) - return - } - } - if err != nil { - mx.bridge.SendMessageErrorCheckpoint(evt, status.MsgStepDecrypted, err, true, decryptionRetryCount) - log.Warn().Err(err).Msg("Failed to decrypt event") - go mx.sendCryptoStatusError(ctx, evt, "", err, decryptionRetryCount, true) - return - } - mx.postDecrypt(ctx, evt, decrypted, decryptionRetryCount, "", time.Since(decryptionStart)) -} - -func (mx *MatrixHandler) waitLongerForSession(ctx context.Context, evt *event.Event, decryptionStart time.Time) { - log := zerolog.Ctx(ctx) - content := evt.Content.AsEncrypted() - log.Debug(). - Int("wait_seconds", int(extendedSessionWaitTimeout.Seconds())). - Msg("Couldn't find session, requesting keys and waiting longer...") - - go mx.bridge.Crypto.RequestSession(ctx, evt.RoomID, content.SenderKey, content.SessionID, evt.Sender, content.DeviceID) - errorEventID := mx.sendCryptoStatusError(ctx, evt, "", fmt.Errorf("%w. The bridge will retry for %d seconds", errNoDecryptionKeys, int(extendedSessionWaitTimeout.Seconds())), 1, false) - - if !mx.bridge.Crypto.WaitForSession(ctx, evt.RoomID, content.SenderKey, content.SessionID, extendedSessionWaitTimeout) { - log.Debug().Msg("Didn't get session, giving up trying to decrypt event") - mx.sendCryptoStatusError(ctx, evt, errorEventID, errNoDecryptionKeys, 2, true) - return - } - - log.Debug().Msg("Got keys after waiting longer, trying to decrypt event again") - decrypted, err := mx.bridge.Crypto.Decrypt(ctx, evt) - if err != nil { - log.Error().Err(err).Msg("Failed to decrypt event") - mx.sendCryptoStatusError(ctx, evt, errorEventID, err, 2, true) - return - } - - mx.postDecrypt(ctx, evt, decrypted, 2, errorEventID, time.Since(decryptionStart)) -} - -func (mx *MatrixHandler) HandleMessage(ctx context.Context, evt *event.Event) { - defer mx.TrackEventDuration(evt.Type)() - log := zerolog.Ctx(ctx).With(). - Str("event_id", evt.ID.String()). - Str("room_id", evt.RoomID.String()). - Str("sender", evt.Sender.String()). - Logger() - ctx = log.WithContext(ctx) - if mx.shouldIgnoreEvent(evt) { - return - } else if !evt.Mautrix.WasEncrypted && mx.bridge.Config.Bridge.GetEncryptionConfig().Require { - log.Warn().Msg("Dropping unencrypted event") - mx.sendCryptoStatusError(ctx, evt, "", errMessageNotEncrypted, 0, true) - return - } - - user := mx.bridge.Child.GetIUser(evt.Sender, true) - if user == nil { - return - } - - content := evt.Content.AsMessage() - content.RemoveReplyFallback() - if user.GetPermissionLevel() >= bridgeconfig.PermissionLevelUser && content.MsgType == event.MsgText { - commandPrefix := mx.bridge.Config.Bridge.GetCommandPrefix() - hasCommandPrefix := strings.HasPrefix(content.Body, commandPrefix) - if hasCommandPrefix { - content.Body = strings.TrimLeft(strings.TrimPrefix(content.Body, commandPrefix), " ") - } - if hasCommandPrefix || evt.RoomID == user.GetManagementRoomID() { - go mx.bridge.CommandProcessor.Handle(ctx, evt.RoomID, evt.ID, user, content.Body, content.RelatesTo.GetReplyTo()) - go mx.bridge.SendMessageSuccessCheckpoint(evt, status.MsgStepCommand, 0) - if mx.bridge.Config.Bridge.EnableMessageStatusEvents() { - statusEvent := &event.BeeperMessageStatusEventContent{ - // TODO: network - RelatesTo: event.RelatesTo{ - Type: event.RelReference, - EventID: evt.ID, - }, - Status: event.MessageStatusSuccess, - } - _, sendErr := mx.bridge.Bot.SendMessageEvent(ctx, evt.RoomID, event.BeeperMessageStatus, statusEvent) - if sendErr != nil { - log.Warn().Err(sendErr).Msg("Failed to send message status event for command") - } - } - return - } - } - - portal := mx.bridge.Child.GetIPortal(evt.RoomID) - if portal != nil { - portal.ReceiveMatrixEvent(user, evt) - } else { - mx.bridge.SendMessageErrorCheckpoint(evt, status.MsgStepRemote, fmt.Errorf("unknown room"), true, 0) - } -} - -func (mx *MatrixHandler) HandleReaction(_ context.Context, evt *event.Event) { - defer mx.TrackEventDuration(evt.Type)() - if mx.shouldIgnoreEvent(evt) { - return - } - - user := mx.bridge.Child.GetIUser(evt.Sender, true) - if user == nil || user.GetPermissionLevel() < bridgeconfig.PermissionLevelUser || !user.IsLoggedIn() { - return - } - - portal := mx.bridge.Child.GetIPortal(evt.RoomID) - if portal != nil { - portal.ReceiveMatrixEvent(user, evt) - } else { - mx.bridge.SendMessageErrorCheckpoint(evt, status.MsgStepRemote, fmt.Errorf("unknown room"), true, 0) - } -} - -func (mx *MatrixHandler) HandleRedaction(_ context.Context, evt *event.Event) { - defer mx.TrackEventDuration(evt.Type)() - if mx.shouldIgnoreEvent(evt) { - return - } - - user := mx.bridge.Child.GetIUser(evt.Sender, true) - if user == nil { - return - } - - portal := mx.bridge.Child.GetIPortal(evt.RoomID) - if portal != nil { - portal.ReceiveMatrixEvent(user, evt) - } else { - mx.bridge.SendMessageErrorCheckpoint(evt, status.MsgStepRemote, fmt.Errorf("unknown room"), true, 0) - } -} - -func (mx *MatrixHandler) HandleReceipt(_ context.Context, evt *event.Event) { - portal := mx.bridge.Child.GetIPortal(evt.RoomID) - if portal == nil { - return - } - - rrPortal, ok := portal.(ReadReceiptHandlingPortal) - if !ok { - return - } - - for eventID, receipts := range *evt.Content.AsReceipt() { - for userID, receipt := range receipts[event.ReceiptTypeRead] { - user := mx.bridge.Child.GetIUser(userID, false) - if user == nil { - // Not a bridge user - continue - } - customPuppet := user.GetIDoublePuppet() - if val, ok := receipt.Extra[appservice.DoublePuppetKey].(string); ok && customPuppet != nil && val == mx.bridge.Name { - // Ignore double puppeted read receipts. - mx.log.Debug().Interface("content", evt.Content.Raw).Msg("Ignoring double-puppeted read receipt") - // But do start disappearing messages, because the user read the chat - dp, ok := portal.(DisappearingPortal) - if ok { - dp.ScheduleDisappearing() - } - } else { - rrPortal.HandleMatrixReadReceipt(user, eventID, receipt) - } - } - } -} - -func (mx *MatrixHandler) HandleTyping(_ context.Context, evt *event.Event) { - portal := mx.bridge.Child.GetIPortal(evt.RoomID) - if portal == nil { - return - } - typingPortal, ok := portal.(TypingPortal) - if !ok { - return - } - typingPortal.HandleMatrixTyping(evt.Content.AsTyping().UserIDs) -} - -func (mx *MatrixHandler) HandlePowerLevels(_ context.Context, evt *event.Event) { - if mx.shouldIgnoreEvent(evt) { - return - } - portal := mx.bridge.Child.GetIPortal(evt.RoomID) - if portal == nil { - return - } - powerLevelPortal, ok := portal.(PowerLevelHandlingPortal) - if ok { - user := mx.bridge.Child.GetIUser(evt.Sender, true) - powerLevelPortal.HandleMatrixPowerLevels(user, evt) - } -} diff --git a/bridge/messagecheckpoint.go b/bridge/messagecheckpoint.go deleted file mode 100644 index a95d2160..00000000 --- a/bridge/messagecheckpoint.go +++ /dev/null @@ -1,61 +0,0 @@ -// Copyright (c) 2021 Sumner Evans -// Copyright (c) 2023 Tulir Asokan -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. - -package bridge - -import ( - "maunium.net/go/mautrix/appservice" - "maunium.net/go/mautrix/bridge/status" - "maunium.net/go/mautrix/event" -) - -func (br *Bridge) SendMessageSuccessCheckpoint(evt *event.Event, step status.MessageCheckpointStep, retryNum int) { - br.SendMessageCheckpoint(evt, step, nil, status.MsgStatusSuccess, retryNum) -} - -func (br *Bridge) SendMessageErrorCheckpoint(evt *event.Event, step status.MessageCheckpointStep, err error, permanent bool, retryNum int) { - s := status.MsgStatusWillRetry - if permanent { - s = status.MsgStatusPermFailure - } - br.SendMessageCheckpoint(evt, step, err, s, retryNum) -} - -func (br *Bridge) SendMessageCheckpoint(evt *event.Event, step status.MessageCheckpointStep, err error, s status.MessageCheckpointStatus, retryNum int) { - checkpoint := status.NewMessageCheckpoint(evt, step, s, retryNum) - if err != nil { - checkpoint.Info = err.Error() - } - go br.SendRawMessageCheckpoint(checkpoint) -} - -func (br *Bridge) SendRawMessageCheckpoint(cp *status.MessageCheckpoint) { - err := br.SendMessageCheckpoints([]*status.MessageCheckpoint{cp}) - if err != nil { - br.ZLog.Warn().Err(err).Interface("message_checkpoint", cp).Msg("Error sending message checkpoint") - } else { - br.ZLog.Debug().Interface("message_checkpoint", cp).Msg("Sent message checkpoint") - } -} - -func (br *Bridge) SendMessageCheckpoints(checkpoints []*status.MessageCheckpoint) error { - checkpointsJSON := status.CheckpointsJSON{Checkpoints: checkpoints} - - if br.Websocket { - return br.AS.SendWebsocket(&appservice.WebsocketRequest{ - Command: "message_checkpoint", - Data: checkpointsJSON, - }) - } - - endpoint := br.Config.Homeserver.MessageSendCheckpointEndpoint - if endpoint == "" { - return nil - } - - return checkpointsJSON.SendHTTP(endpoint, br.AS.Registration.AppToken) -} diff --git a/bridgev2/backfillqueue.go b/bridgev2/backfillqueue.go new file mode 100644 index 00000000..61318d94 --- /dev/null +++ b/bridgev2/backfillqueue.go @@ -0,0 +1,248 @@ +// Copyright (c) 2024 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package bridgev2 + +import ( + "context" + "fmt" + "runtime/debug" + "time" + + "github.com/rs/zerolog" + + "maunium.net/go/mautrix/bridgev2/database" +) + +const BackfillMinBackoffAfterRoomCreate = 1 * time.Minute +const BackfillQueueErrorBackoff = 1 * time.Minute +const BackfillQueueMaxEmptyBackoff = 10 * time.Minute + +func (br *Bridge) WakeupBackfillQueue() { + select { + case br.wakeupBackfillQueue <- struct{}{}: + default: + } +} + +func (br *Bridge) RunBackfillQueue() { + if !br.Config.Backfill.Queue.Enabled || !br.Config.Backfill.Enabled { + return + } + log := br.Log.With().Str("component", "backfill queue").Logger() + if !br.Matrix.GetCapabilities().BatchSending { + log.Warn().Msg("Backfill queue is enabled in config, but Matrix server doesn't support batch sending") + return + } + ctx, cancel := context.WithCancel(log.WithContext(context.Background())) + br.stopBackfillQueue.Clear() + stopChan := br.stopBackfillQueue.GetChan() + go func() { + <-stopChan + cancel() + }() + batchDelay := time.Duration(br.Config.Backfill.Queue.BatchDelay) * time.Second + log.Info().Stringer("batch_delay", batchDelay).Msg("Backfill queue starting") + noTasksFoundCount := 0 + for { + nextDelay := batchDelay + if noTasksFoundCount > 0 { + extraDelay := batchDelay * time.Duration(noTasksFoundCount) + nextDelay += min(BackfillQueueMaxEmptyBackoff, extraDelay) + } + timer := time.NewTimer(nextDelay) + select { + case <-br.wakeupBackfillQueue: + if !timer.Stop() { + select { + case <-timer.C: + default: + } + } + noTasksFoundCount = 0 + case <-stopChan: + if !timer.Stop() { + select { + case <-timer.C: + default: + } + } + log.Info().Msg("Stopping backfill queue") + return + case <-timer.C: + } + backfillTask, err := br.DB.BackfillTask.GetNext(ctx) + if err != nil { + log.Err(err).Msg("Failed to get next backfill queue entry") + time.Sleep(BackfillQueueErrorBackoff) + continue + } else if backfillTask != nil { + br.DoBackfillTask(ctx, backfillTask) + noTasksFoundCount = 0 + } + } +} + +func (br *Bridge) DoBackfillTask(ctx context.Context, task *database.BackfillTask) { + log := zerolog.Ctx(ctx).With(). + Object("portal_key", task.PortalKey). + Str("login_id", string(task.UserLoginID)). + Logger() + defer func() { + err := recover() + if err != nil { + logEvt := log.Error(). + Bytes(zerolog.ErrorStackFieldName, debug.Stack()) + if realErr, ok := err.(error); ok { + logEvt = logEvt.Err(realErr) + } else { + logEvt = logEvt.Any(zerolog.ErrorFieldName, err) + } + logEvt.Msg("Panic in backfill queue") + } + }() + ctx = log.WithContext(ctx) + err := br.DB.BackfillTask.MarkDispatched(ctx, task) + if err != nil { + log.Err(err).Msg("Failed to mark backfill task as dispatched") + time.Sleep(BackfillQueueErrorBackoff) + return + } + completed, err := br.actuallyDoBackfillTask(ctx, task) + if err != nil { + log.Err(err).Msg("Failed to do backfill task") + time.Sleep(BackfillQueueErrorBackoff) + return + } else if completed { + log.Info(). + Int("batch_count", task.BatchCount). + Bool("is_done", task.IsDone). + Msg("Backfill task completed successfully") + } else { + log.Info(). + Int("batch_count", task.BatchCount). + Bool("is_done", task.IsDone). + Msg("Backfill task canceled") + } + err = br.DB.BackfillTask.Update(ctx, task) + if err != nil { + log.Err(err).Msg("Failed to update backfill task") + time.Sleep(BackfillQueueErrorBackoff) + } +} + +func (portal *Portal) deleteBackfillQueueTaskIfRoomDoesNotExist(ctx context.Context) bool { + // Acquire the room create lock to ensure that task deletion doesn't race with room creation + portal.roomCreateLock.Lock() + defer portal.roomCreateLock.Unlock() + if portal.MXID == "" { + zerolog.Ctx(ctx).Debug().Msg("Portal for backfill task doesn't exist, deleting entry") + err := portal.Bridge.DB.BackfillTask.Delete(ctx, portal.PortalKey) + if err != nil { + zerolog.Ctx(ctx).Err(err).Msg("Failed to delete backfill task after portal wasn't found") + } + return true + } + return false +} + +func (br *Bridge) actuallyDoBackfillTask(ctx context.Context, task *database.BackfillTask) (bool, error) { + log := zerolog.Ctx(ctx) + portal, err := br.GetExistingPortalByKey(ctx, task.PortalKey) + if err != nil { + return false, fmt.Errorf("failed to get portal for backfill task: %w", err) + } else if portal == nil { + log.Warn().Msg("Portal not found for backfill task") + err = br.DB.BackfillTask.Delete(ctx, task.PortalKey) + if err != nil { + log.Err(err).Msg("Failed to delete backfill task after portal wasn't found") + time.Sleep(BackfillQueueErrorBackoff) + } + return false, nil + } else if portal.MXID == "" { + portal.deleteBackfillQueueTaskIfRoomDoesNotExist(ctx) + return false, nil + } + login, err := br.GetExistingUserLoginByID(ctx, task.UserLoginID) + if err != nil { + return false, fmt.Errorf("failed to get user login for backfill task: %w", err) + } else if login == nil || !login.Client.IsLoggedIn() { + if login == nil { + log.Warn().Msg("User login not found for backfill task") + } else { + log.Warn().Msg("User login not logged in for backfill task") + } + logins, err := br.GetUserLoginsInPortal(ctx, portal.PortalKey) + if err != nil { + return false, fmt.Errorf("failed to get user portals for backfill task: %w", err) + } else if len(logins) == 0 { + log.Debug().Msg("No user logins found for backfill task") + task.NextDispatchMinTS = database.BackfillNextDispatchNever + if login == nil { + task.UserLoginID = "" + } + return false, nil + } + if login == nil { + task.UserLoginID = "" + } + foundLogin := false + for _, login = range logins { + if login.Client.IsLoggedIn() { + foundLogin = true + task.UserLoginID = login.ID + log.UpdateContext(func(c zerolog.Context) zerolog.Context { + return c.Str("overridden_login_id", string(login.ID)) + }) + log.Debug().Msg("Found user login for backfill task") + break + } + } + if !foundLogin { + log.Debug().Msg("No logged in user logins found for backfill task") + task.NextDispatchMinTS = database.BackfillNextDispatchNever + return false, nil + } + } + if task.BatchCount < 0 { + var msgCount int + msgCount, err = br.DB.Message.CountMessagesInPortal(ctx, task.PortalKey) + if err != nil { + return false, fmt.Errorf("failed to count messages in portal: %w", err) + } + task.BatchCount = msgCount / br.Config.Backfill.Queue.BatchSize + log.Debug(). + Int("message_count", msgCount). + Int("batch_count", task.BatchCount). + Msg("Calculated existing batch count") + } + maxBatches := br.Config.Backfill.Queue.MaxBatches + api, ok := login.Client.(BackfillingNetworkAPI) + if !ok { + return false, fmt.Errorf("network API does not support backfilling") + } + limiterAPI, ok := api.(BackfillingNetworkAPIWithLimits) + if ok { + maxBatches = limiterAPI.GetBackfillMaxBatchCount(ctx, portal, task) + } + if maxBatches < 0 || maxBatches > task.BatchCount { + err = portal.DoBackwardsBackfill(ctx, login, task) + if err != nil { + return false, fmt.Errorf("failed to backfill: %w", err) + } + task.BatchCount++ + } else { + log.Debug(). + Int("max_batches", maxBatches). + Int("batch_count", task.BatchCount). + Msg("Not actually backfilling: max batches reached") + } + task.IsDone = task.IsDone || (maxBatches > 0 && task.BatchCount >= maxBatches) + batchDelay := time.Duration(br.Config.Backfill.Queue.BatchDelay) * time.Second + task.CompletedAt = time.Now() + task.NextDispatchMinTS = task.CompletedAt.Add(batchDelay) + return true, nil +} diff --git a/bridgev2/bridge.go b/bridgev2/bridge.go new file mode 100644 index 00000000..226adc90 --- /dev/null +++ b/bridgev2/bridge.go @@ -0,0 +1,458 @@ +// Copyright (c) 2024 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package bridgev2 + +import ( + "context" + "fmt" + "os" + "sync" + "sync/atomic" + "time" + + "github.com/rs/zerolog" + "go.mau.fi/util/dbutil" + "go.mau.fi/util/exhttp" + "go.mau.fi/util/exsync" + + "maunium.net/go/mautrix/bridgev2/bridgeconfig" + "maunium.net/go/mautrix/bridgev2/database" + "maunium.net/go/mautrix/bridgev2/networkid" + "maunium.net/go/mautrix/bridgev2/status" + "maunium.net/go/mautrix/id" +) + +type CommandProcessor interface { + Handle(ctx context.Context, roomID id.RoomID, eventID id.EventID, user *User, message string, replyTo id.EventID) +} + +type Bridge struct { + ID networkid.BridgeID + DB *database.Database + Log zerolog.Logger + + Matrix MatrixConnector + Bot MatrixAPI + Network NetworkConnector + Commands CommandProcessor + Config *bridgeconfig.BridgeConfig + + DisappearLoop *DisappearLoop + + usersByMXID map[id.UserID]*User + userLoginsByID map[networkid.UserLoginID]*UserLogin + portalsByKey map[networkid.PortalKey]*Portal + portalsByMXID map[id.RoomID]*Portal + ghostsByID map[networkid.UserID]*Ghost + cacheLock sync.Mutex + + didSplitPortals bool + + Background bool + ExternallyManagedDB bool + stopping atomic.Bool + + wakeupBackfillQueue chan struct{} + stopBackfillQueue *exsync.Event + + BackgroundCtx context.Context + cancelBackgroundCtx context.CancelFunc +} + +func NewBridge( + bridgeID networkid.BridgeID, + db *dbutil.Database, + log zerolog.Logger, + cfg *bridgeconfig.BridgeConfig, + matrix MatrixConnector, + network NetworkConnector, + newCommandProcessor func(*Bridge) CommandProcessor, +) *Bridge { + br := &Bridge{ + ID: bridgeID, + DB: database.New(bridgeID, network.GetDBMetaTypes(), db), + Log: log, + + Matrix: matrix, + Network: network, + Config: cfg, + + usersByMXID: make(map[id.UserID]*User), + userLoginsByID: make(map[networkid.UserLoginID]*UserLogin), + portalsByKey: make(map[networkid.PortalKey]*Portal), + portalsByMXID: make(map[id.RoomID]*Portal), + ghostsByID: make(map[networkid.UserID]*Ghost), + + wakeupBackfillQueue: make(chan struct{}), + stopBackfillQueue: exsync.NewEvent(), + } + if br.Config == nil { + br.Config = &bridgeconfig.BridgeConfig{CommandPrefix: "!bridge"} + } + br.Commands = newCommandProcessor(br) + br.Matrix.Init(br) + br.Bot = br.Matrix.BotIntent() + br.Network.Init(br) + br.DisappearLoop = &DisappearLoop{br: br} + return br +} + +type DBUpgradeError struct { + Err error + Section string +} + +func (e DBUpgradeError) Error() string { + return e.Err.Error() +} + +func (e DBUpgradeError) Unwrap() error { + return e.Err +} + +func (br *Bridge) Start(ctx context.Context) error { + ctx = br.Log.WithContext(ctx) + err := br.StartConnectors(ctx) + if err != nil { + return err + } + err = br.StartLogins(ctx) + if err != nil { + return err + } + go br.PostStart(ctx) + return nil +} + +func (br *Bridge) RunOnce(ctx context.Context, loginID networkid.UserLoginID, params *ConnectBackgroundParams) error { + br.Background = true + br.stopping.Store(false) + err := br.StartConnectors(ctx) + if err != nil { + return err + } + + if loginID == "" { + br.Log.Info().Msg("No login ID provided to RunOnce, running all logins for 20 seconds") + err = br.StartLogins(ctx) + if err != nil { + return err + } + defer br.StopWithTimeout(5 * time.Second) + select { + case <-time.After(20 * time.Second): + case <-ctx.Done(): + } + return nil + } + + defer br.stop(true, 5*time.Second) + login, err := br.GetExistingUserLoginByID(ctx, loginID) + if err != nil { + return fmt.Errorf("failed to get user login: %w", err) + } else if login == nil { + return ErrNotLoggedIn + } + syncClient, ok := login.Client.(BackgroundSyncingNetworkAPI) + if !ok { + br.Log.Warn().Msg("Network connector doesn't implement background mode, using fallback mechanism for RunOnce") + login.Client.Connect(ctx) + defer login.DisconnectWithTimeout(5 * time.Second) + select { + case <-time.After(20 * time.Second): + case <-ctx.Done(): + } + br.stopping.Store(true) + return nil + } else { + br.Log.Info().Str("user_login_id", string(login.ID)).Msg("Starting individual user login in background mode") + return syncClient.ConnectBackground(login.Log.WithContext(ctx), params) + } +} + +func (br *Bridge) StartConnectors(ctx context.Context) error { + br.Log.Info().Msg("Starting bridge") + br.stopping.Store(false) + if br.BackgroundCtx == nil || br.BackgroundCtx.Err() != nil { + br.BackgroundCtx, br.cancelBackgroundCtx = context.WithCancel(context.Background()) + br.BackgroundCtx = br.Log.WithContext(br.BackgroundCtx) + } + + if !br.ExternallyManagedDB { + err := br.DB.Upgrade(ctx) + if err != nil { + return DBUpgradeError{Err: err, Section: "main"} + } + } + if !br.Background { + var postMigrate func() + br.didSplitPortals, postMigrate = br.MigrateToSplitPortals(ctx) + if postMigrate != nil { + defer postMigrate() + } + } + br.Log.Info().Msg("Starting Matrix connector") + err := br.Matrix.Start(ctx) + if err != nil { + return fmt.Errorf("failed to start Matrix connector: %w", err) + } + br.Log.Info().Msg("Starting network connector") + err = br.Network.Start(ctx) + if err != nil { + return fmt.Errorf("failed to start network connector: %w", err) + } + if br.Network.GetCapabilities().DisappearingMessages && !br.Background { + go br.DisappearLoop.Start() + } + return nil +} + +func (br *Bridge) PostStart(ctx context.Context) { + if br.Background { + return + } + rawBridgeInfoVer := br.DB.KV.Get(ctx, database.KeyBridgeInfoVersion) + bridgeInfoVer, capVer, err := parseBridgeInfoVersion(rawBridgeInfoVer) + if err != nil { + br.Log.Err(err).Str("db_bridge_info_version", rawBridgeInfoVer).Msg("Failed to parse bridge info version") + return + } + expectedBridgeInfoVer, expectedCapVer := br.Network.GetBridgeInfoVersion() + doResendBridgeInfo := bridgeInfoVer != expectedBridgeInfoVer || br.didSplitPortals || br.Config.ResendBridgeInfo + doResendCapabilities := capVer != expectedCapVer || br.didSplitPortals + if doResendBridgeInfo || doResendCapabilities { + br.ResendBridgeInfo(ctx, doResendBridgeInfo, doResendCapabilities) + } + br.DB.KV.Set(ctx, database.KeyBridgeInfoVersion, fmt.Sprintf("%d,%d", expectedBridgeInfoVer, expectedCapVer)) +} + +func parseBridgeInfoVersion(version string) (info, capabilities int, err error) { + _, err = fmt.Sscanf(version, "%d,%d", &info, &capabilities) + if version == "" { + err = nil + } + return +} + +func (br *Bridge) ResendBridgeInfo(ctx context.Context, resendInfo, resendCaps bool) { + log := zerolog.Ctx(ctx).With().Str("action", "resend bridge info").Logger() + portals, err := br.GetAllPortalsWithMXID(ctx) + if err != nil { + log.Err(err).Msg("Failed to get portals") + return + } + for _, portal := range portals { + if resendInfo { + portal.UpdateBridgeInfo(ctx) + } + if resendCaps { + logins, err := br.GetUserLoginsInPortal(ctx, portal.PortalKey) + if err != nil { + log.Err(err). + Stringer("room_id", portal.MXID). + Object("portal_key", portal.PortalKey). + Msg("Failed to get user logins in portal") + } else { + found := false + for _, login := range logins { + if portal.CapState.ID == "" || login.ID == portal.CapState.Source { + portal.UpdateCapabilities(ctx, login, true) + found = true + } + } + if !found && len(logins) > 0 { + portal.CapState.Source = "" + portal.UpdateCapabilities(ctx, logins[0], true) + } else if !found { + log.Warn(). + Stringer("room_id", portal.MXID). + Object("portal_key", portal.PortalKey). + Msg("No user login found to update capabilities") + } + } + } + } + log.Info(). + Bool("capabilities", resendCaps). + Bool("info", resendInfo). + Msg("Resent bridge info to all portals") +} + +func (br *Bridge) MigrateToSplitPortals(ctx context.Context) (bool, func()) { + log := zerolog.Ctx(ctx).With().Str("action", "migrate to split portals").Logger() + ctx = log.WithContext(ctx) + if !br.Config.SplitPortals || br.DB.KV.Get(ctx, database.KeySplitPortalsEnabled) == "true" { + return false, nil + } + affected, err := br.DB.Portal.MigrateToSplitPortals(ctx) + if err != nil { + log.WithLevel(zerolog.FatalLevel).Err(err).Msg("Failed to migrate portals") + os.Exit(31) + return false, nil + } + log.Info().Int64("rows_affected", affected).Msg("Migrated to split portals") + affected2, err := br.DB.Portal.FixParentsAfterSplitPortalMigration(ctx) + if err != nil { + log.Err(err).Msg("Failed to fix parent portals after split portal migration") + os.Exit(31) + return false, nil + } + log.Info().Int64("rows_affected", affected2).Msg("Updated parent receivers after split portal migration") + withoutReceiver, err := br.DB.Portal.GetAllWithoutReceiver(ctx) + if err != nil { + log.WithLevel(zerolog.FatalLevel).Err(err).Msg("Failed to get portals that failed to migrate") + os.Exit(31) + return false, nil + } + var roomsToDelete []id.RoomID + log.Info().Int("remaining_portals", len(withoutReceiver)).Msg("Deleting remaining portals without receiver") + for _, portal := range withoutReceiver { + if err = br.DB.Portal.Delete(ctx, portal.PortalKey); err != nil { + log.Err(err). + Str("portal_id", string(portal.ID)). + Stringer("mxid", portal.MXID). + Msg("Failed to delete portal database row that failed to migrate") + } else if portal.MXID != "" { + log.Debug(). + Str("portal_id", string(portal.ID)). + Stringer("mxid", portal.MXID). + Msg("Marked portal room for deletion from homeserver") + roomsToDelete = append(roomsToDelete, portal.MXID) + } else { + log.Debug(). + Str("portal_id", string(portal.ID)). + Msg("Deleted portal row with no Matrix room") + } + } + br.DB.KV.Set(ctx, database.KeySplitPortalsEnabled, "true") + log.Info().Msg("Finished split portal migration successfully") + return affected > 0, func() { + for _, roomID := range roomsToDelete { + if err = br.Bot.DeleteRoom(ctx, roomID, true); err != nil { + log.Err(err). + Stringer("mxid", roomID). + Msg("Failed to delete portal room that failed to migrate") + } + } + log.Info().Int("room_count", len(roomsToDelete)).Msg("Finished deleting rooms that failed to migrate") + } +} + +func (br *Bridge) StartLogins(ctx context.Context) error { + userIDs, err := br.DB.UserLogin.GetAllUserIDsWithLogins(ctx) + if err != nil { + return fmt.Errorf("failed to get users with logins: %w", err) + } + startedAny := false + for _, userID := range userIDs { + br.Log.Info().Stringer("user_id", userID).Msg("Loading user") + var user *User + user, err = br.GetUserByMXID(ctx, userID) + if err != nil { + br.Log.Err(err).Stringer("user_id", userID).Msg("Failed to load user") + } else { + for _, login := range user.GetUserLogins() { + startedAny = true + br.Log.Info().Str("id", string(login.ID)).Msg("Starting user login") + login.Client.Connect(login.Log.WithContext(ctx)) + } + } + } + if !startedAny { + br.Log.Info().Msg("No user logins found") + br.SendGlobalBridgeState(status.BridgeState{StateEvent: status.StateUnconfigured}) + } + if !br.Background { + go br.RunBackfillQueue() + } + + br.Log.Info().Msg("Bridge started") + return nil +} + +func (br *Bridge) ResetNetworkConnections() { + nrn, ok := br.Network.(NetworkResettingNetwork) + if ok { + br.Log.Info().Msg("Resetting network connections with NetworkConnector.ResetNetworkConnections") + nrn.ResetNetworkConnections() + return + } + + br.Log.Info().Msg("Network connector doesn't support ResetNetworkConnections, recreating clients manually") + for _, login := range br.GetAllCachedUserLogins() { + login.Log.Debug().Msg("Disconnecting and recreating client for network reset") + ctx := login.Log.WithContext(br.BackgroundCtx) + login.Client.Disconnect() + err := login.recreateClient(ctx) + if err != nil { + login.Log.Err(err).Msg("Failed to recreate client during network reset") + login.BridgeState.Send(status.BridgeState{ + StateEvent: status.StateUnknownError, + Error: "bridgev2-network-reset-fail", + Info: map[string]any{"go_error": err.Error()}, + }) + } else { + login.Client.Connect(ctx) + } + } + br.Log.Info().Msg("Finished resetting all user logins") +} + +func (br *Bridge) GetHTTPClientSettings() exhttp.ClientSettings { + mchs, ok := br.Matrix.(MatrixConnectorWithHTTPSettings) + if ok { + return mchs.GetHTTPClientSettings() + } + return exhttp.SensibleClientSettings +} + +func (br *Bridge) IsStopping() bool { + return br.stopping.Load() +} + +func (br *Bridge) Stop() { + br.stop(false, 0) +} + +func (br *Bridge) StopWithTimeout(timeout time.Duration) { + br.stop(false, timeout) +} + +func (br *Bridge) stop(isRunOnce bool, timeout time.Duration) { + br.Log.Info().Msg("Shutting down bridge") + br.stopping.Store(true) + br.DisappearLoop.Stop() + br.stopBackfillQueue.Set() + br.Matrix.PreStop() + if !isRunOnce { + br.cacheLock.Lock() + var wg sync.WaitGroup + wg.Add(len(br.userLoginsByID)) + for _, login := range br.userLoginsByID { + go func() { + login.DisconnectWithTimeout(timeout) + wg.Done() + }() + } + br.cacheLock.Unlock() + wg.Wait() + } + br.Matrix.Stop() + if br.cancelBackgroundCtx != nil { + br.cancelBackgroundCtx() + } + if stopNet, ok := br.Network.(StoppableNetwork); ok { + stopNet.Stop() + } + if !br.ExternallyManagedDB { + err := br.DB.Close() + if err != nil { + br.Log.Warn().Err(err).Msg("Failed to close database") + } + } + br.Log.Info().Msg("Shutdown complete") +} diff --git a/bridgev2/bridgeconfig/appservice.go b/bridgev2/bridgeconfig/appservice.go new file mode 100644 index 00000000..f709c8e0 --- /dev/null +++ b/bridgev2/bridgeconfig/appservice.go @@ -0,0 +1,140 @@ +// Copyright (c) 2024 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package bridgeconfig + +import ( + "fmt" + "regexp" + "strings" + "text/template" + + "go.mau.fi/util/exerrors" + "go.mau.fi/util/random" + "gopkg.in/yaml.v3" + + "maunium.net/go/mautrix/appservice" + "maunium.net/go/mautrix/id" +) + +type AppserviceConfig struct { + Address string `yaml:"address"` + PublicAddress string `yaml:"public_address"` + Hostname string `yaml:"hostname"` + Port uint16 `yaml:"port"` + + ID string `yaml:"id"` + Bot BotUserConfig `yaml:"bot"` + + ASToken string `yaml:"as_token"` + HSToken string `yaml:"hs_token"` + + EphemeralEvents bool `yaml:"ephemeral_events"` + AsyncTransactions bool `yaml:"async_transactions"` + + UsernameTemplate string `yaml:"username_template"` + usernameTemplate *template.Template `yaml:"-"` +} + +func (asc *AppserviceConfig) FormatUsername(username string) string { + if asc.usernameTemplate == nil { + asc.usernameTemplate = exerrors.Must(template.New("username").Parse(asc.UsernameTemplate)) + } + var buf strings.Builder + _ = asc.usernameTemplate.Execute(&buf, username) + return buf.String() +} + +func (config *Config) MakeUserIDRegex(matcher string) *regexp.Regexp { + usernamePlaceholder := strings.ToLower(random.String(16)) + usernameTemplate := fmt.Sprintf("@%s:%s", + config.AppService.FormatUsername(usernamePlaceholder), + config.Homeserver.Domain) + usernameTemplate = regexp.QuoteMeta(usernameTemplate) + usernameTemplate = strings.Replace(usernameTemplate, usernamePlaceholder, matcher, 1) + usernameTemplate = fmt.Sprintf("^%s$", usernameTemplate) + return regexp.MustCompile(usernameTemplate) +} + +// GetRegistration copies the data from the bridge config into an *appservice.Registration struct. +// This can't be used with the homeserver, see GenerateRegistration for generating files for the homeserver. +func (asc *AppserviceConfig) GetRegistration() *appservice.Registration { + reg := &appservice.Registration{} + asc.copyToRegistration(reg) + reg.SenderLocalpart = asc.Bot.Username + reg.ServerToken = asc.HSToken + reg.AppToken = asc.ASToken + return reg +} + +func (asc *AppserviceConfig) copyToRegistration(registration *appservice.Registration) { + registration.ID = asc.ID + registration.URL = asc.Address + falseVal := false + registration.RateLimited = &falseVal + registration.EphemeralEvents = asc.EphemeralEvents + registration.SoruEphemeralEvents = asc.EphemeralEvents +} + +func (ec *EncryptionConfig) applyUnstableFlags(registration *appservice.Registration) { + registration.MSC4190 = ec.MSC4190 + registration.MSC3202 = ec.Appservice +} + +// GenerateRegistration generates a registration file for the homeserver. +func (config *Config) GenerateRegistration() *appservice.Registration { + registration := appservice.CreateRegistration() + config.AppService.HSToken = registration.ServerToken + config.AppService.ASToken = registration.AppToken + config.AppService.copyToRegistration(registration) + config.Encryption.applyUnstableFlags(registration) + + registration.SenderLocalpart = random.String(32) + botRegex := regexp.MustCompile(fmt.Sprintf("^@%s:%s$", + regexp.QuoteMeta(config.AppService.Bot.Username), + regexp.QuoteMeta(config.Homeserver.Domain))) + registration.Namespaces.UserIDs.Register(botRegex, true) + registration.Namespaces.UserIDs.Register(config.MakeUserIDRegex(".*"), true) + + return registration +} + +func (config *Config) MakeAppService() *appservice.AppService { + as := appservice.Create() + as.HomeserverDomain = config.Homeserver.Domain + _ = as.SetHomeserverURL(config.Homeserver.Address) + as.Host.Hostname = config.AppService.Hostname + as.Host.Port = config.AppService.Port + as.Registration = config.AppService.GetRegistration() + config.Encryption.applyUnstableFlags(as.Registration) + return as +} + +type BotUserConfig struct { + Username string `yaml:"username"` + Displayname string `yaml:"displayname"` + Avatar string `yaml:"avatar"` + + ParsedAvatar id.ContentURI `yaml:"-"` +} + +type serializableBUC BotUserConfig + +func (buc *BotUserConfig) UnmarshalYAML(node *yaml.Node) error { + var sbuc serializableBUC + err := node.Decode(&sbuc) + if err != nil { + return err + } + *buc = (BotUserConfig)(sbuc) + if buc.Avatar != "" && buc.Avatar != "remove" { + buc.ParsedAvatar, err = id.ParseContentURI(buc.Avatar) + if err != nil { + return fmt.Errorf("%w in bot avatar", err) + } + } + return nil +} diff --git a/bridgev2/bridgeconfig/backfill.go b/bridgev2/bridgeconfig/backfill.go new file mode 100644 index 00000000..eedae1e8 --- /dev/null +++ b/bridgev2/bridgeconfig/backfill.go @@ -0,0 +1,45 @@ +// Copyright (c) 2024 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package bridgeconfig + +type BackfillConfig struct { + Enabled bool `yaml:"enabled"` + MaxInitialMessages int `yaml:"max_initial_messages"` + MaxCatchupMessages int `yaml:"max_catchup_messages"` + UnreadHoursThreshold int `yaml:"unread_hours_threshold"` + + Threads BackfillThreadsConfig `yaml:"threads"` + Queue BackfillQueueConfig `yaml:"queue"` + + // Flag to indicate that the creator will not run the backfill queue but will still paginate + // backfill by calling DoBackfillTask directly. Note that this is not used anywhere within + // mautrix-go and exists so bridges can use it to decide when to drop backfill data. + WillPaginateManually bool `yaml:"will_paginate_manually"` +} + +type BackfillThreadsConfig struct { + MaxInitialMessages int `yaml:"max_initial_messages"` +} + +type BackfillQueueConfig struct { + Enabled bool `yaml:"enabled"` + BatchSize int `yaml:"batch_size"` + BatchDelay int `yaml:"batch_delay"` + MaxBatches int `yaml:"max_batches"` + + MaxBatchesOverride map[string]int `yaml:"max_batches_override"` +} + +func (bqc *BackfillQueueConfig) GetOverride(names ...string) int { + for _, name := range names { + override, ok := bqc.MaxBatchesOverride[name] + if ok { + return override + } + } + return bqc.MaxBatches +} diff --git a/bridgev2/bridgeconfig/config.go b/bridgev2/bridgeconfig/config.go new file mode 100644 index 00000000..bd6b9c06 --- /dev/null +++ b/bridgev2/bridgeconfig/config.go @@ -0,0 +1,139 @@ +// Copyright (c) 2024 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package bridgeconfig + +import ( + "time" + + "go.mau.fi/util/dbutil" + "go.mau.fi/zeroconfig" + "gopkg.in/yaml.v3" + + "maunium.net/go/mautrix/event" + "maunium.net/go/mautrix/mediaproxy" +) + +type Config struct { + Network yaml.Node `yaml:"network"` + Bridge BridgeConfig `yaml:"bridge"` + Database dbutil.Config `yaml:"database"` + Homeserver HomeserverConfig `yaml:"homeserver"` + AppService AppserviceConfig `yaml:"appservice"` + Matrix MatrixConfig `yaml:"matrix"` + Analytics AnalyticsConfig `yaml:"analytics"` + Provisioning ProvisioningConfig `yaml:"provisioning"` + PublicMedia PublicMediaConfig `yaml:"public_media"` + DirectMedia DirectMediaConfig `yaml:"direct_media"` + Backfill BackfillConfig `yaml:"backfill"` + DoublePuppet DoublePuppetConfig `yaml:"double_puppet"` + Encryption EncryptionConfig `yaml:"encryption"` + Logging zeroconfig.Config `yaml:"logging"` + + EnvConfigPrefix string `yaml:"env_config_prefix"` + + ManagementRoomTexts ManagementRoomTexts `yaml:"management_room_texts"` +} + +type CleanupAction string + +const ( + CleanupActionNull CleanupAction = "" + CleanupActionNothing CleanupAction = "nothing" + CleanupActionKick CleanupAction = "kick" + CleanupActionUnbridge CleanupAction = "unbridge" + CleanupActionDelete CleanupAction = "delete" +) + +type CleanupOnLogout struct { + Private CleanupAction `yaml:"private"` + Relayed CleanupAction `yaml:"relayed"` + SharedNoUsers CleanupAction `yaml:"shared_no_users"` + SharedHasUsers CleanupAction `yaml:"shared_has_users"` +} + +type CleanupOnLogouts struct { + Enabled bool `yaml:"enabled"` + Manual CleanupOnLogout `yaml:"manual"` + BadCredentials CleanupOnLogout `yaml:"bad_credentials"` +} + +type BridgeConfig struct { + CommandPrefix string `yaml:"command_prefix"` + PersonalFilteringSpaces bool `yaml:"personal_filtering_spaces"` + PrivateChatPortalMeta bool `yaml:"private_chat_portal_meta"` + AsyncEvents bool `yaml:"async_events"` + SplitPortals bool `yaml:"split_portals"` + ResendBridgeInfo bool `yaml:"resend_bridge_info"` + NoBridgeInfoStateKey bool `yaml:"no_bridge_info_state_key"` + BridgeStatusNotices string `yaml:"bridge_status_notices"` + UnknownErrorAutoReconnect time.Duration `yaml:"unknown_error_auto_reconnect"` + UnknownErrorMaxAutoReconnects int `yaml:"unknown_error_max_auto_reconnects"` + BridgeMatrixLeave bool `yaml:"bridge_matrix_leave"` + BridgeNotices bool `yaml:"bridge_notices"` + TagOnlyOnCreate bool `yaml:"tag_only_on_create"` + OnlyBridgeTags []event.RoomTag `yaml:"only_bridge_tags"` + MuteOnlyOnCreate bool `yaml:"mute_only_on_create"` + DeduplicateMatrixMessages bool `yaml:"deduplicate_matrix_messages"` + CrossRoomReplies bool `yaml:"cross_room_replies"` + OutgoingMessageReID bool `yaml:"outgoing_message_re_id"` + RevertFailedStateChanges bool `yaml:"revert_failed_state_changes"` + KickMatrixUsers bool `yaml:"kick_matrix_users"` + CleanupOnLogout CleanupOnLogouts `yaml:"cleanup_on_logout"` + Relay RelayConfig `yaml:"relay"` + Permissions PermissionConfig `yaml:"permissions"` + Backfill BackfillConfig `yaml:"backfill"` +} + +type MatrixConfig struct { + MessageStatusEvents bool `yaml:"message_status_events"` + DeliveryReceipts bool `yaml:"delivery_receipts"` + MessageErrorNotices bool `yaml:"message_error_notices"` + SyncDirectChatList bool `yaml:"sync_direct_chat_list"` + FederateRooms bool `yaml:"federate_rooms"` + UploadFileThreshold int64 `yaml:"upload_file_threshold"` + GhostExtraProfileInfo bool `yaml:"ghost_extra_profile_info"` +} + +type AnalyticsConfig struct { + Token string `yaml:"token"` + URL string `yaml:"url"` + UserID string `yaml:"user_id"` +} + +type ProvisioningConfig struct { + SharedSecret string `yaml:"shared_secret"` + DebugEndpoints bool `yaml:"debug_endpoints"` + EnableSessionTransfers bool `yaml:"enable_session_transfers"` +} + +type DirectMediaConfig struct { + Enabled bool `yaml:"enabled"` + MediaIDPrefix string `yaml:"media_id_prefix"` + mediaproxy.BasicConfig `yaml:",inline"` +} + +type PublicMediaConfig struct { + Enabled bool `yaml:"enabled"` + SigningKey string `yaml:"signing_key"` + Expiry int `yaml:"expiry"` + HashLength int `yaml:"hash_length"` + PathPrefix string `yaml:"path_prefix"` + UseDatabase bool `yaml:"use_database"` +} + +type DoublePuppetConfig struct { + Servers map[string]string `yaml:"servers"` + AllowDiscovery bool `yaml:"allow_discovery"` + Secrets map[string]string `yaml:"secrets"` +} + +type ManagementRoomTexts struct { + Welcome string `yaml:"welcome"` + WelcomeConnected string `yaml:"welcome_connected"` + WelcomeUnconnected string `yaml:"welcome_unconnected"` + AdditionalHelp string `yaml:"additional_help"` +} diff --git a/bridgev2/bridgeconfig/encryption.go b/bridgev2/bridgeconfig/encryption.go new file mode 100644 index 00000000..934613ca --- /dev/null +++ b/bridgev2/bridgeconfig/encryption.go @@ -0,0 +1,51 @@ +// Copyright (c) 2024 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package bridgeconfig + +import ( + "maunium.net/go/mautrix/id" +) + +type EncryptionConfig struct { + Allow bool `yaml:"allow"` + Default bool `yaml:"default"` + Require bool `yaml:"require"` + Appservice bool `yaml:"appservice"` + MSC4190 bool `yaml:"msc4190"` + MSC4392 bool `yaml:"msc4392"` + SelfSign bool `yaml:"self_sign"` + + PlaintextMentions bool `yaml:"plaintext_mentions"` + + PickleKey string `yaml:"pickle_key"` + + DeleteKeys struct { + DeleteOutboundOnAck bool `yaml:"delete_outbound_on_ack"` + DontStoreOutbound bool `yaml:"dont_store_outbound"` + RatchetOnDecrypt bool `yaml:"ratchet_on_decrypt"` + DeleteFullyUsedOnDecrypt bool `yaml:"delete_fully_used_on_decrypt"` + DeletePrevOnNewSession bool `yaml:"delete_prev_on_new_session"` + DeleteOnDeviceDelete bool `yaml:"delete_on_device_delete"` + PeriodicallyDeleteExpired bool `yaml:"periodically_delete_expired"` + DeleteOutdatedInbound bool `yaml:"delete_outdated_inbound"` + } `yaml:"delete_keys"` + + VerificationLevels struct { + Receive id.TrustState `yaml:"receive"` + Send id.TrustState `yaml:"send"` + Share id.TrustState `yaml:"share"` + } `yaml:"verification_levels"` + AllowKeySharing bool `yaml:"allow_key_sharing"` + + Rotation struct { + EnableCustom bool `yaml:"enable_custom"` + Milliseconds int64 `yaml:"milliseconds"` + Messages int `yaml:"messages"` + + DisableDeviceChangeKeyRotation bool `yaml:"disable_device_change_key_rotation"` + } `yaml:"rotation"` +} diff --git a/bridgev2/bridgeconfig/homeserver.go b/bridgev2/bridgeconfig/homeserver.go new file mode 100644 index 00000000..8d888d4f --- /dev/null +++ b/bridgev2/bridgeconfig/homeserver.go @@ -0,0 +1,38 @@ +// Copyright (c) 2024 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package bridgeconfig + +type HomeserverSoftware string + +const ( + SoftwareStandard HomeserverSoftware = "standard" + SoftwareAsmux HomeserverSoftware = "asmux" + SoftwareHungry HomeserverSoftware = "hungry" +) + +var AllowedHomeserverSoftware = map[HomeserverSoftware]bool{ + SoftwareStandard: true, + SoftwareAsmux: true, + SoftwareHungry: true, +} + +type HomeserverConfig struct { + Address string `yaml:"address"` + Domain string `yaml:"domain"` + AsyncMedia bool `yaml:"async_media"` + + PublicAddress string `yaml:"public_address,omitempty"` + + Software HomeserverSoftware `yaml:"software"` + + StatusEndpoint string `yaml:"status_endpoint"` + MessageSendCheckpointEndpoint string `yaml:"message_send_checkpoint_endpoint"` + + Websocket bool `yaml:"websocket"` + WSProxy string `yaml:"websocket_proxy"` + WSPingInterval int `yaml:"ping_interval_seconds"` +} diff --git a/bridgev2/bridgeconfig/legacymigrate.go b/bridgev2/bridgeconfig/legacymigrate.go new file mode 100644 index 00000000..954a37c3 --- /dev/null +++ b/bridgev2/bridgeconfig/legacymigrate.go @@ -0,0 +1,174 @@ +// Copyright (c) 2024 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package bridgeconfig + +import ( + "fmt" + "net/url" + "os" + "strings" + + up "go.mau.fi/util/configupgrade" +) + +var HackyMigrateLegacyNetworkConfig func(up.Helper) + +func CopyToOtherLocation(helper up.Helper, fieldType up.YAMLType, source, dest []string) { + val, ok := helper.Get(fieldType, source...) + if ok { + helper.Set(fieldType, val, dest...) + } +} + +func CopyMapToOtherLocation(helper up.Helper, source, dest []string) { + val := helper.GetNode(source...) + if val != nil && val.Map != nil { + helper.SetMap(val.Map, dest...) + } +} + +func doMigrateLegacy(helper up.Helper, python bool) { + if HackyMigrateLegacyNetworkConfig == nil { + _, _ = fmt.Fprintln(os.Stderr, "Legacy bridge config detected, but hacky network config migrator is not set") + os.Exit(1) + } + _, _ = fmt.Fprintln(os.Stderr, "Migrating legacy bridge config") + + helper.Copy(up.Str, "homeserver", "address") + helper.Copy(up.Str, "homeserver", "domain") + helper.Copy(up.Str, "homeserver", "software") + helper.Copy(up.Str|up.Null, "homeserver", "status_endpoint") + helper.Copy(up.Str|up.Null, "homeserver", "message_send_checkpoint_endpoint") + helper.Copy(up.Bool, "homeserver", "async_media") + helper.Copy(up.Str|up.Null, "homeserver", "websocket_proxy") + helper.Copy(up.Bool, "homeserver", "websocket") + helper.Copy(up.Int, "homeserver", "ping_interval_seconds") + + helper.Copy(up.Str|up.Null, "appservice", "address") + helper.Copy(up.Str|up.Null, "appservice", "hostname") + helper.Copy(up.Int|up.Null, "appservice", "port") + helper.Copy(up.Str, "appservice", "id") + if python { + CopyToOtherLocation(helper, up.Str, []string{"appservice", "bot_username"}, []string{"appservice", "bot", "username"}) + CopyToOtherLocation(helper, up.Str, []string{"appservice", "bot_displayname"}, []string{"appservice", "bot", "displayname"}) + CopyToOtherLocation(helper, up.Str, []string{"appservice", "bot_avatar"}, []string{"appservice", "bot", "avatar"}) + } else { + helper.Copy(up.Str, "appservice", "bot", "username") + helper.Copy(up.Str, "appservice", "bot", "displayname") + helper.Copy(up.Str, "appservice", "bot", "avatar") + } + helper.Copy(up.Bool, "appservice", "ephemeral_events") + helper.Copy(up.Bool, "appservice", "async_transactions") + helper.Copy(up.Str, "appservice", "as_token") + helper.Copy(up.Str, "appservice", "hs_token") + + helper.Copy(up.Str, "bridge", "command_prefix") + helper.Copy(up.Bool, "bridge", "personal_filtering_spaces") + if oldPM, ok := helper.Get(up.Str, "bridge", "private_chat_portal_meta"); ok && (oldPM == "default" || oldPM == "always") { + helper.Set(up.Bool, "true", "bridge", "private_chat_portal_meta") + } else { + helper.Set(up.Bool, "false", "bridge", "private_chat_portal_meta") + } + helper.Copy(up.Bool, "bridge", "relay", "enabled") + helper.Copy(up.Bool, "bridge", "relay", "admin_only") + helper.Copy(up.Map, "bridge", "permissions") + + if python { + legacyDB, ok := helper.Get(up.Str, "appservice", "database") + if ok { + if strings.HasPrefix(legacyDB, "postgres") { + parsedDB, err := url.Parse(legacyDB) + if err != nil { + panic(err) + } + q := parsedDB.Query() + if parsedDB.Host == "" && !q.Has("host") { + q.Set("host", "/var/run/postgresql") + } else if !q.Has("sslmode") { + q.Set("sslmode", "disable") + } + parsedDB.RawQuery = q.Encode() + helper.Set(up.Str, parsedDB.String(), "database", "uri") + helper.Set(up.Str, "postgres", "database", "type") + } else { + dbPath := strings.TrimPrefix(strings.TrimPrefix(legacyDB, "sqlite:"), "///") + helper.Set(up.Str, fmt.Sprintf("file:%s?_txlock=immediate", dbPath), "database", "uri") + helper.Set(up.Str, "sqlite3-fk-wal", "database", "type") + } + } + if legacyDBMinSize, ok := helper.Get(up.Int, "appservice", "database_opts", "min_size"); ok { + helper.Set(up.Int, legacyDBMinSize, "database", "max_idle_conns") + } + if legacyDBMaxSize, ok := helper.Get(up.Int, "appservice", "database_opts", "max_size"); ok { + helper.Set(up.Int, legacyDBMaxSize, "database", "max_open_conns") + } + } else { + if dbType, ok := helper.Get(up.Str, "appservice", "database", "type"); ok && dbType == "sqlite3" { + helper.Set(up.Str, "sqlite3-fk-wal", "database", "type") + } else { + CopyToOtherLocation(helper, up.Str, []string{"appservice", "database", "type"}, []string{"database", "type"}) + } + CopyToOtherLocation(helper, up.Str, []string{"appservice", "database", "uri"}, []string{"database", "uri"}) + CopyToOtherLocation(helper, up.Int, []string{"appservice", "database", "max_open_conns"}, []string{"database", "max_open_conns"}) + CopyToOtherLocation(helper, up.Int, []string{"appservice", "database", "max_idle_conns"}, []string{"database", "max_idle_conns"}) + CopyToOtherLocation(helper, up.Int, []string{"appservice", "database", "max_conn_idle_time"}, []string{"database", "max_conn_idle_time"}) + CopyToOtherLocation(helper, up.Int, []string{"appservice", "database", "max_conn_lifetime"}, []string{"database", "max_conn_lifetime"}) + } + + if python { + if usernameTemplate, ok := helper.Get(up.Str, "bridge", "username_template"); ok && strings.Contains(usernameTemplate, "{userid}") { + helper.Set(up.Str, strings.ReplaceAll(usernameTemplate, "{userid}", "{{.}}"), "appservice", "username_template") + } + } else { + CopyToOtherLocation(helper, up.Str, []string{"bridge", "username_template"}, []string{"appservice", "username_template"}) + } + + CopyToOtherLocation(helper, up.Bool, []string{"bridge", "message_status_events"}, []string{"matrix", "message_status_events"}) + CopyToOtherLocation(helper, up.Bool, []string{"bridge", "delivery_receipts"}, []string{"matrix", "delivery_receipts"}) + CopyToOtherLocation(helper, up.Bool, []string{"bridge", "message_error_notices"}, []string{"matrix", "message_error_notices"}) + CopyToOtherLocation(helper, up.Bool, []string{"bridge", "sync_direct_chat_list"}, []string{"matrix", "sync_direct_chat_list"}) + CopyToOtherLocation(helper, up.Bool, []string{"bridge", "federate_rooms"}, []string{"matrix", "federate_rooms"}) + + CopyToOtherLocation(helper, up.Str, []string{"bridge", "provisioning", "shared_secret"}, []string{"provisioning", "shared_secret"}) + CopyToOtherLocation(helper, up.Str, []string{"appservice", "provisioning", "shared_secret"}, []string{"provisioning", "shared_secret"}) + CopyToOtherLocation(helper, up.Bool, []string{"bridge", "provisioning", "debug_endpoints"}, []string{"provisioning", "debug_endpoints"}) + + CopyToOtherLocation(helper, up.Bool, []string{"bridge", "double_puppet_allow_discovery"}, []string{"double_puppet", "allow_discovery"}) + CopyMapToOtherLocation(helper, []string{"bridge", "double_puppet_server_map"}, []string{"double_puppet", "servers"}) + CopyMapToOtherLocation(helper, []string{"bridge", "login_shared_secret_map"}, []string{"double_puppet", "secrets"}) + + CopyToOtherLocation(helper, up.Bool, []string{"bridge", "encryption", "allow"}, []string{"encryption", "allow"}) + CopyToOtherLocation(helper, up.Bool, []string{"bridge", "encryption", "default"}, []string{"encryption", "default"}) + CopyToOtherLocation(helper, up.Bool, []string{"bridge", "encryption", "require"}, []string{"encryption", "require"}) + CopyToOtherLocation(helper, up.Bool, []string{"bridge", "encryption", "appservice"}, []string{"encryption", "appservice"}) + CopyToOtherLocation(helper, up.Bool, []string{"bridge", "encryption", "allow_key_sharing"}, []string{"encryption", "allow_key_sharing"}) + CopyToOtherLocation(helper, up.Bool, []string{"bridge", "encryption", "delete_keys", "delete_outbound_on_ack"}, []string{"encryption", "delete_keys", "delete_outbound_on_ack"}) + CopyToOtherLocation(helper, up.Bool, []string{"bridge", "encryption", "delete_keys", "dont_store_outbound"}, []string{"encryption", "delete_keys", "dont_store_outbound"}) + CopyToOtherLocation(helper, up.Bool, []string{"bridge", "encryption", "delete_keys", "ratchet_on_decrypt"}, []string{"encryption", "delete_keys", "ratchet_on_decrypt"}) + CopyToOtherLocation(helper, up.Bool, []string{"bridge", "encryption", "delete_keys", "delete_fully_used_on_decrypt"}, []string{"encryption", "delete_keys", "delete_fully_used_on_decrypt"}) + CopyToOtherLocation(helper, up.Bool, []string{"bridge", "encryption", "delete_keys", "delete_prev_on_new_session"}, []string{"encryption", "delete_keys", "delete_prev_on_new_session"}) + CopyToOtherLocation(helper, up.Bool, []string{"bridge", "encryption", "delete_keys", "delete_on_device_delete"}, []string{"encryption", "delete_keys", "delete_on_device_delete"}) + CopyToOtherLocation(helper, up.Bool, []string{"bridge", "encryption", "delete_keys", "periodically_delete_expired"}, []string{"encryption", "delete_keys", "periodically_delete_expired"}) + CopyToOtherLocation(helper, up.Bool, []string{"bridge", "encryption", "delete_keys", "delete_outdated_inbound"}, []string{"encryption", "delete_keys", "delete_outdated_inbound"}) + CopyToOtherLocation(helper, up.Str, []string{"bridge", "encryption", "verification_levels", "receive"}, []string{"encryption", "verification_levels", "receive"}) + CopyToOtherLocation(helper, up.Str, []string{"bridge", "encryption", "verification_levels", "send"}, []string{"encryption", "verification_levels", "send"}) + CopyToOtherLocation(helper, up.Str, []string{"bridge", "encryption", "verification_levels", "share"}, []string{"encryption", "verification_levels", "share"}) + CopyToOtherLocation(helper, up.Bool, []string{"bridge", "encryption", "rotation", "enable_custom"}, []string{"encryption", "rotation", "enable_custom"}) + CopyToOtherLocation(helper, up.Int, []string{"bridge", "encryption", "rotation", "milliseconds"}, []string{"encryption", "rotation", "milliseconds"}) + CopyToOtherLocation(helper, up.Int, []string{"bridge", "encryption", "rotation", "messages"}, []string{"encryption", "rotation", "messages"}) + CopyToOtherLocation(helper, up.Bool, []string{"bridge", "encryption", "rotation", "disable_device_change_key_rotation"}, []string{"encryption", "rotation", "disable_device_change_key_rotation"}) + + if helper.GetNode("logging", "writers") == nil && (helper.GetNode("logging", "print_level") != nil || helper.GetNode("logging", "file_name_format") != nil) { + _, _ = fmt.Fprintln(os.Stderr, "Migrating maulogger configs is not supported") + } else if (helper.GetNode("logging", "writers") == nil && (helper.GetNode("logging", "handlers") != nil)) || python { + _, _ = fmt.Fprintln(os.Stderr, "Migrating Python log configs is not supported") + } else { + helper.Copy(up.Map, "logging") + } + + HackyMigrateLegacyNetworkConfig(helper) +} diff --git a/bridgev2/bridgeconfig/permissions.go b/bridgev2/bridgeconfig/permissions.go new file mode 100644 index 00000000..9efe068e --- /dev/null +++ b/bridgev2/bridgeconfig/permissions.go @@ -0,0 +1,124 @@ +// Copyright (c) 2023 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package bridgeconfig + +import ( + "fmt" + "os" + "strconv" + "strings" + + "gopkg.in/yaml.v3" + + "maunium.net/go/mautrix/id" +) + +type Permissions struct { + SendEvents bool `yaml:"send_events"` + Commands bool `yaml:"commands"` + Login bool `yaml:"login"` + DoublePuppet bool `yaml:"double_puppet"` + Admin bool `yaml:"admin"` + ManageRelay bool `yaml:"manage_relay"` + MaxLogins int `yaml:"max_logins"` +} + +type PermissionConfig map[string]*Permissions + +func boolToInt(val bool) int { + if val { + return 1 + } + return 0 +} + +func (pc PermissionConfig) IsConfigured() bool { + _, hasWildcard := pc["*"] + _, hasExampleDomain := pc["example.com"] + _, hasExampleUser := pc["@admin:example.com"] + exampleLen := boolToInt(hasWildcard) + boolToInt(hasExampleUser) + boolToInt(hasExampleDomain) + return len(pc) > exampleLen +} + +func (pc PermissionConfig) Get(userID id.UserID) Permissions { + if level, ok := pc[string(userID)]; ok { + return *level + } else if level, ok = pc[userID.Homeserver()]; len(userID.Homeserver()) > 0 && ok { + return *level + } else if level, ok = pc["*"]; ok { + return *level + } else { + return PermissionLevelBlock + } +} + +var ( + PermissionLevelBlock = Permissions{} + PermissionLevelRelay = Permissions{SendEvents: true} + PermissionLevelCommands = Permissions{SendEvents: true, Commands: true, ManageRelay: true} + PermissionLevelUser = Permissions{SendEvents: true, Commands: true, ManageRelay: true, Login: true, DoublePuppet: true} + PermissionLevelAdmin = Permissions{SendEvents: true, Commands: true, ManageRelay: true, Login: true, DoublePuppet: true, Admin: true} +) + +var namesToLevels = map[string]Permissions{ + "block": PermissionLevelBlock, + "relay": PermissionLevelRelay, + "commands": PermissionLevelCommands, + "user": PermissionLevelUser, + "admin": PermissionLevelAdmin, +} + +var levelsToNames = map[Permissions]string{ + PermissionLevelBlock: "block", + PermissionLevelRelay: "relay", + PermissionLevelCommands: "commands", + PermissionLevelUser: "user", + PermissionLevelAdmin: "admin", +} + +type umPerm Permissions + +func (p *Permissions) UnmarshalYAML(perm *yaml.Node) error { + switch perm.Tag { + case "!!str": + var ok bool + *p, ok = namesToLevels[strings.ToLower(perm.Value)] + if !ok { + return fmt.Errorf("invalid permissions level %s", perm.Value) + } + return nil + case "!!map": + err := perm.Decode((*umPerm)(p)) + return err + case "!!int": + val, err := strconv.Atoi(perm.Value) + if err != nil { + return fmt.Errorf("invalid permissions level %s", perm.Value) + } + _, _ = fmt.Fprintln(os.Stderr, "Warning: config contains deprecated integer permission values") + // Integer values are deprecated, so they're hardcoded + if val < 5 { + *p = PermissionLevelBlock + } else if val < 10 { + *p = PermissionLevelRelay + } else if val < 100 { + *p = PermissionLevelUser + } else { + *p = PermissionLevelAdmin + } + return nil + default: + return fmt.Errorf("invalid permissions type %s", perm.Tag) + } +} + +func (p *Permissions) MarshalYAML() (any, error) { + if level, ok := levelsToNames[*p]; ok { + return level, nil + } + return umPerm(*p), nil +} diff --git a/bridgev2/bridgeconfig/relay.go b/bridgev2/bridgeconfig/relay.go new file mode 100644 index 00000000..c802f85e --- /dev/null +++ b/bridgev2/bridgeconfig/relay.go @@ -0,0 +1,109 @@ +// Copyright (c) 2024 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package bridgeconfig + +import ( + "fmt" + "strings" + "text/template" + + "gopkg.in/yaml.v3" + + "maunium.net/go/mautrix/bridgev2/networkid" + "maunium.net/go/mautrix/event" + "maunium.net/go/mautrix/format" +) + +type RelayConfig struct { + Enabled bool `yaml:"enabled"` + AdminOnly bool `yaml:"admin_only"` + DefaultRelays []networkid.UserLoginID `yaml:"default_relays"` + MessageFormats map[event.MessageType]string `yaml:"message_formats"` + DisplaynameFormat string `yaml:"displayname_format"` + messageTemplates *template.Template `yaml:"-"` + nameTemplate *template.Template `yaml:"-"` +} + +type umRelayConfig RelayConfig + +func (rc *RelayConfig) UnmarshalYAML(node *yaml.Node) error { + err := node.Decode((*umRelayConfig)(rc)) + if err != nil { + return err + } + + rc.messageTemplates = template.New("messageTemplates") + for key, template := range rc.MessageFormats { + _, err = rc.messageTemplates.New(string(key)).Parse(template) + if err != nil { + return err + } + } + + rc.nameTemplate, err = template.New("nameTemplate").Parse(rc.DisplaynameFormat) + if err != nil { + return err + } + + return nil +} + +type formatData struct { + Sender any + Content *event.MessageEventContent + Caption string + Message string + FileName string +} + +func isMedia(msgType event.MessageType) bool { + switch msgType { + case event.MsgImage, event.MsgVideo, event.MsgAudio, event.MsgFile: + return true + default: + return false + } +} + +func (rc *RelayConfig) FormatMessage(content *event.MessageEventContent, sender any) (*event.MessageEventContent, error) { + _, isSupported := rc.MessageFormats[content.MsgType] + if !isSupported { + return nil, fmt.Errorf("unsupported msgtype for relaying") + } + contentCopy := *content + content = &contentCopy + content.EnsureHasHTML() + fd := &formatData{ + Sender: sender, + Content: content, + Message: content.FormattedBody, + } + fd.Message = content.FormattedBody + if content.FileName != "" { + fd.FileName = content.FileName + if content.FileName != content.Body { + fd.Caption = fd.Message + } + } else if isMedia(content.MsgType) { + content.FileName = content.Body + fd.FileName = content.Body + } + var output strings.Builder + err := rc.messageTemplates.ExecuteTemplate(&output, string(content.MsgType), fd) + if err != nil { + return nil, err + } + content.FormattedBody = output.String() + content.Body = format.HTMLToText(content.FormattedBody) + return content, nil +} + +func (rc *RelayConfig) FormatName(sender any) string { + var buf strings.Builder + _ = rc.nameTemplate.Execute(&buf, sender) + return strings.TrimSpace(buf.String()) +} diff --git a/bridgev2/bridgeconfig/upgrade.go b/bridgev2/bridgeconfig/upgrade.go new file mode 100644 index 00000000..92515ea0 --- /dev/null +++ b/bridgev2/bridgeconfig/upgrade.go @@ -0,0 +1,227 @@ +// Copyright (c) 2024 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package bridgeconfig + +import ( + "fmt" + + up "go.mau.fi/util/configupgrade" + "go.mau.fi/util/random" + + "maunium.net/go/mautrix/federation" +) + +func doUpgrade(helper up.Helper) { + if _, isLegacyConfig := helper.Get(up.Str, "appservice", "database", "uri"); isLegacyConfig { + doMigrateLegacy(helper, false) + return + } else if _, isLegacyPython := helper.Get(up.Str, "appservice", "database"); isLegacyPython { + doMigrateLegacy(helper, true) + return + } + + helper.Copy(up.Str, "bridge", "command_prefix") + helper.Copy(up.Bool, "bridge", "personal_filtering_spaces") + helper.Copy(up.Bool, "bridge", "private_chat_portal_meta") + helper.Copy(up.Bool, "bridge", "async_events") + helper.Copy(up.Bool, "bridge", "split_portals") + helper.Copy(up.Bool, "bridge", "resend_bridge_info") + helper.Copy(up.Bool, "bridge", "no_bridge_info_state_key") + helper.Copy(up.Str|up.Null, "bridge", "bridge_status_notices") + helper.Copy(up.Str|up.Int|up.Null, "bridge", "unknown_error_auto_reconnect") + helper.Copy(up.Int, "bridge", "unknown_error_max_auto_reconnects") + helper.Copy(up.Bool, "bridge", "bridge_matrix_leave") + helper.Copy(up.Bool, "bridge", "bridge_notices") + helper.Copy(up.Bool, "bridge", "tag_only_on_create") + helper.Copy(up.List, "bridge", "only_bridge_tags") + helper.Copy(up.Bool, "bridge", "mute_only_on_create") + helper.Copy(up.Bool, "bridge", "deduplicate_matrix_messages") + helper.Copy(up.Bool, "bridge", "cross_room_replies") + helper.Copy(up.Bool, "bridge", "revert_failed_state_changes") + helper.Copy(up.Bool, "bridge", "kick_matrix_users") + helper.Copy(up.Bool, "bridge", "cleanup_on_logout", "enabled") + helper.Copy(up.Str, "bridge", "cleanup_on_logout", "manual", "private") + helper.Copy(up.Str, "bridge", "cleanup_on_logout", "manual", "relayed") + helper.Copy(up.Str, "bridge", "cleanup_on_logout", "manual", "shared_no_users") + helper.Copy(up.Str, "bridge", "cleanup_on_logout", "manual", "shared_has_users") + helper.Copy(up.Str, "bridge", "cleanup_on_logout", "bad_credentials", "private") + helper.Copy(up.Str, "bridge", "cleanup_on_logout", "bad_credentials", "relayed") + helper.Copy(up.Str, "bridge", "cleanup_on_logout", "bad_credentials", "shared_no_users") + helper.Copy(up.Str, "bridge", "cleanup_on_logout", "bad_credentials", "shared_has_users") + helper.Copy(up.Bool, "bridge", "relay", "enabled") + helper.Copy(up.Bool, "bridge", "relay", "admin_only") + helper.Copy(up.List, "bridge", "relay", "default_relays") + helper.Copy(up.Map, "bridge", "relay", "message_formats") + helper.Copy(up.Str, "bridge", "relay", "displayname_format") + helper.Copy(up.Map, "bridge", "permissions") + + if dbType, ok := helper.Get(up.Str, "database", "type"); ok && dbType == "sqlite3" { + fmt.Println("Warning: invalid database type sqlite3 in config. Autocorrecting to sqlite3-fk-wal") + helper.Set(up.Str, "sqlite3-fk-wal", "database", "type") + } else { + helper.Copy(up.Str, "database", "type") + } + helper.Copy(up.Str, "database", "uri") + helper.Copy(up.Int, "database", "max_open_conns") + helper.Copy(up.Int, "database", "max_idle_conns") + helper.Copy(up.Str|up.Null, "database", "max_conn_idle_time") + helper.Copy(up.Str|up.Null, "database", "max_conn_lifetime") + + helper.Copy(up.Str, "homeserver", "address") + helper.Copy(up.Str, "homeserver", "domain") + helper.Copy(up.Str, "homeserver", "software") + helper.Copy(up.Str|up.Null, "homeserver", "status_endpoint") + helper.Copy(up.Str|up.Null, "homeserver", "message_send_checkpoint_endpoint") + helper.Copy(up.Bool, "homeserver", "async_media") + helper.Copy(up.Str|up.Null, "homeserver", "websocket_proxy") + helper.Copy(up.Bool, "homeserver", "websocket") + helper.Copy(up.Int, "homeserver", "ping_interval_seconds") + + helper.Copy(up.Str|up.Null, "appservice", "address") + helper.Copy(up.Str|up.Null, "appservice", "public_address") + helper.Copy(up.Str|up.Null, "appservice", "hostname") + helper.Copy(up.Int|up.Null, "appservice", "port") + helper.Copy(up.Str, "appservice", "id") + helper.Copy(up.Str, "appservice", "bot", "username") + helper.Copy(up.Str, "appservice", "bot", "displayname") + helper.Copy(up.Str, "appservice", "bot", "avatar") + helper.Copy(up.Bool, "appservice", "ephemeral_events") + helper.Copy(up.Bool, "appservice", "async_transactions") + helper.Copy(up.Str, "appservice", "as_token") + helper.Copy(up.Str, "appservice", "hs_token") + helper.Copy(up.Str, "appservice", "username_template") + + helper.Copy(up.Bool, "matrix", "message_status_events") + helper.Copy(up.Bool, "matrix", "delivery_receipts") + helper.Copy(up.Bool, "matrix", "message_error_notices") + helper.Copy(up.Bool, "matrix", "sync_direct_chat_list") + helper.Copy(up.Bool, "matrix", "federate_rooms") + helper.Copy(up.Int, "matrix", "upload_file_threshold") + helper.Copy(up.Bool, "matrix", "ghost_extra_profile_info") + + helper.Copy(up.Str|up.Null, "analytics", "token") + helper.Copy(up.Str|up.Null, "analytics", "url") + helper.Copy(up.Str|up.Null, "analytics", "user_id") + + if secret, ok := helper.Get(up.Str, "provisioning", "shared_secret"); !ok || secret == "generate" { + sharedSecret := random.String(64) + helper.Set(up.Str, sharedSecret, "provisioning", "shared_secret") + } else { + helper.Copy(up.Str, "provisioning", "shared_secret") + } + helper.Copy(up.Bool, "provisioning", "debug_endpoints") + helper.Copy(up.Bool, "provisioning", "enable_session_transfers") + + helper.Copy(up.Bool, "direct_media", "enabled") + helper.Copy(up.Str|up.Null, "direct_media", "media_id_prefix") + helper.Copy(up.Str, "direct_media", "server_name") + helper.Copy(up.Str|up.Null, "direct_media", "well_known_response") + helper.Copy(up.Bool, "direct_media", "allow_proxy") + if serverKey, ok := helper.Get(up.Str, "direct_media", "server_key"); !ok || serverKey == "generate" { + serverKey = federation.GenerateSigningKey().SynapseString() + helper.Set(up.Str, serverKey, "direct_media", "server_key") + } else { + helper.Copy(up.Str, "direct_media", "server_key") + } + + helper.Copy(up.Bool, "public_media", "enabled") + if signingKey, ok := helper.Get(up.Str, "public_media", "signing_key"); !ok || signingKey == "generate" { + helper.Set(up.Str, random.String(64), "public_media", "signing_key") + } else { + helper.Copy(up.Str, "public_media", "signing_key") + } + helper.Copy(up.Int, "public_media", "expiry") + helper.Copy(up.Int, "public_media", "hash_length") + helper.Copy(up.Str|up.Null, "public_media", "path_prefix") + helper.Copy(up.Bool, "public_media", "use_database") + + helper.Copy(up.Bool, "backfill", "enabled") + helper.Copy(up.Int, "backfill", "max_initial_messages") + helper.Copy(up.Int, "backfill", "max_catchup_messages") + helper.Copy(up.Int, "backfill", "unread_hours_threshold") + helper.Copy(up.Int, "backfill", "threads", "max_initial_messages") + helper.Copy(up.Bool, "backfill", "queue", "enabled") + helper.Copy(up.Int, "backfill", "queue", "batch_size") + helper.Copy(up.Int, "backfill", "queue", "batch_delay") + helper.Copy(up.Int, "backfill", "queue", "max_batches") + helper.Copy(up.Map, "backfill", "queue", "max_batches_override") + + helper.Copy(up.Map, "double_puppet", "servers") + helper.Copy(up.Bool, "double_puppet", "allow_discovery") + helper.Copy(up.Map, "double_puppet", "secrets") + + helper.Copy(up.Bool, "encryption", "allow") + helper.Copy(up.Bool, "encryption", "default") + helper.Copy(up.Bool, "encryption", "require") + helper.Copy(up.Bool, "encryption", "appservice") + if val, ok := helper.Get(up.Bool, "appservice", "msc4190"); ok { + helper.Set(up.Bool, val, "encryption", "msc4190") + } else { + helper.Copy(up.Bool, "encryption", "msc4190") + } + helper.Copy(up.Bool, "encryption", "msc4392") + helper.Copy(up.Bool, "encryption", "self_sign") + helper.Copy(up.Bool, "encryption", "allow_key_sharing") + if secret, ok := helper.Get(up.Str, "encryption", "pickle_key"); !ok || secret == "generate" { + helper.Set(up.Str, random.String(64), "encryption", "pickle_key") + } else { + helper.Copy(up.Str, "encryption", "pickle_key") + } + helper.Copy(up.Bool, "encryption", "delete_keys", "delete_outbound_on_ack") + helper.Copy(up.Bool, "encryption", "delete_keys", "dont_store_outbound") + helper.Copy(up.Bool, "encryption", "delete_keys", "ratchet_on_decrypt") + helper.Copy(up.Bool, "encryption", "delete_keys", "delete_fully_used_on_decrypt") + helper.Copy(up.Bool, "encryption", "delete_keys", "delete_prev_on_new_session") + helper.Copy(up.Bool, "encryption", "delete_keys", "delete_on_device_delete") + helper.Copy(up.Bool, "encryption", "delete_keys", "periodically_delete_expired") + helper.Copy(up.Bool, "encryption", "delete_keys", "delete_outdated_inbound") + helper.Copy(up.Str, "encryption", "verification_levels", "receive") + helper.Copy(up.Str, "encryption", "verification_levels", "send") + helper.Copy(up.Str, "encryption", "verification_levels", "share") + helper.Copy(up.Bool, "encryption", "rotation", "enable_custom") + helper.Copy(up.Int, "encryption", "rotation", "milliseconds") + helper.Copy(up.Int, "encryption", "rotation", "messages") + helper.Copy(up.Bool, "encryption", "rotation", "disable_device_change_key_rotation") + + helper.Copy(up.Str|up.Null, "env_config_prefix") + + helper.Copy(up.Map, "logging") +} + +var SpacedBlocks = [][]string{ + {"bridge"}, + {"bridge", "bridge_matrix_leave"}, + {"bridge", "cleanup_on_logout"}, + {"bridge", "relay"}, + {"bridge", "permissions"}, + {"database"}, + {"homeserver"}, + {"homeserver", "software"}, + {"homeserver", "websocket"}, + {"appservice"}, + {"appservice", "hostname"}, + {"appservice", "id"}, + {"appservice", "ephemeral_events"}, + {"appservice", "as_token"}, + {"appservice", "username_template"}, + {"matrix"}, + {"analytics"}, + {"provisioning"}, + {"public_media"}, + {"direct_media"}, + {"backfill"}, + {"double_puppet"}, + {"encryption"}, + {"env_config_prefix"}, + {"logging"}, +} + +// Upgrader is a config upgrader that copies the default fields in the homeserver, appservice and logging blocks. +var Upgrader up.SpacedUpgrader = &up.StructUpgrader{ + SimpleUpgrader: up.SimpleUpgrader(doUpgrade), + Blocks: SpacedBlocks, +} diff --git a/bridgev2/bridgestate.go b/bridgev2/bridgestate.go new file mode 100644 index 00000000..96d9fd5c --- /dev/null +++ b/bridgev2/bridgestate.go @@ -0,0 +1,333 @@ +// Copyright (c) 2024 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package bridgev2 + +import ( + "context" + "fmt" + "math/rand/v2" + "runtime/debug" + "sync/atomic" + "time" + + "github.com/rs/zerolog" + "go.mau.fi/util/exfmt" + + "maunium.net/go/mautrix/bridgev2/status" + "maunium.net/go/mautrix/event" + "maunium.net/go/mautrix/format" +) + +var CatchBridgeStateQueuePanics = true + +type BridgeStateQueue struct { + prevUnsent *status.BridgeState + prevSent *status.BridgeState + errorSent bool + ch chan status.BridgeState + bridge *Bridge + login *UserLogin + + firstTransientDisconnect time.Time + cancelScheduledNotice atomic.Pointer[context.CancelFunc] + + stopChan chan struct{} + stopReconnect atomic.Pointer[context.CancelFunc] + + unknownErrorReconnects int +} + +func (br *Bridge) SendGlobalBridgeState(state status.BridgeState) { + state = state.Fill(nil) + for { + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + if err := br.Matrix.SendBridgeStatus(ctx, &state); err != nil { + br.Log.Warn().Err(err).Msg("Failed to update global bridge state") + cancel() + time.Sleep(5 * time.Second) + continue + } else { + br.Log.Debug().Any("bridge_state", state).Msg("Sent new global bridge state") + cancel() + break + } + } +} + +func (br *Bridge) NewBridgeStateQueue(login *UserLogin) *BridgeStateQueue { + bsq := &BridgeStateQueue{ + ch: make(chan status.BridgeState, 10), + stopChan: make(chan struct{}), + bridge: br, + login: login, + } + go bsq.loop() + return bsq +} + +func (bsq *BridgeStateQueue) Destroy() { + close(bsq.stopChan) + close(bsq.ch) + bsq.StopUnknownErrorReconnect() +} + +func (bsq *BridgeStateQueue) StopUnknownErrorReconnect() { + if bsq == nil { + return + } + if cancelFn := bsq.stopReconnect.Swap(nil); cancelFn != nil { + (*cancelFn)() + } + if cancelFn := bsq.cancelScheduledNotice.Swap(nil); cancelFn != nil { + (*cancelFn)() + } +} + +func (bsq *BridgeStateQueue) loop() { + if CatchBridgeStateQueuePanics { + defer func() { + err := recover() + if err != nil { + bsq.login.Log.Error(). + Bytes(zerolog.ErrorStackFieldName, debug.Stack()). + Any(zerolog.ErrorFieldName, err). + Msg("Panic in bridge state loop") + } + }() + } + for state := range bsq.ch { + bsq.immediateSendBridgeState(state) + } +} + +func (bsq *BridgeStateQueue) scheduleNotice(triggeredBy status.BridgeState) { + log := bsq.login.Log.With().Str("action", "transient disconnect notice").Logger() + ctx := log.WithContext(bsq.bridge.BackgroundCtx) + if !bsq.waitForTransientDisconnectReconnect(ctx) { + return + } + prevUnsent := bsq.GetPrevUnsent() + prev := bsq.GetPrev() + if triggeredBy.Timestamp != prev.Timestamp || len(bsq.ch) > 0 || bsq.errorSent || + prevUnsent.StateEvent != status.StateTransientDisconnect || prev.StateEvent != status.StateTransientDisconnect { + log.Trace().Any("triggered_by", triggeredBy).Msg("Not sending delayed transient disconnect notice") + return + } + log.Debug().Any("triggered_by", triggeredBy).Msg("Sending delayed transient disconnect notice") + bsq.sendNotice(ctx, triggeredBy, true) +} + +func (bsq *BridgeStateQueue) sendNotice(ctx context.Context, state status.BridgeState, isDelayed bool) { + noticeConfig := bsq.bridge.Config.BridgeStatusNotices + isError := state.StateEvent == status.StateBadCredentials || + state.StateEvent == status.StateUnknownError || + state.UserAction == status.UserActionOpenNative || + (isDelayed && state.StateEvent == status.StateTransientDisconnect) + sendNotice := noticeConfig == "all" || (noticeConfig == "errors" && + (isError || (bsq.errorSent && state.StateEvent == status.StateConnected))) + if state.StateEvent != status.StateTransientDisconnect && state.StateEvent != status.StateUnknownError { + bsq.firstTransientDisconnect = time.Time{} + } + if !sendNotice { + if !bsq.errorSent && !isDelayed && noticeConfig == "errors" && state.StateEvent == status.StateTransientDisconnect { + if bsq.firstTransientDisconnect.IsZero() { + bsq.firstTransientDisconnect = time.Now() + } + go bsq.scheduleNotice(state) + } + return + } + managementRoom, err := bsq.login.User.GetManagementRoom(ctx) + if err != nil { + bsq.login.Log.Err(err).Msg("Failed to get management room") + return + } + name := bsq.login.RemoteName + if name == "" { + name = fmt.Sprintf("`%s`", bsq.login.ID) + } + message := fmt.Sprintf("State update for %s: `%s`", name, state.StateEvent) + if state.Error != "" { + message += fmt.Sprintf(" (`%s`)", state.Error) + } + if isDelayed { + message += fmt.Sprintf(" not resolved after waiting %s", exfmt.Duration(TransientDisconnectNoticeDelay)) + } + if state.Message != "" { + message += fmt.Sprintf(": %s", state.Message) + } + content := format.RenderMarkdown(message, true, false) + if !isError { + content.MsgType = event.MsgNotice + } + _, err = bsq.bridge.Bot.SendMessage(ctx, managementRoom, event.EventMessage, &event.Content{ + Parsed: content, + Raw: map[string]any{ + "fi.mau.bridge_state": state, + }, + }, nil) + if err != nil { + bsq.login.Log.Err(err).Msg("Failed to send bridge state notice") + } else { + bsq.errorSent = isError + } +} + +func (bsq *BridgeStateQueue) unknownErrorReconnect(triggeredBy status.BridgeState) { + log := bsq.login.Log.With().Str("action", "unknown error reconnect").Logger() + ctx := log.WithContext(bsq.bridge.BackgroundCtx) + if !bsq.waitForUnknownErrorReconnect(ctx) { + return + } + prevUnsent := bsq.GetPrevUnsent() + prev := bsq.GetPrev() + if triggeredBy.Timestamp != prev.Timestamp { + log.Debug().Msg("Not reconnecting as a new bridge state was sent after the unknown error") + return + } else if len(bsq.ch) > 0 { + log.Warn().Msg("Not reconnecting as there are unsent bridge states") + return + } else if prevUnsent.StateEvent != status.StateUnknownError || prev.StateEvent != status.StateUnknownError { + log.Debug().Msg("Not reconnecting as the previous state was not an unknown error") + return + } else if bsq.unknownErrorReconnects > bsq.bridge.Config.UnknownErrorMaxAutoReconnects { + log.Warn().Msg("Not reconnecting as the maximum number of unknown error reconnects has been reached") + return + } + bsq.unknownErrorReconnects++ + log.Info(). + Int("reconnect_num", bsq.unknownErrorReconnects). + Msg("Disconnecting and reconnecting login due to unknown error") + bsq.login.Disconnect() + log.Debug().Msg("Disconnection finished, recreating client and reconnecting") + err := bsq.login.recreateClient(ctx) + if err != nil { + log.Err(err).Msg("Failed to recreate client after unknown error") + return + } + bsq.login.Client.Connect(ctx) + log.Debug().Msg("Reconnection finished") +} + +func (bsq *BridgeStateQueue) waitForUnknownErrorReconnect(ctx context.Context) bool { + reconnectIn := bsq.bridge.Config.UnknownErrorAutoReconnect + // Don't allow too low values + if reconnectIn < 1*time.Minute { + return false + } + reconnectIn += time.Duration(rand.Int64N(int64(float64(reconnectIn)*0.4)) - int64(float64(reconnectIn)*0.2)) + return bsq.waitForReconnect(ctx, reconnectIn, &bsq.stopReconnect) +} + +const TransientDisconnectNoticeDelay = 3 * time.Minute + +func (bsq *BridgeStateQueue) waitForTransientDisconnectReconnect(ctx context.Context) bool { + timeUntilSchedule := time.Until(bsq.firstTransientDisconnect.Add(TransientDisconnectNoticeDelay)) + zerolog.Ctx(ctx).Trace(). + Stringer("duration", timeUntilSchedule). + Msg("Waiting before sending notice about transient disconnect") + return bsq.waitForReconnect(ctx, timeUntilSchedule, &bsq.cancelScheduledNotice) +} + +func (bsq *BridgeStateQueue) waitForReconnect( + ctx context.Context, reconnectIn time.Duration, ptr *atomic.Pointer[context.CancelFunc], +) bool { + cancelCtx, cancel := context.WithCancel(ctx) + defer cancel() + if oldCancel := ptr.Swap(&cancel); oldCancel != nil { + (*oldCancel)() + } + select { + case <-time.After(reconnectIn): + return ptr.CompareAndSwap(&cancel, nil) + case <-cancelCtx.Done(): + return false + case <-bsq.stopChan: + return false + } +} + +func (bsq *BridgeStateQueue) immediateSendBridgeState(state status.BridgeState) { + if bsq.prevSent != nil && bsq.prevSent.ShouldDeduplicate(&state) { + bsq.login.Log.Debug(). + Str("state_event", string(state.StateEvent)). + Msg("Not sending bridge state as it's a duplicate") + return + } + if state.StateEvent == status.StateUnknownError { + go bsq.unknownErrorReconnect(state) + } + + ctx := bsq.login.Log.WithContext(context.Background()) + bsq.sendNotice(ctx, state, false) + + retryIn := 2 + for { + ctx, cancel := context.WithTimeout(ctx, 30*time.Second) + err := bsq.bridge.Matrix.SendBridgeStatus(ctx, &state) + cancel() + + if err != nil { + bsq.login.Log.Warn().Err(err). + Int("retry_in_seconds", retryIn). + Msg("Failed to update bridge state") + time.Sleep(time.Duration(retryIn) * time.Second) + retryIn *= 2 + if retryIn > 64 { + retryIn = 64 + } + } else { + bsq.prevSent = &state + bsq.login.Log.Debug(). + Any("bridge_state", state). + Msg("Sent new bridge state") + return + } + } +} + +func (bsq *BridgeStateQueue) Send(state status.BridgeState) { + if bsq == nil { + return + } + + state = state.Fill(bsq.login) + bsq.prevUnsent = &state + + if len(bsq.ch) >= 8 { + bsq.login.Log.Warn().Msg("Bridge state queue is nearly full, discarding an item") + select { + case <-bsq.ch: + default: + } + } + select { + case bsq.ch <- state: + default: + bsq.login.Log.Error().Msg("Bridge state queue is full, dropped new state") + } +} + +func (bsq *BridgeStateQueue) GetPrev() status.BridgeState { + if bsq != nil && bsq.prevSent != nil { + return *bsq.prevSent + } + return status.BridgeState{} +} + +func (bsq *BridgeStateQueue) GetPrevUnsent() status.BridgeState { + if bsq != nil && bsq.prevSent != nil { + return *bsq.prevUnsent + } + return status.BridgeState{} +} + +func (bsq *BridgeStateQueue) SetPrev(prev status.BridgeState) { + if bsq != nil { + bsq.prevSent = &prev + } +} diff --git a/bridgev2/commands/cleanup.go b/bridgev2/commands/cleanup.go new file mode 100644 index 00000000..dc21a16e --- /dev/null +++ b/bridgev2/commands/cleanup.go @@ -0,0 +1,97 @@ +// Copyright (c) 2024 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package commands + +import ( + "maunium.net/go/mautrix/bridgev2" +) + +var CommandDeletePortal = &FullHandler{ + Func: func(ce *Event) { + // TODO clean up child portals? + err := ce.Portal.Delete(ce.Ctx) + if err != nil { + ce.Reply("Failed to delete portal: %v", err) + return + } + err = ce.Bot.DeleteRoom(ce.Ctx, ce.Portal.MXID, false) + if err != nil { + ce.Reply("Failed to clean up room: %v", err) + } + ce.MessageStatus.DisableMSS = true + }, + Name: "delete-portal", + Help: HelpMeta{ + Section: HelpSectionAdmin, + Description: "Delete the current portal room", + }, + RequiresAdmin: true, + RequiresPortal: true, +} + +var CommandDeleteAllPortals = &FullHandler{ + Func: func(ce *Event) { + portals, err := ce.Bridge.GetAllPortals(ce.Ctx) + if err != nil { + ce.Reply("Failed to get portals: %v", err) + return + } + bridgev2.DeleteManyPortals(ce.Ctx, portals, func(portal *bridgev2.Portal, delete bool, err error) { + if !delete { + ce.Reply("Failed to delete portal %s: %v", portal.MXID, err) + } else { + ce.Reply("Failed to clean up room %s: %v", portal.MXID, err) + } + }) + }, + Name: "delete-all-portals", + Help: HelpMeta{ + Section: HelpSectionAdmin, + Description: "Delete all portals the bridge knows about", + }, + RequiresAdmin: true, +} + +var CommandSetManagementRoom = &FullHandler{ + Func: func(ce *Event) { + if ce.User.ManagementRoom == ce.RoomID { + ce.Reply("This room is already your management room") + return + } else if ce.Portal != nil { + ce.Reply("This is a portal room: you can't set this as your management room") + return + } + members, err := ce.Bridge.Matrix.GetMembers(ce.Ctx, ce.RoomID) + if err != nil { + ce.Log.Err(err).Msg("Failed to get room members to check if room can be a management room") + ce.Reply("Failed to get room members") + return + } + _, hasBot := members[ce.Bot.GetMXID()] + if !hasBot { + // This reply will probably fail, but whatever + ce.Reply("The bridge bot must be in the room to set it as your management room") + return + } else if len(members) != 2 { + ce.Reply("Your management room must not have any members other than you and the bridge bot") + return + } + ce.User.ManagementRoom = ce.RoomID + err = ce.User.Save(ce.Ctx) + if err != nil { + ce.Log.Err(err).Msg("Failed to save management room") + ce.Reply("Failed to save management room") + } else { + ce.Reply("Management room updated") + } + }, + Name: "set-management-room", + Help: HelpMeta{ + Section: HelpSectionGeneral, + Description: "Mark this room as your management room", + }, +} diff --git a/bridgev2/commands/debug.go b/bridgev2/commands/debug.go new file mode 100644 index 00000000..1cae98fe --- /dev/null +++ b/bridgev2/commands/debug.go @@ -0,0 +1,125 @@ +// Copyright (c) 2024 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package commands + +import ( + "encoding/json" + "strings" + "time" + + "maunium.net/go/mautrix/bridgev2" + "maunium.net/go/mautrix/bridgev2/networkid" + "maunium.net/go/mautrix/event" +) + +var CommandRegisterPush = &FullHandler{ + Func: func(ce *Event) { + if len(ce.Args) < 3 { + ce.Reply("Usage: `$cmdprefix debug-register-push `\n\nYour logins:\n\n%s", ce.User.GetFormattedUserLogins()) + return + } + pushType := bridgev2.PushTypeFromString(ce.Args[1]) + if pushType == bridgev2.PushTypeUnknown { + ce.Reply("Unknown push type `%s`. Allowed types: `web`, `apns`, `fcm`", ce.Args[1]) + return + } + login := ce.Bridge.GetCachedUserLoginByID(networkid.UserLoginID(ce.Args[0])) + if login == nil || login.UserMXID != ce.User.MXID { + ce.Reply("Login `%s` not found", ce.Args[0]) + return + } + pushable, ok := login.Client.(bridgev2.PushableNetworkAPI) + if !ok { + ce.Reply("This network connector does not support push registration") + return + } + pushToken := strings.Join(ce.Args[2:], " ") + if pushToken == "null" { + pushToken = "" + } + err := pushable.RegisterPushNotifications(ce.Ctx, pushType, pushToken) + if err != nil { + ce.Reply("Failed to register pusher: %v", err) + return + } + if pushToken == "" { + ce.Reply("Pusher de-registered successfully") + } else { + ce.Reply("Pusher registered successfully") + } + }, + Name: "debug-register-push", + Help: HelpMeta{ + Section: HelpSectionAdmin, + Description: "Register a pusher", + Args: "<_login ID_> <_push type_> <_push token_>", + }, + RequiresAdmin: true, + RequiresLogin: true, + NetworkAPI: NetworkAPIImplements[bridgev2.PushableNetworkAPI], +} + +var CommandSendAccountData = &FullHandler{ + Func: func(ce *Event) { + if len(ce.Args) < 2 { + ce.Reply("Usage: `$cmdprefix debug-account-data ") + return + } + var content event.Content + evtType := event.Type{Type: ce.Args[0], Class: event.AccountDataEventType} + ce.RawArgs = strings.TrimSpace(strings.Trim(ce.RawArgs, ce.Args[0])) + err := json.Unmarshal([]byte(ce.RawArgs), &content) + if err != nil { + ce.Reply("Failed to parse JSON: %v", err) + return + } + err = content.ParseRaw(evtType) + if err != nil { + ce.Reply("Failed to deserialize content: %v", err) + return + } + res := ce.Bridge.QueueMatrixEvent(ce.Ctx, &event.Event{ + Sender: ce.User.MXID, + Type: evtType, + Timestamp: time.Now().UnixMilli(), + RoomID: ce.RoomID, + Content: content, + }) + ce.Reply("Result: %+v", res) + }, + Name: "debug-account-data", + Help: HelpMeta{ + Section: HelpSectionAdmin, + Description: "Send a room account data event to the bridge", + Args: "<_type_> <_content_>", + }, + RequiresAdmin: true, + RequiresPortal: true, + RequiresLogin: true, +} + +var CommandResetNetwork = &FullHandler{ + Func: func(ce *Event) { + if strings.Contains(strings.ToLower(ce.RawArgs), "--reset-transport") { + nrn, ok := ce.Bridge.Network.(bridgev2.NetworkResettingNetwork) + if ok { + nrn.ResetHTTPTransport() + } else { + ce.Reply("Network connector does not support resetting HTTP transport") + } + } + ce.Bridge.ResetNetworkConnections() + ce.React("✅️") + }, + Name: "debug-reset-network", + Help: HelpMeta{ + Section: HelpSectionAdmin, + Description: "Reset network connections to the remote network", + Args: "[--reset-transport]", + }, + RequiresAdmin: true, +} diff --git a/bridgev2/commands/event.go b/bridgev2/commands/event.go new file mode 100644 index 00000000..88ba9698 --- /dev/null +++ b/bridgev2/commands/event.go @@ -0,0 +1,100 @@ +// Copyright (c) 2023 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package commands + +import ( + "context" + "fmt" + "strings" + "time" + + "github.com/rs/zerolog" + + "maunium.net/go/mautrix/bridgev2" + + "maunium.net/go/mautrix" + "maunium.net/go/mautrix/event" + "maunium.net/go/mautrix/format" + "maunium.net/go/mautrix/id" +) + +// Event stores all data which might be used to handle commands +type Event struct { + Bot bridgev2.MatrixAPI + Bridge *bridgev2.Bridge + Portal *bridgev2.Portal + Processor *Processor + Handler MinimalCommandHandler + RoomID id.RoomID + OrigRoomID id.RoomID + EventID id.EventID + User *bridgev2.User + Command string + Args []string + RawArgs string + ReplyTo id.EventID + Ctx context.Context + Log *zerolog.Logger + + MessageStatus *bridgev2.MessageStatus +} + +// Reply sends a reply to command as notice, with optional string formatting and automatic $cmdprefix replacement. +func (ce *Event) Reply(msg string, args ...any) { + msg = strings.ReplaceAll(msg, "$cmdprefix ", ce.Bridge.Config.CommandPrefix+" ") + if len(args) > 0 { + msg = fmt.Sprintf(msg, args...) + } + ce.ReplyAdvanced(msg, true, false) +} + +// ReplyAdvanced sends a reply to command as notice. It allows using HTML and disabling markdown, +// but doesn't have built-in string formatting. +func (ce *Event) ReplyAdvanced(msg string, allowMarkdown, allowHTML bool) { + content := format.RenderMarkdown(msg, allowMarkdown, allowHTML) + content.MsgType = event.MsgNotice + _, err := ce.Bot.SendMessage(ce.Ctx, ce.OrigRoomID, event.EventMessage, &event.Content{Parsed: &content}, nil) + if err != nil { + ce.Log.Err(err).Msg("Failed to reply to command") + } +} + +// React sends a reaction to the command. +func (ce *Event) React(key string) { + _, err := ce.Bot.SendMessage(ce.Ctx, ce.OrigRoomID, event.EventReaction, &event.Content{ + Parsed: &event.ReactionEventContent{ + RelatesTo: event.RelatesTo{ + Type: event.RelAnnotation, + EventID: ce.EventID, + Key: key, + }, + }, + }, nil) + if err != nil { + ce.Log.Err(err).Msg("Failed to react to command") + } +} + +// Redact redacts the command. +func (ce *Event) Redact(req ...mautrix.ReqRedact) { + _, err := ce.Bot.SendMessage(ce.Ctx, ce.OrigRoomID, event.EventRedaction, &event.Content{ + Parsed: &event.RedactionEventContent{ + Redacts: ce.EventID, + }, + }, nil) + if err != nil { + ce.Log.Err(err).Msg("Failed to redact command") + } +} + +// MarkRead marks the command event as read. +func (ce *Event) MarkRead() { + err := ce.Bot.MarkRead(ce.Ctx, ce.RoomID, ce.EventID, time.Now()) + if err != nil { + ce.Log.Err(err).Msg("Failed to mark command as read") + } +} diff --git a/bridgev2/commands/handler.go b/bridgev2/commands/handler.go new file mode 100644 index 00000000..672c81dc --- /dev/null +++ b/bridgev2/commands/handler.go @@ -0,0 +1,118 @@ +// Copyright (c) 2024 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package commands + +import ( + "maunium.net/go/mautrix/bridgev2" + "maunium.net/go/mautrix/event" +) + +type MinimalCommandHandler interface { + Run(*Event) +} + +type MinimalCommandHandlerFunc func(*Event) + +func (mhf MinimalCommandHandlerFunc) Run(ce *Event) { + mhf(ce) +} + +type CommandState struct { + Next MinimalCommandHandler + Action string + Meta any + Cancel func() +} + +type CommandHandler interface { + MinimalCommandHandler + GetName() string +} + +type AliasedCommandHandler interface { + CommandHandler + GetAliases() []string +} + +func NetworkAPIImplements[T bridgev2.NetworkAPI](val bridgev2.NetworkAPI) bool { + _, ok := val.(T) + return ok +} + +func NetworkConnectorImplements[T bridgev2.NetworkConnector](val bridgev2.NetworkConnector) bool { + _, ok := val.(T) + return ok +} + +type ImplementationChecker[T any] func(val T) bool + +type FullHandler struct { + Func func(*Event) + + Name string + Aliases []string + Help HelpMeta + + RequiresAdmin bool + RequiresPortal bool + RequiresLogin bool + RequiresEventLevel event.Type + RequiresLoginPermission bool + + NetworkAPI ImplementationChecker[bridgev2.NetworkAPI] + NetworkConnector ImplementationChecker[bridgev2.NetworkConnector] +} + +func (fh *FullHandler) GetHelp() HelpMeta { + fh.Help.Command = fh.Name + return fh.Help +} + +func (fh *FullHandler) GetName() string { + return fh.Name +} + +func (fh *FullHandler) GetAliases() []string { + return fh.Aliases +} + +func (fh *FullHandler) ImplementationsFulfilled(ce *Event) bool { + // TODO add dedicated method to get an empty NetworkAPI instead of getting default login + client := ce.User.GetDefaultLogin() + return (fh.NetworkAPI == nil || client == nil || fh.NetworkAPI(client.Client)) && + (fh.NetworkConnector == nil || fh.NetworkConnector(ce.Bridge.Network)) +} + +func (fh *FullHandler) ShowInHelp(ce *Event) bool { + return fh.ImplementationsFulfilled(ce) && (!fh.RequiresAdmin || ce.User.Permissions.Admin) +} + +func (fh *FullHandler) userHasRoomPermission(ce *Event) bool { + levels, err := ce.Bridge.Matrix.GetPowerLevels(ce.Ctx, ce.RoomID) + if err != nil { + ce.Log.Warn().Err(err).Msg("Failed to check room power levels") + ce.Reply("Failed to get room power levels to see if you're allowed to use that command") + return false + } + return levels.GetUserLevel(ce.User.MXID) >= levels.GetEventLevel(fh.RequiresEventLevel) +} + +func (fh *FullHandler) Run(ce *Event) { + if fh.RequiresAdmin && !ce.User.Permissions.Admin { + ce.Reply("That command is limited to bridge administrators.") + } else if fh.RequiresLoginPermission && !ce.User.Permissions.Login { + ce.Reply("You do not have permissions to log into this bridge.") + } else if fh.RequiresEventLevel.Type != "" && !ce.User.Permissions.Admin && !fh.userHasRoomPermission(ce) { + ce.Reply("That command requires room admin rights.") + } else if fh.RequiresPortal && ce.Portal == nil { + ce.Reply("That command can only be ran in portal rooms.") + } else if fh.RequiresLogin && ce.User.GetDefaultLogin() == nil { + ce.Reply("That command requires you to be logged in.") + } else { + fh.Func(ce) + } +} diff --git a/bridge/commands/help.go b/bridgev2/commands/help.go similarity index 88% rename from bridge/commands/help.go rename to bridgev2/commands/help.go index f4891555..5c91a4d1 100644 --- a/bridge/commands/help.go +++ b/bridgev2/commands/help.go @@ -13,7 +13,7 @@ import ( ) type HelpfulHandler interface { - Handler + CommandHandler GetHelp() HelpMeta ShowInHelp(*Event) bool } @@ -29,6 +29,7 @@ var ( HelpSectionGeneral = HelpSection{"General", 0} HelpSectionAuth = HelpSection{"Authentication", 10} + HelpSectionChats = HelpSection{"Starting and managing chats", 20} HelpSectionAdmin = HelpSection{"Administration", 50} ) @@ -101,14 +102,14 @@ func FormatHelp(ce *Event) string { output.Grow(10240) var prefixMsg string - if ce.RoomID == ce.User.GetManagementRoomID() { + if ce.RoomID == ce.User.ManagementRoom { prefixMsg = "This is your management room: prefixing commands with `%s` is not required." } else if ce.Portal != nil { prefixMsg = "**This is a portal room**: you must always prefix commands with `%s`. Management commands will not be bridged." } else { prefixMsg = "This is not your management room: prefixing commands with `%s` is required." } - _, _ = fmt.Fprintf(&output, prefixMsg, ce.Bridge.Config.Bridge.GetCommandPrefix()) + _, _ = fmt.Fprintf(&output, prefixMsg, ce.Bridge.Config.CommandPrefix) output.WriteByte('\n') output.WriteString("Parameters in [square brackets] are optional, while parameters in are required.") output.WriteByte('\n') @@ -127,3 +128,14 @@ func FormatHelp(ce *Event) string { } return output.String() } + +var CommandHelp = &FullHandler{ + Func: func(ce *Event) { + ce.Reply(FormatHelp(ce)) + }, + Name: "help", + Help: HelpMeta{ + Section: HelpSectionGeneral, + Description: "Show this help message.", + }, +} diff --git a/bridgev2/commands/login.go b/bridgev2/commands/login.go new file mode 100644 index 00000000..96d62d3e --- /dev/null +++ b/bridgev2/commands/login.go @@ -0,0 +1,616 @@ +// Copyright (c) 2024 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package commands + +import ( + "context" + "encoding/json" + "fmt" + "html" + "net/url" + "regexp" + "slices" + "strings" + + "github.com/skip2/go-qrcode" + "go.mau.fi/util/curl" + + "maunium.net/go/mautrix/bridgev2" + "maunium.net/go/mautrix/bridgev2/networkid" + "maunium.net/go/mautrix/bridgev2/status" + "maunium.net/go/mautrix/event" + "maunium.net/go/mautrix/id" +) + +var CommandLogin = &FullHandler{ + Func: fnLogin, + Name: "login", + Help: HelpMeta{ + Section: HelpSectionAuth, + Description: "Log into the bridge", + Args: "[_flow ID_]", + }, + RequiresLoginPermission: true, +} + +var CommandRelogin = &FullHandler{ + Func: fnLogin, + Name: "relogin", + Help: HelpMeta{ + Section: HelpSectionAuth, + Description: "Re-authenticate an existing login", + Args: "<_login ID_> [_flow ID_]", + }, + RequiresLoginPermission: true, +} + +func formatFlowsReply(flows []bridgev2.LoginFlow) string { + var buf strings.Builder + for _, flow := range flows { + _, _ = fmt.Fprintf(&buf, "* `%s` - %s\n", flow.ID, flow.Description) + } + return buf.String() +} + +func fnLogin(ce *Event) { + var reauth *bridgev2.UserLogin + if ce.Command == "relogin" { + if len(ce.Args) == 0 { + ce.Reply("Usage: `$cmdprefix relogin [_flow ID_]`\n\nYour logins:\n\n%s", ce.User.GetFormattedUserLogins()) + return + } + reauth = ce.Bridge.GetCachedUserLoginByID(networkid.UserLoginID(ce.Args[0])) + if reauth == nil { + ce.Reply("Login `%s` not found", ce.Args[0]) + return + } + ce.Args = ce.Args[1:] + } + if reauth == nil && ce.User.HasTooManyLogins() { + ce.Reply( + "You have reached the maximum number of logins (%d). "+ + "Please logout from an existing login before creating a new one. "+ + "If you want to re-authenticate an existing login, use the `$cmdprefix relogin` command.", + ce.User.Permissions.MaxLogins, + ) + return + } + flows := ce.Bridge.Network.GetLoginFlows() + var chosenFlowID string + if len(ce.Args) > 0 { + inputFlowID := strings.ToLower(ce.Args[0]) + ce.Args = ce.Args[1:] + for _, flow := range flows { + if flow.ID == inputFlowID { + chosenFlowID = flow.ID + break + } + } + if chosenFlowID == "" { + ce.Reply("Invalid login flow `%s`. Available options:\n\n%s", inputFlowID, formatFlowsReply(flows)) + return + } + } else if len(flows) == 1 { + chosenFlowID = flows[0].ID + } else { + if reauth != nil { + ce.Reply("Please specify a login flow, e.g. `relogin %s %s`.\n\n%s", reauth.ID, flows[0].ID, formatFlowsReply(flows)) + } else { + ce.Reply("Please specify a login flow, e.g. `login %s`.\n\n%s", flows[0].ID, formatFlowsReply(flows)) + } + return + } + + login, err := ce.Bridge.Network.CreateLogin(ce.Ctx, ce.User, chosenFlowID) + if err != nil { + ce.Reply("Failed to prepare login process: %v", err) + return + } + overridable, ok := login.(bridgev2.LoginProcessWithOverride) + var nextStep *bridgev2.LoginStep + if ok && reauth != nil { + nextStep, err = overridable.StartWithOverride(ce.Ctx, reauth) + } else { + nextStep, err = login.Start(ce.Ctx) + } + if err != nil { + ce.Reply("Failed to start login: %v", err) + return + } + ce.Log.Debug().Any("first_step", nextStep).Msg("Created login process") + + nextStep = checkLoginCommandDirectParams(ce, login, nextStep) + if nextStep != nil { + doLoginStep(ce, login, nextStep, reauth) + } +} + +func checkLoginCommandDirectParams(ce *Event, login bridgev2.LoginProcess, nextStep *bridgev2.LoginStep) *bridgev2.LoginStep { + if len(ce.Args) == 0 { + return nextStep + } + var ok bool + defer func() { + if !ok { + login.Cancel() + } + }() + var err error + switch nextStep.Type { + case bridgev2.LoginStepTypeDisplayAndWait: + ce.Reply("Invalid extra parameters for display and wait login step") + return nil + case bridgev2.LoginStepTypeUserInput: + if len(ce.Args) != len(nextStep.UserInputParams.Fields) { + ce.Reply("Invalid number of extra parameters (expected 0 or %d, got %d)", len(nextStep.UserInputParams.Fields), len(ce.Args)) + return nil + } + input := make(map[string]string) + var shouldRedact bool + for i, param := range nextStep.UserInputParams.Fields { + param.FillDefaultValidate() + input[param.ID], err = param.Validate(ce.Args[i]) + if err != nil { + ce.Reply("Invalid value for %s: %v", param.Name, err) + return nil + } + if param.Type == bridgev2.LoginInputFieldTypePassword || param.Type == bridgev2.LoginInputFieldTypeToken { + shouldRedact = true + } + } + if shouldRedact { + ce.Redact() + } + nextStep, err = login.(bridgev2.LoginProcessUserInput).SubmitUserInput(ce.Ctx, input) + case bridgev2.LoginStepTypeCookies: + if len(ce.Args) != len(nextStep.CookiesParams.Fields) { + ce.Reply("Invalid number of extra parameters (expected 0 or %d, got %d)", len(nextStep.CookiesParams.Fields), len(ce.Args)) + return nil + } + input := make(map[string]string) + for i, param := range nextStep.CookiesParams.Fields { + val := maybeURLDecodeCookie(ce.Args[i], ¶m) + if match, _ := regexp.MatchString(param.Pattern, val); !match { + ce.Reply("Invalid value for %s: `%s` doesn't match regex `%s`", param.ID, val, param.Pattern) + return nil + } + input[param.ID] = val + } + ce.Redact() + nextStep, err = login.(bridgev2.LoginProcessCookies).SubmitCookies(ce.Ctx, input) + } + if err != nil { + ce.Reply("Failed to submit input: %v", err) + return nil + } + ok = true + return nextStep +} + +type userInputLoginCommandState struct { + Login bridgev2.LoginProcessUserInput + Data map[string]string + RemainingFields []bridgev2.LoginInputDataField + Override *bridgev2.UserLogin +} + +func (uilcs *userInputLoginCommandState) promptNext(ce *Event) { + field := uilcs.RemainingFields[0] + parts := []string{fmt.Sprintf("Please enter your %s", field.Name)} + if field.Description != "" { + parts = append(parts, field.Description) + } + if len(field.Options) > 0 { + parts = append(parts, fmt.Sprintf("Options: `%s`", strings.Join(field.Options, "`, `"))) + } + ce.Reply(strings.Join(parts, "\n")) + StoreCommandState(ce.User, &CommandState{ + Next: MinimalCommandHandlerFunc(uilcs.submitNext), + Action: "Login", + Meta: uilcs, + Cancel: uilcs.Login.Cancel, + }) +} + +func (uilcs *userInputLoginCommandState) submitNext(ce *Event) { + field := uilcs.RemainingFields[0] + field.FillDefaultValidate() + if field.Type == bridgev2.LoginInputFieldTypePassword || field.Type == bridgev2.LoginInputFieldTypeToken { + ce.Redact() + } + var err error + uilcs.Data[field.ID], err = field.Validate(ce.RawArgs) + if err != nil { + ce.Reply("Invalid value: %v", err) + return + } else if len(uilcs.RemainingFields) > 1 { + uilcs.RemainingFields = uilcs.RemainingFields[1:] + uilcs.promptNext(ce) + return + } + StoreCommandState(ce.User, nil) + if nextStep, err := uilcs.Login.SubmitUserInput(ce.Ctx, uilcs.Data); err != nil { + ce.Reply("Failed to submit input: %v", err) + } else { + doLoginStep(ce, uilcs.Login, nextStep, uilcs.Override) + } +} + +const qrSizePx = 512 + +func sendQR(ce *Event, qr string, prevEventID *id.EventID) error { + qrData, err := qrcode.Encode(qr, qrcode.Low, qrSizePx) + if err != nil { + return fmt.Errorf("failed to encode QR code: %w", err) + } + qrMXC, qrFile, err := ce.Bot.UploadMedia(ce.Ctx, ce.RoomID, qrData, "qr.png", "image/png") + if err != nil { + return fmt.Errorf("failed to upload image: %w", err) + } + content := &event.MessageEventContent{ + MsgType: event.MsgImage, + FileName: "qr.png", + URL: qrMXC, + File: qrFile, + Body: qr, + Format: event.FormatHTML, + FormattedBody: fmt.Sprintf("
%s
", html.EscapeString(qr)), + Info: &event.FileInfo{ + MimeType: "image/png", + Width: qrSizePx, + Height: qrSizePx, + Size: len(qrData), + }, + } + if *prevEventID != "" { + content.SetEdit(*prevEventID) + } + newEventID, err := ce.Bot.SendMessage(ce.Ctx, ce.RoomID, event.EventMessage, &event.Content{Parsed: content}, nil) + if err != nil { + return err + } + if *prevEventID == "" { + *prevEventID = newEventID.EventID + } + return nil +} + +func sendUserInputAttachments(ce *Event, atts []*bridgev2.LoginUserInputAttachment) error { + for _, att := range atts { + if att.FileName == "" { + return fmt.Errorf("missing attachment filename") + } + mxc, file, err := ce.Bot.UploadMedia(ce.Ctx, ce.RoomID, att.Content, att.FileName, att.Info.MimeType) + if err != nil { + return fmt.Errorf("failed to upload attachment %q: %w", att.FileName, err) + } + content := &event.MessageEventContent{ + MsgType: att.Type, + FileName: att.FileName, + URL: mxc, + File: file, + Info: &event.FileInfo{ + MimeType: att.Info.MimeType, + Width: att.Info.Width, + Height: att.Info.Height, + Size: att.Info.Size, + }, + Body: att.FileName, + } + _, err = ce.Bot.SendMessage(ce.Ctx, ce.RoomID, event.EventMessage, &event.Content{Parsed: content}, nil) + if err != nil { + return nil + } + } + return nil +} + +type contextKey int + +const ( + contextKeyPrevEventID contextKey = iota +) + +func doLoginDisplayAndWait(ce *Event, login bridgev2.LoginProcessDisplayAndWait, step *bridgev2.LoginStep, override *bridgev2.UserLogin) { + prevEvent, ok := ce.Ctx.Value(contextKeyPrevEventID).(*id.EventID) + if !ok { + prevEvent = new(id.EventID) + ce.Ctx = context.WithValue(ce.Ctx, contextKeyPrevEventID, prevEvent) + } + cancelCtx, cancelFunc := context.WithCancel(ce.Ctx) + defer cancelFunc() + StoreCommandState(ce.User, &CommandState{ + Action: "Login", + Cancel: cancelFunc, + }) + defer StoreCommandState(ce.User, nil) + switch step.DisplayAndWaitParams.Type { + case bridgev2.LoginDisplayTypeQR: + err := sendQR(ce, step.DisplayAndWaitParams.Data, prevEvent) + if err != nil { + ce.Reply("Failed to send QR code: %v", err) + login.Cancel() + return + } + case bridgev2.LoginDisplayTypeEmoji: + ce.ReplyAdvanced(step.DisplayAndWaitParams.Data, false, false) + case bridgev2.LoginDisplayTypeCode: + ce.ReplyAdvanced(fmt.Sprintf("%s", html.EscapeString(step.DisplayAndWaitParams.Data)), false, true) + case bridgev2.LoginDisplayTypeNothing: + // Do nothing + default: + ce.Reply("Unsupported display type %q", step.DisplayAndWaitParams.Type) + login.Cancel() + return + } + nextStep, err := login.Wait(cancelCtx) + // Redact the QR code, unless the next step is refreshing the code (in which case the event is just edited) + if *prevEvent != "" && (nextStep == nil || nextStep.StepID != step.StepID) { + _, _ = ce.Bot.SendMessage(ce.Ctx, ce.RoomID, event.EventRedaction, &event.Content{ + Parsed: &event.RedactionEventContent{ + Redacts: *prevEvent, + }, + }, nil) + *prevEvent = "" + } + if err != nil { + ce.Reply("Login failed: %v", err) + return + } + doLoginStep(ce, login, nextStep, override) +} + +type cookieLoginCommandState struct { + Login bridgev2.LoginProcessCookies + Data *bridgev2.LoginCookiesParams + Override *bridgev2.UserLogin +} + +func (clcs *cookieLoginCommandState) prompt(ce *Event) { + ce.Reply("Login URL: <%s>", clcs.Data.URL) + StoreCommandState(ce.User, &CommandState{ + Next: MinimalCommandHandlerFunc(clcs.submit), + Action: "Login", + Meta: clcs, + Cancel: clcs.Login.Cancel, + }) +} + +func (clcs *cookieLoginCommandState) submit(ce *Event) { + ce.Redact() + + cookiesInput := make(map[string]string) + if strings.HasPrefix(strings.TrimSpace(ce.RawArgs), "curl") { + parsed, err := curl.Parse(ce.RawArgs) + if err != nil { + ce.Reply("Failed to parse curl: %v", err) + return + } + reqCookies := make(map[string]string) + for _, cookie := range parsed.Cookies() { + reqCookies[cookie.Name], err = url.PathUnescape(cookie.Value) + if err != nil { + ce.Reply("Failed to parse cookie %s: %v", cookie.Name, err) + return + } + } + var missingKeys, unsupportedKeys []string + for _, field := range clcs.Data.Fields { + var value string + var supported bool + for _, src := range field.Sources { + switch src.Type { + case bridgev2.LoginCookieTypeCookie: + supported = true + value = reqCookies[src.Name] + case bridgev2.LoginCookieTypeRequestHeader: + supported = true + value = parsed.Header.Get(src.Name) + case bridgev2.LoginCookieTypeRequestBody: + supported = true + switch { + case parsed.MultipartForm != nil: + values, ok := parsed.MultipartForm.Value[src.Name] + if ok && len(values) > 0 { + value = values[0] + } + case parsed.ParsedJSON != nil: + untypedValue, ok := parsed.ParsedJSON[src.Name] + if ok { + value = fmt.Sprintf("%v", untypedValue) + } + } + } + if value != "" { + cookiesInput[field.ID] = value + break + } + } + if value == "" && field.Required { + if supported { + missingKeys = append(missingKeys, field.ID) + } else { + unsupportedKeys = append(unsupportedKeys, field.ID) + } + } + } + if len(unsupportedKeys) > 0 { + ce.Reply("Some keys can't be extracted from a cURL request: %+v\n\nPlease provide a JSON object instead.", unsupportedKeys) + return + } else if len(missingKeys) > 0 { + ce.Reply("Missing some keys: %+v", missingKeys) + return + } + } else { + err := json.Unmarshal([]byte(ce.RawArgs), &cookiesInput) + if err != nil { + ce.Reply("Failed to parse input as JSON: %v", err) + return + } + for _, field := range clcs.Data.Fields { + val, ok := cookiesInput[field.ID] + if ok { + cookiesInput[field.ID] = maybeURLDecodeCookie(val, &field) + } + } + } + var missingKeys []string + for _, field := range clcs.Data.Fields { + val, ok := cookiesInput[field.ID] + if !ok && field.Required { + missingKeys = append(missingKeys, field.ID) + } + if match, _ := regexp.MatchString(field.Pattern, val); !match { + ce.Reply("Invalid value for %s: `%s` doesn't match regex `%s`", field.ID, val, field.Pattern) + return + } + } + if len(missingKeys) > 0 { + ce.Reply("Missing some keys: %+v", missingKeys) + return + } + StoreCommandState(ce.User, nil) + nextStep, err := clcs.Login.SubmitCookies(ce.Ctx, cookiesInput) + if err != nil { + ce.Reply("Login failed: %v", err) + return + } + doLoginStep(ce, clcs.Login, nextStep, clcs.Override) +} + +func maybeURLDecodeCookie(val string, field *bridgev2.LoginCookieField) string { + if val == "" { + return val + } + isCookie := slices.ContainsFunc(field.Sources, func(src bridgev2.LoginCookieFieldSource) bool { + return src.Type == bridgev2.LoginCookieTypeCookie + }) + if !isCookie { + return val + } + decoded, err := url.PathUnescape(val) + if err != nil { + return val + } + return decoded +} + +func doLoginStep(ce *Event, login bridgev2.LoginProcess, step *bridgev2.LoginStep, override *bridgev2.UserLogin) { + ce.Log.Debug().Any("next_step", step).Msg("Got next login step") + if step.Instructions != "" { + ce.Reply(step.Instructions) + } + + switch step.Type { + case bridgev2.LoginStepTypeDisplayAndWait: + doLoginDisplayAndWait(ce, login.(bridgev2.LoginProcessDisplayAndWait), step, override) + case bridgev2.LoginStepTypeCookies: + (&cookieLoginCommandState{ + Login: login.(bridgev2.LoginProcessCookies), + Data: step.CookiesParams, + Override: override, + }).prompt(ce) + case bridgev2.LoginStepTypeUserInput: + err := sendUserInputAttachments(ce, step.UserInputParams.Attachments) + if err != nil { + ce.Reply("Failed to send attachments: %v", err) + } + (&userInputLoginCommandState{ + Login: login.(bridgev2.LoginProcessUserInput), + RemainingFields: step.UserInputParams.Fields, + Data: make(map[string]string), + Override: override, + }).promptNext(ce) + case bridgev2.LoginStepTypeComplete: + if override != nil && override.ID != step.CompleteParams.UserLoginID { + ce.Log.Info(). + Str("old_login_id", string(override.ID)). + Str("new_login_id", string(step.CompleteParams.UserLoginID)). + Msg("Login resulted in different remote ID than what was being overridden. Deleting previous login") + override.Delete(ce.Ctx, status.BridgeState{ + StateEvent: status.StateLoggedOut, + Reason: "LOGIN_OVERRIDDEN", + }, bridgev2.DeleteOpts{LogoutRemote: true}) + } + default: + panic(fmt.Errorf("unknown login step type %q", step.Type)) + } +} + +var CommandListLogins = &FullHandler{ + Func: fnListLogins, + Name: "list-logins", + Help: HelpMeta{ + Section: HelpSectionAuth, + Description: "List your logins", + }, + RequiresLoginPermission: true, +} + +func fnListLogins(ce *Event) { + logins := ce.User.GetFormattedUserLogins() + if len(logins) == 0 { + ce.Reply("You're not logged in") + } else { + ce.Reply("%s", logins) + } +} + +var CommandLogout = &FullHandler{ + Func: fnLogout, + Name: "logout", + Help: HelpMeta{ + Section: HelpSectionAuth, + Description: "Log out of the bridge", + Args: "<_login ID_>", + }, +} + +func fnLogout(ce *Event) { + if len(ce.Args) == 0 { + ce.Reply("Usage: `$cmdprefix logout `\n\nYour logins:\n\n%s", ce.User.GetFormattedUserLogins()) + return + } + login := ce.Bridge.GetCachedUserLoginByID(networkid.UserLoginID(ce.Args[0])) + if login == nil || login.UserMXID != ce.User.MXID { + ce.Reply("Login `%s` not found", ce.Args[0]) + return + } + login.Logout(ce.Ctx) + ce.Reply("Logged out") +} + +var CommandSetPreferredLogin = &FullHandler{ + Func: fnSetPreferredLogin, + Name: "set-preferred-login", + Aliases: []string{"prefer"}, + Help: HelpMeta{ + Section: HelpSectionAuth, + Description: "Set the preferred login ID for sending messages to this portal (only relevant when logged into multiple accounts via the bridge)", + Args: "<_login ID_>", + }, + RequiresPortal: true, + RequiresLoginPermission: true, +} + +func fnSetPreferredLogin(ce *Event) { + if len(ce.Args) == 0 { + ce.Reply("Usage: `$cmdprefix set-preferred-login `\n\nYour logins:\n\n%s", ce.User.GetFormattedUserLogins()) + return + } + login := ce.Bridge.GetCachedUserLoginByID(networkid.UserLoginID(ce.Args[0])) + if login == nil || login.UserMXID != ce.User.MXID { + ce.Reply("Login `%s` not found", ce.Args[0]) + return + } + err := login.MarkAsPreferredIn(ce.Ctx, ce.Portal) + if err != nil { + ce.Reply("Failed to set preferred login: %v", err) + } else { + ce.Reply("Preferred login set") + } +} diff --git a/bridgev2/commands/processor.go b/bridgev2/commands/processor.go new file mode 100644 index 00000000..391c3685 --- /dev/null +++ b/bridgev2/commands/processor.go @@ -0,0 +1,206 @@ +// Copyright (c) 2024 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package commands + +import ( + "context" + "fmt" + "runtime/debug" + "strings" + "sync/atomic" + "unsafe" + + "github.com/rs/zerolog" + + "maunium.net/go/mautrix/bridgev2" + "maunium.net/go/mautrix/bridgev2/status" + "maunium.net/go/mautrix/event" + "maunium.net/go/mautrix/id" +) + +type Processor struct { + bridge *bridgev2.Bridge + log *zerolog.Logger + + handlers map[string]CommandHandler + aliases map[string]string +} + +// NewProcessor creates a Processor +func NewProcessor(bridge *bridgev2.Bridge) bridgev2.CommandProcessor { + proc := &Processor{ + bridge: bridge, + log: &bridge.Log, + + handlers: make(map[string]CommandHandler), + aliases: make(map[string]string), + } + proc.AddHandlers( + CommandHelp, CommandCancel, + CommandRegisterPush, CommandSendAccountData, CommandResetNetwork, + CommandDeletePortal, CommandDeleteAllPortals, CommandSetManagementRoom, + CommandLogin, CommandRelogin, CommandListLogins, CommandLogout, CommandSetPreferredLogin, + CommandSetRelay, CommandUnsetRelay, + CommandResolveIdentifier, CommandStartChat, CommandCreateGroup, CommandSearch, CommandSyncChat, CommandMute, + CommandSudo, CommandDoIn, + ) + return proc +} + +func (proc *Processor) AddHandlers(handlers ...CommandHandler) { + for _, handler := range handlers { + proc.AddHandler(handler) + } +} + +func (proc *Processor) AddHandler(handler CommandHandler) { + proc.handlers[handler.GetName()] = handler + aliased, ok := handler.(AliasedCommandHandler) + if ok { + for _, alias := range aliased.GetAliases() { + proc.aliases[alias] = handler.GetName() + } + } +} + +// Handle handles messages to the bridge +func (proc *Processor) Handle(ctx context.Context, roomID id.RoomID, eventID id.EventID, user *bridgev2.User, message string, replyTo id.EventID) { + ms := &bridgev2.MessageStatus{ + Step: status.MsgStepCommand, + Status: event.MessageStatusSuccess, + } + logCopy := zerolog.Ctx(ctx).With().Logger() + log := &logCopy + defer func() { + statusInfo := &bridgev2.MessageStatusEventInfo{ + RoomID: roomID, + SourceEventID: eventID, + EventType: event.EventMessage, + Sender: user.MXID, + } + err := recover() + if err != nil { + logEvt := log.Error(). + Bytes(zerolog.ErrorStackFieldName, debug.Stack()) + if realErr, ok := err.(error); ok { + logEvt = logEvt.Err(realErr) + } else { + logEvt = logEvt.Any(zerolog.ErrorFieldName, err) + } + logEvt.Msg("Panic in Matrix command handler") + ms.Status = event.MessageStatusFail + ms.IsCertain = true + if realErr, ok := err.(error); ok { + ms.InternalError = realErr + } else { + ms.InternalError = fmt.Errorf("%v", err) + } + ms.ErrorAsMessage = true + } + proc.bridge.Matrix.SendMessageStatus(ctx, ms, statusInfo) + }() + args := strings.Fields(message) + if len(args) == 0 { + args = []string{"unknown-command"} + } + command := strings.ToLower(args[0]) + rawArgs := strings.TrimLeft(strings.TrimPrefix(message, command), " ") + portal, err := proc.bridge.GetPortalByMXID(ctx, roomID) + if err != nil { + log.Err(err).Msg("Failed to get portal") + // :( + } + ce := &Event{ + Bot: proc.bridge.Bot, + Bridge: proc.bridge, + Portal: portal, + Processor: proc, + RoomID: roomID, + OrigRoomID: roomID, + EventID: eventID, + User: user, + Command: command, + Args: args[1:], + RawArgs: rawArgs, + ReplyTo: replyTo, + Ctx: ctx, + Log: log, + + MessageStatus: ms, + } + proc.handleCommand(ctx, ce, message, args) +} + +func (proc *Processor) handleCommand(ctx context.Context, ce *Event, origMessage string, origArgs []string) { + realCommand, ok := proc.aliases[ce.Command] + if !ok { + realCommand = ce.Command + } + log := zerolog.Ctx(ctx) + + var handler MinimalCommandHandler + handler, ok = proc.handlers[realCommand] + if !ok { + state := LoadCommandState(ce.User) + if state != nil && state.Next != nil { + ce.Command = "" + ce.RawArgs = origMessage + ce.Args = origArgs + ce.Handler = state.Next + log.UpdateContext(func(c zerolog.Context) zerolog.Context { + return c.Str("action", state.Action) + }) + log.Debug().Msg("Received reply to command state") + state.Next.Run(ce) + } else { + zerolog.Ctx(ctx).Debug().Str("mx_command", ce.Command).Msg("Received unknown command") + ce.Reply("Unknown command, use the `help` command for help.") + } + } else { + log.UpdateContext(func(c zerolog.Context) zerolog.Context { + return c.Str("mx_command", ce.Command) + }) + log.Debug().Msg("Received command") + ce.Handler = handler + handler.Run(ce) + } +} + +func LoadCommandState(user *bridgev2.User) *CommandState { + return (*CommandState)(atomic.LoadPointer(&user.CommandState)) +} + +func StoreCommandState(user *bridgev2.User, cs *CommandState) { + atomic.StorePointer(&user.CommandState, unsafe.Pointer(cs)) +} + +func SwapCommandState(user *bridgev2.User, cs *CommandState) *CommandState { + return (*CommandState)(atomic.SwapPointer(&user.CommandState, unsafe.Pointer(cs))) +} + +var CommandCancel = &FullHandler{ + Func: func(ce *Event) { + state := SwapCommandState(ce.User, nil) + if state != nil { + action := state.Action + if action == "" { + action = "Unknown action" + } + if state.Cancel != nil { + state.Cancel() + } + ce.Reply("%s cancelled.", action) + } else { + ce.Reply("No ongoing command.") + } + }, + Name: "cancel", + Help: HelpMeta{ + Section: HelpSectionGeneral, + Description: "Cancel an ongoing action.", + }, +} diff --git a/bridgev2/commands/relay.go b/bridgev2/commands/relay.go new file mode 100644 index 00000000..94c19739 --- /dev/null +++ b/bridgev2/commands/relay.go @@ -0,0 +1,156 @@ +// Copyright (c) 2022 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package commands + +import ( + "golang.org/x/exp/slices" + + "maunium.net/go/mautrix/bridgev2" + "maunium.net/go/mautrix/bridgev2/networkid" + "maunium.net/go/mautrix/event" +) + +var fakeEvtSetRelay = event.Type{Type: "fi.mau.bridge.set_relay", Class: event.StateEventType} + +var CommandSetRelay = &FullHandler{ + Func: fnSetRelay, + Name: "set-relay", + Help: HelpMeta{ + Section: HelpSectionAuth, + Description: "Use your account to relay messages sent by users who haven't logged in", + Args: "[_login ID_]", + }, + RequiresPortal: true, +} + +func fnSetRelay(ce *Event) { + if !ce.Bridge.Config.Relay.Enabled { + ce.Reply("This bridge does not allow relay mode") + return + } else if !canManageRelay(ce) { + ce.Reply("You don't have permission to manage the relay in this room") + return + } + onlySetDefaultRelays := !ce.User.Permissions.Admin && ce.Bridge.Config.Relay.AdminOnly + var relay *bridgev2.UserLogin + if len(ce.Args) == 0 && ce.Portal.Receiver == "" { + relay = ce.User.GetDefaultLogin() + isLoggedIn := relay != nil + if onlySetDefaultRelays { + relay = nil + } + if relay == nil { + if len(ce.Bridge.Config.Relay.DefaultRelays) == 0 { + ce.Reply("You're not logged in and there are no default relay users configured") + return + } + logins, err := ce.Bridge.GetUserLoginsInPortal(ce.Ctx, ce.Portal.PortalKey) + if err != nil { + ce.Log.Err(err).Msg("Failed to get user logins in portal") + ce.Reply("Failed to get logins in portal to find default relay") + return + } + Outer: + for _, loginID := range ce.Bridge.Config.Relay.DefaultRelays { + for _, login := range logins { + if login.ID == loginID { + relay = login + break Outer + } + } + } + if relay == nil { + if isLoggedIn { + ce.Reply("You're not allowed to use yourself as relay and none of the default relay users are in the chat") + } else { + ce.Reply("You're not logged in and none of the default relay users are in the chat") + } + return + } + } + } else { + var targetID networkid.UserLoginID + if ce.Portal.Receiver != "" { + targetID = ce.Portal.Receiver + if len(ce.Args) > 0 && ce.Args[0] != string(targetID) { + ce.Reply("In split portals, only the receiver (%s) can be set as relay", targetID) + return + } + } else { + targetID = networkid.UserLoginID(ce.Args[0]) + } + relay = ce.Bridge.GetCachedUserLoginByID(targetID) + if relay == nil { + ce.Reply("User login with ID `%s` not found", targetID) + return + } else if slices.Contains(ce.Bridge.Config.Relay.DefaultRelays, relay.ID) { + // All good + } else if relay.UserMXID != ce.User.MXID && !ce.User.Permissions.Admin { + ce.Reply("Only bridge admins can set another user's login as the relay") + return + } else if onlySetDefaultRelays { + ce.Reply("You're not allowed to use yourself as relay") + return + } + } + err := ce.Portal.SetRelay(ce.Ctx, relay) + if err != nil { + ce.Log.Err(err).Msg("Failed to unset relay") + ce.Reply("Failed to save relay settings") + } else { + ce.Reply( + "Messages sent by users who haven't logged in will now be relayed through %s ([%s](%s)'s login)", + relay.RemoteName, + relay.UserMXID, + // TODO this will need to stop linkifying if we ever allow UserLogins that aren't bound to a real user. + relay.UserMXID.URI().MatrixToURL(), + ) + } +} + +var CommandUnsetRelay = &FullHandler{ + Func: fnUnsetRelay, + Name: "unset-relay", + Help: HelpMeta{ + Section: HelpSectionAuth, + Description: "Stop relaying messages sent by users who haven't logged in", + }, + RequiresPortal: true, +} + +func fnUnsetRelay(ce *Event) { + if ce.Portal.Relay == nil { + ce.Reply("This portal doesn't have a relay set.") + return + } else if !canManageRelay(ce) { + ce.Reply("You don't have permission to manage the relay in this room") + return + } + err := ce.Portal.SetRelay(ce.Ctx, nil) + if err != nil { + ce.Log.Err(err).Msg("Failed to unset relay") + ce.Reply("Failed to save relay settings") + } else { + ce.Reply("Stopped relaying messages for users who haven't logged in") + } +} + +func canManageRelay(ce *Event) bool { + return ce.User.Permissions.ManageRelay && + (ce.User.Permissions.Admin || + (ce.Portal.Relay != nil && ce.Portal.Relay.UserMXID == ce.User.MXID) || + hasRelayRoomPermissions(ce)) +} + +func hasRelayRoomPermissions(ce *Event) bool { + levels, err := ce.Bridge.Matrix.GetPowerLevels(ce.Ctx, ce.RoomID) + if err != nil { + ce.Log.Err(err).Msg("Failed to check room power levels") + return false + } + return levels.GetUserLevel(ce.User.MXID) >= levels.GetEventLevel(fakeEvtSetRelay) +} diff --git a/bridgev2/commands/startchat.go b/bridgev2/commands/startchat.go new file mode 100644 index 00000000..c7b05a6e --- /dev/null +++ b/bridgev2/commands/startchat.go @@ -0,0 +1,333 @@ +// Copyright (c) 2025 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package commands + +import ( + "context" + "errors" + "fmt" + "html" + "maps" + "slices" + "strings" + "time" + + "github.com/rs/zerolog" + + "maunium.net/go/mautrix/bridgev2" + "maunium.net/go/mautrix/bridgev2/networkid" + "maunium.net/go/mautrix/bridgev2/provisionutil" + "maunium.net/go/mautrix/event" + "maunium.net/go/mautrix/format" + "maunium.net/go/mautrix/id" +) + +var CommandResolveIdentifier = &FullHandler{ + Func: fnResolveIdentifier, + Name: "resolve-identifier", + Help: HelpMeta{ + Section: HelpSectionChats, + Description: "Check if a given identifier is on the remote network", + Args: "[_login ID_] <_identifier_>", + }, + RequiresLogin: true, + NetworkAPI: NetworkAPIImplements[bridgev2.IdentifierResolvingNetworkAPI], +} + +var CommandSyncChat = &FullHandler{ + Func: func(ce *Event) { + login, _, err := ce.Portal.FindPreferredLogin(ce.Ctx, ce.User, false) + if err != nil { + ce.Log.Err(err).Msg("Failed to find login for sync") + ce.Reply("Failed to find login: %v", err) + return + } else if login == nil { + ce.Reply("No login found for sync") + return + } + info, err := login.Client.GetChatInfo(ce.Ctx, ce.Portal) + if err != nil { + ce.Log.Err(err).Msg("Failed to get chat info for sync") + ce.Reply("Failed to get chat info: %v", err) + return + } + ce.Portal.UpdateInfo(ce.Ctx, info, login, nil, time.Time{}) + ce.React("✅️") + }, + Name: "sync-portal", + Help: HelpMeta{ + Section: HelpSectionChats, + Description: "Sync the current portal room", + }, + RequiresPortal: true, + RequiresLogin: true, +} + +var CommandStartChat = &FullHandler{ + Func: fnResolveIdentifier, + Name: "start-chat", + Aliases: []string{"pm"}, + Help: HelpMeta{ + Section: HelpSectionChats, + Description: "Start a direct chat with the given user", + Args: "[_login ID_] <_identifier_>", + }, + RequiresLogin: true, + NetworkAPI: NetworkAPIImplements[bridgev2.IdentifierResolvingNetworkAPI], +} + +func getClientForStartingChat[T bridgev2.NetworkAPI](ce *Event, thing string) (*bridgev2.UserLogin, T, []string) { + var remainingArgs []string + if len(ce.Args) > 1 { + remainingArgs = ce.Args[1:] + } + var login *bridgev2.UserLogin + if len(ce.Args) > 0 { + login = ce.Bridge.GetCachedUserLoginByID(networkid.UserLoginID(ce.Args[0])) + } + if login == nil || login.UserMXID != ce.User.MXID { + remainingArgs = ce.Args + login = ce.User.GetDefaultLogin() + } + api, ok := login.Client.(T) + if !ok { + ce.Reply("This bridge does not support %s", thing) + } + return login, api, remainingArgs +} + +func formatResolveIdentifierResult(resp *provisionutil.RespResolveIdentifier) string { + if resp.MXID != "" { + return fmt.Sprintf("`%s` / [%s](%s)", resp.ID, resp.Name, resp.MXID.URI().MatrixToURL()) + } else if resp.Name != "" { + return fmt.Sprintf("`%s` / %s", resp.ID, resp.Name) + } else { + return fmt.Sprintf("`%s`", resp.ID) + } +} + +func fnResolveIdentifier(ce *Event) { + if len(ce.Args) == 0 { + ce.Reply("Usage: `$cmdprefix %s `", ce.Command) + return + } + login, api, identifierParts := getClientForStartingChat[bridgev2.IdentifierResolvingNetworkAPI](ce, "resolving identifiers") + if api == nil { + return + } + allLogins := ce.User.GetUserLogins() + createChat := ce.Command == "start-chat" || ce.Command == "pm" + identifier := strings.Join(identifierParts, " ") + resp, err := provisionutil.ResolveIdentifier(ce.Ctx, login, identifier, createChat) + for i := 0; i < len(allLogins) && errors.Is(err, bridgev2.ErrResolveIdentifierTryNext); i++ { + resp, err = provisionutil.ResolveIdentifier(ce.Ctx, allLogins[i], identifier, createChat) + } + if err != nil { + ce.Reply("Failed to resolve identifier: %v", err) + return + } else if resp == nil { + ce.ReplyAdvanced(fmt.Sprintf("Identifier %s not found", html.EscapeString(identifier)), false, true) + return + } + formattedName := formatResolveIdentifierResult(resp) + if createChat { + name := resp.Portal.Name + if name == "" { + name = resp.Portal.MXID.String() + } + if !resp.JustCreated { + ce.Reply("You already have a direct chat with %s at [%s](%s)", formattedName, name, resp.Portal.MXID.URI().MatrixToURL()) + } else { + ce.Reply("Created chat with %s: [%s](%s)", formattedName, name, resp.Portal.MXID.URI().MatrixToURL()) + } + } else { + ce.Reply("Found %s", formattedName) + } +} + +var CommandCreateGroup = &FullHandler{ + Func: fnCreateGroup, + Name: "create-group", + Aliases: []string{"create"}, + Help: HelpMeta{ + Section: HelpSectionChats, + Description: "Create a new group chat for the current Matrix room", + Args: "[_group type_]", + }, + RequiresLogin: true, + NetworkAPI: NetworkAPIImplements[bridgev2.GroupCreatingNetworkAPI], +} + +func getState[T any](ctx context.Context, roomID id.RoomID, evtType event.Type, provider bridgev2.MatrixConnectorWithArbitraryRoomState) (content T) { + evt, err := provider.GetStateEvent(ctx, roomID, evtType, "") + if err != nil { + zerolog.Ctx(ctx).Err(err).Stringer("event_type", evtType).Msg("Failed to get state event for group creation") + } else if evt != nil { + content, _ = evt.Content.Parsed.(T) + } + return +} + +func fnCreateGroup(ce *Event) { + ce.Bridge.Matrix.GetCapabilities() + login, api, remainingArgs := getClientForStartingChat[bridgev2.GroupCreatingNetworkAPI](ce, "creating group") + if api == nil { + return + } + stateProvider, ok := ce.Bridge.Matrix.(bridgev2.MatrixConnectorWithArbitraryRoomState) + if !ok { + ce.Reply("Matrix connector doesn't support fetching room state") + return + } + members, err := ce.Bridge.Matrix.GetMembers(ce.Ctx, ce.RoomID) + if err != nil { + ce.Log.Err(err).Msg("Failed to get room members for group creation") + ce.Reply("Failed to get room members: %v", err) + return + } + caps := ce.Bridge.Network.GetCapabilities() + params := &bridgev2.GroupCreateParams{ + Username: "", + Participants: make([]networkid.UserID, 0, len(members)-2), + Parent: nil, // TODO check space parent event + Name: getState[*event.RoomNameEventContent](ce.Ctx, ce.RoomID, event.StateRoomName, stateProvider), + Avatar: getState[*event.RoomAvatarEventContent](ce.Ctx, ce.RoomID, event.StateRoomAvatar, stateProvider), + Topic: getState[*event.TopicEventContent](ce.Ctx, ce.RoomID, event.StateTopic, stateProvider), + Disappear: getState[*event.BeeperDisappearingTimer](ce.Ctx, ce.RoomID, event.StateBeeperDisappearingTimer, stateProvider), + RoomID: ce.RoomID, + } + for userID, member := range members { + if userID == ce.User.MXID || userID == ce.Bot.GetMXID() || !member.Membership.IsInviteOrJoin() { + continue + } + if parsedUserID, ok := ce.Bridge.Matrix.ParseGhostMXID(userID); ok { + params.Participants = append(params.Participants, parsedUserID) + } else if !ce.Bridge.Config.SplitPortals { + if user, err := ce.Bridge.GetExistingUserByMXID(ce.Ctx, userID); err != nil { + ce.Log.Err(err).Stringer("user_id", userID).Msg("Failed to get user for room member") + } else if user != nil { + // TODO add user logins to participants + //for _, login := range user.GetUserLogins() { + // params.Participants = append(params.Participants, login.GetUserID()) + //} + } + } + } + + if len(caps.Provisioning.GroupCreation) == 0 { + ce.Reply("No group creation types defined in network capabilities") + return + } else if len(remainingArgs) > 0 { + params.Type = remainingArgs[0] + } else if len(caps.Provisioning.GroupCreation) == 1 { + for params.Type = range caps.Provisioning.GroupCreation { + // The loop assigns the variable we want + } + } else { + types := strings.Join(slices.Collect(maps.Keys(caps.Provisioning.GroupCreation)), "`, `") + ce.Reply("Please specify type of group to create: `%s`", types) + return + } + resp, err := provisionutil.CreateGroup(ce.Ctx, login, params) + if err != nil { + ce.Reply("Failed to create group: %v", err) + return + } + var postfix string + if len(resp.FailedParticipants) > 0 { + failedParticipantsStrings := make([]string, len(resp.FailedParticipants)) + i := 0 + for participantID, meta := range resp.FailedParticipants { + failedParticipantsStrings[i] = fmt.Sprintf("* %s: %s", format.SafeMarkdownCode(participantID), meta.Reason) + i++ + } + postfix += "\n\nFailed to add some participants:\n" + strings.Join(failedParticipantsStrings, "\n") + } + ce.Reply("Successfully created group `%s`%s", resp.ID, postfix) +} + +var CommandSearch = &FullHandler{ + Func: fnSearch, + Name: "search", + Help: HelpMeta{ + Section: HelpSectionChats, + Description: "Search for users on the remote network", + Args: "<_query_>", + }, + RequiresLogin: true, + NetworkAPI: NetworkAPIImplements[bridgev2.UserSearchingNetworkAPI], +} + +func fnSearch(ce *Event) { + if len(ce.Args) == 0 { + ce.Reply("Usage: `$cmdprefix search `") + return + } + login, api, queryParts := getClientForStartingChat[bridgev2.UserSearchingNetworkAPI](ce, "searching users") + if api == nil { + return + } + resp, err := provisionutil.SearchUsers(ce.Ctx, login, strings.Join(queryParts, " ")) + if err != nil { + ce.Reply("Failed to search for users: %v", err) + return + } + resultsString := make([]string, len(resp.Results)) + for i, res := range resp.Results { + formattedName := formatResolveIdentifierResult(res) + resultsString[i] = fmt.Sprintf("* %s", formattedName) + if res.Portal != nil && res.Portal.MXID != "" { + portalName := res.Portal.Name + if portalName == "" { + portalName = res.Portal.MXID.String() + } + resultsString[i] = fmt.Sprintf("%s - DM portal: [%s](%s)", resultsString[i], portalName, res.Portal.MXID.URI().MatrixToURL()) + } + } + ce.Reply("Search results:\n\n%s", strings.Join(resultsString, "\n")) +} + +var CommandMute = &FullHandler{ + Func: fnMute, + Name: "mute", + Aliases: []string{"unmute"}, + Help: HelpMeta{ + Section: HelpSectionChats, + Description: "Mute or unmute a chat on the remote network", + Args: "[duration]", + }, + RequiresPortal: true, + RequiresLogin: true, + NetworkAPI: NetworkAPIImplements[bridgev2.MuteHandlingNetworkAPI], +} + +func fnMute(ce *Event) { + _, api, _ := getClientForStartingChat[bridgev2.MuteHandlingNetworkAPI](ce, "muting chats") + var mutedUntil int64 + if ce.Command == "mute" { + mutedUntil = -1 + if len(ce.Args) > 0 { + duration, err := time.ParseDuration(ce.Args[0]) + if err != nil { + ce.Reply("Invalid duration: %v", err) + return + } + mutedUntil = time.Now().Add(duration).UnixMilli() + } + } + err := api.HandleMute(ce.Ctx, &bridgev2.MatrixMute{ + MatrixEventBase: bridgev2.MatrixEventBase[*event.BeeperMuteEventContent]{ + Content: &event.BeeperMuteEventContent{MutedUntil: mutedUntil}, + Portal: ce.Portal, + }, + }) + if err != nil { + ce.Reply("Failed to %s chat: %v", ce.Command, err) + } else { + ce.React("✅️") + } +} diff --git a/bridgev2/commands/sudo.go b/bridgev2/commands/sudo.go new file mode 100644 index 00000000..f05ca1bb --- /dev/null +++ b/bridgev2/commands/sudo.go @@ -0,0 +1,107 @@ +// Copyright (c) 2024 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package commands + +import ( + "strings" + + "maunium.net/go/mautrix/bridgev2" + "maunium.net/go/mautrix/event" + "maunium.net/go/mautrix/id" +) + +var CommandSudo = &FullHandler{ + Func: fnSudo, + Name: "sudo", + Aliases: []string{"doas", "do-as", "runas", "run-as"}, + Help: HelpMeta{ + Section: HelpSectionAdmin, + Description: "Run a command as a different user.", + Args: "[--create] <_user ID_> <_command_> [_args..._]", + }, + RequiresAdmin: true, +} + +func fnSudo(ce *Event) { + forceNonexistentUser := len(ce.Args) > 0 && strings.ToLower(ce.Args[0]) == "--create" + if forceNonexistentUser { + ce.Args = ce.Args[1:] + } + if len(ce.Args) < 2 { + ce.Reply("Usage: `$cmdprefix sudo [--create] [args...]`") + return + } + targetUserID := id.UserID(ce.Args[0]) + if _, _, err := targetUserID.Parse(); err != nil || len(targetUserID) > id.UserIDMaxLength { + ce.Reply("Invalid user ID `%s`", targetUserID) + return + } + var targetUser *bridgev2.User + var err error + if forceNonexistentUser { + targetUser, err = ce.Bridge.GetUserByMXID(ce.Ctx, targetUserID) + } else { + targetUser, err = ce.Bridge.GetExistingUserByMXID(ce.Ctx, targetUserID) + } + if err != nil { + ce.Log.Err(err).Msg("Failed to get user from database") + ce.Reply("Failed to get user") + return + } else if targetUser == nil { + ce.Reply("User not found. Use `--create` if you want to run commands as a user who has never used the bridge.") + return + } + ce.User = targetUser + origArgs := ce.Args[1:] + ce.Command = strings.ToLower(ce.Args[1]) + ce.Args = ce.Args[2:] + ce.RawArgs = strings.Join(ce.Args, " ") + ce.Processor.handleCommand(ce.Ctx, ce, strings.Join(origArgs, " "), origArgs) +} + +var CommandDoIn = &FullHandler{ + Func: fnDoIn, + Name: "doin", + Aliases: []string{"do-in", "runin", "run-in"}, + Help: HelpMeta{ + Section: HelpSectionAdmin, + Description: "Run a command in a different room.", + Args: "<_room ID_> <_command_> [_args..._]", + }, +} + +func fnDoIn(ce *Event) { + if len(ce.Args) < 2 { + ce.Reply("Usage: `$cmdprefix doin [args...]`") + return + } + targetRoomID := id.RoomID(ce.Args[0]) + if !ce.User.Permissions.Admin { + memberInfo, err := ce.Bridge.Matrix.GetMemberInfo(ce.Ctx, targetRoomID, ce.User.MXID) + if err != nil { + ce.Log.Err(err).Msg("Failed to check if user is in doin target room") + ce.Reply("Failed to check if you're in the target room") + return + } else if memberInfo == nil || memberInfo.Membership != event.MembershipJoin { + ce.Reply("You must be in the target room to run commands there") + return + } + } + ce.RoomID = targetRoomID + var err error + ce.Portal, err = ce.Bridge.GetPortalByMXID(ce.Ctx, targetRoomID) + if err != nil { + ce.Log.Err(err).Msg("Failed to get target portal") + ce.Reply("Failed to get portal") + return + } + origArgs := ce.Args[1:] + ce.Command = strings.ToLower(ce.Args[1]) + ce.Args = ce.Args[2:] + ce.RawArgs = strings.Join(ce.Args, " ") + ce.Processor.handleCommand(ce.Ctx, ce, strings.Join(origArgs, " "), origArgs) +} diff --git a/bridgev2/database/backfillqueue.go b/bridgev2/database/backfillqueue.go new file mode 100644 index 00000000..1f920640 --- /dev/null +++ b/bridgev2/database/backfillqueue.go @@ -0,0 +1,182 @@ +// Copyright (c) 2024 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package database + +import ( + "context" + "database/sql" + "time" + + "go.mau.fi/util/dbutil" + + "maunium.net/go/mautrix/bridgev2/networkid" +) + +type BackfillTaskQuery struct { + BridgeID networkid.BridgeID + *dbutil.QueryHelper[*BackfillTask] +} + +type BackfillTask struct { + BridgeID networkid.BridgeID + PortalKey networkid.PortalKey + UserLoginID networkid.UserLoginID + + BatchCount int + IsDone bool + Cursor networkid.PaginationCursor + OldestMessageID networkid.MessageID + DispatchedAt time.Time + CompletedAt time.Time + NextDispatchMinTS time.Time +} + +var BackfillNextDispatchNever = time.Unix(0, (1<<63)-1) + +const ( + ensureBackfillExistsQuery = ` + INSERT INTO backfill_task (bridge_id, portal_id, portal_receiver, user_login_id, batch_count, is_done, next_dispatch_min_ts) + VALUES ($1, $2, $3, $4, -1, false, $5) + ON CONFLICT (bridge_id, portal_id, portal_receiver) DO UPDATE + SET user_login_id=CASE + WHEN backfill_task.user_login_id='' + THEN excluded.user_login_id + ELSE backfill_task.user_login_id + END, + next_dispatch_min_ts=CASE + WHEN backfill_task.next_dispatch_min_ts=9223372036854775807 + THEN excluded.next_dispatch_min_ts + ELSE backfill_task.next_dispatch_min_ts + END + ` + upsertBackfillQueueQuery = ` + INSERT INTO backfill_task ( + bridge_id, portal_id, portal_receiver, user_login_id, batch_count, is_done, cursor, + oldest_message_id, dispatched_at, completed_at, next_dispatch_min_ts + ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11) + ON CONFLICT (bridge_id, portal_id, portal_receiver) DO UPDATE + SET user_login_id=excluded.user_login_id, + batch_count=excluded.batch_count, + is_done=excluded.is_done, + cursor=excluded.cursor, + oldest_message_id=excluded.oldest_message_id, + dispatched_at=excluded.dispatched_at, + completed_at=excluded.completed_at, + next_dispatch_min_ts=excluded.next_dispatch_min_ts + ` + markBackfillDispatchedQuery = ` + UPDATE backfill_task SET dispatched_at=$4, completed_at=NULL, next_dispatch_min_ts=$5 + WHERE bridge_id = $1 AND portal_id = $2 AND portal_receiver = $3 + ` + updateBackfillQueueQuery = ` + UPDATE backfill_task + SET user_login_id=$4, batch_count=$5, is_done=$6, cursor=$7, oldest_message_id=$8, + dispatched_at=$9, completed_at=$10, next_dispatch_min_ts=$11 + WHERE bridge_id = $1 AND portal_id = $2 AND portal_receiver = $3 + ` + markBackfillTaskNotDoneQuery = ` + UPDATE backfill_task + SET is_done = false + WHERE bridge_id = $1 AND portal_id = $2 AND portal_receiver = $3 AND user_login_id = $4 + ` + getNextBackfillQuery = ` + SELECT + bridge_id, portal_id, portal_receiver, user_login_id, batch_count, is_done, + cursor, oldest_message_id, dispatched_at, completed_at, next_dispatch_min_ts + FROM backfill_task + WHERE bridge_id = $1 AND next_dispatch_min_ts < $2 AND is_done = false AND user_login_id <> '' + ORDER BY next_dispatch_min_ts LIMIT 1 + ` + getNextBackfillQueryForPortal = ` + SELECT + bridge_id, portal_id, portal_receiver, user_login_id, batch_count, is_done, + cursor, oldest_message_id, dispatched_at, completed_at, next_dispatch_min_ts + FROM backfill_task + WHERE bridge_id = $1 AND portal_id = $2 AND portal_receiver = $3 AND is_done = false AND user_login_id <> '' + ` + deleteBackfillQueueQuery = ` + DELETE FROM backfill_task + WHERE bridge_id = $1 AND portal_id = $2 AND portal_receiver = $3 + ` +) + +func (btq *BackfillTaskQuery) EnsureExists(ctx context.Context, portal networkid.PortalKey, loginID networkid.UserLoginID) error { + return btq.Exec(ctx, ensureBackfillExistsQuery, btq.BridgeID, portal.ID, portal.Receiver, loginID, time.Now().UnixNano()) +} + +func (btq *BackfillTaskQuery) Upsert(ctx context.Context, bq *BackfillTask) error { + ensureBridgeIDMatches(&bq.BridgeID, btq.BridgeID) + return btq.Exec(ctx, upsertBackfillQueueQuery, bq.sqlVariables()...) +} + +const UnfinishedBackfillBackoff = 1 * time.Hour + +func (btq *BackfillTaskQuery) MarkDispatched(ctx context.Context, bq *BackfillTask) error { + ensureBridgeIDMatches(&bq.BridgeID, btq.BridgeID) + bq.DispatchedAt = time.Now() + bq.CompletedAt = time.Time{} + bq.NextDispatchMinTS = bq.DispatchedAt.Add(UnfinishedBackfillBackoff) + return btq.Exec( + ctx, markBackfillDispatchedQuery, + bq.BridgeID, bq.PortalKey.ID, bq.PortalKey.Receiver, + bq.DispatchedAt.UnixNano(), bq.NextDispatchMinTS.UnixNano(), + ) +} + +func (btq *BackfillTaskQuery) Update(ctx context.Context, bq *BackfillTask) error { + ensureBridgeIDMatches(&bq.BridgeID, btq.BridgeID) + return btq.Exec(ctx, updateBackfillQueueQuery, bq.sqlVariables()...) +} + +func (btq *BackfillTaskQuery) MarkNotDone(ctx context.Context, portalKey networkid.PortalKey, userLoginID networkid.UserLoginID) error { + return btq.Exec(ctx, markBackfillTaskNotDoneQuery, btq.BridgeID, portalKey.ID, portalKey.Receiver, userLoginID) +} + +func (btq *BackfillTaskQuery) GetNext(ctx context.Context) (*BackfillTask, error) { + return btq.QueryOne(ctx, getNextBackfillQuery, btq.BridgeID, time.Now().UnixNano()) +} + +func (btq *BackfillTaskQuery) GetNextForPortal(ctx context.Context, portalKey networkid.PortalKey) (*BackfillTask, error) { + return btq.QueryOne(ctx, getNextBackfillQueryForPortal, btq.BridgeID, portalKey.ID, portalKey.Receiver) +} + +func (btq *BackfillTaskQuery) Delete(ctx context.Context, portalKey networkid.PortalKey) error { + return btq.Exec(ctx, deleteBackfillQueueQuery, btq.BridgeID, portalKey.ID, portalKey.Receiver) +} + +func (bt *BackfillTask) Scan(row dbutil.Scannable) (*BackfillTask, error) { + var cursor, oldestMessageID sql.NullString + var dispatchedAt, completedAt, nextDispatchMinTS sql.NullInt64 + err := row.Scan( + &bt.BridgeID, &bt.PortalKey.ID, &bt.PortalKey.Receiver, &bt.UserLoginID, &bt.BatchCount, &bt.IsDone, + &cursor, &oldestMessageID, &dispatchedAt, &completedAt, &nextDispatchMinTS) + if err != nil { + return nil, err + } + bt.Cursor = networkid.PaginationCursor(cursor.String) + bt.OldestMessageID = networkid.MessageID(oldestMessageID.String) + if dispatchedAt.Valid { + bt.DispatchedAt = time.Unix(0, dispatchedAt.Int64) + } + if completedAt.Valid { + bt.CompletedAt = time.Unix(0, completedAt.Int64) + } + if nextDispatchMinTS.Valid { + bt.NextDispatchMinTS = time.Unix(0, nextDispatchMinTS.Int64) + } + return bt, nil +} + +func (bt *BackfillTask) sqlVariables() []any { + return []any{ + bt.BridgeID, bt.PortalKey.ID, bt.PortalKey.Receiver, bt.UserLoginID, bt.BatchCount, bt.IsDone, + dbutil.StrPtr(bt.Cursor), dbutil.StrPtr(bt.OldestMessageID), + dbutil.ConvertedPtr(bt.DispatchedAt, time.Time.UnixNano), + dbutil.ConvertedPtr(bt.CompletedAt, time.Time.UnixNano), + bt.NextDispatchMinTS.UnixNano(), + } +} diff --git a/bridgev2/database/database.go b/bridgev2/database/database.go new file mode 100644 index 00000000..05abddf0 --- /dev/null +++ b/bridgev2/database/database.go @@ -0,0 +1,154 @@ +// Copyright (c) 2024 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package database + +import ( + "go.mau.fi/util/dbutil" + + "maunium.net/go/mautrix/bridgev2/networkid" + + "maunium.net/go/mautrix/bridgev2/database/upgrades" +) + +type Database struct { + *dbutil.Database + + BridgeID networkid.BridgeID + Portal *PortalQuery + Ghost *GhostQuery + Message *MessageQuery + DisappearingMessage *DisappearingMessageQuery + Reaction *ReactionQuery + User *UserQuery + UserLogin *UserLoginQuery + UserPortal *UserPortalQuery + BackfillTask *BackfillTaskQuery + KV *KVQuery + PublicMedia *PublicMediaQuery +} + +type MetaMerger interface { + CopyFrom(other any) +} + +type MetaTypeCreator func() any + +type MetaTypes struct { + Portal MetaTypeCreator + Ghost MetaTypeCreator + Message MetaTypeCreator + Reaction MetaTypeCreator + UserLogin MetaTypeCreator +} + +type blankMeta struct{} + +var blankMetaItem = &blankMeta{} + +func blankMetaCreator() any { + return blankMetaItem +} + +func New(bridgeID networkid.BridgeID, mt MetaTypes, db *dbutil.Database) *Database { + if mt.Portal == nil { + mt.Portal = blankMetaCreator + } + if mt.Ghost == nil { + mt.Ghost = blankMetaCreator + } + if mt.Message == nil { + mt.Message = blankMetaCreator + } + if mt.Reaction == nil { + mt.Reaction = blankMetaCreator + } + if mt.UserLogin == nil { + mt.UserLogin = blankMetaCreator + } + db.UpgradeTable = upgrades.Table + return &Database{ + Database: db, + BridgeID: bridgeID, + Portal: &PortalQuery{ + BridgeID: bridgeID, + MetaType: mt.Portal, + QueryHelper: dbutil.MakeQueryHelper(db, func(_ *dbutil.QueryHelper[*Portal]) *Portal { + return (&Portal{}).ensureHasMetadata(mt.Portal) + }), + }, + Ghost: &GhostQuery{ + BridgeID: bridgeID, + MetaType: mt.Ghost, + QueryHelper: dbutil.MakeQueryHelper(db, func(_ *dbutil.QueryHelper[*Ghost]) *Ghost { + return (&Ghost{}).ensureHasMetadata(mt.Ghost) + }), + }, + Message: &MessageQuery{ + BridgeID: bridgeID, + MetaType: mt.Message, + QueryHelper: dbutil.MakeQueryHelper(db, func(_ *dbutil.QueryHelper[*Message]) *Message { + return (&Message{}).ensureHasMetadata(mt.Message) + }), + }, + DisappearingMessage: &DisappearingMessageQuery{ + BridgeID: bridgeID, + QueryHelper: dbutil.MakeQueryHelper(db, func(_ *dbutil.QueryHelper[*DisappearingMessage]) *DisappearingMessage { + return &DisappearingMessage{} + }), + }, + Reaction: &ReactionQuery{ + BridgeID: bridgeID, + MetaType: mt.Reaction, + QueryHelper: dbutil.MakeQueryHelper(db, func(_ *dbutil.QueryHelper[*Reaction]) *Reaction { + return (&Reaction{}).ensureHasMetadata(mt.Reaction) + }), + }, + User: &UserQuery{ + BridgeID: bridgeID, + QueryHelper: dbutil.MakeQueryHelper(db, func(_ *dbutil.QueryHelper[*User]) *User { + return &User{} + }), + }, + UserLogin: &UserLoginQuery{ + BridgeID: bridgeID, + MetaType: mt.UserLogin, + QueryHelper: dbutil.MakeQueryHelper(db, func(_ *dbutil.QueryHelper[*UserLogin]) *UserLogin { + return (&UserLogin{}).ensureHasMetadata(mt.UserLogin) + }), + }, + UserPortal: &UserPortalQuery{ + BridgeID: bridgeID, + QueryHelper: dbutil.MakeQueryHelper(db, func(_ *dbutil.QueryHelper[*UserPortal]) *UserPortal { + return &UserPortal{} + }), + }, + BackfillTask: &BackfillTaskQuery{ + BridgeID: bridgeID, + QueryHelper: dbutil.MakeQueryHelper(db, func(_ *dbutil.QueryHelper[*BackfillTask]) *BackfillTask { + return &BackfillTask{} + }), + }, + KV: &KVQuery{ + BridgeID: bridgeID, + Database: db, + }, + PublicMedia: &PublicMediaQuery{ + BridgeID: bridgeID, + QueryHelper: dbutil.MakeQueryHelper(db, func(_ *dbutil.QueryHelper[*PublicMedia]) *PublicMedia { + return &PublicMedia{} + }), + }, + } +} + +func ensureBridgeIDMatches(ptr *networkid.BridgeID, expected networkid.BridgeID) { + if *ptr == "" { + *ptr = expected + } else if *ptr != expected { + panic("bridge ID mismatch") + } +} diff --git a/bridgev2/database/disappear.go b/bridgev2/database/disappear.go new file mode 100644 index 00000000..df36b205 --- /dev/null +++ b/bridgev2/database/disappear.go @@ -0,0 +1,142 @@ +// Copyright (c) 2024 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package database + +import ( + "context" + "database/sql" + "time" + + "go.mau.fi/util/dbutil" + "go.mau.fi/util/jsontime" + + "maunium.net/go/mautrix/bridgev2/networkid" + "maunium.net/go/mautrix/event" + "maunium.net/go/mautrix/id" +) + +// Deprecated: use [event.DisappearingType] +type DisappearingType = event.DisappearingType + +// Deprecated: use constants in event package +const ( + DisappearingTypeNone = event.DisappearingTypeNone + DisappearingTypeAfterRead = event.DisappearingTypeAfterRead + DisappearingTypeAfterSend = event.DisappearingTypeAfterSend +) + +// DisappearingSetting represents a disappearing message timer setting +// by combining a type with a timer and an optional start timestamp. +type DisappearingSetting struct { + Type event.DisappearingType + Timer time.Duration + DisappearAt time.Time +} + +func DisappearingSettingFromEvent(evt *event.BeeperDisappearingTimer) DisappearingSetting { + if evt == nil || evt.Type == event.DisappearingTypeNone { + return DisappearingSetting{} + } + return DisappearingSetting{ + Type: evt.Type, + Timer: evt.Timer.Duration, + } +} + +func (ds DisappearingSetting) Normalize() DisappearingSetting { + if ds.Type == event.DisappearingTypeNone { + ds.Timer = 0 + } else if ds.Timer == 0 { + ds.Type = event.DisappearingTypeNone + } + return ds +} + +func (ds DisappearingSetting) StartingAt(start time.Time) DisappearingSetting { + ds.DisappearAt = start.Add(ds.Timer) + return ds +} + +func (ds DisappearingSetting) ToEventContent() *event.BeeperDisappearingTimer { + if ds.Type == event.DisappearingTypeNone || ds.Timer == 0 { + return &event.BeeperDisappearingTimer{} + } + return &event.BeeperDisappearingTimer{ + Type: ds.Type, + Timer: jsontime.MS(ds.Timer), + } +} + +type DisappearingMessageQuery struct { + BridgeID networkid.BridgeID + *dbutil.QueryHelper[*DisappearingMessage] +} + +type DisappearingMessage struct { + BridgeID networkid.BridgeID + RoomID id.RoomID + EventID id.EventID + Timestamp time.Time + DisappearingSetting +} + +const ( + upsertDisappearingMessageQuery = ` + INSERT INTO disappearing_message (bridge_id, mx_room, mxid, timestamp, type, timer, disappear_at) + VALUES ($1, $2, $3, $4, $5, $6, $7) + ON CONFLICT (bridge_id, mxid) DO UPDATE SET timer=excluded.timer, disappear_at=excluded.disappear_at + ` + startDisappearingMessagesQuery = ` + UPDATE disappearing_message + SET disappear_at=$1 + timer + WHERE bridge_id=$2 AND mx_room=$3 AND disappear_at IS NULL AND type='after_read' AND timestamp<=$4 + RETURNING bridge_id, mx_room, mxid, timestamp, type, timer, disappear_at + ` + getUpcomingDisappearingMessagesQuery = ` + SELECT bridge_id, mx_room, mxid, timestamp, type, timer, disappear_at + FROM disappearing_message WHERE bridge_id = $1 AND disappear_at IS NOT NULL AND disappear_at < $2 + ORDER BY disappear_at LIMIT $3 + ` + deleteDisappearingMessageQuery = ` + DELETE FROM disappearing_message WHERE bridge_id=$1 AND mxid=$2 + ` +) + +func (dmq *DisappearingMessageQuery) Put(ctx context.Context, dm *DisappearingMessage) error { + ensureBridgeIDMatches(&dm.BridgeID, dmq.BridgeID) + return dmq.Exec(ctx, upsertDisappearingMessageQuery, dm.sqlVariables()...) +} + +func (dmq *DisappearingMessageQuery) StartAllBefore(ctx context.Context, roomID id.RoomID, beforeTS time.Time) ([]*DisappearingMessage, error) { + return dmq.QueryMany(ctx, startDisappearingMessagesQuery, time.Now().UnixNano(), dmq.BridgeID, roomID, beforeTS.UnixNano()) +} + +func (dmq *DisappearingMessageQuery) GetUpcoming(ctx context.Context, duration time.Duration, limit int) ([]*DisappearingMessage, error) { + return dmq.QueryMany(ctx, getUpcomingDisappearingMessagesQuery, dmq.BridgeID, time.Now().Add(duration).UnixNano(), limit) +} + +func (dmq *DisappearingMessageQuery) Delete(ctx context.Context, eventID id.EventID) error { + return dmq.Exec(ctx, deleteDisappearingMessageQuery, dmq.BridgeID, eventID) +} + +func (d *DisappearingMessage) Scan(row dbutil.Scannable) (*DisappearingMessage, error) { + var timestamp int64 + var disappearAt sql.NullInt64 + err := row.Scan(&d.BridgeID, &d.RoomID, &d.EventID, ×tamp, &d.Type, &d.Timer, &disappearAt) + if err != nil { + return nil, err + } + if disappearAt.Valid { + d.DisappearAt = time.Unix(0, disappearAt.Int64) + } + d.Timestamp = time.Unix(0, timestamp) + return d, nil +} + +func (d *DisappearingMessage) sqlVariables() []any { + return []any{d.BridgeID, d.RoomID, d.EventID, d.Timestamp.UnixNano(), d.Type, d.Timer, dbutil.ConvertedPtr(d.DisappearAt, time.Time.UnixNano)} +} diff --git a/bridgev2/database/ghost.go b/bridgev2/database/ghost.go new file mode 100644 index 00000000..16af35ca --- /dev/null +++ b/bridgev2/database/ghost.go @@ -0,0 +1,177 @@ +// Copyright (c) 2024 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package database + +import ( + "bytes" + "context" + "encoding/hex" + "encoding/json" + "fmt" + + "go.mau.fi/util/dbutil" + "go.mau.fi/util/exerrors" + + "maunium.net/go/mautrix/bridgev2/networkid" + "maunium.net/go/mautrix/crypto/canonicaljson" + "maunium.net/go/mautrix/id" +) + +type GhostQuery struct { + BridgeID networkid.BridgeID + MetaType MetaTypeCreator + *dbutil.QueryHelper[*Ghost] +} + +type ExtraProfile map[string]json.RawMessage + +func (ep *ExtraProfile) Set(key string, value any) error { + if key == "displayname" || key == "avatar_url" { + return fmt.Errorf("cannot set reserved profile key %q", key) + } + marshaled, err := json.Marshal(value) + if err != nil { + return err + } + if *ep == nil { + *ep = make(ExtraProfile) + } + (*ep)[key] = canonicaljson.CanonicalJSONAssumeValid(marshaled) + return nil +} + +func (ep *ExtraProfile) With(key string, value any) *ExtraProfile { + exerrors.PanicIfNotNil(ep.Set(key, value)) + return ep +} + +func canonicalizeIfObject(data json.RawMessage) json.RawMessage { + if len(data) > 0 && (data[0] == '{' || data[0] == '[') { + return canonicaljson.CanonicalJSONAssumeValid(data) + } + return data +} + +func (ep *ExtraProfile) CopyTo(dest *ExtraProfile) (changed bool) { + if len(*ep) == 0 { + return + } + if *dest == nil { + *dest = make(ExtraProfile) + } + for key, val := range *ep { + if key == "displayname" || key == "avatar_url" { + continue + } + existing, exists := (*dest)[key] + if !exists || !bytes.Equal(canonicalizeIfObject(existing), val) { + (*dest)[key] = val + changed = true + } + } + return +} + +type Ghost struct { + BridgeID networkid.BridgeID + ID networkid.UserID + + Name string + AvatarID networkid.AvatarID + AvatarHash [32]byte + AvatarMXC id.ContentURIString + NameSet bool + AvatarSet bool + ContactInfoSet bool + IsBot bool + Identifiers []string + ExtraProfile ExtraProfile + Metadata any +} + +const ( + getGhostBaseQuery = ` + SELECT bridge_id, id, name, avatar_id, avatar_hash, avatar_mxc, + name_set, avatar_set, contact_info_set, is_bot, identifiers, extra_profile, metadata + FROM ghost + ` + getGhostByIDQuery = getGhostBaseQuery + `WHERE bridge_id=$1 AND id=$2` + getGhostByMetadataQuery = getGhostBaseQuery + `WHERE bridge_id=$1 AND metadata->>$2=$3` + insertGhostQuery = ` + INSERT INTO ghost ( + bridge_id, id, name, avatar_id, avatar_hash, avatar_mxc, + name_set, avatar_set, contact_info_set, is_bot, identifiers, extra_profile, metadata + ) + VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13) + ` + updateGhostQuery = ` + UPDATE ghost SET name=$3, avatar_id=$4, avatar_hash=$5, avatar_mxc=$6, + name_set=$7, avatar_set=$8, contact_info_set=$9, is_bot=$10, + identifiers=$11, extra_profile=$12, metadata=$13 + WHERE bridge_id=$1 AND id=$2 + ` +) + +func (gq *GhostQuery) GetByID(ctx context.Context, id networkid.UserID) (*Ghost, error) { + return gq.QueryOne(ctx, getGhostByIDQuery, gq.BridgeID, id) +} + +// GetByMetadata returns the ghosts whose metadata field at the given JSON key +// matches the given value. +func (gq *GhostQuery) GetByMetadata(ctx context.Context, key string, value any) ([]*Ghost, error) { + return gq.QueryMany(ctx, getGhostByMetadataQuery, gq.BridgeID, key, value) +} + +func (gq *GhostQuery) Insert(ctx context.Context, ghost *Ghost) error { + ensureBridgeIDMatches(&ghost.BridgeID, gq.BridgeID) + return gq.Exec(ctx, insertGhostQuery, ghost.ensureHasMetadata(gq.MetaType).sqlVariables()...) +} + +func (gq *GhostQuery) Update(ctx context.Context, ghost *Ghost) error { + ensureBridgeIDMatches(&ghost.BridgeID, gq.BridgeID) + return gq.Exec(ctx, updateGhostQuery, ghost.ensureHasMetadata(gq.MetaType).sqlVariables()...) +} + +func (g *Ghost) Scan(row dbutil.Scannable) (*Ghost, error) { + var avatarHash string + err := row.Scan( + &g.BridgeID, &g.ID, + &g.Name, &g.AvatarID, &avatarHash, &g.AvatarMXC, + &g.NameSet, &g.AvatarSet, &g.ContactInfoSet, &g.IsBot, + dbutil.JSON{Data: &g.Identifiers}, dbutil.JSON{Data: &g.ExtraProfile}, dbutil.JSON{Data: g.Metadata}, + ) + if err != nil { + return nil, err + } + if avatarHash != "" { + data, _ := hex.DecodeString(avatarHash) + if len(data) == 32 { + g.AvatarHash = *(*[32]byte)(data) + } + } + return g, nil +} + +func (g *Ghost) ensureHasMetadata(metaType MetaTypeCreator) *Ghost { + if g.Metadata == nil { + g.Metadata = metaType() + } + return g +} + +func (g *Ghost) sqlVariables() []any { + var avatarHash string + if g.AvatarHash != [32]byte{} { + avatarHash = hex.EncodeToString(g.AvatarHash[:]) + } + return []any{ + g.BridgeID, g.ID, + g.Name, g.AvatarID, avatarHash, g.AvatarMXC, + g.NameSet, g.AvatarSet, g.ContactInfoSet, g.IsBot, + dbutil.JSON{Data: &g.Identifiers}, dbutil.JSON{Data: g.ExtraProfile}, dbutil.JSON{Data: g.Metadata}, + } +} diff --git a/bridgev2/database/kvstore.go b/bridgev2/database/kvstore.go new file mode 100644 index 00000000..bca26ed5 --- /dev/null +++ b/bridgev2/database/kvstore.go @@ -0,0 +1,59 @@ +// Copyright (c) 2024 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package database + +import ( + "context" + "database/sql" + "errors" + + "github.com/rs/zerolog" + "go.mau.fi/util/dbutil" + + "maunium.net/go/mautrix/bridgev2/networkid" +) + +type Key string + +const ( + KeySplitPortalsEnabled Key = "split_portals_enabled" + KeyBridgeInfoVersion Key = "bridge_info_version" + KeyEncryptionStateResynced Key = "encryption_state_resynced" + KeyRecoveryKey Key = "recovery_key" +) + +type KVQuery struct { + BridgeID networkid.BridgeID + *dbutil.Database +} + +const ( + getKVQuery = `SELECT value FROM kv_store WHERE bridge_id = $1 AND key = $2` + setKVQuery = ` + INSERT INTO kv_store (bridge_id, key, value) VALUES ($1, $2, $3) + ON CONFLICT (bridge_id, key) DO UPDATE SET value = $3 + ` +) + +func (kvq *KVQuery) Get(ctx context.Context, key Key) string { + var value string + err := kvq.QueryRow(ctx, getKVQuery, kvq.BridgeID, key).Scan(&value) + if err != nil && !errors.Is(err, sql.ErrNoRows) { + zerolog.Ctx(ctx).Err(err).Str("key", string(key)).Msg("Failed to get key from kvstore") + } + return value +} + +func (kvq *KVQuery) Set(ctx context.Context, key Key, value string) { + _, err := kvq.Exec(ctx, setKVQuery, kvq.BridgeID, key, value) + if err != nil { + zerolog.Ctx(ctx).Err(err). + Str("key", string(key)). + Str("value", value). + Msg("Failed to set key in kvstore") + } +} diff --git a/bridgev2/database/message.go b/bridgev2/database/message.go new file mode 100644 index 00000000..4fd599a8 --- /dev/null +++ b/bridgev2/database/message.go @@ -0,0 +1,334 @@ +// Copyright (c) 2024 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package database + +import ( + "context" + "crypto/sha256" + "database/sql" + "encoding/base64" + "fmt" + "strings" + "sync" + "time" + + "github.com/rs/zerolog" + "go.mau.fi/util/dbutil" + + "maunium.net/go/mautrix/bridgev2/networkid" + "maunium.net/go/mautrix/id" +) + +type MessageQuery struct { + BridgeID networkid.BridgeID + MetaType MetaTypeCreator + *dbutil.QueryHelper[*Message] + chunkDeleteLock sync.Mutex +} + +type Message struct { + RowID int64 + BridgeID networkid.BridgeID + ID networkid.MessageID + PartID networkid.PartID + MXID id.EventID + + Room networkid.PortalKey + SenderID networkid.UserID + SenderMXID id.UserID + Timestamp time.Time + EditCount int + IsDoublePuppeted bool + + ThreadRoot networkid.MessageID + ReplyTo networkid.MessageOptionalPartID + + SendTxnID networkid.RawTransactionID + + Metadata any +} + +const ( + getMessageBaseQuery = ` + SELECT rowid, bridge_id, id, part_id, mxid, room_id, room_receiver, sender_id, sender_mxid, + timestamp, edit_count, double_puppeted, thread_root_id, reply_to_id, reply_to_part_id, + send_txn_id, metadata + FROM message + ` + getAllMessagePartsByIDQuery = getMessageBaseQuery + `WHERE bridge_id=$1 AND (room_receiver=$2 OR room_receiver='') AND id=$3` + getMessagePartByIDQuery = getMessageBaseQuery + `WHERE bridge_id=$1 AND (room_receiver=$2 OR room_receiver='') AND id=$3 AND part_id=$4` + getMessagePartByRowIDQuery = getMessageBaseQuery + `WHERE bridge_id=$1 AND rowid=$2` + getMessageByMXIDQuery = getMessageBaseQuery + `WHERE bridge_id=$1 AND mxid=$2` + getMessageByTxnIDQuery = getMessageBaseQuery + `WHERE bridge_id=$1 AND (room_receiver=$2 OR room_receiver='') AND (mxid=$3 OR send_txn_id=$4)` + getLastMessagePartByIDQuery = getMessageBaseQuery + `WHERE bridge_id=$1 AND (room_receiver=$2 OR room_receiver='') AND id=$3 ORDER BY part_id DESC LIMIT 1` + getFirstMessagePartByIDQuery = getMessageBaseQuery + `WHERE bridge_id=$1 AND (room_receiver=$2 OR room_receiver='') AND id=$3 ORDER BY part_id ASC LIMIT 1` + getMessagesBetweenTimeQuery = getMessageBaseQuery + `WHERE bridge_id=$1 AND room_id=$2 AND room_receiver=$3 AND timestamp>$4 AND timestamp<=$5` + getOldestMessageInPortal = getMessageBaseQuery + `WHERE bridge_id=$1 AND room_id=$2 AND room_receiver=$3 ORDER BY timestamp ASC, part_id ASC LIMIT 1` + getFirstMessageInThread = getMessageBaseQuery + `WHERE bridge_id=$1 AND room_id=$2 AND room_receiver=$3 AND (id=$4 OR thread_root_id=$4) ORDER BY thread_root_id NULLS FIRST, timestamp ASC, part_id ASC LIMIT 1` + getLastMessageInThread = getMessageBaseQuery + `WHERE bridge_id=$1 AND room_id=$2 AND room_receiver=$3 AND (id=$4 OR thread_root_id=$4) ORDER BY thread_root_id NULLS LAST, timestamp DESC, part_id DESC LIMIT 1` + getLastNInPortal = getMessageBaseQuery + `WHERE bridge_id=$1 AND room_id=$2 AND room_receiver=$3 ORDER BY timestamp DESC, part_id DESC LIMIT $4` + + getLastMessagePartAtOrBeforeTimeQuery = getMessageBaseQuery + `WHERE bridge_id = $1 AND room_id=$2 AND room_receiver=$3 AND timestamp<=$4 ORDER BY timestamp DESC, part_id DESC LIMIT 1` + getLastNonFakeMessagePartAtOrBeforeTimeQuery = getMessageBaseQuery + `WHERE bridge_id = $1 AND room_id=$2 AND room_receiver=$3 AND timestamp<=$4 AND mxid NOT LIKE '~fake:%' ORDER BY timestamp DESC, part_id DESC LIMIT 1` + + countMessagesInPortalQuery = ` + SELECT COUNT(*) FROM message WHERE bridge_id=$1 AND room_id=$2 AND room_receiver=$3 + ` + + insertMessageQuery = ` + INSERT INTO message ( + bridge_id, id, part_id, mxid, room_id, room_receiver, sender_id, sender_mxid, + timestamp, edit_count, double_puppeted, thread_root_id, reply_to_id, reply_to_part_id, + send_txn_id, metadata + ) + VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16) + RETURNING rowid + ` + updateMessageQuery = ` + UPDATE message SET id=$2, part_id=$3, mxid=$4, room_id=$5, room_receiver=$6, sender_id=$7, sender_mxid=$8, + timestamp=$9, edit_count=$10, double_puppeted=$11, thread_root_id=$12, reply_to_id=$13, + reply_to_part_id=$14, send_txn_id=$15, metadata=$16 + WHERE bridge_id=$1 AND rowid=$17 + ` + deleteAllMessagePartsByIDQuery = ` + DELETE FROM message WHERE bridge_id=$1 AND (room_receiver=$2 OR room_receiver='') AND id=$3 + ` + deleteMessagePartByRowIDQuery = ` + DELETE FROM message WHERE bridge_id=$1 AND rowid=$2 + ` + deleteMessageChunkQuery = ` + DELETE FROM message WHERE bridge_id=$1 AND room_id=$2 AND room_receiver=$3 AND rowid > $4 AND rowid <= $5 + ` + getMaxMessageRowIDQuery = `SELECT MAX(rowid) FROM message WHERE bridge_id=$1` +) + +func (mq *MessageQuery) GetAllPartsByID(ctx context.Context, receiver networkid.UserLoginID, id networkid.MessageID) ([]*Message, error) { + return mq.QueryMany(ctx, getAllMessagePartsByIDQuery, mq.BridgeID, receiver, id) +} + +func (mq *MessageQuery) GetPartByID(ctx context.Context, receiver networkid.UserLoginID, id networkid.MessageID, partID networkid.PartID) (*Message, error) { + return mq.QueryOne(ctx, getMessagePartByIDQuery, mq.BridgeID, receiver, id, partID) +} + +func (mq *MessageQuery) GetPartByMXID(ctx context.Context, mxid id.EventID) (*Message, error) { + return mq.QueryOne(ctx, getMessageByMXIDQuery, mq.BridgeID, mxid) +} + +func (mq *MessageQuery) GetPartByTxnID(ctx context.Context, receiver networkid.UserLoginID, mxid id.EventID, txnID networkid.RawTransactionID) (*Message, error) { + return mq.QueryOne(ctx, getMessageByTxnIDQuery, mq.BridgeID, receiver, mxid, txnID) +} + +func (mq *MessageQuery) GetLastPartByID(ctx context.Context, receiver networkid.UserLoginID, id networkid.MessageID) (*Message, error) { + return mq.QueryOne(ctx, getLastMessagePartByIDQuery, mq.BridgeID, receiver, id) +} + +func (mq *MessageQuery) GetFirstPartByID(ctx context.Context, receiver networkid.UserLoginID, id networkid.MessageID) (*Message, error) { + return mq.QueryOne(ctx, getFirstMessagePartByIDQuery, mq.BridgeID, receiver, id) +} + +func (mq *MessageQuery) GetByRowID(ctx context.Context, rowID int64) (*Message, error) { + return mq.QueryOne(ctx, getMessagePartByRowIDQuery, mq.BridgeID, rowID) +} + +func (mq *MessageQuery) GetFirstOrSpecificPartByID(ctx context.Context, receiver networkid.UserLoginID, id networkid.MessageOptionalPartID) (*Message, error) { + if id.PartID == nil { + return mq.GetFirstPartByID(ctx, receiver, id.MessageID) + } else { + return mq.GetPartByID(ctx, receiver, id.MessageID, *id.PartID) + } +} + +func (mq *MessageQuery) GetLastPartAtOrBeforeTime(ctx context.Context, portal networkid.PortalKey, maxTS time.Time) (*Message, error) { + return mq.QueryOne(ctx, getLastMessagePartAtOrBeforeTimeQuery, mq.BridgeID, portal.ID, portal.Receiver, maxTS.UnixNano()) +} + +func (mq *MessageQuery) GetLastNonFakePartAtOrBeforeTime(ctx context.Context, portal networkid.PortalKey, maxTS time.Time) (*Message, error) { + return mq.QueryOne(ctx, getLastNonFakeMessagePartAtOrBeforeTimeQuery, mq.BridgeID, portal.ID, portal.Receiver, maxTS.UnixNano()) +} + +func (mq *MessageQuery) GetMessagesBetweenTimeQuery(ctx context.Context, portal networkid.PortalKey, start, end time.Time) ([]*Message, error) { + return mq.QueryMany(ctx, getMessagesBetweenTimeQuery, mq.BridgeID, portal.ID, portal.Receiver, start.UnixNano(), end.UnixNano()) +} + +func (mq *MessageQuery) GetFirstPortalMessage(ctx context.Context, portal networkid.PortalKey) (*Message, error) { + return mq.QueryOne(ctx, getOldestMessageInPortal, mq.BridgeID, portal.ID, portal.Receiver) +} + +func (mq *MessageQuery) GetFirstThreadMessage(ctx context.Context, portal networkid.PortalKey, threadRoot networkid.MessageID) (*Message, error) { + return mq.QueryOne(ctx, getFirstMessageInThread, mq.BridgeID, portal.ID, portal.Receiver, threadRoot) +} + +func (mq *MessageQuery) GetLastThreadMessage(ctx context.Context, portal networkid.PortalKey, threadRoot networkid.MessageID) (*Message, error) { + return mq.QueryOne(ctx, getLastMessageInThread, mq.BridgeID, portal.ID, portal.Receiver, threadRoot) +} + +func (mq *MessageQuery) GetLastNInPortal(ctx context.Context, portal networkid.PortalKey, n int) ([]*Message, error) { + return mq.QueryMany(ctx, getLastNInPortal, mq.BridgeID, portal.ID, portal.Receiver, n) +} + +func (mq *MessageQuery) Insert(ctx context.Context, msg *Message) error { + ensureBridgeIDMatches(&msg.BridgeID, mq.BridgeID) + return mq.GetDB().QueryRow(ctx, insertMessageQuery, msg.ensureHasMetadata(mq.MetaType).sqlVariables()...).Scan(&msg.RowID) +} + +func (mq *MessageQuery) Update(ctx context.Context, msg *Message) error { + ensureBridgeIDMatches(&msg.BridgeID, mq.BridgeID) + return mq.Exec(ctx, updateMessageQuery, msg.ensureHasMetadata(mq.MetaType).updateSQLVariables()...) +} + +func (mq *MessageQuery) DeleteAllParts(ctx context.Context, receiver networkid.UserLoginID, id networkid.MessageID) error { + return mq.Exec(ctx, deleteAllMessagePartsByIDQuery, mq.BridgeID, receiver, id) +} + +func (mq *MessageQuery) Delete(ctx context.Context, rowID int64) error { + return mq.Exec(ctx, deleteMessagePartByRowIDQuery, mq.BridgeID, rowID) +} + +func (mq *MessageQuery) deleteChunk(ctx context.Context, portal networkid.PortalKey, minRowID, maxRowID int64) (int64, error) { + res, err := mq.GetDB().Exec(ctx, deleteMessageChunkQuery, mq.BridgeID, portal.ID, portal.Receiver, minRowID, maxRowID) + if err != nil { + return 0, err + } + return res.RowsAffected() +} + +func (mq *MessageQuery) getMaxRowID(ctx context.Context) (maxRowID int64, err error) { + err = mq.GetDB().QueryRow(ctx, getMaxMessageRowIDQuery, mq.BridgeID).Scan(&maxRowID) + return +} + +const deleteChunkSize = 100_000 + +func (mq *MessageQuery) DeleteInChunks(ctx context.Context, portal networkid.PortalKey) error { + if mq.GetDB().Dialect != dbutil.SQLite { + return nil + } + log := zerolog.Ctx(ctx).With(). + Str("action", "delete messages in chunks"). + Stringer("portal_key", portal). + Logger() + if !mq.chunkDeleteLock.TryLock() { + log.Warn().Msg("Portal deletion lock is being held, waiting...") + mq.chunkDeleteLock.Lock() + log.Debug().Msg("Acquired portal deletion lock after waiting") + } + defer mq.chunkDeleteLock.Unlock() + total, err := mq.CountMessagesInPortal(ctx, portal) + if err != nil { + return fmt.Errorf("failed to count messages in portal: %w", err) + } else if total < deleteChunkSize/3 { + return nil + } + globalMaxRowID, err := mq.getMaxRowID(ctx) + if err != nil { + return fmt.Errorf("failed to get max row ID: %w", err) + } + log.Debug(). + Int("total_count", total). + Int64("global_max_row_id", globalMaxRowID). + Msg("Portal has lots of messages, deleting in chunks to avoid database locks") + maxRowID := int64(deleteChunkSize) + globalMaxRowID += deleteChunkSize * 1.2 + var dbTimeUsed time.Duration + globalStart := time.Now() + for total > 500 && maxRowID < globalMaxRowID { + start := time.Now() + count, err := mq.deleteChunk(ctx, portal, maxRowID-deleteChunkSize, maxRowID) + duration := time.Since(start) + dbTimeUsed += duration + if err != nil { + return fmt.Errorf("failed to delete chunk of messages before %d: %w", maxRowID, err) + } + total -= int(count) + maxRowID += deleteChunkSize + sleepTime := max(10*time.Millisecond, min(250*time.Millisecond, time.Duration(count/100)*time.Millisecond)) + log.Debug(). + Int64("max_row_id", maxRowID). + Int64("deleted_count", count). + Int("remaining_count", total). + Dur("duration", duration). + Dur("sleep_time", sleepTime). + Msg("Deleted chunk of messages") + select { + case <-time.After(sleepTime): + case <-ctx.Done(): + return ctx.Err() + } + } + log.Debug(). + Int("remaining_count", total). + Dur("db_time_used", dbTimeUsed). + Dur("total_duration", time.Since(globalStart)). + Msg("Finished chunked delete of messages in portal") + return nil +} + +func (mq *MessageQuery) CountMessagesInPortal(ctx context.Context, key networkid.PortalKey) (count int, err error) { + err = mq.GetDB().QueryRow(ctx, countMessagesInPortalQuery, mq.BridgeID, key.ID, key.Receiver).Scan(&count) + return +} + +func (m *Message) Scan(row dbutil.Scannable) (*Message, error) { + var timestamp int64 + var threadRootID, replyToID, replyToPartID, sendTxnID sql.NullString + var doublePuppeted sql.NullBool + err := row.Scan( + &m.RowID, &m.BridgeID, &m.ID, &m.PartID, &m.MXID, &m.Room.ID, &m.Room.Receiver, &m.SenderID, &m.SenderMXID, + ×tamp, &m.EditCount, &doublePuppeted, &threadRootID, &replyToID, &replyToPartID, &sendTxnID, + dbutil.JSON{Data: m.Metadata}, + ) + if err != nil { + return nil, err + } + m.Timestamp = time.Unix(0, timestamp) + m.ThreadRoot = networkid.MessageID(threadRootID.String) + m.IsDoublePuppeted = doublePuppeted.Valid + if replyToID.Valid { + m.ReplyTo.MessageID = networkid.MessageID(replyToID.String) + if replyToPartID.Valid { + m.ReplyTo.PartID = (*networkid.PartID)(&replyToPartID.String) + } + } + if sendTxnID.Valid { + m.SendTxnID = networkid.RawTransactionID(sendTxnID.String) + } + return m, nil +} + +func (m *Message) ensureHasMetadata(metaType MetaTypeCreator) *Message { + if m.Metadata == nil { + m.Metadata = metaType() + } + return m +} + +func (m *Message) sqlVariables() []any { + return []any{ + m.BridgeID, m.ID, m.PartID, m.MXID, m.Room.ID, m.Room.Receiver, m.SenderID, m.SenderMXID, + m.Timestamp.UnixNano(), m.EditCount, m.IsDoublePuppeted, dbutil.StrPtr(m.ThreadRoot), + dbutil.StrPtr(m.ReplyTo.MessageID), m.ReplyTo.PartID, dbutil.StrPtr(m.SendTxnID), + dbutil.JSON{Data: m.Metadata}, + } +} + +func (m *Message) updateSQLVariables() []any { + return append(m.sqlVariables(), m.RowID) +} + +const FakeMXIDPrefix = "~fake:" +const TxnMXIDPrefix = "~txn:" +const NetworkTxnMXIDPrefix = TxnMXIDPrefix + "network:" +const RandomTxnMXIDPrefix = TxnMXIDPrefix + "random:" + +func (m *Message) SetFakeMXID() { + hash := sha256.Sum256([]byte(m.ID)) + m.MXID = id.EventID(FakeMXIDPrefix + base64.RawURLEncoding.EncodeToString(hash[:])) +} + +func (m *Message) HasFakeMXID() bool { + return strings.HasPrefix(m.MXID.String(), FakeMXIDPrefix) +} diff --git a/bridgev2/database/portal.go b/bridgev2/database/portal.go new file mode 100644 index 00000000..0e6be286 --- /dev/null +++ b/bridgev2/database/portal.go @@ -0,0 +1,296 @@ +// Copyright (c) 2024 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package database + +import ( + "context" + "database/sql" + "encoding/hex" + "errors" + "time" + + "go.mau.fi/util/dbutil" + + "maunium.net/go/mautrix/bridgev2/networkid" + "maunium.net/go/mautrix/event" + "maunium.net/go/mautrix/id" +) + +type RoomType string + +const ( + RoomTypeDefault RoomType = "" + RoomTypeDM RoomType = "dm" + RoomTypeGroupDM RoomType = "group_dm" + RoomTypeSpace RoomType = "space" +) + +type PortalQuery struct { + BridgeID networkid.BridgeID + MetaType MetaTypeCreator + *dbutil.QueryHelper[*Portal] +} + +type CapStateFlags uint32 + +func (csf CapStateFlags) Has(flag CapStateFlags) bool { + return csf&flag != 0 +} + +const ( + CapStateFlagDisappearingTimerSet CapStateFlags = 1 << iota +) + +type CapabilityState struct { + Source networkid.UserLoginID `json:"source"` + ID string `json:"id"` + Flags CapStateFlags `json:"flags"` +} + +type Portal struct { + BridgeID networkid.BridgeID + networkid.PortalKey + MXID id.RoomID + + ParentKey networkid.PortalKey + RelayLoginID networkid.UserLoginID + OtherUserID networkid.UserID + Name string + Topic string + AvatarID networkid.AvatarID + AvatarHash [32]byte + AvatarMXC id.ContentURIString + NameSet bool + TopicSet bool + AvatarSet bool + NameIsCustom bool + InSpace bool + MessageRequest bool + RoomType RoomType + Disappear DisappearingSetting + CapState CapabilityState + Metadata any +} + +const ( + getPortalBaseQuery = ` + SELECT bridge_id, id, receiver, mxid, parent_id, parent_receiver, relay_login_id, other_user_id, + name, topic, avatar_id, avatar_hash, avatar_mxc, + name_set, topic_set, avatar_set, name_is_custom, in_space, message_request, + room_type, disappear_type, disappear_timer, cap_state, + metadata + FROM portal + ` + getPortalByKeyQuery = getPortalBaseQuery + `WHERE bridge_id=$1 AND id=$2 AND receiver=$3` + getPortalByIDWithUncertainReceiverQuery = getPortalBaseQuery + `WHERE bridge_id=$1 AND id=$2 AND (receiver=$3 OR receiver='')` + getPortalByMXIDQuery = getPortalBaseQuery + `WHERE bridge_id=$1 AND mxid=$2` + getAllPortalsWithMXIDQuery = getPortalBaseQuery + `WHERE bridge_id=$1 AND mxid IS NOT NULL` + getAllPortalsWithoutReceiver = getPortalBaseQuery + `WHERE bridge_id=$1 AND (receiver='' OR (parent_id<>'' AND parent_receiver='')) ORDER BY parent_id DESC` + getAllDMPortalsQuery = getPortalBaseQuery + `WHERE bridge_id=$1 AND room_type='dm' AND other_user_id=$2` + getDMPortalQuery = getPortalBaseQuery + `WHERE bridge_id=$1 AND room_type='dm' AND receiver=$2 AND other_user_id=$3` + getAllPortalsQuery = getPortalBaseQuery + `WHERE bridge_id=$1` + getChildPortalsQuery = getPortalBaseQuery + `WHERE bridge_id=$1 AND parent_id=$2 AND parent_receiver=$3` + + findPortalReceiverQuery = `SELECT id, receiver FROM portal WHERE bridge_id=$1 AND id=$2 AND (receiver=$3 OR receiver='') LIMIT 1` + + insertPortalQuery = ` + INSERT INTO portal ( + bridge_id, id, receiver, mxid, + parent_id, parent_receiver, relay_login_id, other_user_id, + name, topic, avatar_id, avatar_hash, avatar_mxc, + name_set, avatar_set, topic_set, name_is_custom, in_space, message_request, + room_type, disappear_type, disappear_timer, cap_state, + metadata, relay_bridge_id + ) VALUES ( + $1, $2, $3, $4, $5, $6, cast($7 AS TEXT), $8, $9, $10, $11, $12, $13, $14, $15, $16, $17, $18, $19, $20, $21, $22, $23, $24, + CASE WHEN cast($7 AS TEXT) IS NULL THEN NULL ELSE $1 END + ) + ` + updatePortalQuery = ` + UPDATE portal + SET mxid=$4, parent_id=$5, parent_receiver=$6, + relay_login_id=cast($7 AS TEXT), relay_bridge_id=CASE WHEN cast($7 AS TEXT) IS NULL THEN NULL ELSE bridge_id END, + other_user_id=$8, name=$9, topic=$10, avatar_id=$11, avatar_hash=$12, avatar_mxc=$13, + name_set=$14, avatar_set=$15, topic_set=$16, name_is_custom=$17, in_space=$18, message_request=$19, + room_type=$20, disappear_type=$21, disappear_timer=$22, cap_state=$23, metadata=$24 + WHERE bridge_id=$1 AND id=$2 AND receiver=$3 + ` + deletePortalQuery = ` + DELETE FROM portal + WHERE bridge_id=$1 AND id=$2 AND receiver=$3 + ` + reIDPortalQuery = `UPDATE portal SET id=$4, receiver=$5 WHERE bridge_id=$1 AND id=$2 AND receiver=$3` + migrateToSplitPortalsQuery = ` + UPDATE portal + SET receiver=new_receiver + FROM ( + SELECT bridge_id, id, COALESCE(( + SELECT login_id + FROM user_portal + WHERE bridge_id=portal.bridge_id AND portal_id=portal.id AND portal_receiver='' + LIMIT 1 + ), ( + SELECT login_id + FROM user_portal + WHERE portal.parent_id<>'' AND bridge_id=portal.bridge_id AND portal_id=portal.parent_id + LIMIT 1 + ), ( + SELECT id FROM user_login WHERE bridge_id=portal.bridge_id LIMIT 1 + ), '') AS new_receiver + FROM portal + WHERE receiver='' AND bridge_id=$1 + ) updates + WHERE portal.bridge_id=updates.bridge_id AND portal.id=updates.id AND portal.receiver='' AND NOT EXISTS ( + SELECT 1 FROM portal p2 WHERE p2.bridge_id=updates.bridge_id AND p2.id=updates.id AND p2.receiver=updates.new_receiver + ) + ` + fixParentsAfterSplitPortalMigrationQuery = ` + UPDATE portal + SET parent_receiver=receiver + WHERE bridge_id=$1 AND parent_receiver='' AND receiver<>'' AND parent_id<>'' + AND EXISTS(SELECT 1 FROM portal pp WHERE pp.bridge_id=$1 AND pp.id=portal.parent_id AND pp.receiver=portal.receiver); + ` +) + +func (pq *PortalQuery) GetByKey(ctx context.Context, key networkid.PortalKey) (*Portal, error) { + return pq.QueryOne(ctx, getPortalByKeyQuery, pq.BridgeID, key.ID, key.Receiver) +} + +func (pq *PortalQuery) FindReceiver(ctx context.Context, id networkid.PortalID, maybeReceiver networkid.UserLoginID) (key networkid.PortalKey, err error) { + err = pq.GetDB().QueryRow(ctx, findPortalReceiverQuery, pq.BridgeID, id, maybeReceiver).Scan(&key.ID, &key.Receiver) + if errors.Is(err, sql.ErrNoRows) { + err = nil + } + return +} + +func (pq *PortalQuery) GetByIDWithUncertainReceiver(ctx context.Context, key networkid.PortalKey) (*Portal, error) { + return pq.QueryOne(ctx, getPortalByIDWithUncertainReceiverQuery, pq.BridgeID, key.ID, key.Receiver) +} + +func (pq *PortalQuery) GetByMXID(ctx context.Context, mxid id.RoomID) (*Portal, error) { + return pq.QueryOne(ctx, getPortalByMXIDQuery, pq.BridgeID, mxid) +} + +func (pq *PortalQuery) GetAllWithMXID(ctx context.Context) ([]*Portal, error) { + return pq.QueryMany(ctx, getAllPortalsWithMXIDQuery, pq.BridgeID) +} + +func (pq *PortalQuery) GetAllWithoutReceiver(ctx context.Context) ([]*Portal, error) { + return pq.QueryMany(ctx, getAllPortalsWithoutReceiver, pq.BridgeID) +} + +func (pq *PortalQuery) GetAll(ctx context.Context) ([]*Portal, error) { + return pq.QueryMany(ctx, getAllPortalsQuery, pq.BridgeID) +} + +func (pq *PortalQuery) GetAllDMsWith(ctx context.Context, otherUserID networkid.UserID) ([]*Portal, error) { + return pq.QueryMany(ctx, getAllDMPortalsQuery, pq.BridgeID, otherUserID) +} + +func (pq *PortalQuery) GetDM(ctx context.Context, receiver networkid.UserLoginID, otherUserID networkid.UserID) (*Portal, error) { + return pq.QueryOne(ctx, getDMPortalQuery, pq.BridgeID, receiver, otherUserID) +} + +func (pq *PortalQuery) GetChildren(ctx context.Context, parentKey networkid.PortalKey) ([]*Portal, error) { + return pq.QueryMany(ctx, getChildPortalsQuery, pq.BridgeID, parentKey.ID, parentKey.Receiver) +} + +func (pq *PortalQuery) ReID(ctx context.Context, oldID, newID networkid.PortalKey) error { + return pq.Exec(ctx, reIDPortalQuery, pq.BridgeID, oldID.ID, oldID.Receiver, newID.ID, newID.Receiver) +} + +func (pq *PortalQuery) Insert(ctx context.Context, p *Portal) error { + ensureBridgeIDMatches(&p.BridgeID, pq.BridgeID) + return pq.Exec(ctx, insertPortalQuery, p.ensureHasMetadata(pq.MetaType).sqlVariables()...) +} + +func (pq *PortalQuery) Update(ctx context.Context, p *Portal) error { + ensureBridgeIDMatches(&p.BridgeID, pq.BridgeID) + return pq.Exec(ctx, updatePortalQuery, p.ensureHasMetadata(pq.MetaType).sqlVariables()...) +} + +func (pq *PortalQuery) Delete(ctx context.Context, key networkid.PortalKey) error { + return pq.Exec(ctx, deletePortalQuery, pq.BridgeID, key.ID, key.Receiver) +} + +func (pq *PortalQuery) MigrateToSplitPortals(ctx context.Context) (int64, error) { + res, err := pq.GetDB().Exec(ctx, migrateToSplitPortalsQuery, pq.BridgeID) + if err != nil { + return 0, err + } + return res.RowsAffected() +} + +func (pq *PortalQuery) FixParentsAfterSplitPortalMigration(ctx context.Context) (int64, error) { + res, err := pq.GetDB().Exec(ctx, fixParentsAfterSplitPortalMigrationQuery, pq.BridgeID) + if err != nil { + return 0, err + } + return res.RowsAffected() +} + +func (p *Portal) Scan(row dbutil.Scannable) (*Portal, error) { + var mxid, parentID, parentReceiver, relayLoginID, otherUserID, disappearType sql.NullString + var disappearTimer sql.NullInt64 + var avatarHash string + err := row.Scan( + &p.BridgeID, &p.ID, &p.Receiver, &mxid, + &parentID, &parentReceiver, &relayLoginID, &otherUserID, + &p.Name, &p.Topic, &p.AvatarID, &avatarHash, &p.AvatarMXC, + &p.NameSet, &p.TopicSet, &p.AvatarSet, &p.NameIsCustom, &p.InSpace, &p.MessageRequest, + &p.RoomType, &disappearType, &disappearTimer, + dbutil.JSON{Data: &p.CapState}, dbutil.JSON{Data: p.Metadata}, + ) + if err != nil { + return nil, err + } + if avatarHash != "" { + data, _ := hex.DecodeString(avatarHash) + if len(data) == 32 { + p.AvatarHash = *(*[32]byte)(data) + } + } + if disappearType.Valid { + p.Disappear = DisappearingSetting{ + Type: event.DisappearingType(disappearType.String), + Timer: time.Duration(disappearTimer.Int64), + } + } + p.MXID = id.RoomID(mxid.String) + p.OtherUserID = networkid.UserID(otherUserID.String) + if parentID.Valid { + p.ParentKey = networkid.PortalKey{ + ID: networkid.PortalID(parentID.String), + Receiver: networkid.UserLoginID(parentReceiver.String), + } + } + p.RelayLoginID = networkid.UserLoginID(relayLoginID.String) + return p, nil +} + +func (p *Portal) ensureHasMetadata(metaType MetaTypeCreator) *Portal { + if p.Metadata == nil { + p.Metadata = metaType() + } + return p +} + +func (p *Portal) sqlVariables() []any { + var avatarHash string + if p.AvatarHash != [32]byte{} { + avatarHash = hex.EncodeToString(p.AvatarHash[:]) + } + return []any{ + p.BridgeID, p.ID, p.Receiver, dbutil.StrPtr(p.MXID), + dbutil.StrPtr(p.ParentKey.ID), p.ParentKey.Receiver, dbutil.StrPtr(p.RelayLoginID), dbutil.StrPtr(p.OtherUserID), + p.Name, p.Topic, p.AvatarID, avatarHash, p.AvatarMXC, + p.NameSet, p.TopicSet, p.AvatarSet, p.NameIsCustom, p.InSpace, p.MessageRequest, + p.RoomType, dbutil.StrPtr(p.Disappear.Type), dbutil.NumPtr(p.Disappear.Timer), + dbutil.JSON{Data: p.CapState}, dbutil.JSON{Data: p.Metadata}, + } +} diff --git a/bridgev2/database/publicmedia.go b/bridgev2/database/publicmedia.go new file mode 100644 index 00000000..b667399c --- /dev/null +++ b/bridgev2/database/publicmedia.go @@ -0,0 +1,72 @@ +// Copyright (c) 2025 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package database + +import ( + "context" + "database/sql" + "time" + + "go.mau.fi/util/dbutil" + + "maunium.net/go/mautrix/bridgev2/networkid" + "maunium.net/go/mautrix/crypto/attachment" + "maunium.net/go/mautrix/id" +) + +type PublicMediaQuery struct { + BridgeID networkid.BridgeID + *dbutil.QueryHelper[*PublicMedia] +} + +type PublicMedia struct { + BridgeID networkid.BridgeID + PublicID string + MXC id.ContentURI + Keys *attachment.EncryptedFile + MimeType string + Expiry time.Time +} + +const ( + upsertPublicMediaQuery = ` + INSERT INTO public_media (bridge_id, public_id, mxc, keys, mimetype, expiry) + VALUES ($1, $2, $3, $4, $5, $6) + ON CONFLICT (bridge_id, public_id) DO UPDATE SET expiry=EXCLUDED.expiry + ` + getPublicMediaQuery = ` + SELECT bridge_id, public_id, mxc, keys, mimetype, expiry + FROM public_media WHERE bridge_id=$1 AND public_id=$2 + ` +) + +func (pmq *PublicMediaQuery) Put(ctx context.Context, pm *PublicMedia) error { + ensureBridgeIDMatches(&pm.BridgeID, pmq.BridgeID) + return pmq.Exec(ctx, upsertPublicMediaQuery, pm.sqlVariables()...) +} + +func (pmq *PublicMediaQuery) Get(ctx context.Context, publicID string) (*PublicMedia, error) { + return pmq.QueryOne(ctx, getPublicMediaQuery, pmq.BridgeID, publicID) +} + +func (pm *PublicMedia) Scan(row dbutil.Scannable) (*PublicMedia, error) { + var expiry sql.NullInt64 + var mimetype sql.NullString + err := row.Scan(&pm.BridgeID, &pm.PublicID, &pm.MXC, dbutil.JSON{Data: &pm.Keys}, &mimetype, &expiry) + if err != nil { + return nil, err + } + if expiry.Valid { + pm.Expiry = time.Unix(0, expiry.Int64) + } + pm.MimeType = mimetype.String + return pm, nil +} + +func (pm *PublicMedia) sqlVariables() []any { + return []any{pm.BridgeID, pm.PublicID, &pm.MXC, dbutil.JSONPtr(pm.Keys), dbutil.StrPtr(pm.MimeType), dbutil.ConvertedPtr(pm.Expiry, time.Time.UnixNano)} +} diff --git a/bridgev2/database/reaction.go b/bridgev2/database/reaction.go new file mode 100644 index 00000000..b65a5c38 --- /dev/null +++ b/bridgev2/database/reaction.go @@ -0,0 +1,120 @@ +// Copyright (c) 2024 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package database + +import ( + "context" + "time" + + "go.mau.fi/util/dbutil" + + "maunium.net/go/mautrix/bridgev2/networkid" + "maunium.net/go/mautrix/id" +) + +type ReactionQuery struct { + BridgeID networkid.BridgeID + MetaType MetaTypeCreator + *dbutil.QueryHelper[*Reaction] +} + +type Reaction struct { + BridgeID networkid.BridgeID + Room networkid.PortalKey + MessageID networkid.MessageID + MessagePartID networkid.PartID + SenderID networkid.UserID + SenderMXID id.UserID + EmojiID networkid.EmojiID + MXID id.EventID + + Timestamp time.Time + Emoji string + Metadata any +} + +const ( + getReactionBaseQuery = ` + SELECT bridge_id, message_id, message_part_id, sender_id, sender_mxid, emoji_id, emoji, room_id, room_receiver, mxid, timestamp, metadata FROM reaction + ` + getReactionByIDQuery = getReactionBaseQuery + `WHERE bridge_id=$1 AND room_receiver=$2 AND message_id=$3 AND message_part_id=$4 AND sender_id=$5 AND emoji_id=$6` + getReactionByIDWithoutMessagePartQuery = getReactionBaseQuery + `WHERE bridge_id=$1 AND room_receiver=$2 AND message_id=$3 AND sender_id=$4 AND emoji_id=$5 ORDER BY message_part_id ASC LIMIT 1` + getAllReactionsToMessageBySenderQuery = getReactionBaseQuery + `WHERE bridge_id=$1 AND room_receiver=$2 AND message_id=$3 AND sender_id=$4 ORDER BY timestamp DESC` + getAllReactionsToMessageQuery = getReactionBaseQuery + `WHERE bridge_id=$1 AND room_receiver=$2 AND message_id=$3` + getAllReactionsToMessagePartQuery = getReactionBaseQuery + `WHERE bridge_id=$1 AND room_receiver=$2 AND message_id=$3 AND message_part_id=$4` + getReactionByMXIDQuery = getReactionBaseQuery + `WHERE bridge_id=$1 AND mxid=$2` + upsertReactionQuery = ` + INSERT INTO reaction (bridge_id, message_id, message_part_id, sender_id, sender_mxid, emoji_id, emoji, room_id, room_receiver, mxid, timestamp, metadata) + VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12) + ON CONFLICT (bridge_id, room_receiver, message_id, message_part_id, sender_id, emoji_id) + DO UPDATE SET sender_mxid=excluded.sender_mxid, mxid=excluded.mxid, timestamp=excluded.timestamp, emoji=excluded.emoji, metadata=excluded.metadata + ` + deleteReactionQuery = ` + DELETE FROM reaction WHERE bridge_id=$1 AND room_receiver=$2 AND message_id=$3 AND message_part_id=$4 AND sender_id=$5 AND emoji_id=$6 + ` +) + +func (rq *ReactionQuery) GetByID(ctx context.Context, receiver networkid.UserLoginID, messageID networkid.MessageID, messagePartID networkid.PartID, senderID networkid.UserID, emojiID networkid.EmojiID) (*Reaction, error) { + return rq.QueryOne(ctx, getReactionByIDQuery, rq.BridgeID, receiver, messageID, messagePartID, senderID, emojiID) +} + +func (rq *ReactionQuery) GetByIDWithoutMessagePart(ctx context.Context, receiver networkid.UserLoginID, messageID networkid.MessageID, senderID networkid.UserID, emojiID networkid.EmojiID) (*Reaction, error) { + return rq.QueryOne(ctx, getReactionByIDWithoutMessagePartQuery, rq.BridgeID, receiver, messageID, senderID, emojiID) +} + +func (rq *ReactionQuery) GetAllToMessageBySender(ctx context.Context, receiver networkid.UserLoginID, messageID networkid.MessageID, senderID networkid.UserID) ([]*Reaction, error) { + return rq.QueryMany(ctx, getAllReactionsToMessageBySenderQuery, rq.BridgeID, receiver, messageID, senderID) +} + +func (rq *ReactionQuery) GetAllToMessage(ctx context.Context, receiver networkid.UserLoginID, messageID networkid.MessageID) ([]*Reaction, error) { + return rq.QueryMany(ctx, getAllReactionsToMessageQuery, rq.BridgeID, receiver, messageID) +} + +func (rq *ReactionQuery) GetAllToMessagePart(ctx context.Context, receiver networkid.UserLoginID, messageID networkid.MessageID, partID networkid.PartID) ([]*Reaction, error) { + return rq.QueryMany(ctx, getAllReactionsToMessagePartQuery, rq.BridgeID, receiver, messageID, partID) +} + +func (rq *ReactionQuery) GetByMXID(ctx context.Context, mxid id.EventID) (*Reaction, error) { + return rq.QueryOne(ctx, getReactionByMXIDQuery, rq.BridgeID, mxid) +} + +func (rq *ReactionQuery) Upsert(ctx context.Context, reaction *Reaction) error { + ensureBridgeIDMatches(&reaction.BridgeID, rq.BridgeID) + return rq.Exec(ctx, upsertReactionQuery, reaction.ensureHasMetadata(rq.MetaType).sqlVariables()...) +} + +func (rq *ReactionQuery) Delete(ctx context.Context, reaction *Reaction) error { + ensureBridgeIDMatches(&reaction.BridgeID, rq.BridgeID) + return rq.Exec(ctx, deleteReactionQuery, reaction.BridgeID, reaction.Room.Receiver, reaction.MessageID, reaction.MessagePartID, reaction.SenderID, reaction.EmojiID) +} + +func (r *Reaction) Scan(row dbutil.Scannable) (*Reaction, error) { + var timestamp int64 + err := row.Scan( + &r.BridgeID, &r.MessageID, &r.MessagePartID, &r.SenderID, &r.SenderMXID, &r.EmojiID, &r.Emoji, + &r.Room.ID, &r.Room.Receiver, &r.MXID, ×tamp, dbutil.JSON{Data: r.Metadata}, + ) + if err != nil { + return nil, err + } + r.Timestamp = time.Unix(0, timestamp) + return r, nil +} + +func (r *Reaction) ensureHasMetadata(metaType MetaTypeCreator) *Reaction { + if r.Metadata == nil { + r.Metadata = metaType() + } + return r +} + +func (r *Reaction) sqlVariables() []any { + return []any{ + r.BridgeID, r.MessageID, r.MessagePartID, r.SenderID, r.SenderMXID, r.EmojiID, r.Emoji, + r.Room.ID, r.Room.Receiver, r.MXID, r.Timestamp.UnixNano(), dbutil.JSON{Data: r.Metadata}, + } +} diff --git a/bridgev2/database/upgrades/00-latest.sql b/bridgev2/database/upgrades/00-latest.sql new file mode 100644 index 00000000..6092dc24 --- /dev/null +++ b/bridgev2/database/upgrades/00-latest.sql @@ -0,0 +1,233 @@ +-- v0 -> v27 (compatible with v9+): Latest revision +CREATE TABLE "user" ( + bridge_id TEXT NOT NULL, + mxid TEXT NOT NULL, + + management_room TEXT, + access_token TEXT, + + PRIMARY KEY (bridge_id, mxid) +); + +CREATE TABLE user_login ( + bridge_id TEXT NOT NULL, + user_mxid TEXT NOT NULL, + id TEXT NOT NULL, + remote_name TEXT NOT NULL, + remote_profile jsonb, + space_room TEXT, + metadata jsonb NOT NULL, + + PRIMARY KEY (bridge_id, id), + CONSTRAINT user_login_user_fkey FOREIGN KEY (bridge_id, user_mxid) + REFERENCES "user" (bridge_id, mxid) + ON DELETE CASCADE ON UPDATE CASCADE +); + +CREATE TABLE portal ( + bridge_id TEXT NOT NULL, + id TEXT NOT NULL, + receiver TEXT NOT NULL, + mxid TEXT, + + parent_id TEXT, + parent_receiver TEXT NOT NULL DEFAULT '', + + relay_bridge_id TEXT, + relay_login_id TEXT, + + other_user_id TEXT, + + name TEXT NOT NULL, + topic TEXT NOT NULL, + avatar_id TEXT NOT NULL, + avatar_hash TEXT NOT NULL, + avatar_mxc TEXT NOT NULL, + name_set BOOLEAN NOT NULL, + avatar_set BOOLEAN NOT NULL, + topic_set BOOLEAN NOT NULL, + name_is_custom BOOLEAN NOT NULL DEFAULT false, + in_space BOOLEAN NOT NULL, + message_request BOOLEAN NOT NULL DEFAULT false, + room_type TEXT NOT NULL, + disappear_type TEXT, + disappear_timer BIGINT, + cap_state jsonb, + metadata jsonb NOT NULL, + + PRIMARY KEY (bridge_id, id, receiver), + CONSTRAINT portal_parent_fkey FOREIGN KEY (bridge_id, parent_id, parent_receiver) + -- Deletes aren't allowed to cascade here: + -- children should be re-parented or cleaned up manually + REFERENCES portal (bridge_id, id, receiver) ON UPDATE CASCADE, + CONSTRAINT portal_relay_fkey FOREIGN KEY (relay_bridge_id, relay_login_id) + REFERENCES user_login (bridge_id, id) + ON DELETE SET NULL ON UPDATE CASCADE +); +CREATE UNIQUE INDEX portal_bridge_mxid_idx ON portal (bridge_id, mxid); +CREATE INDEX portal_parent_idx ON portal (bridge_id, parent_id, parent_receiver); + +CREATE TABLE ghost ( + bridge_id TEXT NOT NULL, + id TEXT NOT NULL, + + name TEXT NOT NULL, + avatar_id TEXT NOT NULL, + avatar_hash TEXT NOT NULL, + avatar_mxc TEXT NOT NULL, + name_set BOOLEAN NOT NULL, + avatar_set BOOLEAN NOT NULL, + contact_info_set BOOLEAN NOT NULL, + is_bot BOOLEAN NOT NULL, + identifiers jsonb NOT NULL, + extra_profile jsonb, + metadata jsonb NOT NULL, + + PRIMARY KEY (bridge_id, id) +); + +CREATE TABLE message ( + -- Messages have an extra rowid to allow a single relates_to column with ON DELETE SET NULL + -- If the foreign key used (bridge_id, relates_to), then deleting the target column + -- would try to set bridge_id to null as well. + + -- only: sqlite (line commented) +-- rowid INTEGER PRIMARY KEY, + -- only: postgres + rowid BIGINT PRIMARY KEY GENERATED BY DEFAULT AS IDENTITY, + + bridge_id TEXT NOT NULL, + id TEXT NOT NULL, + part_id TEXT NOT NULL, + mxid TEXT NOT NULL, + + room_id TEXT NOT NULL, + room_receiver TEXT NOT NULL, + sender_id TEXT NOT NULL, + sender_mxid TEXT NOT NULL, + timestamp BIGINT NOT NULL, + edit_count INTEGER NOT NULL, + double_puppeted BOOLEAN, + thread_root_id TEXT, + reply_to_id TEXT, + reply_to_part_id TEXT, + send_txn_id TEXT, + metadata jsonb NOT NULL, + + CONSTRAINT message_room_fkey FOREIGN KEY (bridge_id, room_id, room_receiver) + REFERENCES portal (bridge_id, id, receiver) + ON DELETE CASCADE ON UPDATE CASCADE, + CONSTRAINT message_sender_fkey FOREIGN KEY (bridge_id, sender_id) + REFERENCES ghost (bridge_id, id) + ON DELETE CASCADE ON UPDATE CASCADE, + CONSTRAINT message_real_pkey UNIQUE (bridge_id, room_receiver, id, part_id), + CONSTRAINT message_mxid_unique UNIQUE (bridge_id, mxid), + CONSTRAINT message_txn_id_unique UNIQUE (bridge_id, room_receiver, send_txn_id) +); +CREATE INDEX message_room_idx ON message (bridge_id, room_id, room_receiver); + +CREATE TABLE disappearing_message ( + bridge_id TEXT NOT NULL, + mx_room TEXT NOT NULL, + mxid TEXT NOT NULL, + timestamp BIGINT NOT NULL DEFAULT 0, + type TEXT NOT NULL, + timer BIGINT NOT NULL, + disappear_at BIGINT, + + PRIMARY KEY (bridge_id, mxid), + CONSTRAINT disappearing_message_portal_fkey + FOREIGN KEY (bridge_id, mx_room) + REFERENCES portal (bridge_id, mxid) + ON DELETE CASCADE +); +CREATE INDEX disappearing_message_portal_idx ON disappearing_message (bridge_id, mx_room); + +CREATE TABLE reaction ( + bridge_id TEXT NOT NULL, + message_id TEXT NOT NULL, + message_part_id TEXT NOT NULL, + sender_id TEXT NOT NULL, + sender_mxid TEXT NOT NULL DEFAULT '', + emoji_id TEXT NOT NULL, + room_id TEXT NOT NULL, + room_receiver TEXT NOT NULL, + mxid TEXT NOT NULL, + + timestamp BIGINT NOT NULL, + emoji TEXT NOT NULL, + metadata jsonb NOT NULL, + + PRIMARY KEY (bridge_id, room_receiver, message_id, message_part_id, sender_id, emoji_id), + CONSTRAINT reaction_room_fkey FOREIGN KEY (bridge_id, room_id, room_receiver) + REFERENCES portal (bridge_id, id, receiver) + ON DELETE CASCADE ON UPDATE CASCADE, + CONSTRAINT reaction_message_fkey FOREIGN KEY (bridge_id, room_receiver, message_id, message_part_id) + REFERENCES message (bridge_id, room_receiver, id, part_id) + ON DELETE CASCADE ON UPDATE CASCADE, + CONSTRAINT reaction_sender_fkey FOREIGN KEY (bridge_id, sender_id) + REFERENCES ghost (bridge_id, id) + ON DELETE CASCADE ON UPDATE CASCADE, + CONSTRAINT reaction_mxid_unique UNIQUE (bridge_id, mxid) +); +CREATE INDEX reaction_room_idx ON reaction (bridge_id, room_id, room_receiver); + +CREATE TABLE user_portal ( + bridge_id TEXT NOT NULL, + user_mxid TEXT NOT NULL, + login_id TEXT NOT NULL, + portal_id TEXT NOT NULL, + portal_receiver TEXT NOT NULL, + in_space BOOLEAN NOT NULL, + preferred BOOLEAN NOT NULL, + last_read BIGINT, + + PRIMARY KEY (bridge_id, user_mxid, login_id, portal_id, portal_receiver), + CONSTRAINT user_portal_user_login_fkey FOREIGN KEY (bridge_id, login_id) + REFERENCES user_login (bridge_id, id) + ON DELETE CASCADE ON UPDATE CASCADE, + CONSTRAINT user_portal_portal_fkey FOREIGN KEY (bridge_id, portal_id, portal_receiver) + REFERENCES portal (bridge_id, id, receiver) + ON DELETE CASCADE ON UPDATE CASCADE +); +CREATE INDEX user_portal_login_idx ON user_portal (bridge_id, login_id); +CREATE INDEX user_portal_portal_idx ON user_portal (bridge_id, portal_id, portal_receiver); + +CREATE TABLE backfill_task ( + bridge_id TEXT NOT NULL, + portal_id TEXT NOT NULL, + portal_receiver TEXT NOT NULL, + user_login_id TEXT NOT NULL, + + batch_count INTEGER NOT NULL, + is_done BOOLEAN NOT NULL, + cursor TEXT, + oldest_message_id TEXT, + dispatched_at BIGINT, + completed_at BIGINT, + next_dispatch_min_ts BIGINT NOT NULL, + + PRIMARY KEY (bridge_id, portal_id, portal_receiver), + CONSTRAINT backfill_queue_portal_fkey FOREIGN KEY (bridge_id, portal_id, portal_receiver) + REFERENCES portal (bridge_id, id, receiver) + ON DELETE CASCADE ON UPDATE CASCADE +); + +CREATE TABLE kv_store ( + bridge_id TEXT NOT NULL, + key TEXT NOT NULL, + value TEXT NOT NULL, + + PRIMARY KEY (bridge_id, key) +); + +CREATE TABLE public_media ( + bridge_id TEXT NOT NULL, + public_id TEXT NOT NULL, + mxc TEXT NOT NULL, + keys jsonb, + mimetype TEXT, + expiry BIGINT, + + PRIMARY KEY (bridge_id, public_id) +); diff --git a/bridgev2/database/upgrades/02-disappearing-messages.sql b/bridgev2/database/upgrades/02-disappearing-messages.sql new file mode 100644 index 00000000..e1425e75 --- /dev/null +++ b/bridgev2/database/upgrades/02-disappearing-messages.sql @@ -0,0 +1,11 @@ +-- v2 (compatible with v1+): Add disappearing messages table +CREATE TABLE disappearing_message ( + bridge_id TEXT NOT NULL, + mx_room TEXT NOT NULL, + mxid TEXT NOT NULL, + type TEXT NOT NULL, + timer BIGINT NOT NULL, + disappear_at BIGINT, + + PRIMARY KEY (bridge_id, mxid) +); diff --git a/bridgev2/database/upgrades/03-portal-relay-postgres.sql b/bridgev2/database/upgrades/03-portal-relay-postgres.sql new file mode 100644 index 00000000..4ea52ac6 --- /dev/null +++ b/bridgev2/database/upgrades/03-portal-relay-postgres.sql @@ -0,0 +1,13 @@ +-- v3 (compatible with v1+): Add relay column for portals (Postgres) +-- only: postgres +ALTER TABLE portal ADD COLUMN relay_bridge_id TEXT; +ALTER TABLE portal ADD COLUMN relay_login_id TEXT; +ALTER TABLE user_portal DROP CONSTRAINT user_portal_user_login_fkey; +ALTER TABLE user_login DROP CONSTRAINT user_login_pkey; +ALTER TABLE user_login ADD CONSTRAINT user_login_pkey PRIMARY KEY (bridge_id, id); +ALTER TABLE user_portal ADD CONSTRAINT user_portal_user_login_fkey FOREIGN KEY (bridge_id, login_id) + REFERENCES user_login (bridge_id, id) + ON DELETE CASCADE ON UPDATE CASCADE; +ALTER TABLE portal ADD CONSTRAINT portal_relay_fkey FOREIGN KEY (relay_bridge_id, relay_login_id) + REFERENCES user_login (bridge_id, id) + ON DELETE CASCADE ON UPDATE CASCADE; diff --git a/bridgev2/database/upgrades/04-portal-relay-sqlite.sql b/bridgev2/database/upgrades/04-portal-relay-sqlite.sql new file mode 100644 index 00000000..04385958 --- /dev/null +++ b/bridgev2/database/upgrades/04-portal-relay-sqlite.sql @@ -0,0 +1,100 @@ +-- v4 (compatible with v1+): Add relay column for portals (SQLite) +-- transaction: off +-- only: sqlite + +PRAGMA foreign_keys = OFF; +BEGIN; + +CREATE TABLE user_login_new ( + bridge_id TEXT NOT NULL, + user_mxid TEXT NOT NULL, + id TEXT NOT NULL, + space_room TEXT, + metadata jsonb NOT NULL, + + PRIMARY KEY (bridge_id, id), + CONSTRAINT user_login_user_fkey FOREIGN KEY (bridge_id, user_mxid) + REFERENCES "user" (bridge_id, mxid) + ON DELETE CASCADE ON UPDATE CASCADE +); + +INSERT INTO user_login_new +SELECT bridge_id, user_mxid, id, space_room, metadata +FROM user_login; + +DROP TABLE user_login; +ALTER TABLE user_login_new RENAME TO user_login; + + +CREATE TABLE user_portal_new ( + bridge_id TEXT NOT NULL, + user_mxid TEXT NOT NULL, + login_id TEXT NOT NULL, + portal_id TEXT NOT NULL, + portal_receiver TEXT NOT NULL, + in_space BOOLEAN NOT NULL, + preferred BOOLEAN NOT NULL, + last_read BIGINT, + + PRIMARY KEY (bridge_id, user_mxid, login_id, portal_id, portal_receiver), + CONSTRAINT user_portal_user_login_fkey FOREIGN KEY (bridge_id, login_id) + REFERENCES user_login (bridge_id, id) + ON DELETE CASCADE ON UPDATE CASCADE, + CONSTRAINT user_portal_portal_fkey FOREIGN KEY (bridge_id, portal_id, portal_receiver) + REFERENCES portal (bridge_id, id, receiver) + ON DELETE CASCADE ON UPDATE CASCADE +); + +INSERT INTO user_portal_new +SELECT bridge_id, user_mxid, login_id, portal_id, portal_receiver, in_space, preferred, last_read +FROM user_portal; + +DROP TABLE user_portal; +ALTER TABLE user_portal_new RENAME TO user_portal; + +CREATE TABLE portal_new ( + bridge_id TEXT NOT NULL, + id TEXT NOT NULL, + receiver TEXT NOT NULL, + mxid TEXT, + + parent_id TEXT, + -- This is not accessed by the bridge, it's only used for the portal parent foreign key. + -- Parent groups are probably never DMs, so they don't need a receiver. + parent_receiver TEXT NOT NULL DEFAULT '', + + relay_bridge_id TEXT, + relay_login_id TEXT, + + name TEXT NOT NULL, + topic TEXT NOT NULL, + avatar_id TEXT NOT NULL, + avatar_hash TEXT NOT NULL, + avatar_mxc TEXT NOT NULL, + name_set BOOLEAN NOT NULL, + avatar_set BOOLEAN NOT NULL, + topic_set BOOLEAN NOT NULL, + in_space BOOLEAN NOT NULL, + metadata jsonb NOT NULL, + + PRIMARY KEY (bridge_id, id, receiver), + CONSTRAINT portal_parent_fkey FOREIGN KEY (bridge_id, parent_id, parent_receiver) + -- Deletes aren't allowed to cascade here: + -- children should be re-parented or cleaned up manually + REFERENCES portal (bridge_id, id, receiver) ON UPDATE CASCADE, + CONSTRAINT portal_relay_fkey FOREIGN KEY (relay_bridge_id, relay_login_id) + REFERENCES user_login (bridge_id, id) + ON DELETE SET NULL ON UPDATE CASCADE +); + +INSERT INTO portal_new +SELECT bridge_id, id, receiver, mxid, parent_id, parent_receiver, NULL, NULL, + name, topic, avatar_id, avatar_hash, avatar_mxc, name_set, avatar_set, topic_set, in_space, metadata +FROM portal; + +DROP TABLE portal; +ALTER TABLE portal_new RENAME TO portal; + +PRAGMA foreign_key_check; +COMMIT; +PRAGMA foreign_keys = ON; diff --git a/bridgev2/database/upgrades/05-message-receiver-pkey-postgres.sql b/bridgev2/database/upgrades/05-message-receiver-pkey-postgres.sql new file mode 100644 index 00000000..1cdbcccf --- /dev/null +++ b/bridgev2/database/upgrades/05-message-receiver-pkey-postgres.sql @@ -0,0 +1,10 @@ +-- v5 (compatible with v1+): Add room_receiver to message unique key (Postgres) +-- only: postgres +ALTER TABLE reaction DROP CONSTRAINT reaction_message_fkey; +ALTER TABLE reaction DROP CONSTRAINT reaction_pkey1; +ALTER TABLE reaction ADD PRIMARY KEY (bridge_id, room_receiver, message_id, message_part_id, sender_id, emoji_id); +ALTER TABLE message DROP CONSTRAINT message_real_pkey; +ALTER TABLE message ADD CONSTRAINT message_real_pkey UNIQUE (bridge_id, room_receiver, id, part_id); +ALTER TABLE reaction ADD CONSTRAINT reaction_message_fkey FOREIGN KEY (bridge_id, room_receiver, message_id, message_part_id) + REFERENCES message (bridge_id, room_receiver, id, part_id) + ON DELETE CASCADE ON UPDATE CASCADE; diff --git a/bridgev2/database/upgrades/06-message-receiver-pkey-sqlite.sql b/bridgev2/database/upgrades/06-message-receiver-pkey-sqlite.sql new file mode 100644 index 00000000..b88c5052 --- /dev/null +++ b/bridgev2/database/upgrades/06-message-receiver-pkey-sqlite.sql @@ -0,0 +1,75 @@ +-- v6 (compatible with v1+): Add room_receiver to message unique key (SQLite) +-- transaction: off +-- only: sqlite + +PRAGMA foreign_keys = OFF; +BEGIN; + +CREATE TABLE message_new ( + rowid INTEGER PRIMARY KEY, + + bridge_id TEXT NOT NULL, + id TEXT NOT NULL, + part_id TEXT NOT NULL, + mxid TEXT NOT NULL, + + room_id TEXT NOT NULL, + room_receiver TEXT NOT NULL, + sender_id TEXT NOT NULL, + timestamp BIGINT NOT NULL, + relates_to BIGINT, + metadata jsonb NOT NULL, + + CONSTRAINT message_relation_fkey FOREIGN KEY (relates_to) + REFERENCES message (rowid) ON DELETE SET NULL, + CONSTRAINT message_room_fkey FOREIGN KEY (bridge_id, room_id, room_receiver) + REFERENCES portal (bridge_id, id, receiver) + ON DELETE CASCADE ON UPDATE CASCADE, + CONSTRAINT message_sender_fkey FOREIGN KEY (bridge_id, sender_id) + REFERENCES ghost (bridge_id, id) + ON DELETE CASCADE ON UPDATE CASCADE, + CONSTRAINT message_real_pkey UNIQUE (bridge_id, room_receiver, id, part_id) +); + +INSERT INTO message_new (rowid, bridge_id, id, part_id, mxid, room_id, room_receiver, sender_id, timestamp, relates_to, metadata) +SELECT rowid, bridge_id, id, part_id, mxid, room_id, room_receiver, sender_id, timestamp, relates_to, metadata +FROM message; + +DROP TABLE message; +ALTER TABLE message_new RENAME TO message; + +CREATE TABLE reaction_new ( + bridge_id TEXT NOT NULL, + message_id TEXT NOT NULL, + message_part_id TEXT NOT NULL, + sender_id TEXT NOT NULL, + emoji_id TEXT NOT NULL, + room_id TEXT NOT NULL, + room_receiver TEXT NOT NULL, + mxid TEXT NOT NULL, + + timestamp BIGINT NOT NULL, + metadata jsonb NOT NULL, + + PRIMARY KEY (bridge_id, room_receiver, message_id, message_part_id, sender_id, emoji_id), + CONSTRAINT reaction_room_fkey FOREIGN KEY (bridge_id, room_id, room_receiver) + REFERENCES portal (bridge_id, id, receiver) + ON DELETE CASCADE ON UPDATE CASCADE, + CONSTRAINT reaction_message_fkey FOREIGN KEY (bridge_id, room_receiver, message_id, message_part_id) + REFERENCES message (bridge_id, room_receiver, id, part_id) + ON DELETE CASCADE ON UPDATE CASCADE, + CONSTRAINT reaction_sender_fkey FOREIGN KEY (bridge_id, sender_id) + REFERENCES ghost (bridge_id, id) + ON DELETE CASCADE ON UPDATE CASCADE +); + +INSERT INTO reaction_new +SELECT bridge_id, message_id, message_part_id, sender_id, emoji_id, room_id, room_receiver, mxid, timestamp, metadata +FROM reaction; + +DROP TABLE reaction; +ALTER TABLE reaction_new RENAME TO reaction; + +PRAGMA foreign_key_check; +COMMIT; +PRAGMA foreign_keys = ON; diff --git a/bridgev2/database/upgrades/07-message-relation-without-fkey.sql b/bridgev2/database/upgrades/07-message-relation-without-fkey.sql new file mode 100644 index 00000000..9c4c9fd5 --- /dev/null +++ b/bridgev2/database/upgrades/07-message-relation-without-fkey.sql @@ -0,0 +1,4 @@ +-- v7: Add new relation columns to messages +ALTER TABLE message ADD COLUMN thread_root_id TEXT; +ALTER TABLE message ADD COLUMN reply_to_id TEXT; +ALTER TABLE message ADD COLUMN reply_to_part_id TEXT; diff --git a/bridgev2/database/upgrades/08-drop-message-relates-to.postgres.sql b/bridgev2/database/upgrades/08-drop-message-relates-to.postgres.sql new file mode 100644 index 00000000..284f6b0e --- /dev/null +++ b/bridgev2/database/upgrades/08-drop-message-relates-to.postgres.sql @@ -0,0 +1,3 @@ +-- v8: Drop relates_to column in messages +-- transaction: off +ALTER TABLE message DROP COLUMN relates_to; diff --git a/bridgev2/database/upgrades/08-drop-message-relates-to.sqlite.sql b/bridgev2/database/upgrades/08-drop-message-relates-to.sqlite.sql new file mode 100644 index 00000000..307a876e --- /dev/null +++ b/bridgev2/database/upgrades/08-drop-message-relates-to.sqlite.sql @@ -0,0 +1,41 @@ +-- v8: Drop relates_to column in messages +-- transaction: off +PRAGMA foreign_keys = OFF; +BEGIN; + +CREATE TABLE message_new ( + rowid INTEGER PRIMARY KEY, + + bridge_id TEXT NOT NULL, + id TEXT NOT NULL, + part_id TEXT NOT NULL, + mxid TEXT NOT NULL, + + room_id TEXT NOT NULL, + room_receiver TEXT NOT NULL, + sender_id TEXT NOT NULL, + timestamp BIGINT NOT NULL, + thread_root_id TEXT, + reply_to_id TEXT, + reply_to_part_id TEXT, + metadata jsonb NOT NULL, + + CONSTRAINT message_room_fkey FOREIGN KEY (bridge_id, room_id, room_receiver) + REFERENCES portal (bridge_id, id, receiver) + ON DELETE CASCADE ON UPDATE CASCADE, + CONSTRAINT message_sender_fkey FOREIGN KEY (bridge_id, sender_id) + REFERENCES ghost (bridge_id, id) + ON DELETE CASCADE ON UPDATE CASCADE, + CONSTRAINT message_real_pkey UNIQUE (bridge_id, room_receiver, id, part_id) +); + +INSERT INTO message_new (rowid, bridge_id, id, part_id, mxid, room_id, room_receiver, sender_id, timestamp, metadata) +SELECT rowid, bridge_id, id, part_id, mxid, room_id, room_receiver, sender_id, timestamp, metadata +FROM message; + +DROP TABLE message; +ALTER TABLE message_new RENAME TO message; + +PRAGMA foreign_key_check; +COMMIT; +PRAGMA foreign_keys = ON; diff --git a/bridgev2/database/upgrades/09-remove-standard-metadata.sql b/bridgev2/database/upgrades/09-remove-standard-metadata.sql new file mode 100644 index 00000000..3f348007 --- /dev/null +++ b/bridgev2/database/upgrades/09-remove-standard-metadata.sql @@ -0,0 +1,45 @@ +-- v9: Move standard metadata to separate columns +ALTER TABLE message ADD COLUMN sender_mxid TEXT NOT NULL DEFAULT ''; +UPDATE message SET sender_mxid=COALESCE((metadata->>'sender_mxid'), ''); + +ALTER TABLE message ADD COLUMN edit_count INTEGER NOT NULL DEFAULT 0; +UPDATE message SET edit_count=COALESCE(CAST((metadata->>'edit_count') AS INTEGER), 0); + +ALTER TABLE portal ADD COLUMN disappear_type TEXT; +UPDATE portal SET disappear_type=(metadata->>'disappear_type'); + +ALTER TABLE portal ADD COLUMN disappear_timer BIGINT; +-- only: postgres +UPDATE portal SET disappear_timer=(metadata->>'disappear_timer')::BIGINT; +-- only: sqlite +UPDATE portal SET disappear_timer=CAST(metadata->>'disappear_timer' AS INTEGER); + +ALTER TABLE portal ADD COLUMN room_type TEXT NOT NULL DEFAULT ''; +UPDATE portal SET room_type='dm' WHERE CAST(metadata->>'is_direct' AS BOOLEAN) IS true; +UPDATE portal SET room_type='space' WHERE CAST(metadata->>'is_space' AS BOOLEAN) IS true; + +ALTER TABLE reaction ADD COLUMN emoji TEXT NOT NULL DEFAULT ''; +UPDATE reaction SET emoji=COALESCE((metadata->>'emoji'), ''); + +ALTER TABLE user_login ADD COLUMN remote_name TEXT NOT NULL DEFAULT ''; +UPDATE user_login SET remote_name=COALESCE((metadata->>'remote_name'), ''); + +ALTER TABLE ghost ADD COLUMN contact_info_set BOOLEAN NOT NULL DEFAULT false; +UPDATE ghost SET contact_info_set=COALESCE(CAST((metadata->>'contact_info_set') AS BOOLEAN), false); + +ALTER TABLE ghost ADD COLUMN is_bot BOOLEAN NOT NULL DEFAULT false; +UPDATE ghost SET is_bot=COALESCE(CAST((metadata->>'is_bot') AS BOOLEAN), false); + +ALTER TABLE ghost ADD COLUMN identifiers jsonb NOT NULL DEFAULT '[]'; +UPDATE ghost SET identifiers=COALESCE((metadata->'identifiers'), '[]'); + +-- only: postgres until "end only" +ALTER TABLE message ALTER COLUMN sender_mxid DROP DEFAULT; +ALTER TABLE message ALTER COLUMN edit_count DROP DEFAULT; +ALTER TABLE portal ALTER COLUMN room_type DROP DEFAULT; +ALTER TABLE reaction ALTER COLUMN emoji DROP DEFAULT; +ALTER TABLE user_login ALTER COLUMN remote_name DROP DEFAULT; +ALTER TABLE ghost ALTER COLUMN contact_info_set DROP DEFAULT; +ALTER TABLE ghost ALTER COLUMN is_bot DROP DEFAULT; +ALTER TABLE ghost ALTER COLUMN identifiers DROP DEFAULT; +-- end only postgres diff --git a/bridgev2/database/upgrades/10-fix-signal-portal-revision.postgres.sql b/bridgev2/database/upgrades/10-fix-signal-portal-revision.postgres.sql new file mode 100644 index 00000000..f42402f3 --- /dev/null +++ b/bridgev2/database/upgrades/10-fix-signal-portal-revision.postgres.sql @@ -0,0 +1,4 @@ +-- v10 (compatible with v9+): Fix Signal portal revisions +UPDATE portal +SET metadata=jsonb_set(metadata, '{revision}', CAST((metadata->>'revision') AS jsonb)) +WHERE jsonb_typeof(metadata->'revision')='string'; diff --git a/bridgev2/database/upgrades/10-fix-signal-portal-revision.sqlite.sql b/bridgev2/database/upgrades/10-fix-signal-portal-revision.sqlite.sql new file mode 100644 index 00000000..0fd67c80 --- /dev/null +++ b/bridgev2/database/upgrades/10-fix-signal-portal-revision.sqlite.sql @@ -0,0 +1,4 @@ +-- v10 (compatible with v9+): Fix Signal portal revisions +UPDATE portal +SET metadata=json_set(metadata, '$.revision', CAST(json_extract(metadata, '$.revision') AS INTEGER)) +WHERE json_type(metadata, '$.revision')='text'; diff --git a/bridgev2/database/upgrades/11-room-fkey-idx.sql b/bridgev2/database/upgrades/11-room-fkey-idx.sql new file mode 100644 index 00000000..d6a67713 --- /dev/null +++ b/bridgev2/database/upgrades/11-room-fkey-idx.sql @@ -0,0 +1,5 @@ +-- v11: Add indexes for some foreign keys +CREATE INDEX message_room_idx ON message (bridge_id, room_id, room_receiver); +CREATE INDEX reaction_room_idx ON reaction (bridge_id, room_id, room_receiver); +CREATE INDEX user_portal_portal_idx ON user_portal (bridge_id, portal_id, portal_receiver); +CREATE INDEX user_portal_login_idx ON user_portal (bridge_id, login_id); diff --git a/bridgev2/database/upgrades/12-dm-portal-other-user.sql b/bridgev2/database/upgrades/12-dm-portal-other-user.sql new file mode 100644 index 00000000..2d2cb900 --- /dev/null +++ b/bridgev2/database/upgrades/12-dm-portal-other-user.sql @@ -0,0 +1,2 @@ +-- v12 (compatible with v9+): Save other user ID in DM portals +ALTER TABLE portal ADD COLUMN other_user_id TEXT; diff --git a/bridgev2/database/upgrades/13-backfill-queue.sql b/bridgev2/database/upgrades/13-backfill-queue.sql new file mode 100644 index 00000000..dada993c --- /dev/null +++ b/bridgev2/database/upgrades/13-backfill-queue.sql @@ -0,0 +1,20 @@ +-- v13 (compatible with v9+): Add backfill queue +CREATE TABLE backfill_task ( + bridge_id TEXT NOT NULL, + portal_id TEXT NOT NULL, + portal_receiver TEXT NOT NULL, + user_login_id TEXT NOT NULL, + + batch_count INTEGER NOT NULL, + is_done BOOLEAN NOT NULL, + cursor TEXT, + oldest_message_id TEXT, + dispatched_at BIGINT, + completed_at BIGINT, + next_dispatch_min_ts BIGINT NOT NULL, + + PRIMARY KEY (bridge_id, portal_id, portal_receiver), + CONSTRAINT backfill_queue_portal_fkey FOREIGN KEY (bridge_id, portal_id, portal_receiver) + REFERENCES portal (bridge_id, id, receiver) + ON DELETE CASCADE ON UPDATE CASCADE +); diff --git a/bridgev2/database/upgrades/14-portal-name-custom.sql b/bridgev2/database/upgrades/14-portal-name-custom.sql new file mode 100644 index 00000000..2c8dfc8f --- /dev/null +++ b/bridgev2/database/upgrades/14-portal-name-custom.sql @@ -0,0 +1,2 @@ +-- v14 (compatible with v9+): Save whether name is custom in portals +ALTER TABLE portal ADD COLUMN name_is_custom BOOLEAN NOT NULL DEFAULT false; diff --git a/bridgev2/database/upgrades/15-reaction-sender-mxid.sql b/bridgev2/database/upgrades/15-reaction-sender-mxid.sql new file mode 100644 index 00000000..e32bd832 --- /dev/null +++ b/bridgev2/database/upgrades/15-reaction-sender-mxid.sql @@ -0,0 +1,2 @@ +-- v15 (compatible with v9+): Save sender MXID for reactions +ALTER TABLE reaction ADD COLUMN sender_mxid TEXT NOT NULL DEFAULT ''; diff --git a/bridgev2/database/upgrades/16-user-login-profile.sql b/bridgev2/database/upgrades/16-user-login-profile.sql new file mode 100644 index 00000000..e143fcee --- /dev/null +++ b/bridgev2/database/upgrades/16-user-login-profile.sql @@ -0,0 +1,2 @@ +-- v16 (compatible with v9+): Save remote profile in user logins +ALTER TABLE user_login ADD COLUMN remote_profile jsonb; diff --git a/bridgev2/database/upgrades/17-message-mxid-unique.sql b/bridgev2/database/upgrades/17-message-mxid-unique.sql new file mode 100644 index 00000000..ee53b3f0 --- /dev/null +++ b/bridgev2/database/upgrades/17-message-mxid-unique.sql @@ -0,0 +1,8 @@ +-- v17 (compatible with v9+): Add unique constraint for message and reaction mxids +DELETE FROM message WHERE mxid IN (SELECT mxid FROM message GROUP BY mxid HAVING COUNT(*) > 1); +-- only: postgres for next 2 lines +ALTER TABLE message ADD CONSTRAINT message_mxid_unique UNIQUE (bridge_id, mxid); +ALTER TABLE reaction ADD CONSTRAINT reaction_mxid_unique UNIQUE (bridge_id, mxid); +-- only: sqlite for next 2 lines +CREATE UNIQUE INDEX message_mxid_unique ON message (bridge_id, mxid); +CREATE UNIQUE INDEX reaction_mxid_unique ON reaction (bridge_id, mxid); diff --git a/bridgev2/database/upgrades/18-kv-store.sql b/bridgev2/database/upgrades/18-kv-store.sql new file mode 100644 index 00000000..9d233095 --- /dev/null +++ b/bridgev2/database/upgrades/18-kv-store.sql @@ -0,0 +1,8 @@ +-- v18 (compatible with v9+): Add generic key-value store +CREATE TABLE kv_store ( + bridge_id TEXT NOT NULL, + key TEXT NOT NULL, + value TEXT NOT NULL, + + PRIMARY KEY (bridge_id, key) +); diff --git a/bridgev2/database/upgrades/19-add-double-puppeted-to-message.sql b/bridgev2/database/upgrades/19-add-double-puppeted-to-message.sql new file mode 100644 index 00000000..ec6fe836 --- /dev/null +++ b/bridgev2/database/upgrades/19-add-double-puppeted-to-message.sql @@ -0,0 +1,2 @@ +-- v19 (compatible with v9+): Add double puppeted state to messages +ALTER TABLE message ADD COLUMN double_puppeted BOOLEAN; diff --git a/bridgev2/database/upgrades/20-portal-capabilities.sql b/bridgev2/database/upgrades/20-portal-capabilities.sql new file mode 100644 index 00000000..00bd96ca --- /dev/null +++ b/bridgev2/database/upgrades/20-portal-capabilities.sql @@ -0,0 +1,2 @@ +-- v20 (compatible with v9+): Add portal capability state +ALTER TABLE portal ADD COLUMN cap_state jsonb; diff --git a/bridgev2/database/upgrades/21-disappearing-message-fkey.postgres.sql b/bridgev2/database/upgrades/21-disappearing-message-fkey.postgres.sql new file mode 100644 index 00000000..d1c1ad9a --- /dev/null +++ b/bridgev2/database/upgrades/21-disappearing-message-fkey.postgres.sql @@ -0,0 +1,8 @@ +-- v21 (compatible with v9+): Add foreign key constraint from disappearing_message.mx_room to portals.mxid +CREATE UNIQUE INDEX portal_bridge_mxid_idx ON portal (bridge_id, mxid); +DELETE FROM disappearing_message WHERE mx_room NOT IN (SELECT mxid FROM portal WHERE mxid IS NOT NULL); +ALTER TABLE disappearing_message + ADD CONSTRAINT disappearing_message_portal_fkey + FOREIGN KEY (bridge_id, mx_room) + REFERENCES portal (bridge_id, mxid) + ON DELETE CASCADE; diff --git a/bridgev2/database/upgrades/21-disappearing-message-fkey.sqlite.sql b/bridgev2/database/upgrades/21-disappearing-message-fkey.sqlite.sql new file mode 100644 index 00000000..f5468c6b --- /dev/null +++ b/bridgev2/database/upgrades/21-disappearing-message-fkey.sqlite.sql @@ -0,0 +1,24 @@ +-- v21 (compatible with v9+): Add foreign key constraint from disappearing_message.mx_room to portals.mxid +CREATE UNIQUE INDEX portal_bridge_mxid_idx ON portal (bridge_id, mxid); +CREATE TABLE disappearing_message_new ( + bridge_id TEXT NOT NULL, + mx_room TEXT NOT NULL, + mxid TEXT NOT NULL, + type TEXT NOT NULL, + timer BIGINT NOT NULL, + disappear_at BIGINT, + + PRIMARY KEY (bridge_id, mxid), + CONSTRAINT disappearing_message_portal_fkey + FOREIGN KEY (bridge_id, mx_room) + REFERENCES portal (bridge_id, mxid) + ON DELETE CASCADE +); + +WITH portal_mxids AS (SELECT mxid FROM portal WHERE mxid IS NOT NULL) +INSERT INTO disappearing_message_new (bridge_id, mx_room, mxid, type, timer, disappear_at) +SELECT bridge_id, mx_room, mxid, type, timer, disappear_at +FROM disappearing_message WHERE mx_room IN portal_mxids; + +DROP TABLE disappearing_message; +ALTER TABLE disappearing_message_new RENAME TO disappearing_message; diff --git a/bridgev2/database/upgrades/22-message-send-txn-id.sql b/bridgev2/database/upgrades/22-message-send-txn-id.sql new file mode 100644 index 00000000..8933984e --- /dev/null +++ b/bridgev2/database/upgrades/22-message-send-txn-id.sql @@ -0,0 +1,6 @@ +-- v22 (compatible with v9+): Add message send transaction ID column +ALTER TABLE message ADD COLUMN send_txn_id TEXT; +-- only: postgres +ALTER TABLE message ADD CONSTRAINT message_txn_id_unique UNIQUE (bridge_id, room_receiver, send_txn_id); +-- only: sqlite +CREATE UNIQUE INDEX message_txn_id_unique ON message (bridge_id, room_receiver, send_txn_id); diff --git a/bridgev2/database/upgrades/23-disappearing-timer-ts.sql b/bridgev2/database/upgrades/23-disappearing-timer-ts.sql new file mode 100644 index 00000000..ecd00b8d --- /dev/null +++ b/bridgev2/database/upgrades/23-disappearing-timer-ts.sql @@ -0,0 +1,2 @@ +-- v23 (compatible with v9+): Add event timestamp for disappearing messages +ALTER TABLE disappearing_message ADD COLUMN timestamp BIGINT NOT NULL DEFAULT 0; diff --git a/bridgev2/database/upgrades/24-public-media.sql b/bridgev2/database/upgrades/24-public-media.sql new file mode 100644 index 00000000..c4290090 --- /dev/null +++ b/bridgev2/database/upgrades/24-public-media.sql @@ -0,0 +1,11 @@ +-- v24 (compatible with v9+): Custom URLs for public media +CREATE TABLE public_media ( + bridge_id TEXT NOT NULL, + public_id TEXT NOT NULL, + mxc TEXT NOT NULL, + keys jsonb, + mimetype TEXT, + expiry BIGINT, + + PRIMARY KEY (bridge_id, public_id) +); diff --git a/bridgev2/database/upgrades/25-message-requests.sql b/bridgev2/database/upgrades/25-message-requests.sql new file mode 100644 index 00000000..b9d82a7a --- /dev/null +++ b/bridgev2/database/upgrades/25-message-requests.sql @@ -0,0 +1,2 @@ +-- v25 (compatible with v9+): Flag for message request portals +ALTER TABLE portal ADD COLUMN message_request BOOLEAN NOT NULL DEFAULT false; diff --git a/bridgev2/database/upgrades/26-disappearing-message-portal-index.sql b/bridgev2/database/upgrades/26-disappearing-message-portal-index.sql new file mode 100644 index 00000000..ae5d8cad --- /dev/null +++ b/bridgev2/database/upgrades/26-disappearing-message-portal-index.sql @@ -0,0 +1,3 @@ +-- v26 (compatible with v9+): Add room index for disappearing message table and portal parents +CREATE INDEX disappearing_message_portal_idx ON disappearing_message (bridge_id, mx_room); +CREATE INDEX portal_parent_idx ON portal (bridge_id, parent_id, parent_receiver); diff --git a/bridgev2/database/upgrades/27-ghost-extra-profile.sql b/bridgev2/database/upgrades/27-ghost-extra-profile.sql new file mode 100644 index 00000000..e8e0549a --- /dev/null +++ b/bridgev2/database/upgrades/27-ghost-extra-profile.sql @@ -0,0 +1,2 @@ +-- v27 (compatible with v9+): Add column for extra ghost profile metadata +ALTER TABLE ghost ADD COLUMN extra_profile jsonb; diff --git a/bridgev2/database/upgrades/upgrades.go b/bridgev2/database/upgrades/upgrades.go new file mode 100644 index 00000000..4fef472e --- /dev/null +++ b/bridgev2/database/upgrades/upgrades.go @@ -0,0 +1,22 @@ +// Copyright (c) 2024 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package upgrades + +import ( + "embed" + + "go.mau.fi/util/dbutil" +) + +var Table dbutil.UpgradeTable + +//go:embed *.sql +var rawUpgrades embed.FS + +func init() { + Table.RegisterFS(rawUpgrades) +} diff --git a/bridgev2/database/user.go b/bridgev2/database/user.go new file mode 100644 index 00000000..00eae7ca --- /dev/null +++ b/bridgev2/database/user.go @@ -0,0 +1,74 @@ +// Copyright (c) 2024 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package database + +import ( + "context" + "database/sql" + + "go.mau.fi/util/dbutil" + + "maunium.net/go/mautrix/bridgev2/networkid" + "maunium.net/go/mautrix/id" +) + +type UserQuery struct { + BridgeID networkid.BridgeID + *dbutil.QueryHelper[*User] +} + +type User struct { + BridgeID networkid.BridgeID + MXID id.UserID + + ManagementRoom id.RoomID + AccessToken string +} + +const ( + getUserBaseQuery = ` + SELECT bridge_id, mxid, management_room, access_token FROM "user" + ` + getUserByMXIDQuery = getUserBaseQuery + `WHERE bridge_id=$1 AND mxid=$2` + insertUserQuery = ` + INSERT INTO "user" (bridge_id, mxid, management_room, access_token) + VALUES ($1, $2, $3, $4) + ` + updateUserQuery = ` + UPDATE "user" SET management_room=$3, access_token=$4 + WHERE bridge_id=$1 AND mxid=$2 + ` +) + +func (uq *UserQuery) GetByMXID(ctx context.Context, userID id.UserID) (*User, error) { + return uq.QueryOne(ctx, getUserByMXIDQuery, uq.BridgeID, userID) +} + +func (uq *UserQuery) Insert(ctx context.Context, user *User) error { + ensureBridgeIDMatches(&user.BridgeID, uq.BridgeID) + return uq.Exec(ctx, insertUserQuery, user.sqlVariables()...) +} + +func (uq *UserQuery) Update(ctx context.Context, user *User) error { + ensureBridgeIDMatches(&user.BridgeID, uq.BridgeID) + return uq.Exec(ctx, updateUserQuery, user.sqlVariables()...) +} + +func (u *User) Scan(row dbutil.Scannable) (*User, error) { + var managementRoom, accessToken sql.NullString + err := row.Scan(&u.BridgeID, &u.MXID, &managementRoom, &accessToken) + if err != nil { + return nil, err + } + u.ManagementRoom = id.RoomID(managementRoom.String) + u.AccessToken = accessToken.String + return u, nil +} + +func (u *User) sqlVariables() []any { + return []any{u.BridgeID, u.MXID, dbutil.StrPtr(u.ManagementRoom), dbutil.StrPtr(u.AccessToken)} +} diff --git a/bridgev2/database/userlogin.go b/bridgev2/database/userlogin.go new file mode 100644 index 00000000..00ff01c9 --- /dev/null +++ b/bridgev2/database/userlogin.go @@ -0,0 +1,123 @@ +// Copyright (c) 2024 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package database + +import ( + "context" + "database/sql" + + "go.mau.fi/util/dbutil" + + "maunium.net/go/mautrix/bridgev2/networkid" + "maunium.net/go/mautrix/bridgev2/status" + "maunium.net/go/mautrix/id" +) + +type UserLoginQuery struct { + BridgeID networkid.BridgeID + MetaType MetaTypeCreator + *dbutil.QueryHelper[*UserLogin] +} + +type UserLogin struct { + BridgeID networkid.BridgeID + UserMXID id.UserID + ID networkid.UserLoginID + RemoteName string + RemoteProfile status.RemoteProfile + SpaceRoom id.RoomID + Metadata any +} + +const ( + getUserLoginBaseQuery = ` + SELECT bridge_id, user_mxid, id, remote_name, remote_profile, space_room, metadata FROM user_login + ` + getLoginByIDQuery = getUserLoginBaseQuery + `WHERE bridge_id=$1 AND id=$2` + getAllUsersWithLoginsQuery = `SELECT DISTINCT user_mxid FROM user_login WHERE bridge_id=$1` + getAllLoginsForUserQuery = getUserLoginBaseQuery + `WHERE bridge_id=$1 AND user_mxid=$2` + getAllLoginsInPortalQuery = ` + SELECT ul.bridge_id, ul.user_mxid, ul.id, ul.remote_name, ul.remote_profile, ul.space_room, ul.metadata FROM user_portal + LEFT JOIN user_login ul ON user_portal.bridge_id=ul.bridge_id AND user_portal.user_mxid=ul.user_mxid AND user_portal.login_id=ul.id + WHERE user_portal.bridge_id=$1 AND user_portal.portal_id=$2 AND user_portal.portal_receiver=$3 + ` + insertUserLoginQuery = ` + INSERT INTO user_login (bridge_id, user_mxid, id, remote_name, remote_profile, space_room, metadata) + VALUES ($1, $2, $3, $4, $5, $6, $7) + ` + updateUserLoginQuery = ` + UPDATE user_login SET remote_name=$4, remote_profile=$5, space_room=$6, metadata=$7 + WHERE bridge_id=$1 AND user_mxid=$2 AND id=$3 + ` + deleteUserLoginQuery = ` + DELETE FROM user_login WHERE bridge_id=$1 AND id=$2 + ` +) + +func (uq *UserLoginQuery) GetByID(ctx context.Context, id networkid.UserLoginID) (*UserLogin, error) { + return uq.QueryOne(ctx, getLoginByIDQuery, uq.BridgeID, id) +} + +func (uq *UserLoginQuery) GetAllUserIDsWithLogins(ctx context.Context) ([]id.UserID, error) { + rows, err := uq.GetDB().Query(ctx, getAllUsersWithLoginsQuery, uq.BridgeID) + return dbutil.NewRowIterWithError(rows, dbutil.ScanSingleColumn[id.UserID], err).AsList() +} + +func (uq *UserLoginQuery) GetAllInPortal(ctx context.Context, portal networkid.PortalKey) ([]*UserLogin, error) { + return uq.QueryMany(ctx, getAllLoginsInPortalQuery, uq.BridgeID, portal.ID, portal.Receiver) +} + +func (uq *UserLoginQuery) GetAllForUser(ctx context.Context, userID id.UserID) ([]*UserLogin, error) { + return uq.QueryMany(ctx, getAllLoginsForUserQuery, uq.BridgeID, userID) +} + +func (uq *UserLoginQuery) Insert(ctx context.Context, login *UserLogin) error { + ensureBridgeIDMatches(&login.BridgeID, uq.BridgeID) + return uq.Exec(ctx, insertUserLoginQuery, login.ensureHasMetadata(uq.MetaType).sqlVariables()...) +} + +func (uq *UserLoginQuery) Update(ctx context.Context, login *UserLogin) error { + ensureBridgeIDMatches(&login.BridgeID, uq.BridgeID) + return uq.Exec(ctx, updateUserLoginQuery, login.ensureHasMetadata(uq.MetaType).sqlVariables()...) +} + +func (uq *UserLoginQuery) Delete(ctx context.Context, loginID networkid.UserLoginID) error { + return uq.Exec(ctx, deleteUserLoginQuery, uq.BridgeID, loginID) +} + +func (u *UserLogin) Scan(row dbutil.Scannable) (*UserLogin, error) { + var spaceRoom sql.NullString + err := row.Scan( + &u.BridgeID, + &u.UserMXID, + &u.ID, + &u.RemoteName, + dbutil.JSON{Data: &u.RemoteProfile}, + &spaceRoom, + dbutil.JSON{Data: u.Metadata}, + ) + if err != nil { + return nil, err + } + u.SpaceRoom = id.RoomID(spaceRoom.String) + return u, nil +} + +func (u *UserLogin) ensureHasMetadata(metaType MetaTypeCreator) *UserLogin { + if u.Metadata == nil { + u.Metadata = metaType() + } + return u +} + +func (u *UserLogin) sqlVariables() []any { + var remoteProfile dbutil.JSON + if !u.RemoteProfile.IsZero() { + remoteProfile.Data = &u.RemoteProfile + } + return []any{u.BridgeID, u.UserMXID, u.ID, u.RemoteName, remoteProfile, dbutil.StrPtr(u.SpaceRoom), dbutil.JSON{Data: u.Metadata}} +} diff --git a/bridgev2/database/userportal.go b/bridgev2/database/userportal.go new file mode 100644 index 00000000..e928a4c7 --- /dev/null +++ b/bridgev2/database/userportal.go @@ -0,0 +1,155 @@ +// Copyright (c) 2024 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package database + +import ( + "context" + "database/sql" + "time" + + "go.mau.fi/util/dbutil" + + "maunium.net/go/mautrix/bridgev2/networkid" + "maunium.net/go/mautrix/id" +) + +type UserPortalQuery struct { + BridgeID networkid.BridgeID + *dbutil.QueryHelper[*UserPortal] +} + +type UserPortal struct { + BridgeID networkid.BridgeID + UserMXID id.UserID + LoginID networkid.UserLoginID + Portal networkid.PortalKey + InSpace *bool + Preferred *bool + LastRead time.Time +} + +const ( + getUserPortalBaseQuery = ` + SELECT bridge_id, user_mxid, login_id, portal_id, portal_receiver, in_space, preferred, last_read + FROM user_portal + ` + getUserPortalQuery = getUserPortalBaseQuery + ` + WHERE bridge_id=$1 AND user_mxid=$2 AND login_id=$3 AND portal_id=$4 AND portal_receiver=$5 + ` + findUserLoginsOfUserByPortalIDQuery = getUserPortalBaseQuery + ` + WHERE bridge_id=$1 AND user_mxid=$2 AND portal_id=$3 AND portal_receiver=$4 + ORDER BY CASE WHEN preferred THEN 0 ELSE 1 END, login_id + ` + getAllUserLoginsInPortalQuery = getUserPortalBaseQuery + ` + WHERE bridge_id=$1 AND portal_id=$2 AND portal_receiver=$3 + ` + getAllPortalsForLoginQuery = getUserPortalBaseQuery + ` + WHERE bridge_id=$1 AND user_mxid=$2 AND login_id=$3 + ` + getOrCreateUserPortalQuery = ` + INSERT INTO user_portal (bridge_id, user_mxid, login_id, portal_id, portal_receiver, in_space, preferred) + VALUES ($1, $2, $3, $4, $5, false, false) + ON CONFLICT (bridge_id, user_mxid, login_id, portal_id, portal_receiver) DO UPDATE SET portal_id=user_portal.portal_id + RETURNING bridge_id, user_mxid, login_id, portal_id, portal_receiver, in_space, preferred, last_read + ` + upsertUserPortalQuery = ` + INSERT INTO user_portal (bridge_id, user_mxid, login_id, portal_id, portal_receiver, in_space, preferred, last_read) + VALUES ($1, $2, $3, $4, $5, COALESCE($6, false), COALESCE($7, false), $8) + ON CONFLICT (bridge_id, user_mxid, login_id, portal_id, portal_receiver) DO UPDATE + SET in_space=COALESCE($6, user_portal.in_space), + preferred=COALESCE($7, user_portal.preferred), + last_read=COALESCE($8, user_portal.last_read) + ` + markLoginAsPreferredQuery = ` + UPDATE user_portal SET preferred=(login_id=$3) WHERE bridge_id=$1 AND user_mxid=$2 AND portal_id=$4 AND portal_receiver=$5 + ` + markAllNotInSpaceQuery = ` + UPDATE user_portal SET in_space=false WHERE bridge_id=$1 AND portal_id=$2 AND portal_receiver=$3 + ` + deleteUserPortalQuery = ` + DELETE FROM user_portal WHERE bridge_id=$1 AND user_mxid=$2 AND login_id=$3 AND portal_id=$4 AND portal_receiver=$5 + ` +) + +func UserPortalFor(ul *UserLogin, portal networkid.PortalKey) *UserPortal { + return &UserPortal{ + BridgeID: ul.BridgeID, + UserMXID: ul.UserMXID, + LoginID: ul.ID, + Portal: portal, + } +} + +func (upq *UserPortalQuery) GetAllForUserInPortal(ctx context.Context, userID id.UserID, portal networkid.PortalKey) ([]*UserPortal, error) { + return upq.QueryMany(ctx, findUserLoginsOfUserByPortalIDQuery, upq.BridgeID, userID, portal.ID, portal.Receiver) +} + +func (upq *UserPortalQuery) GetAllForLogin(ctx context.Context, login *UserLogin) ([]*UserPortal, error) { + return upq.QueryMany(ctx, getAllPortalsForLoginQuery, upq.BridgeID, login.UserMXID, login.ID) +} + +func (upq *UserPortalQuery) GetAllInPortal(ctx context.Context, portal networkid.PortalKey) ([]*UserPortal, error) { + return upq.QueryMany(ctx, getAllUserLoginsInPortalQuery, upq.BridgeID, portal.ID, portal.Receiver) +} + +func (upq *UserPortalQuery) Get(ctx context.Context, login *UserLogin, portal networkid.PortalKey) (*UserPortal, error) { + return upq.QueryOne(ctx, getUserPortalQuery, upq.BridgeID, login.UserMXID, login.ID, portal.ID, portal.Receiver) +} + +func (upq *UserPortalQuery) GetOrCreate(ctx context.Context, login *UserLogin, portal networkid.PortalKey) (*UserPortal, error) { + return upq.QueryOne(ctx, getOrCreateUserPortalQuery, upq.BridgeID, login.UserMXID, login.ID, portal.ID, portal.Receiver) +} + +func (upq *UserPortalQuery) Put(ctx context.Context, up *UserPortal) error { + ensureBridgeIDMatches(&up.BridgeID, upq.BridgeID) + return upq.Exec(ctx, upsertUserPortalQuery, up.sqlVariables()...) +} + +func (upq *UserPortalQuery) MarkAsPreferred(ctx context.Context, login *UserLogin, portal networkid.PortalKey) error { + return upq.Exec(ctx, markLoginAsPreferredQuery, upq.BridgeID, login.UserMXID, login.ID, portal.ID, portal.Receiver) +} + +func (upq *UserPortalQuery) MarkAllNotInSpace(ctx context.Context, portal networkid.PortalKey) error { + return upq.Exec(ctx, markAllNotInSpaceQuery, upq.BridgeID, portal.ID, portal.Receiver) +} + +func (upq *UserPortalQuery) Delete(ctx context.Context, up *UserPortal) error { + return upq.Exec(ctx, deleteUserPortalQuery, up.BridgeID, up.UserMXID, up.LoginID, up.Portal.ID, up.Portal.Receiver) +} + +func (up *UserPortal) Scan(row dbutil.Scannable) (*UserPortal, error) { + var lastRead sql.NullInt64 + err := row.Scan( + &up.BridgeID, &up.UserMXID, &up.LoginID, &up.Portal.ID, &up.Portal.Receiver, + &up.InSpace, &up.Preferred, &lastRead, + ) + if err != nil { + return nil, err + } + if lastRead.Valid { + up.LastRead = time.Unix(0, lastRead.Int64) + } + return up, nil +} + +func (up *UserPortal) sqlVariables() []any { + return []any{ + up.BridgeID, up.UserMXID, up.LoginID, up.Portal.ID, up.Portal.Receiver, + up.InSpace, + up.Preferred, + dbutil.ConvertedPtr(up.LastRead, time.Time.UnixNano), + } +} + +func (up *UserPortal) CopyWithoutValues() *UserPortal { + return &UserPortal{ + BridgeID: up.BridgeID, + UserMXID: up.UserMXID, + LoginID: up.LoginID, + Portal: up.Portal, + } +} diff --git a/bridgev2/disappear.go b/bridgev2/disappear.go new file mode 100644 index 00000000..b5c37e8f --- /dev/null +++ b/bridgev2/disappear.go @@ -0,0 +1,153 @@ +// Copyright (c) 2024 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package bridgev2 + +import ( + "context" + "sync/atomic" + "time" + + "github.com/rs/zerolog" + "golang.org/x/exp/slices" + + "maunium.net/go/mautrix/bridgev2/database" + "maunium.net/go/mautrix/event" + "maunium.net/go/mautrix/id" +) + +type DisappearLoop struct { + br *Bridge + nextCheck atomic.Pointer[time.Time] + stop atomic.Pointer[context.CancelFunc] +} + +const DisappearCheckInterval = 1 * time.Hour + +func (dl *DisappearLoop) Start() { + log := dl.br.Log.With().Str("component", "disappear loop").Logger() + ctx, stop := context.WithCancel(log.WithContext(context.Background())) + if oldStop := dl.stop.Swap(&stop); oldStop != nil { + (*oldStop)() + } + log.Debug().Msg("Disappearing message loop starting") + for { + nextCheck := time.Now().Add(DisappearCheckInterval) + dl.nextCheck.Store(&nextCheck) + const MessageLimit = 200 + messages, err := dl.br.DB.DisappearingMessage.GetUpcoming(ctx, DisappearCheckInterval, MessageLimit) + if err != nil { + log.Err(err).Msg("Failed to get upcoming disappearing messages") + } else if len(messages) > 0 { + if len(messages) >= MessageLimit { + lastDisappearTime := messages[len(messages)-1].DisappearAt + log.Debug(). + Int("message_count", len(messages)). + Time("last_due", lastDisappearTime). + Msg("Deleting disappearing messages synchronously and checking again immediately") + // Store the expected next check time to avoid Add spawning unnecessary goroutines. + // This can be in the past, in which case Add will put everything in the db, which is also fine. + dl.nextCheck.Store(&lastDisappearTime) + // If there are many messages, process them synchronously and then check again. + dl.sleepAndDisappear(ctx, messages...) + continue + } + go dl.sleepAndDisappear(ctx, messages...) + } + select { + case <-time.After(time.Until(dl.GetNextCheck())): + case <-ctx.Done(): + log.Debug().Msg("Disappearing message loop stopping") + return + } + } +} + +func (dl *DisappearLoop) GetNextCheck() time.Time { + if dl == nil { + return time.Time{} + } + nextCheck := dl.nextCheck.Load() + if nextCheck == nil { + return time.Time{} + } + return *nextCheck +} + +func (dl *DisappearLoop) Stop() { + if dl == nil { + return + } + if stop := dl.stop.Load(); stop != nil { + (*stop)() + } +} + +func (dl *DisappearLoop) StartAllBefore(ctx context.Context, roomID id.RoomID, beforeTS time.Time) { + startedMessages, err := dl.br.DB.DisappearingMessage.StartAllBefore(ctx, roomID, beforeTS) + if err != nil { + zerolog.Ctx(ctx).Err(err).Msg("Failed to start disappearing messages") + return + } + startedMessages = slices.DeleteFunc(startedMessages, func(dm *database.DisappearingMessage) bool { + return dm.DisappearAt.After(dl.GetNextCheck()) + }) + slices.SortFunc(startedMessages, func(a, b *database.DisappearingMessage) int { + return a.DisappearAt.Compare(b.DisappearAt) + }) + if len(startedMessages) > 0 { + go dl.sleepAndDisappear(ctx, startedMessages...) + } +} + +func (dl *DisappearLoop) Add(ctx context.Context, dm *database.DisappearingMessage) { + err := dl.br.DB.DisappearingMessage.Put(ctx, dm) + if err != nil { + zerolog.Ctx(ctx).Err(err). + Stringer("event_id", dm.EventID). + Msg("Failed to save disappearing message") + } + if !dm.DisappearAt.IsZero() && dm.DisappearAt.Before(dl.GetNextCheck()) { + go dl.sleepAndDisappear(zerolog.Ctx(ctx).WithContext(dl.br.BackgroundCtx), dm) + } +} + +func (dl *DisappearLoop) sleepAndDisappear(ctx context.Context, dms ...*database.DisappearingMessage) { + for _, msg := range dms { + timeUntilDisappear := time.Until(msg.DisappearAt) + if timeUntilDisappear <= 0 { + if ctx.Err() != nil { + return + } + } else { + select { + case <-time.After(timeUntilDisappear): + case <-ctx.Done(): + return + } + } + resp, err := dl.br.Bot.SendMessage(ctx, msg.RoomID, event.EventRedaction, &event.Content{ + Parsed: &event.RedactionEventContent{ + Redacts: msg.EventID, + Reason: "Message disappeared", + }, + }, nil) + if err != nil { + zerolog.Ctx(ctx).Err(err).Stringer("target_event_id", msg.EventID).Msg("Failed to disappear message") + } else { + zerolog.Ctx(ctx).Debug(). + Stringer("target_event_id", msg.EventID). + Stringer("redaction_event_id", resp.EventID). + Msg("Disappeared message") + } + err = dl.br.DB.DisappearingMessage.Delete(ctx, msg.EventID) + if err != nil { + zerolog.Ctx(ctx).Err(err). + Stringer("event_id", msg.EventID). + Msg("Failed to delete disappearing message entry from database") + } + } +} diff --git a/bridgev2/errors.go b/bridgev2/errors.go new file mode 100644 index 00000000..f6677d2e --- /dev/null +++ b/bridgev2/errors.go @@ -0,0 +1,133 @@ +// Copyright (c) 2024 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package bridgev2 + +import ( + "errors" + "fmt" + "net/http" + + "maunium.net/go/mautrix" + "maunium.net/go/mautrix/event" +) + +// ErrIgnoringRemoteEvent can be returned by [RemoteMessage.ConvertMessage] or [RemoteEdit.ConvertEdit] +// to indicate that the event should be ignored after all. Handling the event will be cancelled immediately. +var ErrIgnoringRemoteEvent = errors.New("ignoring remote event") + +// ErrNoStatus can be returned by [MatrixMessageResponse.HandleEcho] to indicate that the message is still in-flight +// and a status should not be sent yet. The message will still be saved into the database. +var ErrNoStatus = errors.New("omit message status") + +// ErrResolveIdentifierTryNext can be returned by ResolveIdentifier or CreateChatWithGhost to signal that +// the identifier is valid, but can't be reached by the current login, and the caller should try the next +// login if there are more. +// +// This should generally only be returned when resolving internal IDs (which happens when initiating chats via Matrix). +// For example, Google Messages would return this when trying to resolve another login's user ID, +// and Telegram would return this when the access hash isn't available. +var ErrResolveIdentifierTryNext = errors.New("that identifier is not available via this login") + +var ErrNotLoggedIn = errors.New("not logged in") + +// ErrDirectMediaNotEnabled may be returned by Matrix connectors if [MatrixConnector.GenerateContentURI] is called, +// but direct media is not enabled. +var ErrDirectMediaNotEnabled = errors.New("direct media is not enabled") + +var ErrPortalIsDeleted = errors.New("portal is deleted") +var ErrPortalNotFoundInEventHandler = errors.New("portal not found to handle remote event") + +// Common message status errors +var ( + ErrPanicInEventHandler error = WrapErrorInStatus(errors.New("panic in event handler")).WithSendNotice(true).WithErrorAsMessage() + ErrNoPortal error = WrapErrorInStatus(errors.New("room is not a portal")).WithIsCertain(true).WithSendNotice(false) + ErrIgnoringReactionFromRelayedUser error = WrapErrorInStatus(errors.New("ignoring reaction event from relayed user")).WithIsCertain(true).WithSendNotice(false) + ErrIgnoringPollFromRelayedUser error = WrapErrorInStatus(errors.New("ignoring poll event from relayed user")).WithIsCertain(true).WithSendNotice(false) + ErrIgnoringDeleteChatRelayedUser error = WrapErrorInStatus(errors.New("ignoring delete chat event from relayed user")).WithIsCertain(true).WithSendNotice(false) + ErrIgnoringAcceptRequestRelayedUser error = WrapErrorInStatus(errors.New("ignoring accept message request event from relayed user")).WithIsCertain(true).WithSendNotice(false) + ErrEditsNotSupported error = WrapErrorInStatus(errors.New("this bridge does not support edits")).WithIsCertain(true).WithErrorAsMessage().WithErrorReason(event.MessageStatusUnsupported) + ErrEditsNotSupportedInPortal error = WrapErrorInStatus(errors.New("edits are not allowed in this chat")).WithIsCertain(true).WithErrorAsMessage().WithErrorReason(event.MessageStatusUnsupported) + ErrCaptionsNotAllowed error = WrapErrorInStatus(errors.New("captions are not supported here")).WithIsCertain(true).WithErrorAsMessage().WithErrorReason(event.MessageStatusUnsupported) + ErrLocationMessagesNotAllowed error = WrapErrorInStatus(errors.New("location messages are not supported here")).WithIsCertain(true).WithErrorAsMessage().WithErrorReason(event.MessageStatusUnsupported) + ErrEditTargetTooOld error = WrapErrorInStatus(errors.New("the message is too old to be edited")).WithIsCertain(true).WithErrorAsMessage().WithErrorReason(event.MessageStatusUnsupported) + ErrEditTargetTooManyEdits error = WrapErrorInStatus(errors.New("the message has been edited too many times")).WithIsCertain(true).WithErrorAsMessage().WithErrorReason(event.MessageStatusUnsupported) + ErrReactionsNotSupported error = WrapErrorInStatus(errors.New("this bridge does not support reactions")).WithIsCertain(true).WithErrorAsMessage().WithErrorReason(event.MessageStatusUnsupported) + ErrPollsNotSupported error = WrapErrorInStatus(errors.New("this bridge does not support polls")).WithIsCertain(true).WithErrorAsMessage().WithErrorReason(event.MessageStatusUnsupported) + ErrRoomMetadataNotSupported error = WrapErrorInStatus(errors.New("this bridge does not support changing room metadata")).WithIsCertain(true).WithErrorAsMessage().WithSendNotice(false).WithErrorReason(event.MessageStatusUnsupported) + ErrRoomMetadataNotAllowed error = WrapErrorInStatus(errors.New("changes are not allowed here")).WithIsCertain(true).WithErrorAsMessage().WithSendNotice(false).WithErrorReason(event.MessageStatusUnsupported) + ErrRedactionsNotSupported error = WrapErrorInStatus(errors.New("this bridge does not support deleting messages")).WithIsCertain(true).WithErrorAsMessage().WithErrorReason(event.MessageStatusUnsupported) + ErrUnexpectedParsedContentType error = WrapErrorInStatus(errors.New("unexpected parsed content type")).WithErrorAsMessage().WithIsCertain(true).WithSendNotice(true) + ErrInvalidStateKey error = WrapErrorInStatus(errors.New("room metadata state key is unset or non-empty")).WithErrorAsMessage().WithIsCertain(true).WithSendNotice(false) + ErrDatabaseError error = WrapErrorInStatus(errors.New("database error")).WithMessage("internal database error").WithIsCertain(true).WithSendNotice(true) + ErrTargetMessageNotFound error = WrapErrorInStatus(errors.New("target message not found")).WithErrorAsMessage().WithIsCertain(true).WithSendNotice(false) + ErrUnsupportedMessageType error = WrapErrorInStatus(errors.New("unsupported message type")).WithErrorAsMessage().WithIsCertain(true).WithSendNotice(true).WithErrorReason(event.MessageStatusUnsupported) + ErrUnsupportedMediaType error = WrapErrorInStatus(errors.New("unsupported media type")).WithErrorAsMessage().WithIsCertain(true).WithSendNotice(true).WithErrorReason(event.MessageStatusUnsupported) + ErrMediaDurationTooLong error = WrapErrorInStatus(errors.New("media duration too long")).WithErrorAsMessage().WithSendNotice(true).WithErrorReason(event.MessageStatusUnsupported) + ErrVoiceMessageDurationTooLong error = WrapErrorInStatus(errors.New("voice message too long")).WithErrorAsMessage().WithSendNotice(true).WithErrorReason(event.MessageStatusUnsupported) + ErrMediaTooLarge error = WrapErrorInStatus(errors.New("media too large")).WithErrorAsMessage().WithIsCertain(true).WithSendNotice(true).WithErrorReason(event.MessageStatusUnsupported) + ErrIgnoringMNotice error = WrapErrorInStatus(errors.New("ignoring m.notice message")).WithIsCertain(true).WithErrorAsMessage().WithSendNotice(false) + ErrMediaDownloadFailed error = WrapErrorInStatus(errors.New("failed to download media")).WithMessage("failed to download media").WithIsCertain(true).WithSendNotice(true) + ErrMediaReuploadFailed error = WrapErrorInStatus(errors.New("failed to reupload media")).WithMessage("failed to reupload media").WithIsCertain(true).WithSendNotice(true) + ErrMediaConvertFailed error = WrapErrorInStatus(errors.New("failed to convert media")).WithMessage("failed to convert media").WithIsCertain(true).WithSendNotice(true) + ErrMembershipNotSupported error = WrapErrorInStatus(errors.New("this bridge does not support changing group membership")).WithIsCertain(true).WithErrorAsMessage().WithSendNotice(false).WithErrorReason(event.MessageStatusUnsupported) + ErrDeleteChatNotSupported error = WrapErrorInStatus(errors.New("this bridge does not support deleting chats")).WithIsCertain(true).WithErrorAsMessage().WithSendNotice(false).WithErrorReason(event.MessageStatusUnsupported) + ErrBeeperAIStreamNotSupported error = WrapErrorInStatus(errors.New("this bridge does not support Beeper AI stream events")).WithIsCertain(true).WithErrorAsMessage().WithSendNotice(false).WithErrorReason(event.MessageStatusUnsupported) + ErrPowerLevelsNotSupported error = WrapErrorInStatus(errors.New("this bridge does not support changing group power levels")).WithIsCertain(true).WithErrorAsMessage().WithSendNotice(false).WithErrorReason(event.MessageStatusUnsupported) + ErrRemoteEchoTimeout = WrapErrorInStatus(errors.New("remote echo timed out")).WithIsCertain(false).WithSendNotice(true).WithErrorReason(event.MessageStatusTooOld) + ErrRemoteAckTimeout = WrapErrorInStatus(errors.New("remote ack timed out")).WithIsCertain(false).WithSendNotice(true).WithErrorReason(event.MessageStatusTooOld) + + ErrPublicMediaDisabled = WrapErrorInStatus(errors.New("public media is not enabled in the bridge config")).WithIsCertain(true).WithErrorAsMessage().WithErrorReason(event.MessageStatusUnsupported).WithSendNotice(true) + ErrPublicMediaDatabaseDisabled = WrapErrorInStatus(errors.New("public media database storage is disabled")).WithIsCertain(true).WithErrorAsMessage().WithErrorReason(event.MessageStatusUnsupported).WithSendNotice(true) + ErrPublicMediaGenerateFailed = WrapErrorInStatus(errors.New("failed to generate public media URL")).WithIsCertain(true).WithMessage("failed to generate public media URL").WithErrorReason(event.MessageStatusUnsupported).WithSendNotice(true) + + ErrDisappearingTimerUnsupported error = WrapErrorInStatus(errors.New("invalid disappearing timer")).WithIsCertain(true) +) + +// Common login interface errors +var ( + ErrInvalidLoginFlowID error = RespError(mautrix.MNotFound.WithMessage("Invalid login flow ID")) +) + +// RespError is a class of error that certain network interface methods can return to ensure that the error +// is properly translated into an HTTP error when the method is called via the provisioning API. +// +// However, unlike mautrix.RespError, this does not include the error code +// in the message shown to users when used outside HTTP contexts. +type RespError mautrix.RespError + +func (re RespError) Error() string { + return re.Err +} + +func (re RespError) Is(err error) bool { + var e2 RespError + if errors.As(err, &e2) { + return e2.Err == re.Err + } + return errors.Is(err, mautrix.RespError(re)) +} + +func (re RespError) Write(w http.ResponseWriter) { + mautrix.RespError(re).Write(w) +} + +func (re RespError) WithMessage(msg string, args ...any) RespError { + return RespError(mautrix.RespError(re).WithMessage(msg, args...)) +} + +func (re RespError) AppendMessage(append string, args ...any) RespError { + re.Err += fmt.Sprintf(append, args...) + return re +} + +func WrapRespErrManual(err error, code string, status int) RespError { + return RespError{ErrCode: code, Err: err.Error(), StatusCode: status} +} + +func WrapRespErr(err error, target mautrix.RespError) RespError { + return RespError{ErrCode: target.ErrCode, Err: err.Error(), StatusCode: target.StatusCode} +} diff --git a/bridgev2/ghost.go b/bridgev2/ghost.go new file mode 100644 index 00000000..590dd1dc --- /dev/null +++ b/bridgev2/ghost.go @@ -0,0 +1,321 @@ +// Copyright (c) 2024 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package bridgev2 + +import ( + "context" + "crypto/sha256" + "encoding/json" + "fmt" + "maps" + "net/http" + "slices" + + "github.com/rs/zerolog" + "go.mau.fi/util/exerrors" + "go.mau.fi/util/exmime" + + "maunium.net/go/mautrix/bridgev2/database" + "maunium.net/go/mautrix/bridgev2/networkid" + "maunium.net/go/mautrix/event" + "maunium.net/go/mautrix/id" +) + +type Ghost struct { + *database.Ghost + Bridge *Bridge + Log zerolog.Logger + Intent MatrixAPI +} + +func (br *Bridge) loadGhost(ctx context.Context, dbGhost *database.Ghost, queryErr error, id *networkid.UserID) (*Ghost, error) { + if queryErr != nil { + return nil, fmt.Errorf("failed to query db: %w", queryErr) + } + if dbGhost == nil { + if id == nil { + return nil, nil + } + dbGhost = &database.Ghost{ + BridgeID: br.ID, + ID: *id, + } + err := br.DB.Ghost.Insert(ctx, dbGhost) + if err != nil { + return nil, fmt.Errorf("failed to insert new ghost: %w", err) + } + } + ghost := &Ghost{ + Ghost: dbGhost, + Bridge: br, + Log: br.Log.With().Str("ghost_id", string(dbGhost.ID)).Logger(), + Intent: br.Matrix.GhostIntent(dbGhost.ID), + } + br.ghostsByID[ghost.ID] = ghost + return ghost, nil +} + +func (br *Bridge) unlockedGetGhostByID(ctx context.Context, id networkid.UserID, onlyIfExists bool) (*Ghost, error) { + cached, ok := br.ghostsByID[id] + if ok { + return cached, nil + } + idPtr := &id + if onlyIfExists { + idPtr = nil + } + db, err := br.DB.Ghost.GetByID(ctx, id) + return br.loadGhost(ctx, db, err, idPtr) +} + +func (br *Bridge) IsGhostMXID(userID id.UserID) bool { + _, isGhost := br.Matrix.ParseGhostMXID(userID) + return isGhost +} + +func (br *Bridge) GetGhostByMXID(ctx context.Context, mxid id.UserID) (*Ghost, error) { + ghostID, ok := br.Matrix.ParseGhostMXID(mxid) + if !ok { + return nil, nil + } + return br.GetGhostByID(ctx, ghostID) +} + +func (br *Bridge) GetGhostByID(ctx context.Context, id networkid.UserID) (*Ghost, error) { + br.cacheLock.Lock() + defer br.cacheLock.Unlock() + ghost, err := br.unlockedGetGhostByID(ctx, id, false) + if err != nil { + return nil, err + } else if ghost == nil { + panic(fmt.Errorf("unlockedGetGhostByID(ctx, %q, false) returned nil", id)) + } + return ghost, nil +} + +func (br *Bridge) GetExistingGhostByID(ctx context.Context, id networkid.UserID) (*Ghost, error) { + br.cacheLock.Lock() + defer br.cacheLock.Unlock() + return br.unlockedGetGhostByID(ctx, id, true) +} + +type Avatar struct { + ID networkid.AvatarID + Get func(ctx context.Context) ([]byte, error) + Remove bool + + // For pre-uploaded avatars, the MXC URI and hash can be provided directly + MXC id.ContentURIString + Hash [32]byte +} + +func (a *Avatar) Reupload(ctx context.Context, intent MatrixAPI, currentHash [32]byte, currentMXC id.ContentURIString) (id.ContentURIString, [32]byte, error) { + if a.MXC != "" || a.Hash != [32]byte{} { + return a.MXC, a.Hash, nil + } else if a.Get == nil { + return "", [32]byte{}, fmt.Errorf("no Get function provided for avatar") + } + data, err := a.Get(ctx) + if err != nil { + return "", [32]byte{}, err + } + hash := sha256.Sum256(data) + if hash == currentHash && currentMXC != "" { + return currentMXC, hash, nil + } + mime := http.DetectContentType(data) + fileName := "avatar" + exmime.ExtensionFromMimetype(mime) + uri, _, err := intent.UploadMedia(ctx, "", data, fileName, mime) + if err != nil { + return "", hash, err + } + return uri, hash, nil +} + +type UserInfo struct { + Identifiers []string + Name *string + Avatar *Avatar + IsBot *bool + ExtraProfile database.ExtraProfile + + ExtraUpdates ExtraUpdater[*Ghost] +} + +func (ghost *Ghost) UpdateName(ctx context.Context, name string) bool { + if ghost.Name == name && ghost.NameSet { + return false + } + ghost.Name = name + ghost.NameSet = false + err := ghost.Intent.SetDisplayName(ctx, name) + if err != nil { + zerolog.Ctx(ctx).Err(err).Msg("Failed to set display name") + } else { + ghost.NameSet = true + } + return true +} + +func (ghost *Ghost) UpdateAvatar(ctx context.Context, avatar *Avatar) bool { + if ghost.AvatarID == avatar.ID && (avatar.Remove || ghost.AvatarMXC != "") && ghost.AvatarSet { + return false + } + ghost.AvatarID = avatar.ID + if !avatar.Remove { + newMXC, newHash, err := avatar.Reupload(ctx, ghost.Intent, ghost.AvatarHash, ghost.AvatarMXC) + if err != nil { + ghost.AvatarSet = false + zerolog.Ctx(ctx).Err(err).Msg("Failed to reupload avatar") + return true + } else if newHash == ghost.AvatarHash && ghost.AvatarMXC != "" && ghost.AvatarSet { + return true + } + ghost.AvatarHash = newHash + ghost.AvatarMXC = newMXC + } else { + ghost.AvatarMXC = "" + } + ghost.AvatarSet = false + if err := ghost.Intent.SetAvatarURL(ctx, ghost.AvatarMXC); err != nil { + zerolog.Ctx(ctx).Err(err).Msg("Failed to set avatar URL") + } else { + ghost.AvatarSet = true + } + return true +} + +func (ghost *Ghost) getExtraProfileMeta() any { + bridgeName := ghost.Bridge.Network.GetName() + baseExtra := &event.BeeperProfileExtra{ + RemoteID: string(ghost.ID), + Identifiers: ghost.Identifiers, + Service: bridgeName.BeeperBridgeType, + Network: bridgeName.NetworkID, + IsBridgeBot: false, + IsNetworkBot: ghost.IsBot, + } + if len(ghost.ExtraProfile) == 0 { + return baseExtra + } + mergedExtra := maps.Clone(ghost.ExtraProfile) + baseExtraMarshaled := exerrors.Must(json.Marshal(baseExtra)) + exerrors.PanicIfNotNil(json.Unmarshal(baseExtraMarshaled, &mergedExtra)) + return mergedExtra +} + +func (ghost *Ghost) UpdateContactInfo(ctx context.Context, identifiers []string, isBot *bool, extraProfile database.ExtraProfile) bool { + if !ghost.Bridge.Matrix.GetCapabilities().ExtraProfileMeta { + ghost.ContactInfoSet = false + return false + } + if identifiers != nil { + slices.Sort(identifiers) + } + changed := extraProfile.CopyTo(&ghost.ExtraProfile) + if identifiers != nil { + changed = changed || !slices.Equal(identifiers, ghost.Identifiers) + ghost.Identifiers = identifiers + } + if isBot != nil { + changed = changed || *isBot != ghost.IsBot + ghost.IsBot = *isBot + } + if ghost.ContactInfoSet && !changed { + return false + } + err := ghost.Intent.SetExtraProfileMeta(ctx, ghost.getExtraProfileMeta()) + if err != nil { + zerolog.Ctx(ctx).Err(err).Msg("Failed to set extra profile metadata") + } else { + ghost.ContactInfoSet = true + } + return true +} + +func (br *Bridge) allowAggressiveUpdateForType(evtType RemoteEventType) bool { + if !br.Network.GetCapabilities().AggressiveUpdateInfo { + return false + } + switch evtType { + case RemoteEventUnknown, RemoteEventMessage, RemoteEventEdit, RemoteEventReaction: + return true + default: + return false + } +} + +func (ghost *Ghost) UpdateInfoIfNecessary(ctx context.Context, source *UserLogin, evtType RemoteEventType) { + if ghost.Name != "" && ghost.NameSet && ghost.AvatarSet && !ghost.Bridge.allowAggressiveUpdateForType(evtType) { + return + } + info, err := source.Client.GetUserInfo(ctx, ghost) + if err != nil { + zerolog.Ctx(ctx).Err(err).Str("ghost_id", string(ghost.ID)).Msg("Failed to get info to update ghost") + } else if info != nil { + zerolog.Ctx(ctx).Debug(). + Bool("has_name", ghost.Name != ""). + Bool("name_set", ghost.NameSet). + Bool("has_avatar", ghost.AvatarMXC != ""). + Bool("avatar_set", ghost.AvatarSet). + Msg("Updating ghost info in IfNecessary call") + ghost.UpdateInfo(ctx, info) + } else { + zerolog.Ctx(ctx).Trace(). + Bool("has_name", ghost.Name != ""). + Bool("name_set", ghost.NameSet). + Bool("has_avatar", ghost.AvatarMXC != ""). + Bool("avatar_set", ghost.AvatarSet). + Msg("No ghost info received in IfNecessary call") + } +} + +func (ghost *Ghost) updateDMPortals(ctx context.Context) { + if !ghost.Bridge.Config.PrivateChatPortalMeta { + return + } + dmPortals, err := ghost.Bridge.GetDMPortalsWith(ctx, ghost.ID) + if err != nil { + zerolog.Ctx(ctx).Err(err).Msg("Failed to get DM portals to update info") + return + } + for _, portal := range dmPortals { + go portal.lockedUpdateInfoFromGhost(ctx, ghost) + } +} + +func (ghost *Ghost) UpdateInfo(ctx context.Context, info *UserInfo) { + update := false + oldName := ghost.Name + oldAvatar := ghost.AvatarMXC + if info.Name != nil { + update = ghost.UpdateName(ctx, *info.Name) || update + } + if info.Avatar != nil { + update = ghost.UpdateAvatar(ctx, info.Avatar) || update + } else if oldAvatar == "" && !ghost.AvatarSet { + // Special case: nil avatar means we're not expecting one ever, if we don't currently have + // one we flag it as set to avoid constantly refetching in UpdateInfoIfNecessary. + ghost.AvatarSet = true + update = true + } + if info.Identifiers != nil || info.IsBot != nil || info.ExtraProfile != nil { + update = ghost.UpdateContactInfo(ctx, info.Identifiers, info.IsBot, info.ExtraProfile) || update + } + if info.ExtraUpdates != nil { + update = info.ExtraUpdates(ctx, ghost) || update + } + if oldName != ghost.Name || oldAvatar != ghost.AvatarMXC { + ghost.updateDMPortals(ctx) + } + if update { + err := ghost.Bridge.DB.Ghost.Update(ctx, ghost.Ghost) + if err != nil { + zerolog.Ctx(ctx).Err(err).Msg("Failed to update ghost in database after updating info") + } + } +} diff --git a/bridgev2/login.go b/bridgev2/login.go new file mode 100644 index 00000000..b8321719 --- /dev/null +++ b/bridgev2/login.go @@ -0,0 +1,301 @@ +// Copyright (c) 2024 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package bridgev2 + +import ( + "context" + "fmt" + "regexp" + "strings" + + "maunium.net/go/mautrix/bridgev2/networkid" + "maunium.net/go/mautrix/event" +) + +// LoginProcess represents a single occurrence of a user logging into the remote network. +type LoginProcess interface { + // Start starts the process and returns the first step. + // + // For example, a network using QR login may connect to the network, fetch a QR code, + // and return a DisplayAndWait-type step. + // + // This will only ever be called once. + Start(ctx context.Context) (*LoginStep, error) + // Cancel stops the login process and cleans up any resources. + // No other methods will be called after cancel. + // + // Cancel will not be called if any other method returned an error: + // errors are always treated as fatal and the process is assumed to be automatically cancelled. + Cancel() +} + +type LoginProcessWithOverride interface { + LoginProcess + // StartWithOverride starts the process with the intent of re-authenticating an existing login. + // + // The call to this is mutually exclusive with the call to the default Start method. + // + // The user login being overridden will still be logged out automatically + // in case the complete step returns a different login. + StartWithOverride(ctx context.Context, override *UserLogin) (*LoginStep, error) +} + +type LoginProcessDisplayAndWait interface { + LoginProcess + Wait(ctx context.Context) (*LoginStep, error) +} + +type LoginProcessUserInput interface { + LoginProcess + SubmitUserInput(ctx context.Context, input map[string]string) (*LoginStep, error) +} + +type LoginProcessCookies interface { + LoginProcess + SubmitCookies(ctx context.Context, cookies map[string]string) (*LoginStep, error) +} + +type LoginFlow struct { + Name string `json:"name"` + Description string `json:"description"` + ID string `json:"id"` +} + +type LoginStepType string + +const ( + LoginStepTypeUserInput LoginStepType = "user_input" + LoginStepTypeCookies LoginStepType = "cookies" + LoginStepTypeDisplayAndWait LoginStepType = "display_and_wait" + LoginStepTypeComplete LoginStepType = "complete" +) + +type LoginDisplayType string + +const ( + LoginDisplayTypeQR LoginDisplayType = "qr" + LoginDisplayTypeEmoji LoginDisplayType = "emoji" + LoginDisplayTypeCode LoginDisplayType = "code" + LoginDisplayTypeNothing LoginDisplayType = "nothing" +) + +type LoginStep struct { + // The type of login step + Type LoginStepType `json:"type"` + // A unique ID for this step. The ID should be same for every login using the same flow, + // but it should be different for different bridges and step types. + // + // For example, Telegram's QR scan followed by a 2-factor password + // might use the IDs `fi.mau.telegram.qr` and `fi.mau.telegram.2fa_password`. + StepID string `json:"step_id"` + // Instructions contains human-readable instructions for completing the login step. + Instructions string `json:"instructions"` + + // Exactly one of the following structs must be filled depending on the step type. + + DisplayAndWaitParams *LoginDisplayAndWaitParams `json:"display_and_wait,omitempty"` + CookiesParams *LoginCookiesParams `json:"cookies,omitempty"` + UserInputParams *LoginUserInputParams `json:"user_input,omitempty"` + CompleteParams *LoginCompleteParams `json:"complete,omitempty"` +} + +type LoginDisplayAndWaitParams struct { + // The type of thing to display (QR, emoji or text code) + Type LoginDisplayType `json:"type"` + // The thing to display (raw data for QR, unicode emoji for emoji, plain string for code, omitted for nothing) + Data string `json:"data,omitempty"` + // An image containing the thing to display. If present, this is recommended over using data directly. + // For emojis, the URL to the canonical image representation of the emoji + ImageURL string `json:"image_url,omitempty"` +} + +type LoginCookieFieldSourceType string + +const ( + LoginCookieTypeCookie LoginCookieFieldSourceType = "cookie" + LoginCookieTypeLocalStorage LoginCookieFieldSourceType = "local_storage" + LoginCookieTypeRequestHeader LoginCookieFieldSourceType = "request_header" + LoginCookieTypeRequestBody LoginCookieFieldSourceType = "request_body" + LoginCookieTypeSpecial LoginCookieFieldSourceType = "special" +) + +type LoginCookieFieldSource struct { + // The type of source. + Type LoginCookieFieldSourceType `json:"type"` + // The name of the field. The exact meaning depends on the type of source. + // Cookie: cookie name + // Local storage: key in local storage + // Request header: header name + // Request body: field name inside body after it's parsed (as JSON or multipart form data) + // Special: a namespaced identifier that clients can implement special handling for + Name string `json:"name"` + + // For request header & body types, a regex matching request URLs where the value can be extracted from. + RequestURLRegex string `json:"request_url_regex,omitempty"` + // For cookie types, the domain the cookie is present on. + CookieDomain string `json:"cookie_domain,omitempty"` +} + +type LoginCookieField struct { + // The key in the map that is submitted to the connector. + ID string `json:"id"` + Required bool `json:"required"` + // The sources that can be used to acquire the field value. Only one of these needs to be used. + Sources []LoginCookieFieldSource `json:"sources"` + // A regex pattern that the client can use to validate value client-side. + Pattern string `json:"pattern,omitempty"` +} + +type LoginCookiesParams struct { + URL string `json:"url"` + UserAgent string `json:"user_agent,omitempty"` + + // The fields that are needed for this cookie login. + Fields []LoginCookieField `json:"fields"` + // A JavaScript snippet that can extract some or all of the fields. + // The snippet will evaluate to a promise that resolves when the relevant fields are found. + // Fields that are not present in the promise result must be extracted another way. + ExtractJS string `json:"extract_js,omitempty"` + // A regex pattern that the URL should match before the client closes the webview. + // + // The client may submit the login if the user closes the webview after all cookies are collected + // even if this URL is not reached, but it should only automatically close the webview after + // both cookies and the URL match. + WaitForURLPattern string `json:"wait_for_url_pattern,omitempty"` +} + +type LoginInputFieldType string + +const ( + LoginInputFieldTypeUsername LoginInputFieldType = "username" + LoginInputFieldTypePassword LoginInputFieldType = "password" + LoginInputFieldTypePhoneNumber LoginInputFieldType = "phone_number" + LoginInputFieldTypeEmail LoginInputFieldType = "email" + LoginInputFieldType2FACode LoginInputFieldType = "2fa_code" + LoginInputFieldTypeToken LoginInputFieldType = "token" + LoginInputFieldTypeURL LoginInputFieldType = "url" + LoginInputFieldTypeDomain LoginInputFieldType = "domain" + LoginInputFieldTypeSelect LoginInputFieldType = "select" + LoginInputFieldTypeCaptchaCode LoginInputFieldType = "captcha_code" +) + +type LoginInputDataField struct { + // The type of input field as a hint for the client. + Type LoginInputFieldType `json:"type"` + // The ID of the field to be used as the key in the map that is submitted to the connector. + ID string `json:"id"` + // The name of the field shown to the user. + Name string `json:"name"` + // The description of the field shown to the user. + Description string `json:"description"` + // A default value that the client can pre-fill the field with. + DefaultValue string `json:"default_value,omitempty"` + // A regex pattern that the client can use to validate input client-side. + Pattern string `json:"pattern,omitempty"` + // For fields of type select, the valid options. + // Pattern may also be filled with a regex that matches the same options. + Options []string `json:"options,omitempty"` + // A function that validates the input and optionally cleans it up before it's submitted to the connector. + Validate func(string) (string, error) `json:"-"` +} + +var numberCleaner = strings.NewReplacer("-", "", " ", "", "(", "", ")", "") + +func isOnlyNumbers(input string) bool { + for _, r := range input { + if r < '0' || r > '9' { + return false + } + } + return true +} + +func CleanNonInternationalPhoneNumber(phone string) (string, error) { + phone = numberCleaner.Replace(phone) + if !isOnlyNumbers(strings.TrimPrefix(phone, "+")) { + return "", fmt.Errorf("phone number must only contain numbers") + } + return phone, nil +} + +func CleanPhoneNumber(phone string) (string, error) { + phone = numberCleaner.Replace(phone) + if len(phone) < 2 { + return "", fmt.Errorf("phone number must start with + and contain numbers") + } else if phone[0] != '+' { + return "", fmt.Errorf("phone number must start with +") + } else if !isOnlyNumbers(phone[1:]) { + return "", fmt.Errorf("phone number must only contain numbers") + } + return phone, nil +} + +func noopValidate(input string) (string, error) { + return input, nil +} + +func (f *LoginInputDataField) FillDefaultValidate() { + if f.Validate != nil { + return + } + switch f.Type { + case LoginInputFieldTypePhoneNumber: + f.Validate = CleanPhoneNumber + case LoginInputFieldTypeEmail: + f.Validate = func(email string) (string, error) { + if !strings.ContainsRune(email, '@') { + return "", fmt.Errorf("invalid email") + } + return email, nil + } + default: + if f.Pattern != "" { + f.Validate = func(s string) (string, error) { + match, err := regexp.MatchString(f.Pattern, s) + if err != nil { + return "", err + } else if !match { + return "", fmt.Errorf("doesn't match regex `%s`", f.Pattern) + } else { + return s, nil + } + } + } else { + f.Validate = noopValidate + } + } +} + +type LoginUserInputParams struct { + // The fields that the user needs to fill in. + Fields []LoginInputDataField `json:"fields"` + + // Attachments to display alongside the input fields. + Attachments []*LoginUserInputAttachment `json:"attachments"` +} + +type LoginUserInputAttachment struct { + Type event.MessageType `json:"type,omitempty"` + FileName string `json:"filename,omitempty"` + Content []byte `json:"content,omitempty"` + Info LoginUserInputAttachmentInfo `json:"info,omitempty"` +} + +type LoginUserInputAttachmentInfo struct { + MimeType string `json:"mimetype,omitempty"` + Width int `json:"w,omitempty"` + Height int `json:"h,omitempty"` + Size int `json:"size,omitempty"` +} + +type LoginCompleteParams struct { + UserLoginID networkid.UserLoginID `json:"user_login_id"` + UserLogin *UserLogin `json:"-"` +} + +type LoginSubmit struct { +} diff --git a/bridgev2/matrix/analytics.go b/bridgev2/matrix/analytics.go new file mode 100644 index 00000000..7eb2a33a --- /dev/null +++ b/bridgev2/matrix/analytics.go @@ -0,0 +1,62 @@ +package matrix + +import ( + "bytes" + "encoding/json" + "fmt" + "net/http" + + "maunium.net/go/mautrix/id" +) + +func (br *Connector) trackSync(userID id.UserID, event string, properties map[string]any) error { + var buf bytes.Buffer + var analyticsUserID string + if br.Config.Analytics.UserID != "" { + analyticsUserID = br.Config.Analytics.UserID + } else { + analyticsUserID = userID.String() + } + err := json.NewEncoder(&buf).Encode(map[string]any{ + "userId": analyticsUserID, + "event": event, + "properties": properties, + }) + if err != nil { + return err + } + + req, err := http.NewRequest(http.MethodPost, br.Config.Analytics.URL, &buf) + if err != nil { + return err + } + req.SetBasicAuth(br.Config.Analytics.Token, "") + resp, err := br.AS.HTTPClient.Do(req) + if err != nil { + return err + } + _ = resp.Body.Close() + if resp.StatusCode != http.StatusOK { + return fmt.Errorf("unexpected status code %d", resp.StatusCode) + } + return nil +} + +func (br *Connector) TrackAnalytics(userID id.UserID, event string, props map[string]any) { + if br.Config.Analytics.Token == "" || br.Config.Analytics.URL == "" { + return + } + + if props == nil { + props = map[string]any{} + } + props["bridge"] = br.Bridge.Network.GetName().BeeperBridgeType + go func() { + err := br.trackSync(userID, event, props) + if err != nil { + br.Log.Err(err).Str("component", "analytics").Str("event", event).Msg("Error tracking event") + } else { + br.Log.Debug().Str("component", "analytics").Str("event", event).Msg("Tracked event") + } + }() +} diff --git a/bridge/commands/admin.go b/bridgev2/matrix/cmdadmin.go similarity index 70% rename from bridge/commands/admin.go rename to bridgev2/matrix/cmdadmin.go index ff3340e3..0bd3eb82 100644 --- a/bridge/commands/admin.go +++ b/bridgev2/matrix/cmdadmin.go @@ -1,36 +1,38 @@ -// Copyright (c) 2022 Tulir Asokan +// Copyright (c) 2024 Tulir Asokan // // This Source Code Form is subject to the terms of the Mozilla Public // License, v. 2.0. If a copy of the MPL was not distributed with this // file, You can obtain one at http://mozilla.org/MPL/2.0/. -package commands +package matrix import ( "strconv" + "maunium.net/go/mautrix/bridgev2/commands" "maunium.net/go/mautrix/id" ) -var CommandDiscardMegolmSession = &FullHandler{ - Func: func(ce *Event) { - if ce.Bridge.Crypto == nil { +var CommandDiscardMegolmSession = &commands.FullHandler{ + Func: func(ce *commands.Event) { + matrix := ce.Bridge.Matrix.(*Connector) + if matrix.Crypto == nil { ce.Reply("This bridge instance doesn't have end-to-bridge encryption enabled") } else { - ce.Bridge.Crypto.ResetSession(ce.Ctx, ce.RoomID) + matrix.Crypto.ResetSession(ce.Ctx, ce.RoomID) ce.Reply("Successfully reset Megolm session in this room. New decryption keys will be shared the next time a message is sent from the remote network.") } }, Name: "discard-megolm-session", Aliases: []string{"discard-session"}, - Help: HelpMeta{ - Section: HelpSectionAdmin, + Help: commands.HelpMeta{ + Section: commands.HelpSectionAdmin, Description: "Discard the Megolm session in the room", }, RequiresAdmin: true, } -func fnSetPowerLevel(ce *Event) { +func fnSetPowerLevel(ce *commands.Event) { var level int var userID id.UserID var err error @@ -40,7 +42,7 @@ func fnSetPowerLevel(ce *Event) { ce.Reply("Invalid power level \"%s\"", ce.Args[0]) return } - userID = ce.User.GetMXID() + userID = ce.User.MXID } else if len(ce.Args) == 2 { userID = id.UserID(ce.Args[0]) _, _, err := userID.Parse() @@ -57,18 +59,18 @@ func fnSetPowerLevel(ce *Event) { ce.Reply("**Usage:** `set-pl [user] `") return } - _, err = ce.Portal.MainIntent().SetPowerLevel(ce.Ctx, ce.RoomID, userID, level) + _, err = ce.Bot.(*ASIntent).Matrix.SetPowerLevel(ce.Ctx, ce.RoomID, userID, level) if err != nil { ce.Reply("Failed to set power levels: %v", err) } } -var CommandSetPowerLevel = &FullHandler{ +var CommandSetPowerLevel = &commands.FullHandler{ Func: fnSetPowerLevel, Name: "set-pl", Aliases: []string{"set-power-level"}, - Help: HelpMeta{ - Section: HelpSectionAdmin, + Help: commands.HelpMeta{ + Section: commands.HelpSectionAdmin, Description: "Change the power level in a portal room.", Args: "[_user ID_] <_power level_>", }, diff --git a/bridgev2/matrix/cmddoublepuppet.go b/bridgev2/matrix/cmddoublepuppet.go new file mode 100644 index 00000000..2f3a3dc2 --- /dev/null +++ b/bridgev2/matrix/cmddoublepuppet.go @@ -0,0 +1,90 @@ +// Copyright (c) 2024 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package matrix + +import ( + "maunium.net/go/mautrix/bridgev2/commands" +) + +var CommandLoginMatrix = &commands.FullHandler{ + Func: fnLoginMatrix, + Name: "login-matrix", + Help: commands.HelpMeta{ + Section: commands.HelpSectionAuth, + Description: "Enable double puppeting.", + Args: "<_access token_>", + }, + RequiresLogin: true, +} + +func fnLoginMatrix(ce *commands.Event) { + if !ce.User.Permissions.DoublePuppet { + ce.Reply("You don't have permission to manage double puppeting.") + return + } + if len(ce.Args) == 0 { + ce.Reply("**Usage:** `login-matrix `") + return + } + err := ce.User.LoginDoublePuppet(ce.Ctx, ce.Args[0]) + if err != nil { + ce.Reply("Failed to enable double puppeting: %v", err) + } else { + ce.Reply("Successfully switched puppets") + } +} + +var CommandPingMatrix = &commands.FullHandler{ + Func: fnPingMatrix, + Name: "ping-matrix", + Help: commands.HelpMeta{ + Section: commands.HelpSectionAuth, + Description: "Ping the Matrix server with the double puppet.", + }, +} + +func fnPingMatrix(ce *commands.Event) { + intent := ce.User.DoublePuppet(ce.Ctx) + if intent == nil { + ce.Reply("You don't have double puppeting enabled.") + return + } + asIntent := intent.(*ASIntent) + resp, err := asIntent.Matrix.Whoami(ce.Ctx) + if err != nil { + ce.Reply("Failed to validate Matrix login: %v", err) + } else { + if asIntent.Matrix.SetAppServiceUserID && resp.DeviceID == "" { + ce.Reply("Confirmed valid access token for %s (appservice double puppeting)", resp.UserID) + } else { + ce.Reply("Confirmed valid access token for %s / %s", resp.UserID, resp.DeviceID) + } + } +} + +var CommandLogoutMatrix = &commands.FullHandler{ + Func: fnLogoutMatrix, + Name: "logout-matrix", + Help: commands.HelpMeta{ + Section: commands.HelpSectionAuth, + Description: "Disable double puppeting.", + }, + RequiresLogin: true, +} + +func fnLogoutMatrix(ce *commands.Event) { + if !ce.User.Permissions.DoublePuppet { + ce.Reply("You don't have permission to manage double puppeting.") + return + } + if ce.User.AccessToken == "" { + ce.Reply("You don't have double puppeting enabled.") + return + } + ce.User.LogoutDoublePuppet(ce.Ctx) + ce.Reply("Successfully disabled double puppeting.") +} diff --git a/bridgev2/matrix/connector.go b/bridgev2/matrix/connector.go new file mode 100644 index 00000000..5a2df953 --- /dev/null +++ b/bridgev2/matrix/connector.go @@ -0,0 +1,758 @@ +// Copyright (c) 2024 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package matrix + +import ( + "context" + "crypto/sha256" + "encoding/base64" + "errors" + "fmt" + "net/http" + "net/url" + "os" + "regexp" + "strings" + "sync" + "time" + + _ "github.com/lib/pq" + "github.com/rs/zerolog" + "go.mau.fi/util/dbutil" + _ "go.mau.fi/util/dbutil/litestream" + "go.mau.fi/util/exbytes" + "go.mau.fi/util/exsync" + "go.mau.fi/util/ptr" + "go.mau.fi/util/random" + "golang.org/x/sync/semaphore" + + "maunium.net/go/mautrix" + "maunium.net/go/mautrix/appservice" + "maunium.net/go/mautrix/bridgev2" + "maunium.net/go/mautrix/bridgev2/bridgeconfig" + "maunium.net/go/mautrix/bridgev2/commands" + "maunium.net/go/mautrix/bridgev2/database" + "maunium.net/go/mautrix/bridgev2/networkid" + "maunium.net/go/mautrix/bridgev2/status" + "maunium.net/go/mautrix/event" + "maunium.net/go/mautrix/id" + "maunium.net/go/mautrix/mediaproxy" + "maunium.net/go/mautrix/sqlstatestore" +) + +type Crypto interface { + HandleMemberEvent(context.Context, *event.Event) + Decrypt(context.Context, *event.Event) (*event.Event, error) + Encrypt(context.Context, id.RoomID, event.Type, *event.Content) error + WaitForSession(context.Context, id.RoomID, id.SenderKey, id.SessionID, time.Duration) bool + RequestSession(context.Context, id.RoomID, id.SenderKey, id.SessionID, id.UserID, id.DeviceID) + ResetSession(context.Context, id.RoomID) + Init(ctx context.Context) error + Start() + Stop() + Reset(ctx context.Context, startAfterReset bool) + Client() *mautrix.Client + ShareKeys(context.Context) error +} + +type Connector struct { + AS *appservice.AppService + Bot *appservice.IntentAPI + StateStore *sqlstatestore.SQLStateStore + Crypto Crypto + Log *zerolog.Logger + Config *bridgeconfig.Config + Bridge *bridgev2.Bridge + Provisioning *ProvisioningAPI + DoublePuppet *doublePuppetUtil + MediaProxy *mediaproxy.MediaProxy + + uploadSema *semaphore.Weighted + dmaSigKey [32]byte + pubMediaSigKey []byte + + doublePuppetIntents *exsync.Map[id.UserID, *appservice.IntentAPI] + + deterministicEventIDServer string + + MediaConfig mautrix.RespMediaConfig + SpecVersions *mautrix.RespVersions + SpecCaps *mautrix.RespCapabilities + specCapsLock sync.Mutex + Capabilities *bridgev2.MatrixCapabilities + IgnoreUnsupportedServer bool + + EventProcessor *appservice.EventProcessor + + userIDRegex *regexp.Regexp + + Websocket bool + wsStopPinger chan struct{} + wsStarted chan struct{} + wsStopped chan struct{} + wsShortCircuitReconnectBackoff chan struct{} + wsStartupWait *sync.WaitGroup + stopping bool + hasSentAnyStates bool + OnWebsocketReplaced func() +} + +var ( + _ bridgev2.MatrixConnector = (*Connector)(nil) + _ bridgev2.MatrixConnectorWithServer = (*Connector)(nil) + _ bridgev2.MatrixConnectorWithArbitraryRoomState = (*Connector)(nil) + _ bridgev2.MatrixConnectorWithPostRoomBridgeHandling = (*Connector)(nil) + _ bridgev2.MatrixConnectorWithPublicMedia = (*Connector)(nil) + _ bridgev2.MatrixConnectorWithNameDisambiguation = (*Connector)(nil) + _ bridgev2.MatrixConnectorWithURLPreviews = (*Connector)(nil) + _ bridgev2.MatrixConnectorWithAnalytics = (*Connector)(nil) +) + +func NewConnector(cfg *bridgeconfig.Config) *Connector { + c := &Connector{} + c.Config = cfg + c.userIDRegex = cfg.MakeUserIDRegex("(.+)") + c.MediaConfig.UploadSize = 50 * 1024 * 1024 + c.uploadSema = semaphore.NewWeighted(c.MediaConfig.UploadSize + 1) + c.Capabilities = &bridgev2.MatrixCapabilities{} + c.doublePuppetIntents = exsync.NewMap[id.UserID, *appservice.IntentAPI]() + return c +} + +func (br *Connector) Init(bridge *bridgev2.Bridge) { + br.Bridge = bridge + br.Log = &bridge.Log + br.StateStore = sqlstatestore.NewSQLStateStore(bridge.DB.Database, dbutil.ZeroLogger(br.Log.With().Str("db_section", "matrix_state").Logger()), false) + br.AS = br.Config.MakeAppService() + br.AS.Log = bridge.Log + br.AS.StateStore = br.StateStore + br.EventProcessor = appservice.NewEventProcessor(br.AS) + if !br.Config.AppService.AsyncTransactions { + br.EventProcessor.ExecMode = appservice.Sync + } + for evtType := range status.CheckpointTypes { + br.EventProcessor.On(evtType, br.sendBridgeCheckpoint) + } + br.EventProcessor.On(event.EventMessage, br.handleRoomEvent) + br.EventProcessor.On(event.EventSticker, br.handleRoomEvent) + br.EventProcessor.On(event.EventUnstablePollStart, br.handleRoomEvent) + br.EventProcessor.On(event.EventUnstablePollResponse, br.handleRoomEvent) + br.EventProcessor.On(event.EventReaction, br.handleRoomEvent) + br.EventProcessor.On(event.EventRedaction, br.handleRoomEvent) + br.EventProcessor.On(event.EventEncrypted, br.handleEncryptedEvent) + br.EventProcessor.On(event.EphemeralEventEncrypted, br.handleEncryptedEvent) + br.EventProcessor.On(event.StateMember, br.handleRoomEvent) + br.EventProcessor.On(event.StatePowerLevels, br.handleRoomEvent) + br.EventProcessor.On(event.StateRoomName, br.handleRoomEvent) + br.EventProcessor.On(event.BeeperSendState, br.handleRoomEvent) + br.EventProcessor.On(event.StateRoomAvatar, br.handleRoomEvent) + br.EventProcessor.On(event.StateTopic, br.handleRoomEvent) + br.EventProcessor.On(event.StateTombstone, br.handleRoomEvent) + br.EventProcessor.On(event.StateBeeperDisappearingTimer, br.handleRoomEvent) + br.EventProcessor.On(event.BeeperDeleteChat, br.handleRoomEvent) + br.EventProcessor.On(event.BeeperAcceptMessageRequest, br.handleRoomEvent) + br.EventProcessor.On(event.EphemeralEventReceipt, br.handleEphemeralEvent) + br.EventProcessor.On(event.EphemeralEventTyping, br.handleEphemeralEvent) + br.EventProcessor.On(event.BeeperEphemeralEventAIStream, br.handleEphemeralEvent) + br.Bot = br.AS.BotIntent() + br.Crypto = NewCryptoHelper(br) + br.Bridge.Commands.(*commands.Processor).AddHandlers( + CommandDiscardMegolmSession, CommandSetPowerLevel, + CommandLoginMatrix, CommandPingMatrix, CommandLogoutMatrix, + ) + br.Provisioning = &ProvisioningAPI{br: br} + br.DoublePuppet = newDoublePuppetUtil(br) + br.deterministicEventIDServer = "backfill." + br.Config.Homeserver.Domain +} + +func (br *Connector) Start(ctx context.Context) error { + br.Provisioning.Init() + err := br.initDirectMedia() + if err != nil { + return err + } + err = br.initPublicMedia() + if err != nil { + return err + } + needsStateResync := br.Config.Encryption.Default && + br.Bridge.DB.KV.Get(ctx, database.KeyEncryptionStateResynced) != "true" + if needsStateResync { + dbExists, err := br.StateStore.TableExists(ctx, "mx_version") + if err != nil { + return fmt.Errorf("failed to check if mx_version table exists: %w", err) + } else if !dbExists { + needsStateResync = false + br.Bridge.DB.KV.Set(ctx, database.KeyEncryptionStateResynced, "true") + } + } + err = br.StateStore.Upgrade(ctx) + if err != nil { + return bridgev2.DBUpgradeError{Section: "matrix_state", Err: err} + } + if br.Config.Homeserver.Websocket || len(br.Config.Homeserver.WSProxy) > 0 { + br.Websocket = true + br.Log.Debug().Msg("Starting appservice websocket") + var wg sync.WaitGroup + wg.Add(1) + br.wsStartupWait = &wg + br.wsShortCircuitReconnectBackoff = make(chan struct{}) + go br.startWebsocket(&wg) + } else if br.AS.Host.IsConfigured() { + br.Log.Debug().Msg("Starting appservice HTTP server") + go br.AS.Start() + } else { + br.Log.WithLevel(zerolog.FatalLevel).Msg("Neither appservice HTTP listener nor websocket is enabled") + os.Exit(23) + } + + br.Log.Debug().Msg("Checking connection to homeserver") + br.ensureConnection(ctx) + go br.fetchMediaConfig(ctx) + if br.Crypto != nil { + err = br.Crypto.Init(ctx) + if err != nil { + return err + } + } + br.EventProcessor.Start(ctx) + go br.UpdateBotProfile(ctx) + if br.Crypto != nil { + go br.Crypto.Start() + } + parsed, _ := url.Parse(br.Bridge.Network.GetName().NetworkURL) + if parsed != nil { + br.deterministicEventIDServer = strings.TrimPrefix(parsed.Hostname(), "www.") + } + br.AS.Ready = true + if br.Websocket && br.Config.Homeserver.WSPingInterval > 0 { + br.wsStopPinger = make(chan struct{}, 1) + go br.websocketServerPinger() + } + if needsStateResync { + br.ResyncEncryptionState(ctx) + } + return nil +} + +func (br *Connector) ResyncEncryptionState(ctx context.Context) { + log := zerolog.Ctx(ctx) + roomIDScanner := dbutil.ConvertRowFn[id.RoomID](dbutil.ScanSingleColumn[id.RoomID]) + rooms, err := roomIDScanner.NewRowIter(br.Bridge.DB.Query(ctx, ` + SELECT rooms.room_id + FROM (SELECT DISTINCT(room_id) FROM mx_user_profile WHERE room_id<>'') rooms + LEFT JOIN mx_room_state ON rooms.room_id = mx_room_state.room_id + WHERE mx_room_state.encryption IS NULL + `)).AsList() + if err != nil { + log.Err(err).Msg("Failed to get room list to resync state") + return + } + var failedCount, successCount, forbiddenCount int + for _, roomID := range rooms { + if roomID == "" { + continue + } + var outContent *event.EncryptionEventContent + err = br.Bot.Client.StateEvent(ctx, roomID, event.StateEncryption, "", &outContent) + if errors.Is(err, mautrix.MForbidden) { + // Most likely non-existent room + log.Debug().Err(err).Stringer("room_id", roomID).Msg("Failed to get state for room") + forbiddenCount++ + } else if err != nil { + log.Err(err).Stringer("room_id", roomID).Msg("Failed to get state for room") + failedCount++ + } else { + successCount++ + } + } + br.Bridge.DB.KV.Set(ctx, database.KeyEncryptionStateResynced, "true") + log.Info(). + Int("success_count", successCount). + Int("forbidden_count", forbiddenCount). + Int("failed_count", failedCount). + Msg("Resynced rooms") +} + +func (br *Connector) GetPublicAddress() string { + if br.Config.AppService.PublicAddress == "https://bridge.example.com" { + return "" + } + return strings.TrimRight(br.Config.AppService.PublicAddress, "/") +} + +func (br *Connector) GetRouter() *http.ServeMux { + if br.GetPublicAddress() != "" { + return br.AS.Router + } + return nil +} + +func (br *Connector) GetCapabilities() *bridgev2.MatrixCapabilities { + return br.Capabilities +} + +func sendStopSignal(ch chan struct{}) { + if ch != nil { + select { + case ch <- struct{}{}: + default: + } + } +} + +func (br *Connector) PreStop() { + br.stopping = true + br.AS.Stop() + if stopWebsocket := br.AS.StopWebsocket; stopWebsocket != nil { + stopWebsocket(appservice.ErrWebsocketManualStop) + } + sendStopSignal(br.wsStopPinger) + sendStopSignal(br.wsShortCircuitReconnectBackoff) +} + +func (br *Connector) Stop() { + br.EventProcessor.Stop() + if br.Crypto != nil { + br.Crypto.Stop() + } + if wsStopChan := br.wsStopped; wsStopChan != nil { + select { + case <-wsStopChan: + case <-time.After(4 * time.Second): + br.Log.Warn().Msg("Timed out waiting for websocket to close") + } + } +} + +var MinSpecVersion = mautrix.SpecV14 + +func (br *Connector) logInitialRequestError(err error, defaultMessage string) { + if errors.Is(err, mautrix.MUnknownToken) { + br.Log.WithLevel(zerolog.FatalLevel).Msg("The as_token was not accepted. Is the registration file installed in your homeserver correctly?") + br.Log.Info().Msg("See https://docs.mau.fi/faq/as-token for more info") + } else if errors.Is(err, mautrix.MExclusive) { + br.Log.WithLevel(zerolog.FatalLevel).Msg("The as_token was accepted, but the /register request was not. Are the homeserver domain, bot username and username template in the config correct, and do they match the values in the registration?") + br.Log.Info().Msg("See https://docs.mau.fi/faq/as-register for more info") + } else { + br.Log.WithLevel(zerolog.FatalLevel).Err(err).Msg(defaultMessage) + } +} + +func (br *Connector) ensureConnection(ctx context.Context) { + triedToRegister := false + for { + versions, err := br.Bot.Versions(ctx) + if err != nil { + if errors.Is(err, mautrix.MForbidden) && !triedToRegister { + br.Log.Debug().Msg("M_FORBIDDEN in /versions, trying to register before retrying") + err = br.Bot.EnsureRegistered(ctx) + if err != nil { + br.logInitialRequestError(err, "Failed to register after /versions failed with M_FORBIDDEN") + os.Exit(16) + } + triedToRegister = true + } else if errors.Is(err, mautrix.MUnknownToken) || errors.Is(err, mautrix.MExclusive) { + br.logInitialRequestError(err, "/versions request failed with auth error") + os.Exit(16) + } else { + br.Log.Err(err).Msg("Failed to connect to homeserver, retrying in 10 seconds...") + time.Sleep(10 * time.Second) + } + } else { + br.SpecVersions = versions + *br.AS.SpecVersions = *versions + br.Capabilities.AutoJoinInvites = br.SpecVersions.Supports(mautrix.BeeperFeatureAutojoinInvites) + br.Capabilities.BatchSending = br.SpecVersions.Supports(mautrix.BeeperFeatureBatchSending) + br.Capabilities.ArbitraryMemberChange = br.SpecVersions.Supports(mautrix.BeeperFeatureArbitraryMemberChange) + br.Capabilities.ExtraProfileMeta = br.SpecVersions.Supports(mautrix.BeeperFeatureArbitraryProfileMeta) || + (br.SpecVersions.Supports(mautrix.FeatureArbitraryProfileFields) && br.Config.Matrix.GhostExtraProfileInfo) + break + } + } + + unsupportedServerLogLevel := zerolog.FatalLevel + if br.IgnoreUnsupportedServer { + unsupportedServerLogLevel = zerolog.ErrorLevel + } + if br.Config.Homeserver.Software == bridgeconfig.SoftwareHungry && !br.SpecVersions.Supports(mautrix.BeeperFeatureHungry) { + br.Log.WithLevel(zerolog.FatalLevel).Msg("The config claims the homeserver is hungryserv, but the /versions response didn't confirm it") + os.Exit(18) + } else if !br.SpecVersions.ContainsGreaterOrEqual(MinSpecVersion) { + br.Log.WithLevel(unsupportedServerLogLevel). + Stringer("server_supports", br.SpecVersions.GetLatest()). + Stringer("bridge_requires", MinSpecVersion). + Msg("The homeserver is outdated (supported spec versions are below minimum required by bridge)") + if !br.IgnoreUnsupportedServer { + os.Exit(18) + } + } + + resp, err := br.Bot.Whoami(ctx) + if err != nil { + br.logInitialRequestError(err, "/whoami request failed with unknown error") + os.Exit(16) + } else if resp.UserID != br.Bot.UserID { + br.Log.WithLevel(zerolog.FatalLevel). + Stringer("got_user_id", resp.UserID). + Stringer("expected_user_id", br.Bot.UserID). + Msg("Unexpected user ID in whoami call") + os.Exit(17) + } + + if br.Websocket { + br.Log.Debug().Msg("Websocket mode: no need to check status of homeserver -> bridge connection") + return + } else if !br.SpecVersions.Supports(mautrix.FeatureAppservicePing) { + br.Log.Debug().Msg("Homeserver does not support checking status of homeserver -> bridge connection") + return + } + + br.Bot.EnsureAppserviceConnection(ctx) +} + +func (br *Connector) fetchCapabilities(ctx context.Context) *mautrix.RespCapabilities { + br.specCapsLock.Lock() + defer br.specCapsLock.Unlock() + if br.SpecCaps != nil { + return br.SpecCaps + } + caps, err := br.Bot.Capabilities(ctx) + if err != nil { + br.Log.Err(err).Msg("Failed to fetch capabilities from homeserver") + return nil + } + br.SpecCaps = caps + return caps +} + +func (br *Connector) fetchMediaConfig(ctx context.Context) { + cfg, err := br.Bot.GetMediaConfig(ctx) + if err != nil { + br.Log.Warn().Err(err).Msg("Failed to fetch media config") + } else { + if cfg.UploadSize == 0 { + cfg.UploadSize = 50 * 1024 * 1024 + } + br.MediaConfig = *cfg + mfsn, ok := br.Bridge.Network.(bridgev2.MaxFileSizeingNetwork) + if ok { + mfsn.SetMaxFileSize(br.MediaConfig.UploadSize) + } + br.uploadSema = semaphore.NewWeighted(br.MediaConfig.UploadSize + 1) + } +} + +func (br *Connector) UpdateBotProfile(ctx context.Context) { + br.Log.Debug().Msg("Updating bot profile") + botConfig := &br.Config.AppService.Bot + + var err error + var mxc id.ContentURI + if botConfig.Avatar == "remove" { + err = br.Bot.SetAvatarURL(ctx, mxc) + } else if !botConfig.ParsedAvatar.IsEmpty() { + err = br.Bot.SetAvatarURL(ctx, botConfig.ParsedAvatar) + } + if err != nil { + br.Log.Warn().Err(err).Msg("Failed to update bot avatar") + } + + if botConfig.Displayname == "remove" { + err = br.Bot.SetDisplayName(ctx, "") + } else if len(botConfig.Displayname) > 0 { + err = br.Bot.SetDisplayName(ctx, botConfig.Displayname) + } + if err != nil { + br.Log.Warn().Err(err).Msg("Failed to update bot displayname") + } + + if br.SpecVersions.Supports(mautrix.BeeperFeatureArbitraryProfileMeta) { + br.Log.Debug().Msg("Setting contact info on the appservice bot") + netName := br.Bridge.Network.GetName() + err = br.Bot.BeeperUpdateProfile(ctx, event.BeeperProfileExtra{ + Service: netName.BeeperBridgeType, + Network: netName.NetworkID, + IsBridgeBot: true, + }) + if err != nil { + br.Log.Warn().Err(err).Msg("Failed to update bot contact info") + } + } +} + +func (br *Connector) GhostIntent(userID networkid.UserID) bridgev2.MatrixAPI { + return &ASIntent{ + Matrix: br.AS.Intent(br.FormatGhostMXID(userID)), + Connector: br, + } +} + +func (br *Connector) SendBridgeStatus(ctx context.Context, state *status.BridgeState) error { + if br.Websocket { + br.hasSentAnyStates = true + return br.AS.SendWebsocket(ctx, &appservice.WebsocketRequest{ + Command: "bridge_status", + Data: state, + }) + } else if br.Config.Homeserver.StatusEndpoint != "" { + // Connecting states aren't really relevant unless the bridge runs somewhere with an unreliable network + if state.StateEvent == status.StateConnecting { + return nil + } + return state.SendHTTP(ctx, br.Config.Homeserver.StatusEndpoint, br.Config.AppService.ASToken) + } else { + return nil + } +} + +func (br *Connector) SendMessageStatus(ctx context.Context, ms *bridgev2.MessageStatus, evt *bridgev2.MessageStatusEventInfo) { + go br.internalSendMessageStatus(ctx, ms, evt, "") +} + +func (br *Connector) internalSendMessageStatus(ctx context.Context, ms *bridgev2.MessageStatus, evt *bridgev2.MessageStatusEventInfo, editEvent id.EventID) id.EventID { + if evt.EventType.IsEphemeral() || evt.SourceEventID == "" { + return "" + } + log := zerolog.Ctx(ctx) + + if !evt.IsSourceEventDoublePuppeted { + err := br.SendMessageCheckpoints(ctx, []*status.MessageCheckpoint{ms.ToCheckpoint(evt)}) + if err != nil { + log.Err(err).Msg("Failed to send message checkpoint") + } + } + + if !ms.DisableMSS && br.Config.Matrix.MessageStatusEvents { + mssEvt := ms.ToMSSEvent(evt) + _, err := br.Bot.SendMessageEvent(ctx, evt.RoomID, event.BeeperMessageStatus, mssEvt) + if err != nil { + log.Err(err). + Stringer("room_id", evt.RoomID). + Stringer("event_id", evt.SourceEventID). + Any("mss_content", mssEvt). + Msg("Failed to send MSS event") + } + } + if ms.SendNotice && br.Config.Matrix.MessageErrorNotices && evt.MessageType != event.MsgNotice && + (ms.Status == event.MessageStatusFail || ms.Status == event.MessageStatusRetriable || ms.Step == status.MsgStepDecrypted) { + content := ms.ToNoticeEvent(evt) + if editEvent != "" { + content.SetEdit(editEvent) + } + resp, err := br.Bot.SendMessageEvent(ctx, evt.RoomID, event.EventMessage, content) + if err != nil { + log.Err(err). + Stringer("room_id", evt.RoomID). + Stringer("event_id", evt.SourceEventID). + Str("notice_message", content.Body). + Msg("Failed to send notice event") + } else { + return resp.EventID + } + } + if ms.Status == event.MessageStatusSuccess && br.Config.Matrix.DeliveryReceipts { + err := br.Bot.SendReceipt(ctx, evt.RoomID, evt.SourceEventID, event.ReceiptTypeRead, nil) + if err != nil { + log.Err(err). + Stringer("room_id", evt.RoomID). + Stringer("event_id", evt.SourceEventID). + Msg("Failed to send Matrix delivery receipt") + } + } + return "" +} + +func (br *Connector) SendMessageCheckpoints(ctx context.Context, checkpoints []*status.MessageCheckpoint) error { + checkpointsJSON := status.CheckpointsJSON{Checkpoints: checkpoints} + + if br.Websocket { + return br.AS.SendWebsocket(ctx, &appservice.WebsocketRequest{ + Command: "message_checkpoint", + Data: checkpointsJSON, + }) + } + + endpoint := br.Config.Homeserver.MessageSendCheckpointEndpoint + if endpoint == "" { + return nil + } + + return checkpointsJSON.SendHTTP(ctx, br.AS.HTTPClient, endpoint, br.AS.Registration.AppToken) +} + +func (br *Connector) ParseGhostMXID(userID id.UserID) (networkid.UserID, bool) { + match := br.userIDRegex.FindStringSubmatch(string(userID)) + if match == nil || userID == br.Bot.UserID { + return "", false + } + decoded, err := id.DecodeUserLocalpart(match[1]) + if err != nil { + return "", false + } + return networkid.UserID(decoded), true +} + +func (br *Connector) FormatGhostMXID(userID networkid.UserID) id.UserID { + localpart := br.Config.AppService.FormatUsername(id.EncodeUserLocalpart(string(userID))) + return id.NewUserID(localpart, br.Config.Homeserver.Domain) +} + +func (br *Connector) NewUserIntent(ctx context.Context, userID id.UserID, accessToken string) (bridgev2.MatrixAPI, string, error) { + intent, newToken, err := br.DoublePuppet.Setup(ctx, userID, accessToken) + if err != nil { + if errors.Is(err, ErrNoAccessToken) { + err = nil + } + return nil, accessToken, err + } + br.doublePuppetIntents.Set(userID, intent) + return &ASIntent{Connector: br, Matrix: intent}, newToken, nil +} + +func (br *Connector) BotIntent() bridgev2.MatrixAPI { + return &ASIntent{Connector: br, Matrix: br.Bot} +} + +func (br *Connector) GetPowerLevels(ctx context.Context, roomID id.RoomID) (*event.PowerLevelsEventContent, error) { + return br.Bot.PowerLevels(ctx, roomID) +} + +func (br *Connector) GetStateEvent(ctx context.Context, roomID id.RoomID, eventType event.Type, stateKey string) (*event.Event, error) { + if stateKey == "" { + switch eventType { + case event.StateCreate: + createEvt, err := br.Bot.StateStore.GetCreate(ctx, roomID) + if err != nil || createEvt != nil { + return createEvt, err + } + case event.StateJoinRules: + joinRulesContent, err := br.Bot.StateStore.GetJoinRules(ctx, roomID) + if err != nil { + return nil, err + } else if joinRulesContent != nil { + return &event.Event{ + Type: event.StateJoinRules, + RoomID: roomID, + StateKey: ptr.Ptr(""), + Content: event.Content{Parsed: joinRulesContent}, + }, nil + } + } + } + return br.Bot.FullStateEvent(ctx, roomID, eventType, stateKey) +} + +func (br *Connector) GetMembers(ctx context.Context, roomID id.RoomID) (map[id.UserID]*event.MemberEventContent, error) { + fetched, err := br.Bot.StateStore.HasFetchedMembers(ctx, roomID) + if err != nil { + return nil, err + } else if fetched { + return br.Bot.StateStore.GetAllMembers(ctx, roomID) + } + members, err := br.Bot.Members(ctx, roomID) + if err != nil { + return nil, err + } + output := make(map[id.UserID]*event.MemberEventContent, len(members.Chunk)) + for _, evt := range members.Chunk { + output[id.UserID(evt.GetStateKey())] = evt.Content.AsMember() + } + return output, nil +} + +func (br *Connector) GetMemberInfo(ctx context.Context, roomID id.RoomID, userID id.UserID) (*event.MemberEventContent, error) { + // TODO fetch from network sometimes? + return br.AS.StateStore.GetMember(ctx, roomID, userID) +} + +func (br *Connector) IsConfusableName(ctx context.Context, roomID id.RoomID, userID id.UserID, name string) ([]id.UserID, error) { + return br.AS.StateStore.IsConfusableName(ctx, roomID, userID, name) +} + +func (br *Connector) GetUniqueBridgeID() string { + return fmt.Sprintf("%s/%s", br.Config.Homeserver.Domain, br.Config.AppService.ID) +} + +func (br *Connector) BatchSend(ctx context.Context, roomID id.RoomID, req *mautrix.ReqBeeperBatchSend, extras []*bridgev2.MatrixSendExtra) (*mautrix.RespBeeperBatchSend, error) { + if encrypted, err := br.StateStore.IsEncrypted(ctx, roomID); err != nil { + return nil, fmt.Errorf("failed to check if room is encrypted: %w", err) + } else if encrypted { + for _, evt := range req.Events { + intent, _ := br.doublePuppetIntents.Get(evt.Sender) + if intent != nil { + intent.AddDoublePuppetValueWithTS(&evt.Content, evt.Timestamp) + } + if evt.Type != event.EventEncrypted && evt.Type != event.EventReaction { + err = br.Crypto.Encrypt(ctx, roomID, evt.Type, &evt.Content) + if err != nil { + return nil, err + } + evt.Type = event.EventEncrypted + if intent != nil { + intent.AddDoublePuppetValueWithTS(&evt.Content, evt.Timestamp) + } + } + } + } + return br.Bot.BeeperBatchSend(ctx, roomID, req) +} + +func (br *Connector) GenerateDeterministicEventID(roomID id.RoomID, _ networkid.PortalKey, messageID networkid.MessageID, partID networkid.PartID) id.EventID { + data := make([]byte, 0, len(roomID)+1+len(messageID)+1+len(partID)) + data = append(data, roomID...) + data = append(data, 0) + data = append(data, messageID...) + data = append(data, 0) + data = append(data, partID...) + + hash := sha256.Sum256(data) + hashB64Len := base64.RawURLEncoding.EncodedLen(len(hash)) + + eventID := make([]byte, 1+hashB64Len+1+len(br.deterministicEventIDServer)) + eventID[0] = '$' + base64.RawURLEncoding.Encode(eventID[1:1+hashB64Len], hash[:]) + eventID[1+hashB64Len] = ':' + copy(eventID[1+hashB64Len+1:], br.deterministicEventIDServer) + + return id.EventID(exbytes.UnsafeString(eventID)) +} + +func (br *Connector) GenerateDeterministicRoomID(key networkid.PortalKey) id.RoomID { + return id.RoomID(fmt.Sprintf("!%s.%s:%s", key.ID, key.Receiver, br.ServerName())) +} + +func (br *Connector) GenerateReactionEventID(roomID id.RoomID, targetMessage *database.Message, sender networkid.UserID, emojiID networkid.EmojiID) id.EventID { + // We don't care about determinism for reactions + return id.EventID(fmt.Sprintf("$%s:%s", base64.RawURLEncoding.EncodeToString(random.Bytes(32)), br.deterministicEventIDServer)) +} + +func (br *Connector) ServerName() string { + return br.Config.Homeserver.Domain +} + +func (br *Connector) HandleNewlyBridgedRoom(ctx context.Context, roomID id.RoomID) error { + _, err := br.Bot.Members(ctx, roomID) + if err != nil { + zerolog.Ctx(ctx).Err(err).Msg("Failed to fetch members in newly bridged room") + } + if !br.Config.Encryption.Default { + return nil + } + _, err = br.Bot.SendStateEvent(ctx, roomID, event.StateEncryption, "", &event.Content{ + Parsed: br.getDefaultEncryptionEvent(), + }) + if err != nil { + zerolog.Ctx(ctx).Err(err).Msg("Failed to enable encryption in newly bridged room") + return fmt.Errorf("failed to enable encryption") + } + return nil +} + +func (br *Connector) GetURLPreview(ctx context.Context, url string) (*event.LinkPreview, error) { + return br.Bot.GetURLPreview(ctx, url) +} diff --git a/bridge/crypto.go b/bridgev2/matrix/crypto.go similarity index 70% rename from bridge/crypto.go rename to bridgev2/matrix/crypto.go index f0b90056..7f18f1f5 100644 --- a/bridge/crypto.go +++ b/bridgev2/matrix/crypto.go @@ -6,7 +6,7 @@ //go:build cgo && !nocrypto -package bridge +package matrix import ( "context" @@ -14,14 +14,17 @@ import ( "fmt" "os" "runtime/debug" + "strings" "sync" "time" + "github.com/lib/pq" "github.com/rs/zerolog" "go.mau.fi/util/dbutil" "maunium.net/go/mautrix" - "maunium.net/go/mautrix/bridge/bridgeconfig" + "maunium.net/go/mautrix/bridgev2" + "maunium.net/go/mautrix/bridgev2/database" "maunium.net/go/mautrix/crypto" "maunium.net/go/mautrix/crypto/olm" "maunium.net/go/mautrix/event" @@ -29,14 +32,18 @@ import ( "maunium.net/go/mautrix/sqlstatestore" ) +func init() { + crypto.PostgresArrayWrapper = pq.Array +} + var _ crypto.StateStore = (*sqlstatestore.SQLStateStore)(nil) -var NoSessionFound = crypto.NoSessionFound -var DuplicateMessageIndex = crypto.DuplicateMessageIndex -var UnknownMessageIndex = olm.UnknownMessageIndex +var NoSessionFound = crypto.ErrNoSessionFound +var DuplicateMessageIndex = crypto.ErrDuplicateMessageIndex +var UnknownMessageIndex = olm.ErrUnknownMessageIndex type CryptoHelper struct { - bridge *Bridge + bridge *Connector client *mautrix.Client mach *crypto.OlmMachine store *SQLCryptoStore @@ -49,35 +56,36 @@ type CryptoHelper struct { cancelPeriodicDeleteLoop func() } -func NewCryptoHelper(bridge *Bridge) Crypto { - if !bridge.Config.Bridge.GetEncryptionConfig().Allow { - bridge.ZLog.Debug().Msg("Bridge built with end-to-bridge encryption, but disabled in config") +func NewCryptoHelper(c *Connector) Crypto { + if !c.Config.Encryption.Allow { + c.Log.Debug().Msg("Bridge built with end-to-bridge encryption, but disabled in config") return nil } - log := bridge.ZLog.With().Str("component", "crypto").Logger() + log := c.Log.With().Str("component", "crypto").Logger() return &CryptoHelper{ - bridge: bridge, + bridge: c, log: &log, } } func (helper *CryptoHelper) Init(ctx context.Context) error { - if len(helper.bridge.CryptoPickleKey) == 0 { + if len(helper.bridge.Config.Encryption.PickleKey) == 0 { panic("CryptoPickleKey not set") } helper.log.Debug().Msg("Initializing end-to-bridge encryption...") helper.store = NewSQLCryptoStore( - helper.bridge.DB, - dbutil.ZeroLogger(helper.bridge.ZLog.With().Str("db_section", "crypto").Logger()), + helper.bridge.Bridge.DB.Database, + dbutil.ZeroLogger(helper.bridge.Log.With().Str("db_section", "crypto").Logger()), + string(helper.bridge.Bridge.ID), helper.bridge.AS.BotMXID(), - fmt.Sprintf("@%s:%s", helper.bridge.Config.Bridge.FormatUsername("%"), helper.bridge.AS.HomeserverDomain), - helper.bridge.CryptoPickleKey, + fmt.Sprintf("@%s:%s", strings.ReplaceAll(helper.bridge.Config.AppService.FormatUsername("%"), "_", `\_`), helper.bridge.AS.HomeserverDomain), + helper.bridge.Config.Encryption.PickleKey, ) err := helper.store.DB.Upgrade(ctx) if err != nil { - helper.bridge.LogDBUpgradeErrorAndExit("crypto", err) + return bridgev2.DBUpgradeError{Section: "crypto", Err: err} } var isExistingDevice bool @@ -89,11 +97,11 @@ func (helper *CryptoHelper) Init(ctx context.Context) error { helper.log.Debug(). Str("device_id", helper.client.DeviceID.String()). Msg("Logged in as bridge bot") - stateStore := &cryptoStateStore{helper.bridge} - helper.mach = crypto.NewOlmMachine(helper.client, helper.log, helper.store, stateStore) + helper.mach = crypto.NewOlmMachine(helper.client, helper.log, helper.store, helper.bridge.StateStore) + helper.mach.DisableSharedGroupSessionTracking = true helper.mach.AllowKeyShare = helper.allowKeyShare - encryptionConfig := helper.bridge.Config.Bridge.GetEncryptionConfig() + encryptionConfig := helper.bridge.Config.Encryption helper.mach.SendKeysMinTrust = encryptionConfig.VerificationLevels.Receive helper.mach.PlaintextMentions = encryptionConfig.PlaintextMentions @@ -128,7 +136,19 @@ func (helper *CryptoHelper) Init(ctx context.Context) error { return err } if isExistingDevice { - helper.verifyKeysAreOnServer(ctx) + if !helper.verifyKeysAreOnServer(ctx) { + return nil + } + } else { + err = helper.ShareKeys(ctx) + if err != nil { + return fmt.Errorf("failed to share device keys: %w", err) + } + } + if helper.bridge.Config.Encryption.SelfSign { + if !helper.doSelfSign(ctx) { + os.Exit(34) + } } go helper.resyncEncryptionInfo(context.TODO()) @@ -136,30 +156,66 @@ func (helper *CryptoHelper) Init(ctx context.Context) error { return nil } +func (helper *CryptoHelper) doSelfSign(ctx context.Context) bool { + log := zerolog.Ctx(ctx) + hasKeys, isVerified, err := helper.mach.GetOwnVerificationStatus(ctx) + if err != nil { + log.WithLevel(zerolog.FatalLevel).Err(err).Msg("Failed to check verification status") + return false + } + log.Debug().Bool("has_keys", hasKeys).Bool("is_verified", isVerified).Msg("Checked verification status") + keyInDB := helper.bridge.Bridge.DB.KV.Get(ctx, database.KeyRecoveryKey) + if !hasKeys || keyInDB == "overwrite" { + if keyInDB != "" && keyInDB != "overwrite" { + log.WithLevel(zerolog.FatalLevel). + Msg("No keys on server, but database already has recovery key. Delete `recovery_key` from `kv_store` manually to continue.") + return false + } + recoveryKey, err := helper.mach.GenerateAndVerifyWithRecoveryKey(ctx) + if recoveryKey != "" { + helper.bridge.Bridge.DB.KV.Set(ctx, database.KeyRecoveryKey, recoveryKey) + } + if err != nil { + log.WithLevel(zerolog.FatalLevel).Err(err).Msg("Failed to generate recovery key and self-sign") + return false + } + log.Info().Msg("Generated new recovery key and self-signed bot device") + } else if !isVerified { + if keyInDB == "" { + log.WithLevel(zerolog.FatalLevel). + Msg("Server already has cross-signing keys, but no key in database. Add `recovery_key` to `kv_store`, or set it to `overwrite` to generate new keys.") + return false + } + err = helper.mach.VerifyWithRecoveryKey(ctx, keyInDB) + if err != nil { + log.WithLevel(zerolog.FatalLevel).Err(err).Msg("Failed to verify with recovery key") + return false + } + log.Info().Msg("Verified bot device with existing recovery key") + } + return true +} + func (helper *CryptoHelper) resyncEncryptionInfo(ctx context.Context) { log := helper.log.With().Str("action", "resync encryption event").Logger() - rows, err := helper.bridge.DB.Query(ctx, `SELECT room_id FROM mx_room_state WHERE encryption='{"resync":true}'`) + rows, err := helper.store.DB.Query(ctx, `SELECT room_id FROM mx_room_state WHERE encryption='{"resync":true}'`) + roomIDs, err := dbutil.NewRowIterWithError(rows, dbutil.ScanSingleColumn[id.RoomID], err).AsList() if err != nil { log.Err(err).Msg("Failed to query rooms for resync") return } - roomIDs, err := dbutil.NewRowIter(rows, dbutil.ScanSingleColumn[id.RoomID]).AsList() - if err != nil { - log.Err(err).Msg("Failed to scan rooms for resync") - return - } if len(roomIDs) > 0 { log.Debug().Interface("room_ids", roomIDs).Msg("Resyncing rooms") for _, roomID := range roomIDs { var evt event.EncryptionEventContent err = helper.client.StateEvent(ctx, roomID, event.StateEncryption, "", &evt) if err != nil { - log.Err(err).Str("room_id", roomID.String()).Msg("Failed to get encryption event") - _, err = helper.bridge.DB.Exec(ctx, ` + log.Err(err).Stringer("room_id", roomID).Msg("Failed to get encryption event") + _, err = helper.store.DB.Exec(ctx, ` UPDATE mx_room_state SET encryption=NULL WHERE room_id=$1 AND encryption='{"resync":true}' `, roomID) if err != nil { - log.Err(err).Str("room_id", roomID.String()).Msg("Failed to unmark room for resync after failed sync") + log.Err(err).Stringer("room_id", roomID).Msg("Failed to unmark room for resync after failed sync") } } else { maxAge := evt.RotationPeriodMillis @@ -176,15 +232,15 @@ func (helper *CryptoHelper) resyncEncryptionInfo(ctx context.Context) { Int("max_messages", maxMessages). Interface("content", &evt). Msg("Resynced encryption event") - _, err = helper.bridge.DB.Exec(ctx, ` + _, err = helper.store.DB.Exec(ctx, ` UPDATE crypto_megolm_inbound_session SET max_age=$1, max_messages=$2 WHERE room_id=$3 AND max_age IS NULL AND max_messages IS NULL `, maxAge, maxMessages, roomID) if err != nil { - log.Err(err).Str("room_id", roomID.String()).Msg("Failed to update megolm session table") + log.Err(err).Stringer("room_id", roomID).Msg("Failed to update megolm session table") } else { - log.Debug().Str("room_id", roomID.String()).Msg("Updated megolm session table") + log.Debug().Stringer("room_id", roomID).Msg("Updated megolm session table") } } } @@ -192,22 +248,31 @@ func (helper *CryptoHelper) resyncEncryptionInfo(ctx context.Context) { } func (helper *CryptoHelper) allowKeyShare(ctx context.Context, device *id.Device, info event.RequestedKeyInfo) *crypto.KeyShareRejection { - cfg := helper.bridge.Config.Bridge.GetEncryptionConfig() + cfg := helper.bridge.Config.Encryption if !cfg.AllowKeySharing { return &crypto.KeyShareRejectNoResponse } else if device.Trust == id.TrustStateBlacklisted { return &crypto.KeyShareRejectBlacklisted - } else if trustState := helper.mach.ResolveTrust(device); trustState >= cfg.VerificationLevels.Share { - portal := helper.bridge.Child.GetIPortal(info.RoomID) - if portal == nil { + } else if trustState, _ := helper.mach.ResolveTrustContext(ctx, device); trustState >= cfg.VerificationLevels.Share { + portal, err := helper.bridge.Bridge.GetPortalByMXID(ctx, info.RoomID) + if err != nil { + zerolog.Ctx(ctx).Err(err).Msg("Failed to get portal to handle key request") + return &crypto.KeyShareRejectNoResponse + } else if portal == nil { zerolog.Ctx(ctx).Debug().Msg("Rejecting key request: room is not a portal") return &crypto.KeyShareRejection{Code: event.RoomKeyWithheldUnavailable, Reason: "Requested room is not a portal room"} } - user := helper.bridge.Child.GetIUser(device.UserID, true) - // FIXME reimplement IsInPortal - if user.GetPermissionLevel() < bridgeconfig.PermissionLevelAdmin /*&& !user.IsInPortal(portal.Key)*/ { - zerolog.Ctx(ctx).Debug().Msg("Rejecting key request: user is not in portal") - return &crypto.KeyShareRejection{Code: event.RoomKeyWithheldUnauthorized, Reason: "You're not in that portal"} + user, err := helper.bridge.Bridge.GetExistingUserByMXID(ctx, device.UserID) + if err != nil { + zerolog.Ctx(ctx).Err(err).Msg("Failed to get user to handle key request") + return &crypto.KeyShareRejectNoResponse + } else if user == nil { + zerolog.Ctx(ctx).Debug().Msg("Couldn't find user to handle key request") + return &crypto.KeyShareRejectNoResponse + } else if !user.Permissions.Admin { + zerolog.Ctx(ctx).Debug().Msg("Rejecting key request: user is not admin") + // TODO is in room check? + return &crypto.KeyShareRejection{Code: event.RoomKeyWithheldUnauthorized, Reason: "Key sharing for non-admins is not yet implemented"} } zerolog.Ctx(ctx).Debug().Msg("Accepting key request") return nil @@ -221,27 +286,39 @@ func (helper *CryptoHelper) loginBot(ctx context.Context) (*mautrix.Client, bool if err != nil { return nil, false, fmt.Errorf("failed to find existing device ID: %w", err) } else if len(deviceID) > 0 { - helper.log.Debug().Str("device_id", deviceID.String()).Msg("Found existing device ID for bot in database") + helper.log.Debug().Stringer("device_id", deviceID).Msg("Found existing device ID for bot in database") } // Create a new client instance with the default AS settings (including as_token), // the Login call will then override the access token in the client. client := helper.bridge.AS.NewMautrixClient(helper.bridge.AS.BotMXID()) + + initialDeviceDisplayName := fmt.Sprintf("%s bridge", helper.bridge.Bridge.Network.GetName().DisplayName) + if helper.bridge.Config.Encryption.MSC4190 { + helper.log.Debug().Msg("Creating bot device with MSC4190") + err = client.CreateDeviceMSC4190(ctx, deviceID, initialDeviceDisplayName) + if err != nil { + return nil, deviceID != "", fmt.Errorf("failed to create device for bridge bot: %w", err) + } + helper.store.DeviceID = client.DeviceID + return client, deviceID != "", nil + } + flows, err := client.GetLoginFlows(ctx) if err != nil { return nil, deviceID != "", fmt.Errorf("failed to get supported login flows: %w", err) } else if !flows.HasFlow(mautrix.AuthTypeAppservice) { return nil, deviceID != "", fmt.Errorf("homeserver does not support appservice login") } + resp, err := client.Login(ctx, &mautrix.ReqLogin{ Type: mautrix.AuthTypeAppservice, Identifier: mautrix.UserIdentifier{ Type: mautrix.IdentifierTypeUser, User: string(helper.bridge.AS.BotMXID()), }, - DeviceID: deviceID, - StoreCredentials: true, - - InitialDeviceDisplayName: fmt.Sprintf("%s bridge", helper.bridge.ProtocolName), + DeviceID: deviceID, + StoreCredentials: true, + InitialDeviceDisplayName: initialDeviceDisplayName, }) if err != nil { return nil, deviceID != "", fmt.Errorf("failed to log in as bridge bot: %w", err) @@ -250,7 +327,7 @@ func (helper *CryptoHelper) loginBot(ctx context.Context) (*mautrix.Client, bool return client, deviceID != "", nil } -func (helper *CryptoHelper) verifyKeysAreOnServer(ctx context.Context) { +func (helper *CryptoHelper) verifyKeysAreOnServer(ctx context.Context) bool { helper.log.Debug().Msg("Making sure keys are still on server") resp, err := helper.client.QueryKeys(ctx, &mautrix.ReqQueryKeys{ DeviceKeys: map[id.UserID]mautrix.DeviceIDList{ @@ -263,14 +340,15 @@ func (helper *CryptoHelper) verifyKeysAreOnServer(ctx context.Context) { } device, ok := resp.DeviceKeys[helper.client.UserID][helper.client.DeviceID] if ok && len(device.Keys) > 0 { - return + return true } helper.log.Warn().Msg("Existing device doesn't have keys on server, resetting crypto") helper.Reset(ctx, false) + return false } func (helper *CryptoHelper) Start() { - if helper.bridge.Config.Bridge.GetEncryptionConfig().Appservice { + if helper.bridge.Config.Encryption.Appservice { helper.log.Debug().Msg("End-to-bridge encryption is in appservice mode, registering event listeners and not starting syncer") helper.bridge.AS.Registration.EphemeralEvents = true helper.mach.AddAppserviceListener(helper.bridge.EventProcessor) @@ -361,7 +439,7 @@ func (helper *CryptoHelper) Encrypt(ctx context.Context, roomID id.RoomID, evtTy var encrypted *event.EncryptedEventContent encrypted, err = helper.mach.EncryptMegolmEvent(ctx, roomID, evtType, content) if err != nil { - if !errors.Is(err, crypto.SessionExpired) && !errors.Is(err, crypto.SessionNotShared) && !errors.Is(err, crypto.NoGroupSession) { + if !errors.Is(err, crypto.ErrSessionExpired) && !errors.Is(err, crypto.ErrSessionNotShared) && !errors.Is(err, crypto.ErrNoGroupSession) { return } helper.log.Debug().Err(err). @@ -476,36 +554,14 @@ func (syncer *cryptoSyncer) OnFailedSync(_ *mautrix.RespSync, err error) (time.D func (syncer *cryptoSyncer) GetFilterJSON(_ id.UserID) *mautrix.Filter { everything := []event.Type{{Type: "*"}} return &mautrix.Filter{ - Presence: mautrix.FilterPart{NotTypes: everything}, - AccountData: mautrix.FilterPart{NotTypes: everything}, - Room: mautrix.RoomFilter{ + Presence: &mautrix.FilterPart{NotTypes: everything}, + AccountData: &mautrix.FilterPart{NotTypes: everything}, + Room: &mautrix.RoomFilter{ IncludeLeave: false, - Ephemeral: mautrix.FilterPart{NotTypes: everything}, - AccountData: mautrix.FilterPart{NotTypes: everything}, - State: mautrix.FilterPart{NotTypes: everything}, - Timeline: mautrix.FilterPart{NotTypes: everything}, + Ephemeral: &mautrix.FilterPart{NotTypes: everything}, + AccountData: &mautrix.FilterPart{NotTypes: everything}, + State: &mautrix.FilterPart{NotTypes: everything}, + Timeline: &mautrix.FilterPart{NotTypes: everything}, }, } } - -type cryptoStateStore struct { - bridge *Bridge -} - -var _ crypto.StateStore = (*cryptoStateStore)(nil) - -func (c *cryptoStateStore) IsEncrypted(ctx context.Context, id id.RoomID) (bool, error) { - portal := c.bridge.Child.GetIPortal(id) - if portal != nil { - return portal.IsEncrypted(), nil - } - return c.bridge.StateStore.IsEncrypted(ctx, id) -} - -func (c *cryptoStateStore) FindSharedRooms(ctx context.Context, id id.UserID) ([]id.RoomID, error) { - return c.bridge.StateStore.FindSharedRooms(ctx, id) -} - -func (c *cryptoStateStore) GetEncryptionEvent(ctx context.Context, id id.RoomID) (*event.EncryptionEventContent, error) { - return c.bridge.StateStore.GetEncryptionEvent(ctx, id) -} diff --git a/bridgev2/matrix/cryptoerror.go b/bridgev2/matrix/cryptoerror.go new file mode 100644 index 00000000..ea29703a --- /dev/null +++ b/bridgev2/matrix/cryptoerror.go @@ -0,0 +1,94 @@ +// Copyright (c) 2024 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package matrix + +import ( + "context" + "errors" + "fmt" + + "maunium.net/go/mautrix/bridgev2" + "maunium.net/go/mautrix/bridgev2/status" + "maunium.net/go/mautrix/event" + "maunium.net/go/mautrix/id" +) + +var ( + errDeviceNotTrusted = errors.New("your device is not trusted") + errMessageNotEncrypted = errors.New("unencrypted message") + errNoDecryptionKeys = errors.New("the bridge hasn't received the decryption keys") + errNoCrypto = errors.New("this bridge has not been configured to support encryption") +) + +func errorToHumanMessage(err error) string { + var withheld *event.RoomKeyWithheldEventContent + switch { + case errors.Is(err, errDeviceNotTrusted), errors.Is(err, errNoDecryptionKeys), errors.Is(err, errNoCrypto): + return err.Error() + case errors.Is(err, UnknownMessageIndex): + return "the keys received by the bridge can't decrypt the message" + case errors.Is(err, DuplicateMessageIndex): + return "your client encrypted multiple messages with the same key" + case errors.As(err, &withheld): + if withheld.Code == event.RoomKeyWithheldBeeperRedacted { + return "your client used an outdated encryption session" + } + return "your client refused to share decryption keys with the bridge" + case errors.Is(err, errMessageNotEncrypted): + return "the message is not encrypted" + default: + return "the bridge failed to decrypt the message" + } +} + +func deviceUnverifiedErrorWithExplanation(trust id.TrustState) error { + var explanation string + switch trust { + case id.TrustStateBlacklisted: + explanation = "device is blacklisted" + case id.TrustStateUnset: + explanation = "unverified" + case id.TrustStateUnknownDevice: + explanation = "device info not found" + case id.TrustStateForwarded: + explanation = "keys were forwarded from an unknown device" + case id.TrustStateCrossSignedUntrusted: + explanation = "cross-signing keys changed after setting up the bridge" + default: + return errDeviceNotTrusted + } + return fmt.Errorf("%w (%s)", errDeviceNotTrusted, explanation) +} + +func (br *Connector) sendCryptoStatusError(ctx context.Context, evt *event.Event, err error, errorEventID *id.EventID, retryNum int, isFinal bool) { + ms := &bridgev2.MessageStatus{ + Step: status.MsgStepDecrypted, + Status: event.MessageStatusRetriable, + ErrorReason: event.MessageStatusUndecryptable, + InternalError: err, + Message: errorToHumanMessage(err), + IsCertain: true, + SendNotice: true, + RetryNum: retryNum, + } + if !isFinal { + ms.Status = event.MessageStatusPending + // Don't send notice for first error + if retryNum == 0 { + ms.SendNotice = false + ms.DisableMSS = true + } + } + var editEventID id.EventID + if errorEventID != nil { + editEventID = *errorEventID + } + respEventID := br.internalSendMessageStatus(ctx, ms, bridgev2.StatusEventInfoFromEvent(evt), editEventID) + if errorEventID != nil && *errorEventID == "" { + *errorEventID = respEventID + } +} diff --git a/bridge/cryptostore.go b/bridgev2/matrix/cryptostore.go similarity index 85% rename from bridge/cryptostore.go rename to bridgev2/matrix/cryptostore.go index dde48a25..4c3b5d30 100644 --- a/bridge/cryptostore.go +++ b/bridgev2/matrix/cryptostore.go @@ -6,7 +6,7 @@ //go:build cgo && !nocrypto -package bridge +package matrix import ( "context" @@ -30,9 +30,9 @@ type SQLCryptoStore struct { var _ crypto.Store = (*SQLCryptoStore)(nil) -func NewSQLCryptoStore(db *dbutil.Database, log dbutil.DatabaseLogger, userID id.UserID, ghostIDFormat, pickleKey string) *SQLCryptoStore { +func NewSQLCryptoStore(db *dbutil.Database, log dbutil.DatabaseLogger, accountID string, userID id.UserID, ghostIDFormat, pickleKey string) *SQLCryptoStore { return &SQLCryptoStore{ - SQLCryptoStore: crypto.NewSQLCryptoStore(db, log, "", "", []byte(pickleKey)), + SQLCryptoStore: crypto.NewSQLCryptoStore(db, log, accountID, "", []byte(pickleKey)), UserID: userID, GhostIDFormat: ghostIDFormat, } @@ -45,7 +45,7 @@ func (store *SQLCryptoStore) GetRoomJoinedOrInvitedMembers(ctx context.Context, WHERE room_id=$1 AND (membership='join' OR membership='invite') AND user_id<>$2 - AND user_id NOT LIKE $3 + AND user_id NOT LIKE $3 ESCAPE '\' `, roomID, store.UserID, store.GhostIDFormat) if err != nil { return diff --git a/bridgev2/matrix/directmedia.go b/bridgev2/matrix/directmedia.go new file mode 100644 index 00000000..0667981a --- /dev/null +++ b/bridgev2/matrix/directmedia.go @@ -0,0 +1,86 @@ +// Copyright (c) 2024 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package matrix + +import ( + "bytes" + "context" + "crypto/hmac" + "crypto/sha256" + "encoding/base64" + "fmt" + "strings" + + "maunium.net/go/mautrix" + "maunium.net/go/mautrix/bridgev2" + "maunium.net/go/mautrix/bridgev2/networkid" + "maunium.net/go/mautrix/id" + "maunium.net/go/mautrix/mediaproxy" +) + +const MediaIDPrefix = "\U0001F408" +const MediaIDTruncatedHashLength = 16 +const ContentURIMaxLength = 255 + +func (br *Connector) initDirectMedia() error { + if !br.Config.DirectMedia.Enabled { + return nil + } + dmn, ok := br.Bridge.Network.(bridgev2.DirectMediableNetwork) + if !ok { + return fmt.Errorf("direct media is enabled in config, but the network connector does not support it") + } + var err error + br.MediaProxy, err = mediaproxy.NewFromConfig(br.Config.DirectMedia.BasicConfig, br.getDirectMedia) + if err != nil { + return fmt.Errorf("failed to initialize media proxy: %w", err) + } + br.MediaProxy.RegisterRoutes(br.AS.Router, br.Log.With().Str("component", "media proxy").Logger()) + br.dmaSigKey = sha256.Sum256(br.MediaProxy.GetServerKey().Priv.Seed()) + dmn.SetUseDirectMedia() + br.Log.Debug().Str("server_name", br.MediaProxy.GetServerName()).Msg("Enabled direct media access") + return nil +} + +func (br *Connector) hashMediaID(data []byte) []byte { + hasher := hmac.New(sha256.New, br.dmaSigKey[:]) + hasher.Write(data) + return hasher.Sum(nil)[:MediaIDTruncatedHashLength] +} + +func (br *Connector) GenerateContentURI(ctx context.Context, mediaID networkid.MediaID) (id.ContentURIString, error) { + if br.MediaProxy == nil { + return "", bridgev2.ErrDirectMediaNotEnabled + } + buf := make([]byte, len(MediaIDPrefix)+len(mediaID)+MediaIDTruncatedHashLength) + copy(buf, MediaIDPrefix) + copy(buf[len(MediaIDPrefix):], mediaID) + truncatedHash := br.hashMediaID(buf[:len(MediaIDPrefix)+len(mediaID)]) + copy(buf[len(MediaIDPrefix)+len(mediaID):], truncatedHash) + mxc := id.ContentURI{ + Homeserver: br.MediaProxy.GetServerName(), + FileID: br.Config.DirectMedia.MediaIDPrefix + base64.RawURLEncoding.EncodeToString(buf), + }.CUString() + if len(mxc) > ContentURIMaxLength { + return "", fmt.Errorf("content URI too long (%d > %d)", len(mxc), ContentURIMaxLength) + } + return mxc, nil +} + +func (br *Connector) getDirectMedia(ctx context.Context, mediaIDStr string, params map[string]string) (response mediaproxy.GetMediaResponse, err error) { + mediaID, err := base64.RawURLEncoding.DecodeString(strings.TrimPrefix(mediaIDStr, br.Config.DirectMedia.MediaIDPrefix)) + if err != nil || !bytes.HasPrefix(mediaID, []byte(MediaIDPrefix)) || len(mediaID) < len(MediaIDPrefix)+MediaIDTruncatedHashLength+1 { + return nil, mediaproxy.ErrInvalidMediaIDSyntax + } + receivedHash := mediaID[len(mediaID)-MediaIDTruncatedHashLength:] + expectedHash := br.hashMediaID(mediaID[:len(mediaID)-MediaIDTruncatedHashLength]) + if !hmac.Equal(receivedHash, expectedHash) { + return nil, mautrix.MNotFound.WithMessage("Invalid checksum in media ID part") + } + remoteMediaID := networkid.MediaID(mediaID[len(MediaIDPrefix) : len(mediaID)-MediaIDTruncatedHashLength]) + return br.Bridge.Network.(bridgev2.DirectMediableNetwork).Download(ctx, remoteMediaID, params) +} diff --git a/bridgev2/matrix/doublepuppet.go b/bridgev2/matrix/doublepuppet.go new file mode 100644 index 00000000..ace33f30 --- /dev/null +++ b/bridgev2/matrix/doublepuppet.go @@ -0,0 +1,131 @@ +// Copyright (c) 2023 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package matrix + +import ( + "context" + "errors" + "fmt" + "strings" + "sync" + + "github.com/rs/zerolog" + + "maunium.net/go/mautrix" + "maunium.net/go/mautrix/appservice" + "maunium.net/go/mautrix/id" +) + +type doublePuppetUtil struct { + br *Connector + + discoveryCache map[string]string + discoveryCacheLock sync.Mutex +} + +func newDoublePuppetUtil(br *Connector) *doublePuppetUtil { + return &doublePuppetUtil{ + br: br, + discoveryCache: make(map[string]string), + } +} + +func (dp *doublePuppetUtil) newClient(ctx context.Context, mxid id.UserID, accessToken string) (*mautrix.Client, error) { + _, homeserver, err := mxid.Parse() + if err != nil { + return nil, err + } + homeserverURL, found := dp.br.Config.DoublePuppet.Servers[homeserver] + if !found { + if homeserver == dp.br.AS.HomeserverDomain { + homeserverURL = "" + } else if dp.br.Config.DoublePuppet.AllowDiscovery { + dp.discoveryCacheLock.Lock() + defer dp.discoveryCacheLock.Unlock() + if homeserverURL, found = dp.discoveryCache[homeserver]; !found { + resp, err := mautrix.DiscoverClientAPI(ctx, homeserver) + if err != nil { + return nil, fmt.Errorf("failed to find homeserver URL for %s: %v", homeserver, err) + } + homeserverURL = resp.Homeserver.BaseURL + dp.discoveryCache[homeserver] = homeserverURL + zerolog.Ctx(ctx).Debug(). + Str("homeserver", homeserver). + Str("url", homeserverURL). + Str("user_id", mxid.String()). + Msg("Discovered URL to enable double puppeting for user") + } + } else { + return nil, fmt.Errorf("double puppeting from %s is not allowed", homeserver) + } + } + return dp.br.AS.NewExternalMautrixClient(mxid, accessToken, homeserverURL) +} + +func (dp *doublePuppetUtil) newIntent(ctx context.Context, mxid id.UserID, accessToken string) (*appservice.IntentAPI, error) { + client, err := dp.newClient(ctx, mxid, accessToken) + if err != nil { + return nil, err + } + + ia := dp.br.AS.NewIntentAPI("custom") + ia.Client = client + ia.Localpart, _, _ = mxid.Parse() + ia.UserID = mxid + ia.IsCustomPuppet = true + return ia, nil +} + +var ( + ErrMismatchingMXID = errors.New("whoami result does not match custom mxid") + ErrNoAccessToken = errors.New("no access token provided") + ErrNoMXID = errors.New("no mxid provided") +) + +const useConfigASToken = "appservice-config" +const asTokenModePrefix = "as_token:" + +func (dp *doublePuppetUtil) Setup(ctx context.Context, mxid id.UserID, savedAccessToken string) (intent *appservice.IntentAPI, newAccessToken string, err error) { + if len(mxid) == 0 { + err = ErrNoMXID + return + } + _, homeserver, _ := mxid.Parse() + loginSecret, hasSecret := dp.br.Config.DoublePuppet.Secrets[homeserver] + if hasSecret && strings.HasPrefix(loginSecret, asTokenModePrefix) { + intent, err = dp.newIntent(ctx, mxid, strings.TrimPrefix(loginSecret, asTokenModePrefix)) + if err != nil { + return + } + intent.SetAppServiceUserID = true + if savedAccessToken != useConfigASToken { + var resp *mautrix.RespWhoami + resp, err = intent.Whoami(ctx) + if err == nil && resp.UserID != mxid { + err = ErrMismatchingMXID + } + } + return intent, useConfigASToken, err + } else if savedAccessToken == "" || savedAccessToken == useConfigASToken { + err = ErrNoAccessToken + return + } + intent, err = dp.newIntent(ctx, mxid, savedAccessToken) + if err != nil { + return + } + var resp *mautrix.RespWhoami + resp, err = intent.Whoami(ctx) + if err == nil { + if resp.UserID != mxid { + err = ErrMismatchingMXID + } else { + newAccessToken = savedAccessToken + } + } + return +} diff --git a/bridgev2/matrix/intent.go b/bridgev2/matrix/intent.go new file mode 100644 index 00000000..f7254bd4 --- /dev/null +++ b/bridgev2/matrix/intent.go @@ -0,0 +1,795 @@ +// Copyright (c) 2024 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package matrix + +import ( + "bytes" + "context" + "encoding/json" + "errors" + "fmt" + "io" + "os" + "strings" + "sync" + "time" + + "github.com/rs/zerolog" + "go.mau.fi/util/fallocate" + "go.mau.fi/util/ptr" + "golang.org/x/exp/slices" + + "maunium.net/go/mautrix" + "maunium.net/go/mautrix/appservice" + "maunium.net/go/mautrix/bridgev2" + "maunium.net/go/mautrix/bridgev2/bridgeconfig" + "maunium.net/go/mautrix/crypto/attachment" + "maunium.net/go/mautrix/crypto/canonicaljson" + "maunium.net/go/mautrix/event" + "maunium.net/go/mautrix/id" + "maunium.net/go/mautrix/pushrules" +) + +// ASIntent implements the bridge ghost API interface using a real Matrix homeserver as the backend. +type ASIntent struct { + Matrix *appservice.IntentAPI + Connector *Connector + + dmUpdateLock sync.Mutex + directChatsCache event.DirectChatsEventContent +} + +var _ bridgev2.MatrixAPI = (*ASIntent)(nil) +var _ bridgev2.MarkAsDMMatrixAPI = (*ASIntent)(nil) +var _ bridgev2.EphemeralSendingMatrixAPI = (*ASIntent)(nil) + +func (as *ASIntent) SendMessage(ctx context.Context, roomID id.RoomID, eventType event.Type, content *event.Content, extra *bridgev2.MatrixSendExtra) (*mautrix.RespSendEvent, error) { + if extra == nil { + extra = &bridgev2.MatrixSendExtra{} + } + if eventType == event.EventRedaction && !as.Connector.SpecVersions.Supports(mautrix.FeatureRedactSendAsEvent) { + parsedContent := content.Parsed.(*event.RedactionEventContent) + as.Matrix.AddDoublePuppetValue(content) + return as.Matrix.RedactEvent(ctx, roomID, parsedContent.Redacts, mautrix.ReqRedact{ + Reason: parsedContent.Reason, + Extra: content.Raw, + }) + } + if (eventType != event.EventReaction || as.Connector.Config.Encryption.MSC4392) && eventType != event.EventRedaction { + msgContent, ok := content.Parsed.(*event.MessageEventContent) + if ok { + msgContent.AddPerMessageProfileFallback() + } + if encrypted, err := as.Matrix.StateStore.IsEncrypted(ctx, roomID); err != nil { + return nil, fmt.Errorf("failed to check if room is encrypted: %w", err) + } else if encrypted { + if as.Connector.Crypto == nil { + return nil, fmt.Errorf("room is encrypted, but bridge isn't configured to support encryption") + } + if as.Matrix.IsCustomPuppet { + if extra.Timestamp.IsZero() { + as.Matrix.AddDoublePuppetValue(content) + } else { + as.Matrix.AddDoublePuppetValueWithTS(content, extra.Timestamp.UnixMilli()) + } + } + err = as.Connector.Crypto.Encrypt(ctx, roomID, eventType, content) + if err != nil { + return nil, err + } + eventType = event.EventEncrypted + } + } + return as.Matrix.SendMessageEvent(ctx, roomID, eventType, content, mautrix.ReqSendEvent{Timestamp: extra.Timestamp.UnixMilli()}) +} + +func (as *ASIntent) BeeperSendEphemeralEvent(ctx context.Context, roomID id.RoomID, eventType event.Type, content *event.Content, txnID string) (*mautrix.RespSendEvent, error) { + if !as.Connector.SpecVersions.Supports(mautrix.BeeperFeatureEphemeralEvents) { + return nil, mautrix.MUnrecognized.WithMessage("Homeserver does not advertise com.beeper.ephemeral support") + } + if encrypted, err := as.Matrix.StateStore.IsEncrypted(ctx, roomID); err != nil { + return nil, fmt.Errorf("failed to check if room is encrypted: %w", err) + } else if encrypted && as.Connector.Crypto != nil { + if err = as.Connector.Crypto.Encrypt(ctx, roomID, eventType, content); err != nil { + return nil, err + } + eventType = event.EventEncrypted + } + return as.Matrix.BeeperSendEphemeralEvent(ctx, roomID, eventType, content, mautrix.ReqSendEvent{TransactionID: txnID}) +} + +func (as *ASIntent) fillMemberEvent(ctx context.Context, roomID id.RoomID, userID id.UserID, content *event.Content) { + targetContent, ok := content.Parsed.(*event.MemberEventContent) + if !ok || targetContent.Displayname != "" || targetContent.AvatarURL != "" { + return + } + memberContent, err := as.Matrix.StateStore.TryGetMember(ctx, roomID, userID) + if err != nil { + zerolog.Ctx(ctx).Err(err). + Stringer("target_user_id", userID). + Str("membership", string(targetContent.Membership)). + Msg("Failed to get old member content from state store to fill new membership event") + } else if memberContent != nil { + targetContent.Displayname = memberContent.Displayname + targetContent.AvatarURL = memberContent.AvatarURL + } else if ghost, err := as.Connector.Bridge.GetGhostByMXID(ctx, userID); err != nil { + zerolog.Ctx(ctx).Err(err). + Stringer("target_user_id", userID). + Str("membership", string(targetContent.Membership)). + Msg("Failed to get ghost to fill new membership event") + } else if ghost != nil { + targetContent.Displayname = ghost.Name + targetContent.AvatarURL = ghost.AvatarMXC + } else if profile, err := as.Matrix.GetProfile(ctx, userID); err != nil { + zerolog.Ctx(ctx).Debug().Err(err). + Stringer("target_user_id", userID). + Str("membership", string(targetContent.Membership)). + Msg("Failed to get profile to fill new membership event") + } else if profile != nil { + targetContent.Displayname = profile.DisplayName + targetContent.AvatarURL = profile.AvatarURL.CUString() + } +} + +func (as *ASIntent) SendState(ctx context.Context, roomID id.RoomID, eventType event.Type, stateKey string, content *event.Content, ts time.Time) (resp *mautrix.RespSendEvent, err error) { + if eventType == event.StateMember { + as.fillMemberEvent(ctx, roomID, id.UserID(stateKey), content) + } + resp, err = as.Matrix.SendStateEvent(ctx, roomID, eventType, stateKey, content, mautrix.ReqSendEvent{Timestamp: ts.UnixMilli()}) + if err != nil && eventType == event.StateMember { + var httpErr mautrix.HTTPError + if errors.As(err, &httpErr) && httpErr.RespError != nil && + (strings.Contains(httpErr.RespError.Err, "is already in the room") || strings.Contains(httpErr.RespError.Err, "is already joined to room")) { + err = as.Matrix.StateStore.SetMembership(ctx, roomID, id.UserID(stateKey), event.MembershipJoin) + } + } + return resp, err +} + +func (as *ASIntent) MarkRead(ctx context.Context, roomID id.RoomID, eventID id.EventID, ts time.Time) (err error) { + extraData := map[string]any{} + if !ts.IsZero() { + extraData["ts"] = ts.UnixMilli() + } + as.Matrix.AddDoublePuppetValue(extraData) + req := mautrix.ReqSetReadMarkers{ + Read: eventID, + BeeperReadExtra: extraData, + } + if as.Matrix.IsCustomPuppet { + req.FullyRead = eventID + req.BeeperFullyReadExtra = extraData + } + if as.Matrix.IsCustomPuppet && as.Connector.SpecVersions.Supports(mautrix.BeeperFeatureInboxState) && as.Connector.Config.Homeserver.Software != bridgeconfig.SoftwareHungry { + err = as.Matrix.SetBeeperInboxState(ctx, roomID, &mautrix.ReqSetBeeperInboxState{ + //MarkedUnread: ptr.Ptr(false), + ReadMarkers: &req, + }) + } else { + err = as.Matrix.SetReadMarkers(ctx, roomID, &req) + if err == nil && as.Matrix.IsCustomPuppet && as.Connector.Config.Homeserver.Software != bridgeconfig.SoftwareHungry { + err = as.Matrix.SetRoomAccountData(ctx, roomID, event.AccountDataMarkedUnread.Type, &event.MarkedUnreadEventContent{ + Unread: false, + }) + } + } + return +} + +func (as *ASIntent) MarkUnread(ctx context.Context, roomID id.RoomID, unread bool) error { + if as.Connector.Config.Homeserver.Software == bridgeconfig.SoftwareHungry { + return nil + } + if as.Matrix.IsCustomPuppet && as.Connector.SpecVersions.Supports(mautrix.BeeperFeatureInboxState) { + return as.Matrix.SetBeeperInboxState(ctx, roomID, &mautrix.ReqSetBeeperInboxState{ + MarkedUnread: ptr.Ptr(unread), + }) + } else { + return as.Matrix.SetRoomAccountData(ctx, roomID, event.AccountDataMarkedUnread.Type, &event.MarkedUnreadEventContent{ + Unread: unread, + }) + } +} + +func (as *ASIntent) MarkTyping(ctx context.Context, roomID id.RoomID, typingType bridgev2.TypingType, timeout time.Duration) error { + if typingType != bridgev2.TypingTypeText { + return nil + } else if as.Matrix.IsCustomPuppet { + // Don't send double puppeted typing notifications, there's no good way to prevent echoing them + return nil + } + _, err := as.Matrix.UserTyping(ctx, roomID, timeout > 0, timeout) + return err +} + +func (as *ASIntent) DownloadMedia(ctx context.Context, uri id.ContentURIString, file *event.EncryptedFileInfo) ([]byte, error) { + if file != nil { + uri = file.URL + } + parsedURI, err := uri.Parse() + if err != nil { + return nil, err + } + data, err := as.Matrix.DownloadBytes(ctx, parsedURI) + if err != nil { + return nil, err + } + if file != nil { + err = file.DecryptInPlace(data) + if err != nil { + return nil, err + } + } + return data, nil +} + +func (as *ASIntent) DownloadMediaToFile(ctx context.Context, uri id.ContentURIString, file *event.EncryptedFileInfo, writable bool, callback func(*os.File) error) error { + if file != nil { + uri = file.URL + err := file.PrepareForDecryption() + if err != nil { + return err + } + } + parsedURI, err := uri.Parse() + if err != nil { + return err + } + tempFile, err := os.CreateTemp("", "mautrix-download-*") + if err != nil { + return fmt.Errorf("failed to create temp file: %w", err) + } + defer func() { + _ = tempFile.Close() + _ = os.Remove(tempFile.Name()) + }() + resp, err := as.Matrix.Download(ctx, parsedURI) + if err != nil { + return fmt.Errorf("failed to send download request: %w", err) + } + defer resp.Body.Close() + reader := resp.Body + if file != nil { + reader = file.DecryptStream(reader) + } + if resp.ContentLength > 0 { + err = fallocate.Fallocate(tempFile, int(resp.ContentLength)) + if err != nil { + return fmt.Errorf("failed to preallocate file: %w", err) + } + } + _, err = io.Copy(tempFile, reader) + if err != nil { + return fmt.Errorf("failed to read response: %w", err) + } + err = reader.Close() + if err != nil { + return fmt.Errorf("failed to close response body: %w", err) + } + _, err = tempFile.Seek(0, io.SeekStart) + if err != nil { + return fmt.Errorf("failed to seek to start of temp file: %w", err) + } + err = callback(tempFile) + if err != nil { + return bridgev2.CallbackError{Type: "read", Wrapped: err} + } + return nil +} + +func (as *ASIntent) UploadMedia(ctx context.Context, roomID id.RoomID, data []byte, fileName, mimeType string) (url id.ContentURIString, file *event.EncryptedFileInfo, err error) { + if int64(len(data)) > as.Connector.MediaConfig.UploadSize { + return "", nil, fmt.Errorf("file too large (%.2f MB > %.2f MB)", float64(len(data))/1000/1000, float64(as.Connector.MediaConfig.UploadSize)/1000/1000) + } + if roomID != "" { + var encrypted bool + if encrypted, err = as.Matrix.StateStore.IsEncrypted(ctx, roomID); err != nil { + err = fmt.Errorf("failed to check if room is encrypted: %w", err) + return + } else if encrypted { + file = &event.EncryptedFileInfo{ + EncryptedFile: *attachment.NewEncryptedFile(), + } + file.EncryptInPlace(data) + mimeType = "application/octet-stream" + fileName = "" + } + } + url, err = as.doUploadReq(ctx, file, mautrix.ReqUploadMedia{ + ContentBytes: data, + ContentType: mimeType, + FileName: fileName, + }) + return +} + +func (as *ASIntent) UploadMediaStream( + ctx context.Context, + roomID id.RoomID, + size int64, + requireFile bool, + cb bridgev2.FileStreamCallback, +) (url id.ContentURIString, file *event.EncryptedFileInfo, err error) { + if size > as.Connector.MediaConfig.UploadSize { + return "", nil, fmt.Errorf("file too large (%.2f MB > %.2f MB)", float64(size)/1000/1000, float64(as.Connector.MediaConfig.UploadSize)/1000/1000) + } + if !requireFile && 0 < size && size < as.Connector.Config.Matrix.UploadFileThreshold { + var buf bytes.Buffer + res, err := cb(&buf) + if err != nil { + return "", nil, err + } else if res.ReplacementFile != "" { + panic(fmt.Errorf("logic error: replacement path must only be returned if requireFile is true")) + } + return as.UploadMedia(ctx, roomID, buf.Bytes(), res.FileName, res.MimeType) + } + var tempFile *os.File + tempFile, err = os.CreateTemp("", "mautrix-upload-*") + if err != nil { + err = fmt.Errorf("failed to create temp file: %w", err) + return + } + removeAndClose := func(f *os.File) { + _ = f.Close() + _ = os.Remove(f.Name()) + } + startedAsyncUpload := false + defer func() { + if !startedAsyncUpload { + removeAndClose(tempFile) + } + }() + if size > 0 { + err = fallocate.Fallocate(tempFile, int(size)) + if err != nil { + err = fmt.Errorf("failed to preallocate file: %w", err) + return + } + } + if roomID != "" { + var encrypted bool + if encrypted, err = as.Matrix.StateStore.IsEncrypted(ctx, roomID); err != nil { + err = fmt.Errorf("failed to check if room is encrypted: %w", err) + return + } else if encrypted { + file = &event.EncryptedFileInfo{ + EncryptedFile: *attachment.NewEncryptedFile(), + } + } + } + var res *bridgev2.FileStreamResult + res, err = cb(tempFile) + if err != nil { + err = bridgev2.CallbackError{Type: "write", Wrapped: err} + return + } + var replFile *os.File + if res.ReplacementFile != "" { + replFile, err = os.OpenFile(res.ReplacementFile, os.O_RDWR, 0) + if err != nil { + err = fmt.Errorf("failed to open replacement file: %w", err) + return + } + defer func() { + if !startedAsyncUpload { + removeAndClose(replFile) + } + }() + } else { + replFile = tempFile + _, err = replFile.Seek(0, io.SeekStart) + if err != nil { + err = fmt.Errorf("failed to seek to start of temp file: %w", err) + return + } + } + if file != nil { + res.FileName = "" + res.MimeType = "application/octet-stream" + err = file.EncryptFile(replFile) + if err != nil { + err = fmt.Errorf("failed to encrypt file: %w", err) + return + } + _, err = replFile.Seek(0, io.SeekStart) + if err != nil { + err = fmt.Errorf("failed to seek to start of temp file after encrypting: %w", err) + return + } + } + info, err := replFile.Stat() + if err != nil { + err = fmt.Errorf("failed to get temp file info: %w", err) + return + } + size = info.Size() + if size > as.Connector.MediaConfig.UploadSize { + return "", nil, fmt.Errorf("file too large (%.2f MB > %.2f MB)", float64(size)/1000/1000, float64(as.Connector.MediaConfig.UploadSize)/1000/1000) + } + req := mautrix.ReqUploadMedia{ + Content: replFile, + ContentLength: size, + ContentType: res.MimeType, + FileName: res.FileName, + } + if as.Connector.Config.Homeserver.AsyncMedia { + req.DoneCallback = func() { + removeAndClose(replFile) + removeAndClose(tempFile) + } + req.AsyncContext = zerolog.Ctx(ctx).WithContext(as.Connector.Bridge.BackgroundCtx) + startedAsyncUpload = true + var resp *mautrix.RespCreateMXC + resp, err = as.Matrix.UploadAsync(ctx, req) + if resp != nil { + url = resp.ContentURI.CUString() + } + } else { + var resp *mautrix.RespMediaUpload + resp, err = as.Matrix.UploadMedia(ctx, req) + if resp != nil { + url = resp.ContentURI.CUString() + } + } + if file != nil { + file.URL = url + url = "" + } + return +} + +func (as *ASIntent) doUploadReq(ctx context.Context, file *event.EncryptedFileInfo, req mautrix.ReqUploadMedia) (url id.ContentURIString, err error) { + if as.Connector.Config.Homeserver.AsyncMedia { + if req.ContentBytes != nil { + // Prevent too many background uploads at once + err = as.Connector.uploadSema.Acquire(ctx, int64(len(req.ContentBytes))) + if err != nil { + return + } + req.DoneCallback = func() { + as.Connector.uploadSema.Release(int64(len(req.ContentBytes))) + } + } + req.AsyncContext = zerolog.Ctx(ctx).WithContext(as.Connector.Bridge.BackgroundCtx) + var resp *mautrix.RespCreateMXC + resp, err = as.Matrix.UploadAsync(ctx, req) + if resp != nil { + url = resp.ContentURI.CUString() + } + } else { + var resp *mautrix.RespMediaUpload + resp, err = as.Matrix.UploadMedia(ctx, req) + if resp != nil { + url = resp.ContentURI.CUString() + } + } + if file != nil { + file.URL = url + url = "" + } + return +} + +func (as *ASIntent) SetDisplayName(ctx context.Context, name string) error { + return as.Matrix.SetDisplayName(ctx, name) +} + +func (as *ASIntent) SetAvatarURL(ctx context.Context, avatarURL id.ContentURIString) error { + parsedAvatarURL, err := avatarURL.Parse() + if err != nil { + return err + } + return as.Matrix.SetAvatarURL(ctx, parsedAvatarURL) +} + +func dataToFields(data any) (map[string]json.RawMessage, error) { + fields, ok := data.(map[string]json.RawMessage) + if ok { + return fields, nil + } + d, err := json.Marshal(data) + if err != nil { + return nil, err + } + d = canonicaljson.CanonicalJSONAssumeValid(d) + err = json.Unmarshal(d, &fields) + return fields, err +} + +func marshalField(val any) json.RawMessage { + data, _ := json.Marshal(val) + if len(data) > 0 && (data[0] == '{' || data[0] == '[') { + return canonicaljson.CanonicalJSONAssumeValid(data) + } + return data +} + +var nullJSON = json.RawMessage("null") + +func (as *ASIntent) SetExtraProfileMeta(ctx context.Context, data any) error { + if as.Connector.SpecVersions.Supports(mautrix.BeeperFeatureArbitraryProfileMeta) { + return as.Matrix.BeeperUpdateProfile(ctx, data) + } else if as.Connector.SpecVersions.Supports(mautrix.FeatureArbitraryProfileFields) && as.Connector.Config.Matrix.GhostExtraProfileInfo { + fields, err := dataToFields(data) + if err != nil { + return fmt.Errorf("failed to marshal fields: %w", err) + } + currentProfile, err := as.Matrix.GetProfile(ctx, as.Matrix.UserID) + if err != nil { + return fmt.Errorf("failed to get current profile: %w", err) + } + for key, val := range fields { + existing, ok := currentProfile.Extra[key] + if !ok { + if bytes.Equal(val, nullJSON) { + continue + } + err = as.Matrix.SetProfileField(ctx, key, val) + } else if !bytes.Equal(marshalField(existing), val) { + if bytes.Equal(val, nullJSON) { + err = as.Matrix.DeleteProfileField(ctx, key) + } else { + err = as.Matrix.SetProfileField(ctx, key, val) + } + } + if err != nil { + return fmt.Errorf("failed to set profile field %q: %w", key, err) + } + } + } + return nil +} + +func (as *ASIntent) GetMXID() id.UserID { + return as.Matrix.UserID +} + +func (as *ASIntent) IsDoublePuppet() bool { + return as.Matrix.IsDoublePuppet() +} + +func (as *ASIntent) EnsureJoined(ctx context.Context, roomID id.RoomID, extra ...bridgev2.EnsureJoinedParams) error { + var params bridgev2.EnsureJoinedParams + if len(extra) > 0 { + params = extra[0] + } + err := as.Matrix.EnsureJoined(ctx, roomID, appservice.EnsureJoinedParams{Via: params.Via}) + if err != nil { + return err + } + if as.Connector.Bot.UserID == as.Matrix.UserID { + _, err = as.Matrix.State(ctx, roomID) + if err != nil { + zerolog.Ctx(ctx).Err(err).Msg("Failed to get state after joining room with bot") + } + } + return nil +} + +func (as *ASIntent) EnsureInvited(ctx context.Context, roomID id.RoomID, userID id.UserID) error { + return as.Matrix.EnsureInvited(ctx, roomID, userID) +} + +func (br *Connector) getDefaultEncryptionEvent() *event.EncryptionEventContent { + content := &event.EncryptionEventContent{Algorithm: id.AlgorithmMegolmV1} + if rot := br.Config.Encryption.Rotation; rot.EnableCustom { + content.RotationPeriodMillis = rot.Milliseconds + content.RotationPeriodMessages = rot.Messages + } + return content +} + +func (as *ASIntent) filterCreateRequestForV12(ctx context.Context, req *mautrix.ReqCreateRoom) { + if as.Connector.Config.Homeserver.Software == bridgeconfig.SoftwareHungry { + // Hungryserv doesn't override the capabilities endpoint nor do room versions + return + } + caps := as.Connector.fetchCapabilities(ctx) + roomVer := req.RoomVersion + if roomVer == "" && caps != nil && caps.RoomVersions != nil { + roomVer = id.RoomVersion(caps.RoomVersions.Default) + } + if roomVer != "" && !roomVer.PrivilegedRoomCreators() { + return + } + creators, _ := req.CreationContent["additional_creators"].([]id.UserID) + creators = append(slices.Clone(creators), as.GetMXID()) + if req.PowerLevelOverride != nil { + for _, creator := range creators { + delete(req.PowerLevelOverride.Users, creator) + } + } + for _, evt := range req.InitialState { + if evt.Type != event.StatePowerLevels { + continue + } + content, ok := evt.Content.Parsed.(*event.PowerLevelsEventContent) + if ok { + for _, creator := range creators { + delete(content.Users, creator) + } + } + } +} + +func (as *ASIntent) CreateRoom(ctx context.Context, req *mautrix.ReqCreateRoom) (id.RoomID, error) { + if as.Connector.Config.Encryption.Default { + req.InitialState = append(req.InitialState, &event.Event{ + Type: event.StateEncryption, + Content: event.Content{ + Parsed: as.Connector.getDefaultEncryptionEvent(), + }, + }) + } + if !as.Connector.Config.Matrix.FederateRooms { + if req.CreationContent == nil { + req.CreationContent = make(map[string]any) + } + req.CreationContent["m.federate"] = false + } + as.filterCreateRequestForV12(ctx, req) + resp, err := as.Matrix.CreateRoom(ctx, req) + if err != nil { + return "", err + } + return resp.RoomID, nil +} + +func (as *ASIntent) MarkAsDM(ctx context.Context, roomID id.RoomID, withUser id.UserID) error { + if !as.Connector.Config.Matrix.SyncDirectChatList { + return nil + } + as.dmUpdateLock.Lock() + defer as.dmUpdateLock.Unlock() + cached, ok := as.directChatsCache[withUser] + if ok && slices.Contains(cached, roomID) { + return nil + } + var directChats event.DirectChatsEventContent + err := as.Matrix.GetAccountData(ctx, event.AccountDataDirectChats.Type, &directChats) + if err != nil { + return err + } + as.directChatsCache = directChats + rooms := directChats[withUser] + if slices.Contains(rooms, roomID) { + return nil + } + directChats[withUser] = append(rooms, roomID) + err = as.Matrix.SetAccountData(ctx, event.AccountDataDirectChats.Type, &directChats) + if err != nil { + if rooms == nil { + delete(directChats, withUser) + } else { + directChats[withUser] = rooms + } + return fmt.Errorf("failed to set direct chats account data: %w", err) + } + return nil +} + +func (as *ASIntent) DeleteRoom(ctx context.Context, roomID id.RoomID, puppetsOnly bool) error { + if roomID == "" { + return nil + } + if as.Connector.SpecVersions.Supports(mautrix.BeeperFeatureRoomYeeting) { + err := as.Matrix.BeeperDeleteRoom(ctx, roomID) + if err != nil { + return err + } + err = as.Matrix.StateStore.ClearCachedMembers(ctx, roomID) + if err != nil { + zerolog.Ctx(ctx).Err(err).Msg("Failed to clear cached members while cleaning up portal") + } + return nil + } + members, err := as.Matrix.JoinedMembers(ctx, roomID) + if err != nil { + return fmt.Errorf("failed to get portal members for cleanup: %w", err) + } + for member := range members.Joined { + if member == as.Matrix.UserID { + continue + } + if as.Connector.Bridge.IsGhostMXID(member) { + _, err = as.Connector.AS.Intent(member).LeaveRoom(ctx, roomID) + if err != nil { + zerolog.Ctx(ctx).Err(err).Stringer("user_id", member).Msg("Failed to leave room while cleaning up portal") + } + } else if !puppetsOnly { + _, err = as.Matrix.KickUser(ctx, roomID, &mautrix.ReqKickUser{UserID: member, Reason: "Deleting portal"}) + if err != nil { + zerolog.Ctx(ctx).Err(err).Stringer("user_id", member).Msg("Failed to kick user while cleaning up portal") + } + } + } + _, err = as.Matrix.LeaveRoom(ctx, roomID) + if err != nil { + zerolog.Ctx(ctx).Err(err).Msg("Failed to leave room while cleaning up portal") + } + err = as.Matrix.StateStore.ClearCachedMembers(ctx, roomID) + if err != nil { + zerolog.Ctx(ctx).Err(err).Msg("Failed to clear cached members while cleaning up portal") + } + return nil +} + +func (as *ASIntent) TagRoom(ctx context.Context, roomID id.RoomID, tag event.RoomTag, isTagged bool) error { + tags, err := as.Matrix.GetTags(ctx, roomID) + if err != nil { + return fmt.Errorf("failed to get room tags: %w", err) + } + if isTagged { + _, alreadyTagged := tags.Tags[tag] + if alreadyTagged { + return nil + } + err = as.Matrix.AddTagWithCustomData(ctx, roomID, tag, &event.TagMetadata{ + MauDoublePuppetSource: as.Connector.AS.DoublePuppetValue, + }) + if err != nil { + return err + } + } + for extraTag := range tags.Tags { + if extraTag == event.RoomTagFavourite || extraTag == event.RoomTagLowPriority { + err = as.Matrix.RemoveTag(ctx, roomID, extraTag) + if err != nil { + return fmt.Errorf("failed to remove extra tag %s: %w", extraTag, err) + } + } + } + return nil +} + +func (as *ASIntent) MuteRoom(ctx context.Context, roomID id.RoomID, until time.Time) error { + var mutedUntil int64 + if until.Before(time.Now()) { + mutedUntil = 0 + } else if until == event.MutedForever { + mutedUntil = -1 + } else { + mutedUntil = until.UnixMilli() + } + if as.Connector.SpecVersions.Supports(mautrix.BeeperFeatureAccountDataMute) { + return as.Matrix.SetRoomAccountData(ctx, roomID, event.AccountDataBeeperMute.Type, &event.BeeperMuteEventContent{ + MutedUntil: mutedUntil, + }) + } + if mutedUntil == 0 { + err := as.Matrix.DeletePushRule(ctx, "global", pushrules.RoomRule, string(roomID)) + // If the push rule doesn't exist, everything is fine + if errors.Is(err, mautrix.MNotFound) { + err = nil + } + return err + } else { + return as.Matrix.PutPushRule(ctx, "global", pushrules.RoomRule, string(roomID), &mautrix.ReqPutPushRule{ + Actions: []pushrules.PushActionType{pushrules.ActionDontNotify}, + }) + } +} + +func (as *ASIntent) GetEvent(ctx context.Context, roomID id.RoomID, eventID id.EventID) (*event.Event, error) { + evt, err := as.Matrix.Client.GetEvent(ctx, roomID, eventID) + if err != nil { + return nil, err + } + err = evt.Content.ParseRaw(evt.Type) + if err != nil { + zerolog.Ctx(ctx).Err(err).Stringer("room_id", roomID).Stringer("event_id", eventID).Msg("failed to parse event content") + } + + if evt.Type == event.EventEncrypted { + if as.Connector.Crypto == nil || as.Connector.Config.Encryption.DeleteKeys.RatchetOnDecrypt { + return nil, errors.New("can't decrypt the event") + } + return as.Connector.Crypto.Decrypt(ctx, evt) + } + + return evt, nil +} diff --git a/bridgev2/matrix/matrix.go b/bridgev2/matrix/matrix.go new file mode 100644 index 00000000..954d0ad9 --- /dev/null +++ b/bridgev2/matrix/matrix.go @@ -0,0 +1,242 @@ +// Copyright (c) 2023 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package matrix + +import ( + "context" + "errors" + "fmt" + "slices" + "time" + + "github.com/rs/zerolog" + "go.mau.fi/util/jsontime" + + "maunium.net/go/mautrix/appservice" + "maunium.net/go/mautrix/bridgev2" + "maunium.net/go/mautrix/bridgev2/status" + "maunium.net/go/mautrix/event" + "maunium.net/go/mautrix/id" +) + +func (br *Connector) handleRoomEvent(ctx context.Context, evt *event.Event) { + if br.shouldIgnoreEvent(evt) { + return + } + if !br.Config.Bridge.Permissions.Get(evt.Sender).SendEvents && evt.Type != event.StateMember { + zerolog.Ctx(ctx).Debug().Msg("Dropping event from user with no permission to send events") + br.SendMessageStatus(ctx, &bridgev2.ErrNoPermissionToInteract, bridgev2.StatusEventInfoFromEvent(evt)) + return + } + if (evt.Type == event.EventMessage || evt.Type == event.EventSticker) && !evt.Mautrix.WasEncrypted && br.Config.Encryption.Require { + zerolog.Ctx(ctx).Warn().Msg("Dropping unencrypted event as encryption is configured to be required") + br.sendCryptoStatusError(ctx, evt, errMessageNotEncrypted, nil, 0, true) + return + } + if evt.Type == event.StateMember && br.Crypto != nil { + br.Crypto.HandleMemberEvent(ctx, evt) + } + br.Bridge.QueueMatrixEvent(ctx, evt) +} + +func (br *Connector) handleEphemeralEvent(ctx context.Context, evt *event.Event) { + switch evt.Type { + case event.EphemeralEventReceipt: + receiptContent := *evt.Content.AsReceipt() + for eventID, receipts := range receiptContent { + for receiptType, userReceipts := range receipts { + for userID, receipt := range userReceipts { + if br.shouldIgnoreEventFromUser(userID) || (br.AS.DoublePuppetValue != "" && receipt.Extra[appservice.DoublePuppetKey] == br.AS.DoublePuppetValue) { + delete(userReceipts, userID) + } + } + if len(userReceipts) == 0 { + delete(receipts, receiptType) + } + } + if len(receipts) == 0 { + delete(receiptContent, eventID) + } + } + if len(receiptContent) == 0 { + return + } + case event.EphemeralEventTyping: + typingContent := evt.Content.AsTyping() + typingContent.UserIDs = slices.DeleteFunc(typingContent.UserIDs, br.shouldIgnoreEventFromUser) + case event.BeeperEphemeralEventAIStream: + if br.shouldIgnoreEvent(evt) { + return + } + } + br.Bridge.QueueMatrixEvent(ctx, evt) +} + +func (br *Connector) handleEncryptedEvent(ctx context.Context, evt *event.Event) { + if br.shouldIgnoreEvent(evt) { + return + } + content := evt.Content.AsEncrypted() + log := zerolog.Ctx(ctx).With(). + Str("event_id", evt.ID.String()). + Str("session_id", content.SessionID.String()). + Logger() + if !br.Config.Bridge.Permissions.Get(evt.Sender).SendEvents { + log.Debug().Msg("Dropping event from user with no permission to send events") + br.SendMessageStatus(ctx, &bridgev2.ErrNoPermissionToInteract, bridgev2.StatusEventInfoFromEvent(evt)) + return + } + ctx = log.WithContext(ctx) + if br.Crypto == nil { + br.sendCryptoStatusError(ctx, evt, errNoCrypto, nil, 0, true) + log.Error().Msg("Can't decrypt message: no crypto") + return + } + log.Debug().Msg("Decrypting received event") + + decryptionStart := time.Now() + decrypted, err := br.Crypto.Decrypt(ctx, evt) + decryptionRetryCount := 0 + var errorEventID id.EventID + if errors.Is(err, NoSessionFound) { + decryptionRetryCount = 1 + log.Debug(). + Int("wait_seconds", int(initialSessionWaitTimeout.Seconds())). + Msg("Couldn't find session, waiting for keys to arrive...") + go br.sendCryptoStatusError(ctx, evt, err, &errorEventID, 0, false) + if br.Crypto.WaitForSession(ctx, evt.RoomID, content.SenderKey, content.SessionID, initialSessionWaitTimeout) { + log.Debug().Msg("Got keys after waiting, trying to decrypt event again") + decrypted, err = br.Crypto.Decrypt(ctx, evt) + } else { + go br.waitLongerForSession(ctx, evt, decryptionStart, &errorEventID) + return + } + } + if err != nil { + log.Warn().Err(err).Msg("Failed to decrypt event") + go br.sendCryptoStatusError(ctx, evt, err, nil, decryptionRetryCount, true) + return + } + br.postDecrypt(ctx, evt, decrypted, decryptionRetryCount, &errorEventID, time.Since(decryptionStart)) +} + +func (br *Connector) waitLongerForSession(ctx context.Context, evt *event.Event, decryptionStart time.Time, errorEventID *id.EventID) { + log := zerolog.Ctx(ctx) + content := evt.Content.AsEncrypted() + log.Debug(). + Int("wait_seconds", int(extendedSessionWaitTimeout.Seconds())). + Msg("Couldn't find session, requesting keys and waiting longer...") + + //lint:ignore SA1019 RequestSession will gracefully request from all devices if DeviceID is blank + go br.Crypto.RequestSession(ctx, evt.RoomID, content.SenderKey, content.SessionID, evt.Sender, content.DeviceID) + go br.sendCryptoStatusError(ctx, evt, fmt.Errorf("%w. The bridge will retry for %d seconds", errNoDecryptionKeys, int(extendedSessionWaitTimeout.Seconds())), errorEventID, 1, false) + + if !br.Crypto.WaitForSession(ctx, evt.RoomID, content.SenderKey, content.SessionID, extendedSessionWaitTimeout) { + log.Debug().Msg("Didn't get session, giving up trying to decrypt event") + go br.sendCryptoStatusError(ctx, evt, errNoDecryptionKeys, errorEventID, 2, true) + return + } + + log.Debug().Msg("Got keys after waiting longer, trying to decrypt event again") + decrypted, err := br.Crypto.Decrypt(ctx, evt) + if err != nil { + log.Error().Err(err).Msg("Failed to decrypt event") + go br.sendCryptoStatusError(ctx, evt, err, errorEventID, 2, true) + return + } + + br.postDecrypt(ctx, evt, decrypted, 2, errorEventID, time.Since(decryptionStart)) +} + +type CommandProcessor interface { + Handle(ctx context.Context, roomID id.RoomID, eventID id.EventID, user bridgev2.User, message string, replyTo id.EventID) +} + +func (br *Connector) sendSuccessCheckpoint(ctx context.Context, evt *event.Event, step status.MessageCheckpointStep, retryNum int) { + err := br.SendMessageCheckpoints(ctx, []*status.MessageCheckpoint{{ + RoomID: evt.RoomID, + EventID: evt.ID, + EventType: evt.Type, + MessageType: evt.Content.AsMessage().MsgType, + Step: step, + Timestamp: jsontime.UnixMilliNow(), + Status: status.MsgStatusSuccess, + ReportedBy: status.MsgReportedByBridge, + RetryNum: retryNum, + }}) + if err != nil { + zerolog.Ctx(ctx).Err(err).Str("checkpoint_step", string(step)).Msg("Failed to send checkpoint") + } +} + +func (br *Connector) sendBridgeCheckpoint(ctx context.Context, evt *event.Event) { + if !evt.Mautrix.CheckpointSent { + go br.sendSuccessCheckpoint(ctx, evt, status.MsgStepBridge, 0) + } +} + +func (br *Connector) shouldIgnoreEventFromUser(userID id.UserID) bool { + return userID == br.Bot.UserID || br.Bridge.IsGhostMXID(userID) +} + +func (br *Connector) shouldIgnoreEvent(evt *event.Event) bool { + if br.shouldIgnoreEventFromUser(evt.Sender) && evt.Type != event.StateTombstone { + return true + } + dpVal, ok := evt.Content.Raw[appservice.DoublePuppetKey] + if ok && dpVal == br.AS.DoublePuppetValue { + dpTS, ok := evt.Content.Raw[appservice.DoublePuppetTSKey].(float64) + if !ok || int64(dpTS) == evt.Timestamp { + return true + } + } + return false +} + +const initialSessionWaitTimeout = 3 * time.Second +const extendedSessionWaitTimeout = 22 * time.Second + +func copySomeKeys(original, decrypted *event.Event) { + isScheduled, _ := original.Content.Raw["com.beeper.scheduled"].(bool) + _, alreadyExists := decrypted.Content.Raw["com.beeper.scheduled"] + if isScheduled && !alreadyExists { + decrypted.Content.Raw["com.beeper.scheduled"] = true + } +} + +func (br *Connector) postDecrypt(ctx context.Context, original, decrypted *event.Event, retryCount int, errorEventID *id.EventID, duration time.Duration) { + log := zerolog.Ctx(ctx) + minLevel := br.Config.Encryption.VerificationLevels.Send + if decrypted.Mautrix.TrustState < minLevel { + logEvt := log.Warn(). + Str("user_id", decrypted.Sender.String()). + Bool("forwarded_keys", decrypted.Mautrix.ForwardedKeys). + Stringer("device_trust", decrypted.Mautrix.TrustState). + Stringer("min_trust", minLevel) + if decrypted.Mautrix.TrustSource != nil { + dev := decrypted.Mautrix.TrustSource + logEvt. + Str("device_id", dev.DeviceID.String()). + Str("device_signing_key", dev.SigningKey.String()) + } else { + logEvt.Str("device_id", "unknown") + } + logEvt.Msg("Dropping event due to insufficient verification level") + err := deviceUnverifiedErrorWithExplanation(decrypted.Mautrix.TrustState) + go br.sendCryptoStatusError(ctx, decrypted, err, errorEventID, retryCount, true) + return + } + copySomeKeys(original, decrypted) + + go br.sendSuccessCheckpoint(ctx, decrypted, status.MsgStepDecrypted, retryCount) + decrypted.Mautrix.CheckpointSent = true + decrypted.Mautrix.DecryptionDuration = duration + br.EventProcessor.Dispatch(ctx, decrypted) + if errorEventID != nil && *errorEventID != "" { + _, _ = br.Bot.RedactEvent(ctx, decrypted.RoomID, *errorEventID) + } +} diff --git a/bridgev2/matrix/mxmain/config.go b/bridgev2/matrix/mxmain/config.go new file mode 100644 index 00000000..a684d8a2 --- /dev/null +++ b/bridgev2/matrix/mxmain/config.go @@ -0,0 +1,36 @@ +// Copyright (c) 2024 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package mxmain + +import ( + _ "embed" + "strings" + "text/template" + + "go.mau.fi/util/exerrors" +) + +//go:embed example-config.yaml +var MatrixExampleConfigBase string + +var matrixExampleConfigBaseTemplate = exerrors.Must(template.New("example-config.yaml"). + Delims("$<<", ">>"). + Parse(MatrixExampleConfigBase)) + +func (br *BridgeMain) makeFullExampleConfig(networkExample string) string { + var buf strings.Builder + buf.WriteString("# Network-specific config options\n") + buf.WriteString("network:\n") + for _, line := range strings.Split(networkExample, "\n") { + buf.WriteString(" ") + buf.WriteString(line) + buf.WriteRune('\n') + } + buf.WriteRune('\n') + exerrors.PanicIfNotNil(matrixExampleConfigBaseTemplate.Execute(&buf, br.Connector.GetName())) + return buf.String() +} diff --git a/bridgev2/matrix/mxmain/dberror.go b/bridgev2/matrix/mxmain/dberror.go new file mode 100644 index 00000000..f5e438de --- /dev/null +++ b/bridgev2/matrix/mxmain/dberror.go @@ -0,0 +1,79 @@ +// Copyright (c) 2024 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package mxmain + +import ( + "errors" + "os" + + "github.com/lib/pq" + "github.com/mattn/go-sqlite3" + "github.com/rs/zerolog" + + "go.mau.fi/util/dbutil" +) + +type zerologPQError pq.Error + +func (zpe *zerologPQError) MarshalZerologObject(evt *zerolog.Event) { + maybeStr := func(field, value string) { + if value != "" { + evt.Str(field, value) + } + } + maybeStr("severity", zpe.Severity) + if name := zpe.Code.Name(); name != "" { + evt.Str("code", name) + } else if zpe.Code != "" { + evt.Str("code", string(zpe.Code)) + } + //maybeStr("message", zpe.Message) + maybeStr("detail", zpe.Detail) + maybeStr("hint", zpe.Hint) + maybeStr("position", zpe.Position) + maybeStr("internal_position", zpe.InternalPosition) + maybeStr("internal_query", zpe.InternalQuery) + maybeStr("where", zpe.Where) + maybeStr("schema", zpe.Schema) + maybeStr("table", zpe.Table) + maybeStr("column", zpe.Column) + maybeStr("data_type_name", zpe.DataTypeName) + maybeStr("constraint", zpe.Constraint) + maybeStr("file", zpe.File) + maybeStr("line", zpe.Line) + maybeStr("routine", zpe.Routine) +} + +func (br *BridgeMain) LogDBUpgradeErrorAndExit(name string, err error, message string) { + logEvt := br.Log.WithLevel(zerolog.FatalLevel). + Err(err). + Str("db_section", name) + var errWithLine *dbutil.PQErrorWithLine + if errors.As(err, &errWithLine) { + logEvt.Str("sql_line", errWithLine.Line) + } + var pqe *pq.Error + if errors.As(err, &pqe) { + logEvt.Object("pq_error", (*zerologPQError)(pqe)) + } + logEvt.Msg(message) + if sqlError := (&sqlite3.Error{}); errors.As(err, sqlError) && sqlError.Code == sqlite3.ErrCorrupt { + os.Exit(18) + } else if errors.Is(err, dbutil.ErrForeignTables) { + br.Log.Info().Msg("See https://docs.mau.fi/faq/foreign-tables for more info") + } else if errors.Is(err, dbutil.ErrNotOwned) { + var noe dbutil.NotOwnedError + if errors.As(err, &noe) && noe.Owner == br.Name { + br.Log.Info().Msg("The database appears to be on a very old pre-megabridge schema. Perhaps you need to run an older version of the bridge with migration support first?") + } else { + br.Log.Info().Msg("Sharing the same database with different programs is not supported") + } + } else if errors.Is(err, dbutil.ErrUnsupportedDatabaseVersion) { + br.Log.Info().Msg("Downgrading the bridge is not supported") + } + os.Exit(15) +} diff --git a/bridgev2/matrix/mxmain/envconfig.go b/bridgev2/matrix/mxmain/envconfig.go new file mode 100644 index 00000000..1b4f1467 --- /dev/null +++ b/bridgev2/matrix/mxmain/envconfig.go @@ -0,0 +1,161 @@ +// Copyright (c) 2025 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package mxmain + +import ( + "fmt" + "iter" + "os" + "reflect" + "strconv" + "strings" + + "go.mau.fi/util/random" +) + +var randomParseFilePrefix = random.String(16) + "READFILE:" + +func parseEnv(prefix string) iter.Seq2[[]string, string] { + return func(yield func([]string, string) bool) { + for _, s := range os.Environ() { + if !strings.HasPrefix(s, prefix) { + continue + } + kv := strings.SplitN(s, "=", 2) + key := strings.TrimPrefix(kv[0], prefix) + value := kv[1] + if strings.HasSuffix(key, "_FILE") { + key = strings.TrimSuffix(key, "_FILE") + value = randomParseFilePrefix + value + } + key = strings.ToLower(key) + if !strings.ContainsRune(key, '.') { + key = strings.ReplaceAll(key, "__", ".") + } + if !yield(strings.Split(key, "."), value) { + return + } + } + } +} + +func reflectYAMLFieldName(f *reflect.StructField) string { + parts := strings.SplitN(f.Tag.Get("yaml"), ",", 2) + fieldName := parts[0] + if fieldName == "-" && len(parts) == 1 { + return "" + } + if fieldName == "" { + return strings.ToLower(f.Name) + } + return fieldName +} + +type reflectGetResult struct { + val reflect.Value + valKind reflect.Kind + remainingPath []string +} + +func reflectGetYAML(rv reflect.Value, path []string) (*reflectGetResult, bool) { + if len(path) == 0 { + return &reflectGetResult{val: rv, valKind: rv.Kind()}, true + } + if rv.Kind() == reflect.Ptr { + rv = rv.Elem() + } + switch rv.Kind() { + case reflect.Map: + return &reflectGetResult{val: rv, remainingPath: path, valKind: rv.Type().Elem().Kind()}, true + case reflect.Struct: + fields := reflect.VisibleFields(rv.Type()) + for _, field := range fields { + fieldName := reflectYAMLFieldName(&field) + if fieldName != "" && fieldName == path[0] { + return reflectGetYAML(rv.FieldByIndex(field.Index), path[1:]) + } + } + default: + } + return nil, false +} + +func reflectGetFromMainOrNetwork(main, network reflect.Value, path []string) (*reflectGetResult, bool) { + if len(path) > 0 && path[0] == "network" { + return reflectGetYAML(network, path[1:]) + } + return reflectGetYAML(main, path) +} + +func formatKeyString(key []string) string { + return strings.Join(key, "->") +} + +func UpdateConfigFromEnv(cfg, networkData any, prefix string) error { + cfgVal := reflect.ValueOf(cfg) + networkVal := reflect.ValueOf(networkData) + for key, value := range parseEnv(prefix) { + field, ok := reflectGetFromMainOrNetwork(cfgVal, networkVal, key) + if !ok { + return fmt.Errorf("%s not found", formatKeyString(key)) + } + if strings.HasPrefix(value, randomParseFilePrefix) { + filepath := strings.TrimPrefix(value, randomParseFilePrefix) + fileData, err := os.ReadFile(filepath) + if err != nil { + return fmt.Errorf("failed to read file %s for %s: %w", filepath, formatKeyString(key), err) + } + value = strings.TrimSpace(string(fileData)) + } + var parsedVal any + var err error + switch field.valKind { + case reflect.String: + parsedVal = value + case reflect.Bool: + parsedVal, err = strconv.ParseBool(value) + if err != nil { + return fmt.Errorf("invalid value for %s: %w", formatKeyString(key), err) + } + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + parsedVal, err = strconv.ParseInt(value, 10, 64) + if err != nil { + return fmt.Errorf("invalid value for %s: %w", formatKeyString(key), err) + } + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + parsedVal, err = strconv.ParseUint(value, 10, 64) + if err != nil { + return fmt.Errorf("invalid value for %s: %w", formatKeyString(key), err) + } + case reflect.Float32, reflect.Float64: + parsedVal, err = strconv.ParseFloat(value, 64) + if err != nil { + return fmt.Errorf("invalid value for %s: %w", formatKeyString(key), err) + } + default: + return fmt.Errorf("unsupported type %s in %s", field.valKind, formatKeyString(key)) + } + if field.val.Kind() == reflect.Ptr { + if field.val.IsNil() { + field.val.Set(reflect.New(field.val.Type().Elem())) + } + field.val = field.val.Elem() + } + if field.val.Kind() == reflect.Map { + key = key[:len(key)-len(field.remainingPath)] + mapKeyStr := strings.Join(field.remainingPath, ".") + key = append(key, mapKeyStr) + if field.val.Type().Key().Kind() != reflect.String { + return fmt.Errorf("unsupported map key type %s in %s", field.val.Type().Key().Kind(), formatKeyString(key)) + } + field.val.SetMapIndex(reflect.ValueOf(mapKeyStr), reflect.ValueOf(parsedVal)) + } else { + field.val.Set(reflect.ValueOf(parsedVal)) + } + } + return nil +} diff --git a/bridgev2/matrix/mxmain/example-config.yaml b/bridgev2/matrix/mxmain/example-config.yaml new file mode 100644 index 00000000..ccc81c4b --- /dev/null +++ b/bridgev2/matrix/mxmain/example-config.yaml @@ -0,0 +1,476 @@ +# Config options that affect the central bridge module. +bridge: + # The prefix for commands. Only required in non-management rooms. + command_prefix: '$<>' + # Should the bridge create a space for each login containing the rooms that account is in? + personal_filtering_spaces: true + # Whether the bridge should set names and avatars explicitly for DM portals. + # This is only necessary when using clients that don't support MSC4171. + private_chat_portal_meta: true + # Should events be handled asynchronously within portal rooms? + # If true, events may end up being out of order, but slow events won't block other ones. + # This is not yet safe to use. + async_events: false + # Should every user have their own portals rather than sharing them? + # By default, users who are in the same group on the remote network will be + # in the same Matrix room bridged to that group. If this is set to true, + # every user will get their own Matrix room instead. + # SETTING THIS IS IRREVERSIBLE AND POTENTIALLY DESTRUCTIVE IF PORTALS ALREADY EXIST. + split_portals: false + # Should the bridge resend `m.bridge` events to all portals on startup? + resend_bridge_info: false + # Should `m.bridge` events be sent without a state key? + # By default, the bridge uses a unique key that won't conflict with other bridges. + no_bridge_info_state_key: false + # Should bridge connection status be sent to the management room as `m.notice` events? + # These contain the same data that can be posted to an external HTTP server using homeserver -> status_endpoint. + # Allowed values: none, errors, all + bridge_status_notices: errors + # How long after an unknown error should the bridge attempt a full reconnect? + # Must be at least 1 minute. The bridge will add an extra ±20% jitter to this value. + unknown_error_auto_reconnect: null + # Maximum number of times to do the auto-reconnect above. + # The counter is per login, but is never reset except on logout and restart. + unknown_error_max_auto_reconnects: 10 + + # Should leaving Matrix rooms be bridged as leaving groups on the remote network? + bridge_matrix_leave: false + # Should `m.notice` messages be bridged? + bridge_notices: false + # Should room tags only be synced when creating the portal? Tags mean things like favorite/pin and archive/low priority. + # Tags currently can't be synced back to the remote network, so a continuous sync means tagging from Matrix will be undone. + tag_only_on_create: true + # List of tags to allow bridging. If empty, no tags will be bridged. + only_bridge_tags: [m.favourite, m.lowpriority] + # Should room mute status only be synced when creating the portal? + # Like tags, mutes can't currently be synced back to the remote network. + mute_only_on_create: true + # Should the bridge check the db to ensure that incoming events haven't been handled before + deduplicate_matrix_messages: false + # Should cross-room reply metadata be bridged? + # Most Matrix clients don't support this and servers may reject such messages too. + cross_room_replies: false + # If a state event fails to bridge, should the bridge revert any state changes made by that event? + revert_failed_state_changes: false + # In portals with no relay set, should Matrix users be kicked if they're + # not logged into an account that's in the remote chat? + kick_matrix_users: true + + # What should be done to portal rooms when a user logs out or is logged out? + # Permitted values: + # nothing - Do nothing, let the user stay in the portals + # kick - Remove the user from the portal rooms, but don't delete them + # unbridge - Remove all ghosts in the room and disassociate it from the remote chat + # delete - Remove all ghosts and users from the room (i.e. delete it) + cleanup_on_logout: + # Should cleanup on logout be enabled at all? + enabled: false + # Settings for manual logouts (explicitly initiated by the Matrix user) + manual: + # Action for private portals which will never be shared with other Matrix users. + private: nothing + # Action for portals with a relay user configured. + relayed: nothing + # Action for portals which may be shared, but don't currently have any other Matrix users. + shared_no_users: nothing + # Action for portals which have other logged-in Matrix users. + shared_has_users: nothing + # Settings for credentials being invalidated (initiated by the remote network, possibly through user action). + # Keys have the same meanings as in the manual section. + bad_credentials: + private: nothing + relayed: nothing + shared_no_users: nothing + shared_has_users: nothing + + # Settings for relay mode + relay: + # Whether relay mode should be allowed. If allowed, the set-relay command can be used to turn any + # authenticated user into a relaybot for that chat. + enabled: false + # Should only admins be allowed to set themselves as relay users? + # If true, non-admins can only set users listed in default_relays as relays in a room. + admin_only: true + # List of user login IDs which anyone can set as a relay, as long as the relay user is in the room. + default_relays: [] + # The formats to use when sending messages via the relaybot. + # Available variables: + # .Sender.UserID - The Matrix user ID of the sender. + # .Sender.Displayname - The display name of the sender (if set). + # .Sender.RequiresDisambiguation - Whether the sender's name may be confused with the name of another user in the room. + # .Sender.DisambiguatedName - The disambiguated name of the sender. This will be the displayname if set, + # plus the user ID in parentheses if the displayname is not unique. + # If the displayname is not set, this is just the user ID. + # .Message - The `formatted_body` field of the message. + # .Caption - The `formatted_body` field of the message, if it's a caption. Otherwise an empty string. + # .FileName - The name of the file being sent. + message_formats: + m.text: "{{ .Sender.DisambiguatedName }}: {{ .Message }}" + m.notice: "{{ .Sender.DisambiguatedName }}: {{ .Message }}" + m.emote: "* {{ .Sender.DisambiguatedName }} {{ .Message }}" + m.file: "{{ .Sender.DisambiguatedName }} sent a file{{ if .Caption }}: {{ .Caption }}{{ end }}" + m.image: "{{ .Sender.DisambiguatedName }} sent an image{{ if .Caption }}: {{ .Caption }}{{ end }}" + m.audio: "{{ .Sender.DisambiguatedName }} sent an audio file{{ if .Caption }}: {{ .Caption }}{{ end }}" + m.video: "{{ .Sender.DisambiguatedName }} sent a video{{ if .Caption }}: {{ .Caption }}{{ end }}" + m.location: "{{ .Sender.DisambiguatedName }} sent a location{{ if .Caption }}: {{ .Caption }}{{ end }}" + # For networks that support per-message displaynames (i.e. Slack and Discord), the template for those names. + # This has all the Sender variables available under message_formats (but without the .Sender prefix). + # Note that you need to manually remove the displayname from message_formats above. + displayname_format: "{{ .DisambiguatedName }}" + + # Permissions for using the bridge. + # Permitted values: + # relay - Talk through the relaybot (if enabled), no access otherwise + # commands - Access to use commands in the bridge, but not login. + # user - Access to use the bridge with puppeting. + # admin - Full access, user level with some additional administration tools. + # Permitted keys: + # * - All Matrix users + # domain - All users on that homeserver + # mxid - Specific user + permissions: + "*": relay + "example.com": user + "@admin:example.com": admin + +# Config for the bridge's database. +database: + # The database type. "sqlite3-fk-wal" and "postgres" are supported. + type: postgres + # The database URI. + # SQLite: A raw file path is supported, but `file:?_txlock=immediate` is recommended. + # https://github.com/mattn/go-sqlite3#connection-string + # Postgres: Connection string. For example, postgres://user:password@host/database?sslmode=disable + # To connect via Unix socket, use something like postgres:///dbname?host=/var/run/postgresql + uri: postgres://user:password@host/database?sslmode=disable + # Maximum number of connections. + max_open_conns: 5 + max_idle_conns: 1 + # Maximum connection idle time and lifetime before they're closed. Disabled if null. + # Parsed with https://pkg.go.dev/time#ParseDuration + max_conn_idle_time: null + max_conn_lifetime: null + +# Homeserver details. +homeserver: + # The address that this appservice can use to connect to the homeserver. + # Local addresses without HTTPS are generally recommended when the bridge is running on the same machine, + # but https also works if they run on different machines. + address: http://example.localhost:8008 + # The domain of the homeserver (also known as server_name, used for MXIDs, etc). + domain: example.com + + # What software is the homeserver running? + # Standard Matrix homeservers like Synapse, Dendrite and Conduit should just use "standard" here. + software: standard + # The URL to push real-time bridge status to. + # If set, the bridge will make POST requests to this URL whenever a user's remote network connection state changes. + # The bridge will use the appservice as_token to authorize requests. + status_endpoint: + # Endpoint for reporting per-message status. + # If set, the bridge will make POST requests to this URL when processing a message from Matrix. + # It will make one request when receiving the message (step BRIDGE), one after decrypting if applicable + # (step DECRYPTED) and one after sending to the remote network (step REMOTE). Errors will also be reported. + # The bridge will use the appservice as_token to authorize requests. + message_send_checkpoint_endpoint: + # Does the homeserver support https://github.com/matrix-org/matrix-spec-proposals/pull/2246? + async_media: false + + # Should the bridge use a websocket for connecting to the homeserver? + # The server side is currently not documented anywhere and is only implemented by mautrix-wsproxy, + # mautrix-asmux (deprecated), and hungryserv (proprietary). + websocket: false + # How often should the websocket be pinged? Pinging will be disabled if this is zero. + ping_interval_seconds: 0 + +# Application service host/registration related details. +# Changing these values requires regeneration of the registration (except when noted otherwise) +appservice: + # The address that the homeserver can use to connect to this appservice. + # Like the homeserver address, a local non-https address is recommended when the bridge is on the same machine. + # If the bridge is elsewhere, you must secure the connection yourself (e.g. with https or wireguard) + # If you want to use https, you need to use a reverse proxy. The bridge does not have TLS support built in. + address: http://localhost:$<> + # A public address that external services can use to reach this appservice. + # This is only needed for things like public media. A reverse proxy is generally necessary when using this field. + # This value doesn't affect the registration file. + public_address: https://bridge.example.com + + # The hostname and port where this appservice should listen. + # For Docker, you generally have to change the hostname to 0.0.0.0. + hostname: 127.0.0.1 + port: $<> + + # The unique ID of this appservice. + id: $<<.NetworkID>> + # Appservice bot details. + bot: + # Username of the appservice bot. + username: $<<.NetworkID>>bot + # Display name and avatar for bot. Set to "remove" to remove display name/avatar, leave empty + # to leave display name/avatar as-is. + displayname: $<<.DisplayName>> bridge bot + avatar: $<<.NetworkIcon>> + + # Whether to receive ephemeral events via appservice transactions. + ephemeral_events: true + # Should incoming events be handled asynchronously? + # This may be necessary for large public instances with lots of messages going through. + # However, messages will not be guaranteed to be bridged in the same order they were sent in. + # This value doesn't affect the registration file. + async_transactions: false + + # Authentication tokens for AS <-> HS communication. Autogenerated; do not modify. + as_token: "This value is generated when generating the registration" + hs_token: "This value is generated when generating the registration" + + # Localpart template of MXIDs for remote users. + # {{.}} is replaced with the internal ID of the user. + username_template: $<<.NetworkID>>_{{.}} + +# Config options that affect the Matrix connector of the bridge. +matrix: + # Whether the bridge should send the message status as a custom com.beeper.message_send_status event. + message_status_events: false + # Whether the bridge should send a read receipt after successfully bridging a message. + delivery_receipts: false + # Whether the bridge should send error notices via m.notice events when a message fails to bridge. + message_error_notices: true + # Whether the bridge should update the m.direct account data event when double puppeting is enabled. + sync_direct_chat_list: true + # Whether created rooms should have federation enabled. If false, created portal rooms + # will never be federated. Changing this option requires recreating rooms. + federate_rooms: true + # The threshold as bytes after which the bridge should roundtrip uploads via the disk + # rather than keeping the whole file in memory. + upload_file_threshold: 5242880 + # Should the bridge set additional custom profile info for ghosts? + # This can make a lot of requests, as there's no batch profile update endpoint. + ghost_extra_profile_info: false + +# Segment-compatible analytics endpoint for tracking some events, like provisioning API login and encryption errors. +analytics: + # API key to send with tracking requests. Tracking is disabled if this is null. + token: null + # Address to send tracking requests to. + url: https://api.segment.io/v1/track + # Optional user ID for tracking events. If null, defaults to using Matrix user ID. + user_id: null + +# Settings for provisioning API +provisioning: + # Shared secret for authentication. If set to "generate" or null, a random secret will be generated, + # or if set to "disable", the provisioning API will be disabled. Must be at least 16 characters. + shared_secret: generate + # Whether to allow provisioning API requests to be authed using Matrix access tokens. + # This follows the same rules as double puppeting to determine which server to contact to check the token, + # which means that by default, it only works for users on the same server as the bridge. + allow_matrix_auth: true + # Enable debug API at /debug with provisioning authentication. + debug_endpoints: false + # Enable session transfers between bridges. Note that this only validates Matrix or shared secret + # auth before passing live network client credentials down in the response. + enable_session_transfers: false + +# Some networks require publicly accessible media download links (e.g. for user avatars when using Discord webhooks). +# These settings control whether the bridge will provide such public media access. +public_media: + # Should public media be enabled at all? + # The public_address field under the appservice section MUST be set when enabling public media. + enabled: false + # A key for signing public media URLs. + # If set to "generate", a random key will be generated. + signing_key: generate + # Number of seconds that public media URLs are valid for. + # If set to 0, URLs will never expire. + expiry: 0 + # Length of hash to use for public media URLs. Must be between 0 and 32. + hash_length: 32 + # The path prefix for generated URLs. Note that this will NOT change the path where media is actually served. + # If you change this, you must configure your reverse proxy to rewrite the path accordingly. + path_prefix: /_mautrix/publicmedia + # Should the bridge store media metadata in the database in order to support encrypted media and generate shorter URLs? + # If false, the generated URLs will just have the MXC URI and a HMAC signature. + # The hash_length field will be used to decide the length of the generated URL. + # This also allows invalidating URLs by deleting the database entry. + use_database: false + +# Settings for converting remote media to custom mxc:// URIs instead of reuploading. +# More details can be found at https://docs.mau.fi/bridges/go/discord/direct-media.html +direct_media: + # Should custom mxc:// URIs be used instead of reuploading media? + enabled: false + # The server name to use for the custom mxc:// URIs. + # This server name will effectively be a real Matrix server, it just won't implement anything other than media. + # You must either set up .well-known delegation from this domain to the bridge, or proxy the domain directly to the bridge. + server_name: discord-media.example.com + # Optionally a custom .well-known response. This defaults to `server_name:443` + well_known_response: + # Optionally specify a custom prefix for the media ID part of the MXC URI. + media_id_prefix: + # If the remote network supports media downloads over HTTP, then the bridge will use MSC3860/MSC3916 + # media download redirects if the requester supports it. Optionally, you can force redirects + # and not allow proxying at all by setting this to false. + # This option does nothing if the remote network does not support media downloads over HTTP. + allow_proxy: true + # Matrix server signing key to make the federation tester pass, same format as synapse's .signing.key file. + # This key is also used to sign the mxc:// URIs to ensure only the bridge can generate them. + server_key: generate + +# Settings for backfilling messages. +# Note that the exact way settings are applied depends on the network connector. +# See https://docs.mau.fi/bridges/general/backfill.html for more details. +backfill: + # Whether to do backfilling at all. + enabled: false + # Maximum number of messages to backfill in empty rooms. + max_initial_messages: 50 + # Maximum number of missed messages to backfill after bridge restarts. + max_catchup_messages: 500 + # If a backfilled chat is older than this number of hours, + # mark it as read even if it's unread on the remote network. + unread_hours_threshold: 720 + # Settings for backfilling threads within other backfills. + threads: + # Maximum number of messages to backfill in a new thread. + max_initial_messages: 50 + # Settings for the backwards backfill queue. This only applies when connecting to + # Beeper as standard Matrix servers don't support inserting messages into history. + queue: + # Should the backfill queue be enabled? + enabled: false + # Number of messages to backfill in one batch. + batch_size: 100 + # Delay between batches in seconds. + batch_delay: 20 + # Maximum number of batches to backfill per portal. + # If set to -1, all available messages will be backfilled. + max_batches: -1 + # Optional network-specific overrides for max batches. + # Interpretation of this field depends on the network connector. + max_batches_override: {} + +# Settings for enabling double puppeting +double_puppet: + # Servers to always allow double puppeting from. + # This is only for other servers and should NOT contain the server the bridge is on. + servers: + anotherserver.example.org: https://matrix.anotherserver.example.org + # Whether to allow client API URL discovery for other servers. When using this option, + # users on other servers can use double puppeting even if their server URLs aren't + # explicitly added to the servers map above. + allow_discovery: false + # Shared secrets for automatic double puppeting. + # See https://docs.mau.fi/bridges/general/double-puppeting.html for instructions. + secrets: + example.com: as_token:foobar + +# End-to-bridge encryption support options. +# +# See https://docs.mau.fi/bridges/general/end-to-bridge-encryption.html for more info. +encryption: + # Whether to enable encryption at all. If false, the bridge will not function in encrypted rooms. + allow: false + # Whether to force-enable encryption in all bridged rooms. + default: false + # Whether to require all messages to be encrypted and drop any unencrypted messages. + require: false + # Whether to use MSC3202/MSC4203 instead of /sync long polling for receiving encryption-related data. + # This option is not yet compatible with standard Matrix servers like Synapse and should not be used. + # Changing this option requires updating the appservice registration file. + appservice: false + # Whether to use MSC4190 instead of appservice login to create the bridge bot device. + # Requires the homeserver to support MSC4190 and the device masquerading parts of MSC3202. + # Only relevant when using end-to-bridge encryption, required when using encryption with next-gen auth (MSC3861). + # Changing this option requires updating the appservice registration file. + msc4190: false + # Whether to encrypt reactions and reply metadata as per MSC4392. + msc4392: false + # Should the bridge bot generate a recovery key and cross-signing keys and verify itself? + # Note that without the latest version of MSC4190, this will fail if you reset the bridge database. + # The generated recovery key will be saved in the kv_store table under `recovery_key`. + self_sign: false + # Enable key sharing? If enabled, key requests for rooms where users are in will be fulfilled. + # You must use a client that supports requesting keys from other users to use this feature. + allow_key_sharing: true + # Pickle key for encrypting encryption keys in the bridge database. + # If set to generate, a random key will be generated. + pickle_key: generate + # Options for deleting megolm sessions from the bridge. + delete_keys: + # Beeper-specific: delete outbound sessions when hungryserv confirms + # that the user has uploaded the key to key backup. + delete_outbound_on_ack: false + # Don't store outbound sessions in the inbound table. + dont_store_outbound: false + # Ratchet megolm sessions forward after decrypting messages. + ratchet_on_decrypt: false + # Delete fully used keys (index >= max_messages) after decrypting messages. + delete_fully_used_on_decrypt: false + # Delete previous megolm sessions from same device when receiving a new one. + delete_prev_on_new_session: false + # Delete megolm sessions received from a device when the device is deleted. + delete_on_device_delete: false + # Periodically delete megolm sessions when 2x max_age has passed since receiving the session. + periodically_delete_expired: false + # Delete inbound megolm sessions that don't have the received_at field used for + # automatic ratcheting and expired session deletion. This is meant as a migration + # to delete old keys prior to the bridge update. + delete_outdated_inbound: false + # What level of device verification should be required from users? + # + # Valid levels: + # unverified - Send keys to all device in the room. + # cross-signed-untrusted - Require valid cross-signing, but trust all cross-signing keys. + # cross-signed-tofu - Require valid cross-signing, trust cross-signing keys on first use (and reject changes). + # cross-signed-verified - Require valid cross-signing, plus a valid user signature from the bridge bot. + # Note that creating user signatures from the bridge bot is not currently possible. + # verified - Require manual per-device verification + # (currently only possible by modifying the `trust` column in the `crypto_device` database table). + verification_levels: + # Minimum level for which the bridge should send keys to when bridging messages from the remote network to Matrix. + receive: unverified + # Minimum level that the bridge should accept for incoming Matrix messages. + send: unverified + # Minimum level that the bridge should require for accepting key requests. + share: cross-signed-tofu + # Options for Megolm room key rotation. These options allow you to configure the m.room.encryption event content. + # See https://spec.matrix.org/v1.10/client-server-api/#mroomencryption for more information about that event. + rotation: + # Enable custom Megolm room key rotation settings. Note that these + # settings will only apply to rooms created after this option is set. + enable_custom: false + # The maximum number of milliseconds a session should be used + # before changing it. The Matrix spec recommends 604800000 (a week) + # as the default. + milliseconds: 604800000 + # The maximum number of messages that should be sent with a given a + # session before changing it. The Matrix spec recommends 100 as the + # default. + messages: 100 + # Disable rotating keys when a user's devices change? + # You should not enable this option unless you understand all the implications. + disable_device_change_key_rotation: false + +# Prefix for environment variables. All variables with this prefix must map to valid config fields. +# Nesting in variable names is represented with a dot (.). +# If there are no dots in the name, two underscores (__) are replaced with a dot. +# +# e.g. if the prefix is set to `BRIDGE_`, then `BRIDGE_APPSERVICE__AS_TOKEN` will set appservice.as_token. +# `BRIDGE_appservice.as_token` would work as well, but can't be set in a shell as easily. +# +# If this is null, reading config fields from environment will be disabled. +env_config_prefix: null + +# Logging config. See https://github.com/tulir/zeroconfig for details. +logging: + min_level: debug + writers: + - type: stdout + format: pretty-colored + - type: file + format: json + filename: ./logs/bridge.log + max_size: 100 + max_backups: 10 + compress: false diff --git a/bridgev2/matrix/mxmain/legacymigrate.go b/bridgev2/matrix/mxmain/legacymigrate.go new file mode 100644 index 00000000..97cdeddf --- /dev/null +++ b/bridgev2/matrix/mxmain/legacymigrate.go @@ -0,0 +1,263 @@ +// Copyright (c) 2024 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package mxmain + +import ( + "bytes" + "context" + "database/sql" + "errors" + "fmt" + + "github.com/rs/zerolog" + "go.mau.fi/util/dbutil" + + "maunium.net/go/mautrix/appservice" + "maunium.net/go/mautrix/bridgev2" + "maunium.net/go/mautrix/bridgev2/database" + "maunium.net/go/mautrix/bridgev2/matrix" + "maunium.net/go/mautrix/event" + "maunium.net/go/mautrix/id" +) + +func (br *BridgeMain) LegacyMigrateWithAnotherUpgrader(renameTablesQuery, copyDataQuery string, newDBVersion int, otherTable dbutil.UpgradeTable, otherTableName string, otherNewVersion int) func(ctx context.Context) error { + return func(ctx context.Context) error { + // Unique constraints must have globally unique names on postgres, and renaming the table doesn't rename them, + // so just drop the ones that may conflict with the new schema. + if br.DB.Dialect == dbutil.Postgres { + _, err := br.DB.Exec(ctx, "ALTER TABLE message DROP CONSTRAINT IF EXISTS message_mxid_unique") + if err != nil { + return fmt.Errorf("failed to drop potentially conflicting constraint on message: %w", err) + } + _, err = br.DB.Exec(ctx, "ALTER TABLE reaction DROP CONSTRAINT IF EXISTS reaction_mxid_unique") + if err != nil { + return fmt.Errorf("failed to drop potentially conflicting constraint on reaction: %w", err) + } + } + err := dbutil.DangerousInternalUpgradeVersionTable(ctx, br.DB) + if err != nil { + return err + } + _, err = br.DB.Exec(ctx, renameTablesQuery) + if err != nil { + return err + } + upgradesTo, compat, err := br.DB.UpgradeTable[0].DangerouslyRun(ctx, br.DB) + if err != nil { + return err + } + if upgradesTo < newDBVersion || compat > newDBVersion { + return fmt.Errorf("unexpected new database version (%d/c:%d, expected %d)", upgradesTo, compat, newDBVersion) + } + if otherTable != nil { + _, err = br.DB.Exec(ctx, fmt.Sprintf("CREATE TABLE %s (version INTEGER, compat INTEGER)", otherTableName)) + if err != nil { + return err + } + otherUpgradesTo, otherCompat, err := otherTable[0].DangerouslyRun(ctx, br.DB) + if err != nil { + return err + } else if otherUpgradesTo < otherNewVersion || otherCompat > otherNewVersion { + return fmt.Errorf("unexpected new database version for %s (%d/c:%d, expected %d)", otherTableName, otherUpgradesTo, otherCompat, otherNewVersion) + } + _, err = br.DB.Exec(ctx, fmt.Sprintf("INSERT INTO %s (version, compat) VALUES ($1, $2)", otherTableName), otherUpgradesTo, otherCompat) + if err != nil { + return err + } + } + copyDataQuery, err = br.DB.Internals().FilterSQLUpgrade(bytes.Split([]byte(copyDataQuery), []byte("\n"))) + if err != nil { + return err + } + _, err = br.DB.Exec(ctx, copyDataQuery) + if err != nil { + return err + } + _, err = br.DB.Exec(ctx, "DELETE FROM database_owner") + if err != nil { + return err + } + _, err = br.DB.Exec(ctx, "INSERT INTO database_owner (key, owner) VALUES (0, $1)", br.DB.Owner) + if err != nil { + return err + } + _, err = br.DB.Exec(ctx, "DELETE FROM version") + if err != nil { + return err + } + _, err = br.DB.Exec(ctx, "INSERT INTO version (version, compat) VALUES ($1, $2)", upgradesTo, compat) + if err != nil { + return err + } + _, err = br.DB.Exec(ctx, "CREATE TABLE database_was_migrated(empty INTEGER)") + if err != nil { + return err + } + + return nil + } +} + +func (br *BridgeMain) LegacyMigrateSimple(renameTablesQuery, copyDataQuery string, newDBVersion int) func(ctx context.Context) error { + return br.LegacyMigrateWithAnotherUpgrader(renameTablesQuery, copyDataQuery, newDBVersion, nil, "", 0) +} + +func (br *BridgeMain) CheckLegacyDB( + expectedVersion int, + minBridgeVersion, + firstMegaVersion string, + migrator func(context.Context) error, + transaction bool, +) { + log := br.Log.With().Str("action", "migrate legacy db").Logger() + ctx := log.WithContext(context.Background()) + exists, err := br.DB.TableExists(ctx, "database_owner") + if err != nil { + log.Err(err).Msg("Failed to check if database_owner table exists") + return + } else if !exists { + return + } + var owner string + err = br.DB.QueryRow(ctx, "SELECT owner FROM database_owner LIMIT 1").Scan(&owner) + if err != nil && !errors.Is(err, sql.ErrNoRows) { + log.Err(err).Msg("Failed to get database owner") + return + } else if owner != br.Name { + if owner != "megabridge/"+br.Name && owner != "" { + log.Warn().Str("db_owner", owner).Msg("Unexpected database owner, not migrating database") + } + return + } + var dbVersion int + err = br.DB.QueryRow(ctx, "SELECT version FROM version").Scan(&dbVersion) + if err != nil { + log.Fatal().Err(err).Msg("Failed to get database version") + return + } else if dbVersion < expectedVersion { + log.Fatal(). + Int("expected_version", expectedVersion). + Int("version", dbVersion). + Msgf("Unsupported database version. Please upgrade to %s %s or higher before upgrading to %s.", br.Name, minBridgeVersion, firstMegaVersion) // zerolog-allow-msgf + return + } else if dbVersion > expectedVersion { + log.Fatal(). + Int("expected_version", expectedVersion). + Int("version", dbVersion). + Msg("Unsupported database version (higher than expected)") + return + } + log.Info().Msg("Detected legacy database, migrating...") + if transaction { + err = br.DB.DoTxn(ctx, nil, migrator) + } else { + err = migrator(ctx) + } + if err != nil { + br.LogDBUpgradeErrorAndExit("main", err, "Failed to migrate legacy database") + } else { + log.Info().Msg("Successfully migrated legacy database") + } +} + +func (br *BridgeMain) postMigrateDMPortal(ctx context.Context, portal *bridgev2.Portal) error { + otherUserID := portal.OtherUserID + if otherUserID == "" { + zerolog.Ctx(ctx).Warn(). + Str("portal_id", string(portal.ID)). + Msg("DM portal has no other user ID") + return nil + } + ghost, err := br.Bridge.GetGhostByID(ctx, otherUserID) + if err != nil { + return fmt.Errorf("failed to get ghost for %s: %w", otherUserID, err) + } + mx := ghost.Intent.(*matrix.ASIntent).Matrix + err = br.Matrix.Bot.EnsureJoined(ctx, portal.MXID, appservice.EnsureJoinedParams{ + BotOverride: mx.Client, + }) + if err != nil { + zerolog.Ctx(ctx).Err(err). + Str("portal_id", string(portal.ID)). + Stringer("room_id", portal.MXID). + Msg("Failed to ensure bot is joined to DM") + } + pls, err := mx.PowerLevels(ctx, portal.MXID) + if err != nil { + zerolog.Ctx(ctx).Err(err). + Str("portal_id", string(portal.ID)). + Stringer("room_id", portal.MXID). + Msg("Failed to get power levels in room") + } else { + userLevel := pls.GetUserLevel(mx.UserID) + pls.EnsureUserLevel(br.Matrix.Bot.UserID, userLevel) + if userLevel > 50 { + pls.SetUserLevel(mx.UserID, 50) + } + _, err = mx.SetPowerLevels(ctx, portal.MXID, pls) + if err != nil { + zerolog.Ctx(ctx).Err(err). + Str("portal_id", string(portal.ID)). + Stringer("room_id", portal.MXID). + Msg("Failed to set power levels") + } + } + portal.UpdateInfoFromGhost(ctx, ghost) + return nil +} + +func (br *BridgeMain) PostMigrate(ctx context.Context) error { + log := br.Log.With().Str("action", "post-migrate").Logger() + wasMigrated, err := br.DB.TableExists(ctx, "database_was_migrated") + if err != nil { + return fmt.Errorf("failed to check if database_was_migrated table exists: %w", err) + } else if !wasMigrated { + return nil + } + log.Info().Msg("Doing post-migration updates to Matrix rooms") + + portals, err := br.Bridge.GetAllPortalsWithMXID(ctx) + if err != nil { + return fmt.Errorf("failed to get all portals: %w", err) + } + for _, portal := range portals { + log := log.With(). + Stringer("room_id", portal.MXID). + Object("portal_key", portal.PortalKey). + Str("room_type", string(portal.RoomType)). + Logger() + log.Debug().Msg("Migrating portal") + if br.PostMigratePortal != nil { + err = br.PostMigratePortal(ctx, portal) + if err != nil { + log.Err(err).Msg("Failed to run post-migrate portal hook") + continue + } + } else { + switch portal.RoomType { + case database.RoomTypeDM: + err = br.postMigrateDMPortal(ctx, portal) + if err != nil { + return fmt.Errorf("failed to update DM portal %s: %w", portal.MXID, err) + } + } + } + _, err = br.Matrix.Bot.SendStateEvent(ctx, portal.MXID, event.StateElementFunctionalMembers, "", &event.ElementFunctionalMembersContent{ + ServiceMembers: []id.UserID{br.Matrix.Bot.UserID}, + }) + if err != nil { + log.Warn().Err(err).Stringer("room_id", portal.MXID).Msg("Failed to set service members") + } + } + + _, err = br.DB.Exec(ctx, "DROP TABLE database_was_migrated") + if err != nil { + return fmt.Errorf("failed to drop database_was_migrated table: %w", err) + } + log.Info().Msg("Post-migration updates complete") + return nil +} diff --git a/bridgev2/matrix/mxmain/main.go b/bridgev2/matrix/mxmain/main.go new file mode 100644 index 00000000..1e8b51d1 --- /dev/null +++ b/bridgev2/matrix/mxmain/main.go @@ -0,0 +1,450 @@ +// Copyright (c) 2024 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +// Package mxmain contains initialization code for a single-network Matrix bridge using the bridgev2 package. +package mxmain + +import ( + "context" + _ "embed" + "encoding/json" + "errors" + "fmt" + "os" + "os/signal" + "runtime" + "strings" + "syscall" + "time" + + "github.com/mattn/go-sqlite3" + "github.com/rs/zerolog" + "go.mau.fi/util/configupgrade" + "go.mau.fi/util/dbutil" + "go.mau.fi/util/exerrors" + "go.mau.fi/util/exzerolog" + "go.mau.fi/util/progver" + "gopkg.in/yaml.v3" + flag "maunium.net/go/mauflag" + + "maunium.net/go/mautrix" + "maunium.net/go/mautrix/bridgev2" + "maunium.net/go/mautrix/bridgev2/bridgeconfig" + "maunium.net/go/mautrix/bridgev2/commands" + "maunium.net/go/mautrix/bridgev2/matrix" +) + +var configPath = flag.MakeFull("c", "config", "The path to your config file.", "config.yaml").String() +var writeExampleConfig = flag.MakeFull("e", "generate-example-config", "Save the example config to the config path and quit.", "false").Bool() +var dontSaveConfig = flag.MakeFull("n", "no-update", "Don't save updated config to disk.", "false").Bool() +var registrationPath = flag.MakeFull("r", "registration", "The path where to save the appservice registration.", "registration.yaml").String() +var generateRegistration = flag.MakeFull("g", "generate-registration", "Generate registration and quit.", "false").Bool() +var version = flag.MakeFull("v", "version", "View bridge version and quit.", "false").Bool() +var versionJSON = flag.Make().LongKey("version-json").Usage("Print a JSON object representing the bridge version and quit.").Default("false").Bool() +var ignoreUnsupportedDatabase = flag.Make().LongKey("ignore-unsupported-database").Usage("Run even if the database schema is too new").Default("false").Bool() +var ignoreForeignTables = flag.Make().LongKey("ignore-foreign-tables").Usage("Run even if the database contains tables from other programs (like Synapse)").Default("false").Bool() +var ignoreUnsupportedServer = flag.Make().LongKey("ignore-unsupported-server").Usage("Run even if the Matrix homeserver is outdated").Default("false").Bool() +var wantHelp, _ = flag.MakeHelpFlag() + +// BridgeMain contains the main function for a Matrix bridge. +type BridgeMain struct { + // Name is the name of the bridge project, e.g. mautrix-signal. + // Note that when making your own bridges that isn't under github.com/mautrix, + // you should invent your own name and not use the mautrix-* naming scheme. + Name string + // Description is a brief description of the bridge, usually of the form "A Matrix-OtherPlatform puppeting bridge." + Description string + // URL is the Git repository address for the bridge. + URL string + // Version is the latest release of the bridge. InitVersion will compare this to the provided + // git tag to see if the built version is the release or a dev build. + // You can either bump this right after a release or right before, as long as it matches on the release commit. + Version string + // SemCalVer defines whether this bridge uses a mix of semantic and calendar versioning, + // such that the Version field is YY.0M.patch, while git tags are major.YY0M.patch. + SemCalVer bool + + // PostInit is a function that will be called after the bridge has been initialized but before it is started. + PostInit func() + PostStart func() + + // PostMigratePortal is a function that will be called during a legacy + // migration for each portal. + PostMigratePortal func(context.Context, *bridgev2.Portal) error + + // Connector is the network connector for the bridge. + Connector bridgev2.NetworkConnector + + // All fields below are set automatically in Run or InitVersion should not be set manually. + + Log *zerolog.Logger + DB *dbutil.Database + Config *bridgeconfig.Config + Matrix *matrix.Connector + Bridge *bridgev2.Bridge + + ConfigPath string + RegistrationPath string + SaveConfig bool + + ver progver.ProgramVersion + + AdditionalShortFlags string + AdditionalLongFlags string + + manualStop chan int +} + +type VersionJSONOutput struct { + progver.ProgramVersion + + OS string + Arch string + + Mautrix struct { + Version string + Commit string + } +} + +// Run runs the bridge and waits for SIGTERM before stopping. +func (br *BridgeMain) Run() { + br.PreInit() + br.Init() + br.Start() + exitCode := br.WaitForInterrupt() + br.Stop() + os.Exit(exitCode) +} + +// PreInit parses CLI flags and loads the config file. This is called by [Run] and does not need to be called manually. +// +// This also handles all flags that cause the bridge to exit immediately (e.g. `--version` and `--generate-registration`). +func (br *BridgeMain) PreInit() { + br.manualStop = make(chan int, 1) + flag.SetHelpTitles( + fmt.Sprintf("%s - %s", br.Name, br.Description), + fmt.Sprintf("%s [-hgvn%s] [-c ] [-r ]%s", br.Name, br.AdditionalShortFlags, br.AdditionalLongFlags)) + err := flag.Parse() + br.ConfigPath = *configPath + br.RegistrationPath = *registrationPath + br.SaveConfig = !*dontSaveConfig + if err != nil { + _, _ = fmt.Fprintln(os.Stderr, err) + flag.PrintHelp() + os.Exit(1) + } else if *wantHelp { + flag.PrintHelp() + os.Exit(0) + } else if *version { + fmt.Println(br.ver.VersionDescription) + os.Exit(0) + } else if *versionJSON { + output := VersionJSONOutput{ + ProgramVersion: br.ver, + + OS: runtime.GOOS, + Arch: runtime.GOARCH, + } + output.Mautrix.Commit = mautrix.Commit + output.Mautrix.Version = mautrix.Version + _ = json.NewEncoder(os.Stdout).Encode(output) + os.Exit(0) + } else if *writeExampleConfig { + if *configPath != "-" && *configPath != "/dev/stdout" && *configPath != "/dev/stderr" { + if _, err = os.Stat(*configPath); !errors.Is(err, os.ErrNotExist) { + _, _ = fmt.Fprintln(os.Stderr, *configPath, "already exists, please remove it if you want to generate a new example") + os.Exit(1) + } + } + networkExample, _, _ := br.Connector.GetConfig() + fullCfg := br.makeFullExampleConfig(networkExample) + if *configPath == "-" { + fmt.Print(fullCfg) + } else { + exerrors.PanicIfNotNil(os.WriteFile(*configPath, []byte(fullCfg), 0600)) + fmt.Println("Wrote example config to", *configPath) + } + os.Exit(0) + } + br.LoadConfig() + if *generateRegistration { + br.GenerateRegistration() + os.Exit(0) + } +} + +func (br *BridgeMain) GenerateRegistration() { + if !br.SaveConfig { + // We need to save the generated as_token and hs_token in the config + _, _ = fmt.Fprintln(os.Stderr, "--no-update is not compatible with --generate-registration") + os.Exit(5) + } else if br.Config.Homeserver.Domain == "example.com" { + _, _ = fmt.Fprintln(os.Stderr, "Homeserver domain is not set") + os.Exit(20) + } + reg := br.Config.GenerateRegistration() + err := reg.Save(br.RegistrationPath) + if err != nil { + _, _ = fmt.Fprintln(os.Stderr, "Failed to save registration:", err) + os.Exit(21) + } + + updateTokens := func(helper configupgrade.Helper) { + helper.Set(configupgrade.Str, reg.AppToken, "appservice", "as_token") + helper.Set(configupgrade.Str, reg.ServerToken, "appservice", "hs_token") + } + upgrader, _ := br.getConfigUpgrader() + _, _, err = configupgrade.Do(br.ConfigPath, true, upgrader, configupgrade.SimpleUpgrader(updateTokens)) + if err != nil { + _, _ = fmt.Fprintln(os.Stderr, "Failed to save config:", err) + os.Exit(22) + } + fmt.Println("Registration generated. See https://docs.mau.fi/bridges/general/registering-appservices.html for instructions on installing the registration.") + os.Exit(0) +} + +// Init sets up logging, database connection and creates the Matrix connector and central Bridge struct. +// This is called by [Run] and does not need to be called manually. +func (br *BridgeMain) Init() { + var err error + br.Log, err = br.Config.Logging.Compile() + if err != nil { + _, _ = fmt.Fprintln(os.Stderr, "Failed to initialize logger:", err) + os.Exit(12) + } + exzerolog.SetupDefaults(br.Log) + err = br.validateConfig() + if err != nil { + br.Log.WithLevel(zerolog.FatalLevel).Err(err).Msg("Configuration error") + br.Log.Info().Msg("See https://docs.mau.fi/faq/field-unconfigured for more info") + os.Exit(11) + } + + br.Log.Info(). + Str("name", br.Name). + Str("version", br.ver.FormattedVersion). + Time("built_at", br.ver.BuildTime). + Str("go_version", runtime.Version()). + Msg("Initializing bridge") + + br.initDB() + br.Matrix = matrix.NewConnector(br.Config) + br.Matrix.OnWebsocketReplaced = func() { + br.TriggerStop(0) + } + br.Matrix.IgnoreUnsupportedServer = *ignoreUnsupportedServer + br.Bridge = bridgev2.NewBridge("", br.DB, *br.Log, &br.Config.Bridge, br.Matrix, br.Connector, commands.NewProcessor) + br.Matrix.AS.DoublePuppetValue = br.Name + br.Bridge.Commands.(*commands.Processor).AddHandler(&commands.FullHandler{ + Func: func(ce *commands.Event) { + ce.Reply(br.ver.MarkdownDescription()) + }, + Name: "version", + Help: commands.HelpMeta{ + Section: commands.HelpSectionGeneral, + Description: "Get the bridge version.", + }, + }) + if br.PostInit != nil { + br.PostInit() + } +} + +func (br *BridgeMain) initDB() { + br.Log.Debug().Msg("Initializing database connection") + dbConfig := br.Config.Database + if dbConfig.Type == "sqlite3" { + br.Log.WithLevel(zerolog.FatalLevel).Msg("Invalid database type sqlite3. Use sqlite3-fk-wal instead.") + os.Exit(14) + } + if (dbConfig.Type == "sqlite3-fk-wal" || dbConfig.Type == "litestream") && dbConfig.MaxOpenConns != 1 && !strings.Contains(dbConfig.URI, "_txlock=immediate") { + var fixedExampleURI string + if !strings.HasPrefix(dbConfig.URI, "file:") { + fixedExampleURI = fmt.Sprintf("file:%s?_txlock=immediate", dbConfig.URI) + } else if !strings.ContainsRune(dbConfig.URI, '?') { + fixedExampleURI = fmt.Sprintf("%s?_txlock=immediate", dbConfig.URI) + } else { + fixedExampleURI = fmt.Sprintf("%s&_txlock=immediate", dbConfig.URI) + } + br.Log.Warn(). + Str("fixed_uri_example", fixedExampleURI). + Msg("Using SQLite without _txlock=immediate is not recommended") + } + var err error + br.DB, err = dbutil.NewFromConfig("megabridge/"+br.Name, dbConfig, dbutil.ZeroLogger(br.Log.With().Str("db_section", "main").Logger())) + if err != nil { + br.Log.WithLevel(zerolog.FatalLevel).Err(err).Msg("Failed to initialize database connection") + if sqlError := (&sqlite3.Error{}); errors.As(err, sqlError) && sqlError.Code == sqlite3.ErrCorrupt { + os.Exit(18) + } + os.Exit(14) + } + br.DB.IgnoreUnsupportedDatabase = *ignoreUnsupportedDatabase + br.DB.IgnoreForeignTables = *ignoreForeignTables +} + +func (br *BridgeMain) validateConfig() error { + switch { + case br.Config.Homeserver.Address == "http://example.localhost:8008": + return errors.New("homeserver.address not configured") + case br.Config.Homeserver.Domain == "example.com": + return errors.New("homeserver.domain not configured") + case !bridgeconfig.AllowedHomeserverSoftware[br.Config.Homeserver.Software]: + return errors.New("invalid value for homeserver.software (use `standard` if you don't know what the field is for)") + case br.Config.AppService.ASToken == "This value is generated when generating the registration": + return errors.New("appservice.as_token not configured. Did you forget to generate the registration? ") + case br.Config.AppService.HSToken == "This value is generated when generating the registration": + return errors.New("appservice.hs_token not configured. Did you forget to generate the registration? ") + case br.Config.Database.URI == "postgres://user:password@host/database?sslmode=disable": + return errors.New("database.uri not configured") + case !br.Config.Bridge.Permissions.IsConfigured(): + return errors.New("bridge.permissions not configured") + case !strings.Contains(br.Config.AppService.FormatUsername("1234567890"), "1234567890"): + return errors.New("username template is missing user ID placeholder") + default: + cfgValidator, ok := br.Connector.(bridgev2.ConfigValidatingNetwork) + if ok { + err := cfgValidator.ValidateConfig() + if err != nil { + return err + } + } + return nil + } +} + +func (br *BridgeMain) getConfigUpgrader() (configupgrade.BaseUpgrader, any) { + networkExample, networkData, networkUpgrader := br.Connector.GetConfig() + baseConfig := br.makeFullExampleConfig(networkExample) + if networkUpgrader == nil { + networkUpgrader = configupgrade.NoopUpgrader + } + networkUpgraderProxied := &configupgrade.ProxyUpgrader{Target: networkUpgrader, Prefix: []string{"network"}} + upgrader := configupgrade.MergeUpgraders(baseConfig, networkUpgraderProxied, bridgeconfig.Upgrader) + return upgrader, networkData +} + +// LoadConfig upgrades and loads the config file. +// This is called by [Run] and does not need to be called manually. +func (br *BridgeMain) LoadConfig() { + upgrader, networkData := br.getConfigUpgrader() + configData, upgraded, err := configupgrade.Do(br.ConfigPath, br.SaveConfig, upgrader) + if err != nil { + _, _ = fmt.Fprintln(os.Stderr, "Error updating config:", err) + if !upgraded { + os.Exit(10) + } + } + + var cfg bridgeconfig.Config + err = yaml.Unmarshal(configData, &cfg) + if err != nil { + _, _ = fmt.Fprintln(os.Stderr, "Failed to parse config:", err) + os.Exit(10) + } + if networkData != nil { + err = cfg.Network.Decode(networkData) + if err != nil { + _, _ = fmt.Fprintln(os.Stderr, "Failed to parse network config:", err) + os.Exit(10) + } + } + cfg.Bridge.Backfill = cfg.Backfill + if cfg.EnvConfigPrefix != "" { + err = UpdateConfigFromEnv(&cfg, networkData, cfg.EnvConfigPrefix) + if err != nil { + _, _ = fmt.Fprintln(os.Stderr, "Failed to parse environment variables:", err) + os.Exit(10) + } + } + br.Config = &cfg +} + +// Start starts the bridge after everything has been initialized. +// This is called by [Run] and does not need to be called manually. +func (br *BridgeMain) Start() { + ctx := br.Log.WithContext(context.Background()) + err := br.Bridge.StartConnectors(ctx) + if err != nil { + var dbUpgradeErr bridgev2.DBUpgradeError + if errors.As(err, &dbUpgradeErr) { + br.LogDBUpgradeErrorAndExit(dbUpgradeErr.Section, dbUpgradeErr.Err, "Failed to initialize database") + } else { + br.Log.Fatal().Err(err).Msg("Failed to start bridge") + } + } + err = br.PostMigrate(ctx) + if err != nil { + br.Log.Fatal().Err(err).Msg("Failed to run post-migration updates") + } + err = br.Bridge.StartLogins(ctx) + if err != nil { + br.Log.Fatal().Err(err).Msg("Failed to start existing user logins") + } + br.Bridge.PostStart(ctx) + if br.PostStart != nil { + br.PostStart() + } +} + +// WaitForInterrupt waits for a SIGINT or SIGTERM signal. +func (br *BridgeMain) WaitForInterrupt() int { + c := make(chan os.Signal, 1) + signal.Notify(c, os.Interrupt, syscall.SIGTERM) + select { + case <-c: + br.Log.Info().Msg("Interrupt signal received from OS") + return 0 + case exitCode := <-br.manualStop: + br.Log.Info().Msg("Internal stop signal received") + return exitCode + } +} + +func (br *BridgeMain) TriggerStop(exitCode int) { + select { + case br.manualStop <- exitCode: + default: + } +} + +// Stop cleanly stops the bridge. This is called by [Run] and does not need to be called manually. +func (br *BridgeMain) Stop() { + br.Bridge.StopWithTimeout(5 * time.Second) +} + +// InitVersion formats the bridge version and build time nicely for things like +// the `version` bridge command on Matrix and the `--version` CLI flag. +// +// The values should generally be set by the build system. For example, assuming you have +// +// var ( +// Tag = "unknown" +// Commit = "unknown" +// BuildTime = "unknown" +// ) +// +// in your main package, then you'd use the following ldflags to fill them appropriately: +// +// go build -ldflags "-X main.Tag=$(git describe --exact-match --tags 2>/dev/null) -X main.Commit=$(git rev-parse HEAD) -X 'main.BuildTime=`date -Iseconds`'" +// +// You may additionally want to fill the mautrix-go version using another ldflag: +// +// export MAUTRIX_VERSION=$(cat go.mod | grep 'maunium.net/go/mautrix ' | head -n1 | awk '{ print $2 }') +// go build -ldflags "-X 'maunium.net/go/mautrix.GoModVersion=$MAUTRIX_VERSION'" +// +// (to use both at the same time, simply merge the ldflags into one, `-ldflags "-X '...' -X ..."`) +func (br *BridgeMain) InitVersion(tag, commit, rawBuildTime string) { + br.ver = progver.ProgramVersion{ + Name: br.Name, + URL: br.URL, + BaseVersion: br.Version, + SemCalVer: br.SemCalVer, + }.Init(tag, commit, rawBuildTime) + mautrix.DefaultUserAgent = fmt.Sprintf("%s/%s %s", br.Name, br.ver.FormattedVersion, mautrix.DefaultUserAgent) + br.Version = br.ver.FormattedVersion +} diff --git a/bridgev2/matrix/mxmain/main_test.go b/bridgev2/matrix/mxmain/main_test.go new file mode 100644 index 00000000..9a71344d --- /dev/null +++ b/bridgev2/matrix/mxmain/main_test.go @@ -0,0 +1,40 @@ +// Copyright (c) 2024 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package mxmain_test + +import ( + "maunium.net/go/mautrix/bridgev2" + "maunium.net/go/mautrix/bridgev2/matrix/mxmain" +) + +// Information to find out exactly which commit the bridge was built from. +// These are filled at build time with the -X linker flag. +var ( + Tag = "unknown" + Commit = "unknown" + BuildTime = "unknown" +) + +func ExampleBridgeMain() { + // Set this yourself + var yourConnector bridgev2.NetworkConnector + + m := mxmain.BridgeMain{ + Name: "example-matrix-bridge", + URL: "https://github.com/octocat/matrix-bridge", + Description: "An example Matrix bridge.", + Version: "1.0.0", + + Connector: yourConnector, + } + m.PostInit = func() { + // If you want some code to run after all the setup is done, but before the bridge is started, + // you can set a function in PostInit. This is not required if you don't need to do anything special. + } + m.InitVersion(Tag, Commit, BuildTime) + m.Run() +} diff --git a/bridge/no-crypto.go b/bridgev2/matrix/no-crypto.go similarity index 54% rename from bridge/no-crypto.go rename to bridgev2/matrix/no-crypto.go index 019ab7c1..fe942f83 100644 --- a/bridge/no-crypto.go +++ b/bridgev2/matrix/no-crypto.go @@ -1,4 +1,4 @@ -// Copyright (c) 2023 Tulir Asokan +// Copyright (c) 2024 Tulir Asokan // // This Source Code Form is subject to the terms of the Mozilla Public // License, v. 2.0. If a copy of the MPL was not distributed with this @@ -6,17 +6,17 @@ //go:build !cgo || nocrypto -package bridge +package matrix import ( "errors" ) -func NewCryptoHelper(bridge *Bridge) Crypto { - if bridge.Config.Bridge.GetEncryptionConfig().Allow { - bridge.ZLog.Warn().Msg("Bridge built without end-to-bridge encryption, but encryption is enabled in config") +func NewCryptoHelper(c *Connector) Crypto { + if c.Config.Encryption.Allow { + c.Log.Warn().Msg("Bridge built without end-to-bridge encryption, but encryption is enabled in config") } else { - bridge.ZLog.Debug().Msg("Bridge built without end-to-bridge encryption") + c.Log.Debug().Msg("Bridge built without end-to-bridge encryption") } return nil } diff --git a/bridgev2/matrix/provisioning.go b/bridgev2/matrix/provisioning.go new file mode 100644 index 00000000..243b91da --- /dev/null +++ b/bridgev2/matrix/provisioning.go @@ -0,0 +1,797 @@ +// Copyright (c) 2025 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package matrix + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "net/http" + "net/http/pprof" + "strings" + "sync" + "time" + + "github.com/rs/xid" + "github.com/rs/zerolog" + "github.com/rs/zerolog/hlog" + "go.mau.fi/util/exerrors" + "go.mau.fi/util/exhttp" + "go.mau.fi/util/exstrings" + "go.mau.fi/util/jsontime" + "go.mau.fi/util/ptr" + "go.mau.fi/util/requestlog" + + "maunium.net/go/mautrix" + "maunium.net/go/mautrix/bridgev2" + "maunium.net/go/mautrix/bridgev2/networkid" + "maunium.net/go/mautrix/bridgev2/provisionutil" + "maunium.net/go/mautrix/bridgev2/status" + "maunium.net/go/mautrix/federation" + "maunium.net/go/mautrix/id" +) + +type matrixAuthCacheEntry struct { + Expires time.Time + UserID id.UserID +} + +type ProvisioningAPI struct { + Router *http.ServeMux + + br *Connector + log zerolog.Logger + net bridgev2.NetworkConnector + + fedClient *federation.Client + + logins map[string]*ProvLogin + loginsLock sync.RWMutex + + matrixAuthCache map[string]matrixAuthCacheEntry + matrixAuthCacheLock sync.Mutex + + // Set for a given login once credentials have been exported, once in this state the finish + // API is available which will call logout on the client in question. + sessionTransfers map[networkid.UserLoginID]struct{} + sessionTransfersLock sync.Mutex + + // GetAuthFromRequest is a custom function for getting the auth token from + // the request if the Authorization header is not present. + GetAuthFromRequest func(r *http.Request) string + + // GetUserIDFromRequest is a custom function for getting the user ID to + // authenticate as instead of using the user ID provided in the query + // parameter. + GetUserIDFromRequest func(r *http.Request) id.UserID +} + +type ProvLogin struct { + ID string + Process bridgev2.LoginProcess + NextStep *bridgev2.LoginStep + Override *bridgev2.UserLogin + Lock sync.Mutex +} + +type provisioningContextKey int + +const ( + provisioningUserKey provisioningContextKey = iota + provisioningUserLoginKey + provisioningLoginProcessKey + ProvisioningKeyRequest +) + +func (prov *ProvisioningAPI) GetUser(r *http.Request) *bridgev2.User { + return r.Context().Value(provisioningUserKey).(*bridgev2.User) +} + +func (prov *ProvisioningAPI) GetRouter() *http.ServeMux { + return prov.Router +} + +func (br *Connector) GetProvisioning() bridgev2.IProvisioningAPI { + return br.Provisioning +} + +func (prov *ProvisioningAPI) Init() { + prov.matrixAuthCache = make(map[string]matrixAuthCacheEntry) + prov.logins = make(map[string]*ProvLogin) + prov.sessionTransfers = make(map[networkid.UserLoginID]struct{}) + prov.net = prov.br.Bridge.Network + prov.log = prov.br.Log.With().Str("component", "provisioning").Logger() + prov.fedClient = federation.NewClient("", nil, nil) + prov.fedClient.HTTP.Timeout = 20 * time.Second + tp := prov.fedClient.HTTP.Transport.(*federation.ServerResolvingTransport) + tp.Dialer.Timeout = 10 * time.Second + tp.Transport.ResponseHeaderTimeout = 10 * time.Second + tp.Transport.TLSHandshakeTimeout = 10 * time.Second + prov.Router = http.NewServeMux() + prov.Router.HandleFunc("GET /v3/whoami", prov.GetWhoami) + prov.Router.HandleFunc("GET /v3/capabilities", prov.GetCapabilities) + prov.Router.HandleFunc("GET /v3/login/flows", prov.GetLoginFlows) + prov.Router.HandleFunc("POST /v3/login/start/{flowID}", prov.PostLoginStart) + prov.Router.HandleFunc("POST /v3/login/step/{loginProcessID}/{stepID}/{stepType}", prov.PostLoginStep) + prov.Router.HandleFunc("POST /v3/logout/{loginID}", prov.PostLogout) + prov.Router.HandleFunc("GET /v3/logins", prov.GetLogins) + prov.Router.HandleFunc("GET /v3/contacts", prov.GetContactList) + prov.Router.HandleFunc("POST /v3/search_users", prov.PostSearchUsers) + prov.Router.HandleFunc("GET /v3/resolve_identifier/{identifier}", prov.GetResolveIdentifier) + prov.Router.HandleFunc("POST /v3/create_dm/{identifier}", prov.PostCreateDM) + prov.Router.HandleFunc("POST /v3/create_group/{type}", prov.PostCreateGroup) + + if prov.br.Config.Provisioning.EnableSessionTransfers { + prov.log.Debug().Msg("Enabling session transfer API") + prov.Router.HandleFunc("POST /v3/session_transfer/init", prov.PostInitSessionTransfer) + prov.Router.HandleFunc("POST /v3/session_transfer/finish", prov.PostFinishSessionTransfer) + } + + if prov.br.Config.Provisioning.DebugEndpoints { + prov.log.Debug().Msg("Enabling debug API at /debug") + debugRouter := http.NewServeMux() + debugRouter.HandleFunc("GET /pprof/cmdline", pprof.Cmdline) + debugRouter.HandleFunc("GET /pprof/profile", pprof.Profile) + debugRouter.HandleFunc("GET /pprof/symbol", pprof.Symbol) + debugRouter.HandleFunc("GET /pprof/trace", pprof.Trace) + debugRouter.HandleFunc("/pprof/", pprof.Index) + prov.br.AS.Router.Handle("/debug/", exhttp.ApplyMiddleware( + debugRouter, + exhttp.StripPrefix("/debug"), + hlog.NewHandler(prov.br.Log.With().Str("component", "debug api").Logger()), + requestlog.AccessLogger(requestlog.Options{TrustXForwardedFor: true}), + prov.DebugAuthMiddleware, + )) + } + + errorBodies := exhttp.ErrorBodies{ + NotFound: exerrors.Must(ptr.Ptr(mautrix.MUnrecognized.WithMessage("Unrecognized endpoint")).MarshalJSON()), + MethodNotAllowed: exerrors.Must(ptr.Ptr(mautrix.MUnrecognized.WithMessage("Invalid method for endpoint")).MarshalJSON()), + } + prov.br.AS.Router.Handle("/_matrix/provision/", exhttp.ApplyMiddleware( + prov.Router, + exhttp.StripPrefix("/_matrix/provision"), + hlog.NewHandler(prov.log), + hlog.RequestIDHandler("request_id", "Request-Id"), + exhttp.CORSMiddleware, + requestlog.AccessLogger(requestlog.Options{TrustXForwardedFor: true}), + exhttp.HandleErrors(errorBodies), + prov.AuthMiddleware, + )) +} + +func (prov *ProvisioningAPI) checkMatrixAuth(ctx context.Context, userID id.UserID, token string) error { + prov.matrixAuthCacheLock.Lock() + defer prov.matrixAuthCacheLock.Unlock() + if cached, ok := prov.matrixAuthCache[token]; ok && cached.Expires.After(time.Now()) && cached.UserID == userID { + return nil + } else if client, err := prov.br.DoublePuppet.newClient(ctx, userID, token); err != nil { + return err + } else if whoami, err := client.Whoami(ctx); err != nil { + return err + } else if whoami.UserID != userID { + return fmt.Errorf("mismatching user ID (%q != %q)", whoami.UserID, userID) + } else { + prov.matrixAuthCache[token] = matrixAuthCacheEntry{ + Expires: time.Now().Add(5 * time.Minute), + UserID: whoami.UserID, + } + return nil + } +} + +func (prov *ProvisioningAPI) checkFederatedMatrixAuth(ctx context.Context, userID id.UserID, token string) error { + homeserver := userID.Homeserver() + wrappedToken := fmt.Sprintf("%s:%s", homeserver, token) + // TODO smarter locking + prov.matrixAuthCacheLock.Lock() + defer prov.matrixAuthCacheLock.Unlock() + if cached, ok := prov.matrixAuthCache[wrappedToken]; ok && cached.Expires.After(time.Now()) && cached.UserID == userID { + return nil + } else if validationResult, err := prov.fedClient.GetOpenIDUserInfo(ctx, homeserver, token); err != nil { + return fmt.Errorf("failed to validate OpenID token: %w", err) + } else if validationResult.Sub != userID { + return fmt.Errorf("mismatching user ID (%q != %q)", validationResult, userID) + } else { + prov.matrixAuthCache[wrappedToken] = matrixAuthCacheEntry{ + Expires: time.Now().Add(1 * time.Hour), + UserID: userID, + } + return nil + } +} + +func disabledAuth(w http.ResponseWriter, r *http.Request) { + mautrix.MForbidden.WithMessage("Provisioning API is disabled").Write(w) +} + +func (prov *ProvisioningAPI) DebugAuthMiddleware(h http.Handler) http.Handler { + secret := prov.br.Config.Provisioning.SharedSecret + if len(secret) < 16 { + return http.HandlerFunc(disabledAuth) + } + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + auth := strings.TrimPrefix(r.Header.Get("Authorization"), "Bearer ") + if auth == "" { + mautrix.MMissingToken.WithMessage("Missing auth token").Write(w) + } else if !exstrings.ConstantTimeEqual(auth, secret) { + mautrix.MUnknownToken.WithMessage("Invalid auth token").Write(w) + } else { + h.ServeHTTP(w, r) + } + }) +} + +func (prov *ProvisioningAPI) AuthMiddleware(h http.Handler) http.Handler { + secret := prov.br.Config.Provisioning.SharedSecret + if len(secret) < 16 { + return http.HandlerFunc(disabledAuth) + } + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + auth := strings.TrimPrefix(r.Header.Get("Authorization"), "Bearer ") + if auth == "" && prov.GetAuthFromRequest != nil { + auth = prov.GetAuthFromRequest(r) + } + if auth == "" { + mautrix.MMissingToken.WithMessage("Missing auth token").Write(w) + return + } + userID := id.UserID(r.URL.Query().Get("user_id")) + if userID == "" && prov.GetUserIDFromRequest != nil { + userID = prov.GetUserIDFromRequest(r) + } + if !exstrings.ConstantTimeEqual(auth, secret) { + var err error + if strings.HasPrefix(auth, "openid:") { + err = prov.checkFederatedMatrixAuth(r.Context(), userID, strings.TrimPrefix(auth, "openid:")) + } else { + err = prov.checkMatrixAuth(r.Context(), userID, auth) + } + if err != nil { + zerolog.Ctx(r.Context()).Warn().Err(err). + Msg("Provisioning API request contained invalid auth") + mautrix.MUnknownToken.WithMessage("Invalid auth token").Write(w) + return + } + } + user, err := prov.br.Bridge.GetUserByMXID(r.Context(), userID) + if err != nil { + zerolog.Ctx(r.Context()).Err(err).Msg("Failed to get user") + mautrix.MUnknown.WithMessage("Failed to get user").Write(w) + return + } + // TODO handle user being nil? + // TODO per-endpoint permissions? + if !user.Permissions.Login { + mautrix.MForbidden.WithMessage("User does not have login permissions").Write(w) + return + } + + ctx := context.WithValue(r.Context(), ProvisioningKeyRequest, r) + ctx = context.WithValue(ctx, provisioningUserKey, user) + h.ServeHTTP(w, r.WithContext(ctx)) + }) +} + +type RespWhoami struct { + Network bridgev2.BridgeName `json:"network"` + LoginFlows []bridgev2.LoginFlow `json:"login_flows"` + Homeserver string `json:"homeserver"` + BridgeBot id.UserID `json:"bridge_bot"` + CommandPrefix string `json:"command_prefix"` + + ManagementRoom id.RoomID `json:"management_room,omitempty"` + Logins []RespWhoamiLogin `json:"logins"` +} + +type RespWhoamiLogin struct { + // Deprecated + StateEvent status.BridgeStateEvent `json:"state_event"` + // Deprecated + StateTS jsontime.Unix `json:"state_ts"` + // Deprecated + StateReason string `json:"state_reason,omitempty"` + // Deprecated + StateInfo map[string]any `json:"state_info,omitempty"` + + State status.BridgeState `json:"state"` + ID networkid.UserLoginID `json:"id"` + Name string `json:"name"` + Profile status.RemoteProfile `json:"profile"` + SpaceRoom id.RoomID `json:"space_room,omitempty"` +} + +func (prov *ProvisioningAPI) GetWhoami(w http.ResponseWriter, r *http.Request) { + user := prov.GetUser(r) + resp := &RespWhoami{ + Network: prov.br.Bridge.Network.GetName(), + LoginFlows: prov.br.Bridge.Network.GetLoginFlows(), + Homeserver: prov.br.AS.HomeserverDomain, + BridgeBot: prov.br.Bot.UserID, + CommandPrefix: prov.br.Config.Bridge.CommandPrefix, + ManagementRoom: user.ManagementRoom, + } + logins := user.GetUserLogins() + resp.Logins = make([]RespWhoamiLogin, len(logins)) + for i, login := range logins { + prevState := login.BridgeState.GetPrevUnsent() + // Clear redundant fields + prevState.UserID = "" + prevState.RemoteID = "" + prevState.RemoteName = "" + prevState.RemoteProfile = status.RemoteProfile{} + resp.Logins[i] = RespWhoamiLogin{ + StateEvent: prevState.StateEvent, + StateTS: prevState.Timestamp, + StateReason: prevState.Reason, + StateInfo: prevState.Info, + State: prevState, + + ID: login.ID, + Name: login.RemoteName, + Profile: login.RemoteProfile, + SpaceRoom: login.SpaceRoom, + } + } + exhttp.WriteJSONResponse(w, http.StatusOK, resp) +} + +type RespLoginFlows struct { + Flows []bridgev2.LoginFlow `json:"flows"` +} + +type RespSubmitLogin struct { + LoginID string `json:"login_id"` + *bridgev2.LoginStep +} + +func (prov *ProvisioningAPI) GetLoginFlows(w http.ResponseWriter, r *http.Request) { + exhttp.WriteJSONResponse(w, http.StatusOK, &RespLoginFlows{ + Flows: prov.net.GetLoginFlows(), + }) +} + +func (prov *ProvisioningAPI) GetCapabilities(w http.ResponseWriter, r *http.Request) { + exhttp.WriteJSONResponse(w, http.StatusOK, &prov.net.GetCapabilities().Provisioning) +} + +var ErrNilStep = errors.New("bridge returned nil step with no error") +var ErrTooManyLogins = bridgev2.RespError{ErrCode: "FI.MAU.BRIDGE.TOO_MANY_LOGINS", Err: "Maximum number of logins exceeded"} + +func (prov *ProvisioningAPI) PostLoginStart(w http.ResponseWriter, r *http.Request) { + overrideLogin, failed := prov.GetExplicitLoginForRequest(w, r) + if failed { + return + } + user := prov.GetUser(r) + if overrideLogin == nil && user.HasTooManyLogins() { + ErrTooManyLogins.AppendMessage(" (%d)", user.Permissions.MaxLogins).Write(w) + return + } + login, err := prov.net.CreateLogin(r.Context(), user, r.PathValue("flowID")) + if err != nil { + zerolog.Ctx(r.Context()).Err(err).Msg("Failed to create login process") + RespondWithError(w, err, "Internal error creating login process") + return + } + var firstStep *bridgev2.LoginStep + overridable, ok := login.(bridgev2.LoginProcessWithOverride) + if ok && overrideLogin != nil { + firstStep, err = overridable.StartWithOverride(r.Context(), overrideLogin) + } else { + firstStep, err = login.Start(r.Context()) + } + if err == nil && firstStep == nil { + err = ErrNilStep + } + if err != nil { + zerolog.Ctx(r.Context()).Err(err).Msg("Failed to start login") + RespondWithError(w, err, "Internal error starting login") + return + } + loginID := xid.New().String() + prov.loginsLock.Lock() + prov.logins[loginID] = &ProvLogin{ + ID: loginID, + Process: login, + NextStep: firstStep, + Override: overrideLogin, + } + prov.loginsLock.Unlock() + zerolog.Ctx(r.Context()).Info(). + Any("first_step", firstStep). + Msg("Created login process") + exhttp.WriteJSONResponse(w, http.StatusOK, &RespSubmitLogin{LoginID: loginID, LoginStep: firstStep}) +} + +func (prov *ProvisioningAPI) handleCompleteStep(ctx context.Context, login *ProvLogin, step *bridgev2.LoginStep) { + zerolog.Ctx(ctx).Info(). + Str("step_id", step.StepID). + Str("user_login_id", string(step.CompleteParams.UserLoginID)). + Msg("Login completed successfully") + prov.deleteLogin(login, false) + if login.Override == nil || login.Override.ID == step.CompleteParams.UserLoginID { + return + } + zerolog.Ctx(ctx).Info(). + Str("old_login_id", string(login.Override.ID)). + Str("new_login_id", string(step.CompleteParams.UserLoginID)). + Msg("Login resulted in different remote ID than what was being overridden. Deleting previous login") + login.Override.Delete(ctx, status.BridgeState{ + StateEvent: status.StateLoggedOut, + Reason: "LOGIN_OVERRIDDEN", + }, bridgev2.DeleteOpts{LogoutRemote: true}) +} + +func (prov *ProvisioningAPI) deleteLogin(login *ProvLogin, cancel bool) { + if cancel { + login.Process.Cancel() + } + prov.loginsLock.Lock() + delete(prov.logins, login.ID) + prov.loginsLock.Unlock() +} + +func (prov *ProvisioningAPI) PostLoginStep(w http.ResponseWriter, r *http.Request) { + loginID := r.PathValue("loginProcessID") + prov.loginsLock.RLock() + login, ok := prov.logins[loginID] + prov.loginsLock.RUnlock() + if !ok { + zerolog.Ctx(r.Context()).Warn().Str("login_id", loginID).Msg("Login not found") + mautrix.MNotFound.WithMessage("Login not found").Write(w) + return + } + login.Lock.Lock() + // This will only unlock after the handler runs + defer login.Lock.Unlock() + stepID := r.PathValue("stepID") + if login.NextStep.StepID != stepID { + zerolog.Ctx(r.Context()).Warn(). + Str("request_step_id", stepID). + Str("expected_step_id", login.NextStep.StepID). + Msg("Step ID does not match") + mautrix.MBadState.WithMessage("Step ID does not match").Write(w) + return + } + stepType := r.PathValue("stepType") + if login.NextStep.Type != bridgev2.LoginStepType(stepType) { + zerolog.Ctx(r.Context()).Warn(). + Str("request_step_type", stepType). + Str("expected_step_type", string(login.NextStep.Type)). + Msg("Step type does not match") + mautrix.MBadState.WithMessage("Step type does not match").Write(w) + return + } + ctx := context.WithValue(r.Context(), provisioningLoginProcessKey, login) + r = r.WithContext(ctx) + switch bridgev2.LoginStepType(r.PathValue("stepType")) { + case bridgev2.LoginStepTypeUserInput, bridgev2.LoginStepTypeCookies: + prov.PostLoginSubmitInput(w, r) + case bridgev2.LoginStepTypeDisplayAndWait: + prov.PostLoginWait(w, r) + case bridgev2.LoginStepTypeComplete: + fallthrough + default: + // This is probably impossible because of the above check that the next step type matches the request. + mautrix.MUnrecognized.WithMessage("Invalid step type %q", r.PathValue("stepType")).Write(w) + } +} + +func (prov *ProvisioningAPI) PostLoginSubmitInput(w http.ResponseWriter, r *http.Request) { + var params map[string]string + err := json.NewDecoder(r.Body).Decode(¶ms) + if err != nil { + zerolog.Ctx(r.Context()).Err(err).Msg("Failed to decode request body") + mautrix.MNotJSON.WithMessage("Failed to decode request body").Write(w) + return + } + login := r.Context().Value(provisioningLoginProcessKey).(*ProvLogin) + var nextStep *bridgev2.LoginStep + switch login.NextStep.Type { + case bridgev2.LoginStepTypeUserInput: + nextStep, err = login.Process.(bridgev2.LoginProcessUserInput).SubmitUserInput(r.Context(), params) + case bridgev2.LoginStepTypeCookies: + nextStep, err = login.Process.(bridgev2.LoginProcessCookies).SubmitCookies(r.Context(), params) + default: + panic("Impossible state") + } + if err == nil && nextStep == nil { + err = ErrNilStep + } + if err != nil { + zerolog.Ctx(r.Context()).Err(err).Msg("Failed to submit input") + RespondWithError(w, err, "Internal error submitting input") + prov.deleteLogin(login, true) + return + } + login.NextStep = nextStep + if nextStep.Type == bridgev2.LoginStepTypeComplete { + prov.handleCompleteStep(r.Context(), login, nextStep) + } else { + zerolog.Ctx(r.Context()).Debug().Any("next_step", nextStep).Msg("Returning next login step") + } + exhttp.WriteJSONResponse(w, http.StatusOK, &RespSubmitLogin{LoginID: login.ID, LoginStep: nextStep}) +} + +func (prov *ProvisioningAPI) PostLoginWait(w http.ResponseWriter, r *http.Request) { + login := r.Context().Value(provisioningLoginProcessKey).(*ProvLogin) + nextStep, err := login.Process.(bridgev2.LoginProcessDisplayAndWait).Wait(r.Context()) + if err == nil && nextStep == nil { + err = ErrNilStep + } + if err != nil { + zerolog.Ctx(r.Context()).Err(err).Msg("Failed to wait") + RespondWithError(w, err, "Internal error waiting for login") + prov.deleteLogin(login, true) + return + } + login.NextStep = nextStep + if nextStep.Type == bridgev2.LoginStepTypeComplete { + prov.handleCompleteStep(r.Context(), login, nextStep) + } else { + zerolog.Ctx(r.Context()).Debug().Any("next_step", nextStep).Msg("Returning next login step") + } + exhttp.WriteJSONResponse(w, http.StatusOK, &RespSubmitLogin{LoginID: login.ID, LoginStep: nextStep}) +} + +func (prov *ProvisioningAPI) PostLogout(w http.ResponseWriter, r *http.Request) { + user := prov.GetUser(r) + userLoginID := networkid.UserLoginID(r.PathValue("loginID")) + if userLoginID == "all" { + for { + login := user.GetDefaultLogin() + if login == nil { + break + } + login.Logout(r.Context()) + } + } else { + userLogin := prov.br.Bridge.GetCachedUserLoginByID(userLoginID) + if userLogin == nil || userLogin.UserMXID != user.MXID { + mautrix.MNotFound.WithMessage("Login not found").Write(w) + return + } + userLogin.Logout(r.Context()) + } + exhttp.WriteEmptyJSONResponse(w, http.StatusOK) +} + +type RespGetLogins struct { + LoginIDs []networkid.UserLoginID `json:"login_ids"` +} + +func (prov *ProvisioningAPI) GetLogins(w http.ResponseWriter, r *http.Request) { + user := prov.GetUser(r) + exhttp.WriteJSONResponse(w, http.StatusOK, &RespGetLogins{LoginIDs: user.GetUserLoginIDs()}) +} + +func (prov *ProvisioningAPI) GetExplicitLoginForRequest(w http.ResponseWriter, r *http.Request) (*bridgev2.UserLogin, bool) { + userLoginID := networkid.UserLoginID(r.URL.Query().Get("login_id")) + if userLoginID == "" { + return nil, false + } + userLogin := prov.br.Bridge.GetCachedUserLoginByID(userLoginID) + if userLogin == nil || userLogin.UserMXID != prov.GetUser(r).MXID { + hlog.FromRequest(r).Warn(). + Str("login_id", string(userLoginID)). + Msg("Tried to use non-existent login, returning 404") + mautrix.MNotFound.WithMessage("Login not found").Write(w) + return nil, true + } + return userLogin, false +} + +var ErrNotLoggedIn = mautrix.RespError{ + Err: "Not logged in", + ErrCode: "FI.MAU.NOT_LOGGED_IN", + StatusCode: http.StatusBadRequest, +} + +func (prov *ProvisioningAPI) GetLoginForRequest(w http.ResponseWriter, r *http.Request) *bridgev2.UserLogin { + userLogin, failed := prov.GetExplicitLoginForRequest(w, r) + if userLogin != nil || failed { + return userLogin + } + userLogin = prov.GetUser(r).GetDefaultLogin() + if userLogin == nil { + ErrNotLoggedIn.Write(w) + return nil + } + return userLogin +} + +type WritableError interface { + Write(w http.ResponseWriter) +} + +func RespondWithError(w http.ResponseWriter, err error, message string) { + var we WritableError + if errors.As(err, &we) { + we.Write(w) + } else { + mautrix.MUnknown.WithMessage(message).Write(w) + } +} + +func (prov *ProvisioningAPI) doResolveIdentifier(w http.ResponseWriter, r *http.Request, createChat bool) { + login := prov.GetLoginForRequest(w, r) + if login == nil { + return + } + resp, err := provisionutil.ResolveIdentifier(r.Context(), login, r.PathValue("identifier"), createChat) + if err != nil { + RespondWithError(w, err, "Internal error resolving identifier") + } else if resp == nil { + mautrix.MNotFound.WithMessage("Identifier not found").Write(w) + } else { + status := http.StatusOK + if resp.JustCreated { + status = http.StatusCreated + } + exhttp.WriteJSONResponse(w, status, resp) + } +} + +func (prov *ProvisioningAPI) GetContactList(w http.ResponseWriter, r *http.Request) { + login := prov.GetLoginForRequest(w, r) + if login == nil { + return + } + resp, err := provisionutil.GetContactList(r.Context(), login) + if err != nil { + RespondWithError(w, err, "Internal error getting contact list") + return + } + exhttp.WriteJSONResponse(w, http.StatusOK, resp) +} + +type ReqSearchUsers struct { + Query string `json:"query"` +} + +func (prov *ProvisioningAPI) PostSearchUsers(w http.ResponseWriter, r *http.Request) { + var req ReqSearchUsers + err := json.NewDecoder(r.Body).Decode(&req) + if err != nil { + zerolog.Ctx(r.Context()).Err(err).Msg("Failed to decode request body") + mautrix.MNotJSON.WithMessage("Failed to decode request body").Write(w) + return + } + login := prov.GetLoginForRequest(w, r) + if login == nil { + return + } + resp, err := provisionutil.SearchUsers(r.Context(), login, req.Query) + if err != nil { + RespondWithError(w, err, "Internal error searching users") + return + } + exhttp.WriteJSONResponse(w, http.StatusOK, resp) +} + +func (prov *ProvisioningAPI) GetResolveIdentifier(w http.ResponseWriter, r *http.Request) { + prov.doResolveIdentifier(w, r, false) +} + +func (prov *ProvisioningAPI) PostCreateDM(w http.ResponseWriter, r *http.Request) { + prov.doResolveIdentifier(w, r, true) +} + +func (prov *ProvisioningAPI) PostCreateGroup(w http.ResponseWriter, r *http.Request) { + var req bridgev2.GroupCreateParams + err := json.NewDecoder(r.Body).Decode(&req) + if err != nil { + zerolog.Ctx(r.Context()).Err(err).Msg("Failed to decode request body") + mautrix.MNotJSON.WithMessage("Failed to decode request body").Write(w) + return + } + req.Type = r.PathValue("type") + login := prov.GetLoginForRequest(w, r) + if login == nil { + return + } + resp, err := provisionutil.CreateGroup(r.Context(), login, &req) + if err != nil { + RespondWithError(w, err, "Internal error creating group") + return + } + exhttp.WriteJSONResponse(w, http.StatusOK, resp) +} + +type ReqExportCredentials struct { + RemoteID networkid.UserLoginID `json:"remote_id"` +} + +type RespExportCredentials struct { + Credentials any `json:"credentials"` +} + +func (prov *ProvisioningAPI) PostInitSessionTransfer(w http.ResponseWriter, r *http.Request) { + prov.sessionTransfersLock.Lock() + defer prov.sessionTransfersLock.Unlock() + + var req ReqExportCredentials + err := json.NewDecoder(r.Body).Decode(&req) + if err != nil { + zerolog.Ctx(r.Context()).Err(err).Msg("Failed to decode request body") + mautrix.MNotJSON.WithMessage("Failed to decode request body").Write(w) + return + } + + user := prov.GetUser(r) + logins := user.GetUserLogins() + var loginToExport *bridgev2.UserLogin + for _, login := range logins { + if login.ID == req.RemoteID { + loginToExport = login + break + } + } + if loginToExport == nil { + mautrix.MNotFound.WithMessage("No matching user login found").Write(w) + return + } + + client, ok := loginToExport.Client.(bridgev2.CredentialExportingNetworkAPI) + if !ok { + mautrix.MUnrecognized.WithMessage("This bridge does not support exporting credentials").Write(w) + return + } + + if _, ok := prov.sessionTransfers[loginToExport.ID]; ok { + // Warn, but allow, double exports. This might happen if a client crashes handling creds, + // and should be safe to call multiple times. + zerolog.Ctx(r.Context()).Warn().Msg("Exporting already exported credentials") + } + + // Disconnect now so we don't use the same network session in two places at once + client.Disconnect() + exhttp.WriteJSONResponse(w, http.StatusOK, &RespExportCredentials{ + Credentials: client.ExportCredentials(r.Context()), + }) +} + +func (prov *ProvisioningAPI) PostFinishSessionTransfer(w http.ResponseWriter, r *http.Request) { + prov.sessionTransfersLock.Lock() + defer prov.sessionTransfersLock.Unlock() + + var req ReqExportCredentials + err := json.NewDecoder(r.Body).Decode(&req) + if err != nil { + zerolog.Ctx(r.Context()).Err(err).Msg("Failed to decode request body") + mautrix.MNotJSON.WithMessage("Failed to decode request body").Write(w) + return + } + + user := prov.GetUser(r) + logins := user.GetUserLogins() + var loginToExport *bridgev2.UserLogin + for _, login := range logins { + if login.ID == req.RemoteID { + loginToExport = login + break + } + } + if loginToExport == nil { + mautrix.MNotFound.WithMessage("No matching user login found").Write(w) + return + } else if _, ok := prov.sessionTransfers[loginToExport.ID]; !ok { + mautrix.MBadState.WithMessage("No matching credential export found").Write(w) + return + } + + zerolog.Ctx(r.Context()).Info(). + Str("remote_name", string(req.RemoteID)). + Msg("Logging out remote after finishing credential export") + + loginToExport.Client.LogoutRemote(r.Context()) + delete(prov.sessionTransfers, req.RemoteID) + + exhttp.WriteEmptyJSONResponse(w, http.StatusOK) +} diff --git a/bridgev2/matrix/provisioning.yaml b/bridgev2/matrix/provisioning.yaml new file mode 100644 index 00000000..26068db4 --- /dev/null +++ b/bridgev2/matrix/provisioning.yaml @@ -0,0 +1,1025 @@ +openapi: 3.1.0 +info: + title: Megabridge provisioning + description: |- + This is the provisioning API implemented in mautrix-go's bridgev2 package. + It can be used with any bridge built on that package. + license: + name: Mozilla Public License Version 2.0 + url: https://github.com/mautrix/go/blob/main/LICENSE + version: v0.20.0 +externalDocs: + description: mautrix-go godocs + url: https://pkg.go.dev/maunium.net/go/mautrix/bridgev2 +servers: +- url: http://localhost:8080/_matrix/provision +tags: +- name: auth + description: Manage your logins and log into new remote accounts +- name: snc + description: Starting new chats +paths: + /v3/whoami: + get: + tags: [ auth ] + summary: Get info about the bridge and your logins. + description: | + Get all info that is useful for presenting this bridge in a manager interface. + * Server details: remote network details, available login flows, homeserver name, bridge bot user ID, command prefix + * User details: management room ID, list of logins with current state and info + operationId: whoami + responses: + 200: + description: Successfully fetched info + content: + application/json: + schema: + $ref: '#/components/schemas/Whoami' + 401: + $ref: '#/components/responses/Unauthorized' + 500: + $ref: '#/components/responses/InternalError' + security: + - matrix_auth: [ ] + /v3/login/flows: + get: + tags: [ auth ] + summary: Get the available login flows. + operationId: getLoginFlows + responses: + 200: + description: Successfully fetched flows + content: + application/json: + schema: + type: object + properties: + flows: + type: array + items: + $ref: '#/components/schemas/LoginFlow' + 401: + $ref: '#/components/responses/Unauthorized' + 500: + $ref: '#/components/responses/InternalError' + security: + - matrix_auth: [ ] + /v3/logins: + get: + tags: [ auth ] + summary: Get the login IDs of the current user. + operationId: getLoginIDs + responses: + 200: + description: Successfully fetched list of logins + content: + application/json: + schema: + type: object + properties: + login_ids: + type: array + items: + $ref: '#/components/schemas/UserLoginID' + 401: + $ref: '#/components/responses/Unauthorized' + 500: + $ref: '#/components/responses/InternalError' + security: + - matrix_auth: [ ] + /v3/login/start/{flowID}: + post: + tags: [ auth ] + summary: Start a new login process. + description: | + This endpoint starts a new login process, which is used to log into the bridge. + + The basic flow of the entire login, including calling this endpoint, is: + 1. Call `GET /v3/login/flows` to get the list of available flows. + If there's more than one flow, ask the user to pick which one they want to use. + 2. Call this endpoint with the chosen flow ID to start the login. + The first login step will be returned. + 3. Render the information provided in the step. + 4. Call the `/login/step/...` endpoint corresponding to the step type: + * For `user_input` and `cookies`, acquire the requested fields before calling the endpoint. + * For `display_and_wait`, call the endpoint immediately + (as there's nothing to acquire on the client side). + 5. Handle the data returned by the login step endpoint: + * If an error is returned, the login has failed and must be restarted + (from either step 1 or step 2) if the user wants to try again. + * If step type `complete` is returned, the login finished successfully. + * Otherwise, go to step 3 with the new data. + operationId: startLogin + parameters: + - name: login_id + in: query + description: An existing login ID to re-login as. If this is specified and the user logs into a different account, the provided ID will be logged out. + required: false + schema: + $ref: '#/components/schemas/UserLoginID' + - name: flowID + in: path + description: The login flow ID to use. + required: true + schema: + type: string + examples: [ qr ] + responses: + 200: + description: Login successfully started + content: + application/json: + schema: + $ref: '#/components/schemas/LoginStep' + 401: + $ref: '#/components/responses/Unauthorized' + 404: + $ref: '#/components/responses/LoginNotFound' + 500: + $ref: '#/components/responses/InternalError' + security: + - matrix_auth: [ ] + /v3/login/step/{loginProcessID}/{stepID}/user_input: + post: + tags: [ auth ] + summary: Submit user input in a login process. + operationId: submitLoginStepUserInput + parameters: + - $ref: '#/components/parameters/loginProcessID' + - $ref: '#/components/parameters/stepID' + requestBody: + description: The data entered by the user + content: + application/json: + schema: + type: object + additionalProperties: + type: string + responses: + 200: + $ref: '#/components/responses/LoginStepSubmitted' + 400: + $ref: '#/components/responses/BadRequest' + 401: + $ref: '#/components/responses/Unauthorized' + 404: + $ref: '#/components/responses/LoginProcessNotFound' + 500: + $ref: '#/components/responses/InternalError' + security: + - matrix_auth: [ ] + /v3/login/step/{loginProcessID}/{stepID}/cookies: + post: + tags: [ auth ] + summary: Submit extracted cookies in a login process. + operationId: submitLoginStepCookies + parameters: + - $ref: '#/components/parameters/loginProcessID' + - $ref: '#/components/parameters/stepID' + requestBody: + description: The cookies extracted from the website + content: + application/json: + schema: + type: object + additionalProperties: + type: string + responses: + 200: + $ref: '#/components/responses/LoginStepSubmitted' + 400: + $ref: '#/components/responses/BadRequest' + 401: + $ref: '#/components/responses/Unauthorized' + 404: + $ref: '#/components/responses/LoginProcessNotFound' + 500: + $ref: '#/components/responses/InternalError' + security: + - matrix_auth: [ ] + /v3/login/step/{loginProcessID}/{stepID}/display_and_wait: + post: + tags: [ auth ] + summary: Wait for the next step after displaying data to the user. + operationId: submitLoginStepDisplayAndWait + parameters: + - $ref: '#/components/parameters/loginProcessID' + - $ref: '#/components/parameters/stepID' + responses: + 200: + $ref: '#/components/responses/LoginStepSubmitted' + 400: + $ref: '#/components/responses/BadRequest' + 401: + $ref: '#/components/responses/Unauthorized' + 404: + $ref: '#/components/responses/LoginProcessNotFound' + 500: + $ref: '#/components/responses/InternalError' + security: + - matrix_auth: [ ] + /v3/logout/{loginID}: + post: + tags: [ auth ] + summary: Log out of an existing login. + operationId: logout + parameters: + - name: loginID + in: path + description: The ID of the login to log out. Use `all` to log out of all logins. + required: true + schema: + oneOf: + - $ref: '#/components/schemas/UserLoginID' + - type: string + const: all + description: Log out of all logins + responses: + 200: + description: Login was successfully deleted + content: + application/json: + schema: + type: object + description: Empty object + 401: + $ref: '#/components/responses/Unauthorized' + 404: + $ref: '#/components/responses/LoginNotFound' + 500: + $ref: '#/components/responses/InternalError' + /v3/contacts: + get: + tags: [ snc ] + summary: Get a list of contacts. + operationId: getContacts + parameters: + - $ref: "#/components/parameters/loginID" + responses: + 200: + description: Contact list fetched successfully + content: + application/json: + schema: + type: object + properties: + contacts: + type: array + items: + $ref: '#/components/schemas/ResolvedIdentifier' + 401: + $ref: '#/components/responses/Unauthorized' + 404: + $ref: '#/components/responses/LoginNotFound' + 500: + $ref: '#/components/responses/InternalError' + 501: + $ref: '#/components/responses/NotSupported' + /v3/search_users: + post: + tags: [ snc ] + summary: Search for users on the remote network + operationId: searchUsers + parameters: + - $ref: "#/components/parameters/loginID" + requestBody: + content: + application/json: + schema: + type: object + properties: + query: + type: string + description: The search query to send to the remote network + responses: + 200: + description: Search completed successfully + content: + application/json: + schema: + type: object + properties: + results: + type: array + items: + $ref: '#/components/schemas/ResolvedIdentifier' + 401: + $ref: '#/components/responses/Unauthorized' + 404: + $ref: '#/components/responses/LoginNotFound' + 500: + $ref: '#/components/responses/InternalError' + 501: + $ref: '#/components/responses/NotSupported' + /v3/resolve_identifier/{identifier}: + get: + tags: [ snc ] + summary: Resolve an identifier to a user on the remote network. + operationId: resolveIdentifier + parameters: + - $ref: "#/components/parameters/loginID" + - $ref: "#/components/parameters/sncIdentifier" + responses: + 200: + description: Identifier resolved successfully + content: + application/json: + schema: + $ref: '#/components/schemas/ResolvedIdentifier' + 401: + $ref: '#/components/responses/Unauthorized' + 404: + # TODO identifier not found also returns 404 + $ref: '#/components/responses/LoginNotFound' + 500: + $ref: '#/components/responses/InternalError' + 501: + $ref: '#/components/responses/NotSupported' + /v3/create_dm/{identifier}: + post: + tags: [ snc ] + summary: Create a direct chat with a user on the remote network. + operationId: createDM + parameters: + - $ref: "#/components/parameters/loginID" + - $ref: "#/components/parameters/sncIdentifier" + responses: + 200: + description: Identifier resolved successfully + content: + application/json: + schema: + allOf: + - $ref: '#/components/schemas/ResolvedIdentifier' + - required: [id, mxid, dm_room_mxid] + 401: + $ref: '#/components/responses/Unauthorized' + 404: + # TODO identifier not found also returns 404 + $ref: '#/components/responses/LoginNotFound' + 500: + $ref: '#/components/responses/InternalError' + 501: + $ref: '#/components/responses/NotSupported' + /v3/create_group/{type}: + post: + tags: [ snc ] + summary: Create a group chat on the remote network. + operationId: createGroup + parameters: + - $ref: "#/components/parameters/loginID" + requestBody: + content: + application/json: + schema: + $ref: '#/components/schemas/GroupCreateParams' + responses: + 200: + description: Identifier resolved successfully + content: + application/json: + schema: + $ref: '#/components/schemas/CreatedGroup' + 401: + $ref: '#/components/responses/Unauthorized' + 404: + $ref: '#/components/responses/LoginNotFound' + 501: + $ref: '#/components/responses/NotSupported' +components: + parameters: + sncIdentifier: + name: identifier + in: path + description: The identifier to resolve or start a chat with. + required: true + schema: + type: string + examples: + - +12345678 + - username + - meow@example.com + loginID: + name: login_id + in: query + description: An optional explicit login ID to do the action through. + required: false + schema: + $ref: '#/components/schemas/UserLoginID' + loginProcessID: + name: loginProcessID + in: path + description: The ID of the login process, as returned in the `login_id` field of the start call. + required: true + schema: + type: string + stepID: + name: stepID + in: path + description: The ID of the step being submitted, as returned in the `step_id` field of the start call or the previous submit call. + required: true + schema: + type: string + stepType: + name: stepType + in: path + description: The type of step being submitted, as returned in the `type` field of the start call or the previous submit call. + required: true + schema: + type: string + enum: [ display_and_wait, user_input, cookies ] + responses: + BadRequest: + description: Something in the request was invalid + content: + application/json: + schema: + type: object + description: A Matrix-like error response + properties: + errcode: + type: string + enum: [ M_NOT_JSON, M_BAD_STATE ] + description: A Matrix-like error code + error: + type: string + description: A human-readable error message + examples: + - Failed to decode request body + - Step type does not match + Unauthorized: + description: The request contained an invalid token + content: + application/json: + schema: + type: object + description: A Matrix-like error response + properties: + errcode: + type: string + enum: [ M_UNKNOWN_TOKEN, M_MISSING_TOKEN ] + description: A Matrix-like error code + error: + type: string + description: A human-readable error message + examples: + - Invalid auth token + - Missing auth token + InternalError: + description: An unexpected error that doesn't have special handling yet + content: + application/json: + schema: + type: object + description: A Matrix-like error response + properties: + errcode: + type: string + enum: [ M_UNKNOWN ] + description: A Matrix-like error code + error: + type: string + description: A human-readable error message + examples: + - Failed to get user + - Failed to start login + LoginProcessNotFound: + description: The specified login process ID is unknown + content: + application/json: + schema: + type: object + description: A Matrix-like error response + properties: + errcode: + type: string + enum: [ M_NOT_FOUND ] + description: A Matrix-like error code + error: + type: string + description: A human-readable error message + examples: + - Login not found + LoginNotFound: + description: When explicitly specifying an existing user login, the specified login ID is unknown + content: + application/json: + schema: + type: object + description: A Matrix-like error response + properties: + errcode: + type: string + enum: [ M_NOT_FOUND ] + description: A Matrix-like error code + error: + type: string + description: A human-readable error message + examples: + - Login not found + NotSupported: + description: The given endpoint is not supported by this network connector. + content: + application/json: + schema: + type: object + description: A Matrix-like error response + properties: + errcode: + type: string + enum: [ M_UNRECOGNIZED ] + description: A Matrix-like error code + error: + type: string + description: A human-readable error message + examples: + - This bridge does not support listing contacts + LoginStepSubmitted: + description: Step submission successful + content: + application/json: + schema: + $ref: '#/components/schemas/LoginStep' + schemas: + ResolvedIdentifier: + type: object + description: A successfully resolved identifier. + required: [id] + properties: + id: + type: string + description: The internal user ID of the resolved user. + examples: + - c443c1a2-e9f7-48aa-890c-80336c300ba9 + name: + type: string + description: The name of the user on the remote network. + avatar_url: + type: string + format: mxc + description: The avatar of the user on the remote network. + pattern: mxc://[a-zA-Z0-9.:-]+/[a-zA-Z0-9_-]+ + examples: + - mxc://t2bot.io/JYDTofsS6V9aYfUiX7JueA36 + identifiers: + type: array + description: A list of identifiers for the user on the remote network. + items: + type: string + format: uri + examples: + - "tel:+123456789" + - "mailto:foo@example.com" + - "signal:username.123" + mxid: + type: string + format: matrix_user_id + description: The Matrix user ID of the ghost representing the user. + examples: + - '@signal_c443c1a2-e9f7-48aa-890c-80336c300ba9:t2bot.io' + dm_room_mxid: + type: string + format: matrix_room_id + description: The Matrix room ID of the direct chat with the user. + examples: + - '!OKhS0I5q2fCzdnl2qgeozDQw:t2bot.io' + GroupCreateParams: + type: object + description: | + Parameters for creating a group chat. + The /capabilities endpoint response must be checked to see which fields are actually allowed. + properties: + type: + type: string + description: The type of group to create. + examples: + - channel + username: + type: string + description: The public username for the created group. + participants: + type: array + description: The users to add to the group initially. + items: + type: string + parent: + type: object + name: + type: object + description: The `m.room.name` event content for the room. + properties: + name: + type: string + avatar: + type: object + description: The `m.room.avatar` event content for the room. + properties: + url: + type: string + format: mxc + topic: + type: object + description: The `m.room.topic` event content for the room. + properties: + topic: + type: string + disappear: + type: object + description: The `com.beeper.disappearing_timer` event content for the room. + properties: + type: + type: string + timer: + type: number + room_id: + type: string + format: matrix_room_id + description: | + An existing Matrix room ID to bridge to. + The other parameters must be already in sync with the room state when using this parameter. + CreatedGroup: + type: object + description: A successfully created group chat. + required: [id, mxid] + properties: + id: + type: string + description: The internal chat ID of the created group. + mxid: + type: string + format: matrix_room_id + description: The Matrix room ID of the portal. + examples: + - '!OKhS0I5q2fCzdnl2qgeozDQw:t2bot.io' + LoginStep: + type: object + description: A step in a login process. + properties: + login_id: + type: string + description: An identifier for the current login process. Must be passed to execute more steps of the login. + type: + type: string + description: The type of login step + enum: [ display_and_wait, user_input, cookies, complete ] + step_id: + type: string + description: An unique ID identifying this step. This can be used to implement special behavior in clients. + examples: [ fi.mau.signal.qr ] + instructions: + type: string + description: Human-readable instructions for completing this login step. + examples: [ Scan the QR code ] + oneOf: + - description: Display and wait login step + required: [ type, display_and_wait ] + properties: + type: + type: string + const: display_and_wait + display_and_wait: + type: object + description: Parameters for the display and wait login step + required: [ type ] + properties: + type: + type: string + description: The type of thing to display + enum: [ qr, emoji, code, nothing ] + data: + type: string + description: The thing to display (raw data for QR, unicode emoji for emoji, plain string for code) + image_url: + type: string + description: An image containing the thing to display. If present, this is recommended over using data directly. For emojis, the URL to the canonical image representation of the emoji + - description: User input login step + required: [ type, user_input ] + properties: + type: + type: string + const: user_input + user_input: + type: object + description: Parameters for the user input login step + required: [ fields ] + properties: + fields: + type: array + description: The list of fields that the user is requested to fill. + items: + type: object + description: A field that the user can fill. + required: [ type, id, name ] + properties: + type: + type: string + description: The type of field. + enum: [ username, phone_number, email, password, 2fa_code, token, url, domain, select ] + id: + type: string + description: The internal ID of the field. This must be used as the key in the object when submitting the data back to the bridge. + examples: [ uid, email, 2fa_password, meow ] + name: + type: string + description: The name of the field shown to the user. + examples: [ Username, Password, Phone number, 2FA code, Meow ] + description: + type: string + description: A more detailed description of the field shown to the user. + examples: + - Include the country code with a + + default_value: + type: string + description: A default value that the client can pre-fill the field with. + pattern: + type: string + format: regex + description: A regular expression that the field value must match. + options: + type: array + description: For fields of type select, the valid options. + items: + type: string + attachments: + type: array + description: A list of media attachments to show the user alongside the form fields. + items: + type: object + description: A media attachment to show the user. + required: [ type, filename, content ] + properties: + type: + type: string + description: The type of media attachment, using the same media type identifiers as Matrix attachments. Only some are supported. + enum: [ m.image, m.audio ] + filename: + type: string + description: The filename for the media attachment. + content: + type: string + description: The raw file content for the attachment encoded in base64. + info: + type: object + description: Optional but recommended metadata for the attachment. Can generally be derived from the raw content if omitted. + properties: + mimetype: + type: string + description: The MIME type for the media content. + examples: [ image/png, audio/mpeg ] + w: + type: number + description: The width of the media in pixels. Only applicable for images and videos. + h: + type: number + description: The height of the media in pixels. Only applicable for images and videos. + size: + type: number + description: The size of the media content in number of bytes. Strongly recommended to include. + - description: Cookie login step + required: [ type, cookies ] + properties: + type: + type: string + const: cookies + cookies: + type: object + description: Parameters for the cookie login step + required: [ url, fields ] + properties: + url: + type: string + format: uri + description: The URL to open when using a webview to extract cookies. + user_agent: + type: string + description: An optional user agent that the webview should use. + wait_for_url_pattern: + type: string + description: | + A regex pattern that the URL should match before the client closes the webview. + + The client may submit the login if the user closes the webview after all cookies are collected + even if this URL is not reached, but it should only automatically close the webview after + both cookies and the URL match. + extract_js: + type: string + description: | + A JavaScript snippet that can extract some or all of the fields. + The snippet will evaluate to a promise that resolves when the relevant fields are found. + Fields that are not present in the promise result must be extracted another way. + fields: + type: array + description: The list of cookies or other stored data that must be extracted. + items: + type: object + description: An individual cookie or other stored data item that must be extracted. + required: [ type, name ] + properties: + type: + type: string + description: The type of data to extract. + enum: [ cookie, local_storage, request_header, request_body, special ] + name: + type: string + description: The name of the item to extract. + request_url_regex: + type: string + description: For the `request_header` and `request_body` types, a regex that matches the URLs from which the values can be extracted. + cookie_domain: + type: string + description: For the `cookie` type, the domain of the cookie. + - description: Login complete + required: [ type, complete ] + properties: + type: + type: string + const: complete + complete: + type: object + description: Information about the completed login + properties: + user_login_id: + $ref: '#/components/schemas/UserLoginID' + LoginFlow: + type: object + description: An individual login flow which can be used to sign into the remote network. + required: [ name, description, id ] + properties: + name: + type: string + description: A human-readable name for the login flow. + examples: + - QR code + description: + type: string + description: A human-readable description of the login flow. + examples: + - Log in by scanning a QR code on the Signal app + id: + type: string + description: An internal ID that is passed to the /login/start call to start a login with this flow. + examples: + - qr + BridgeName: + type: object + description: Info about the network that the bridge is bridging to. + required: [ displayname, network_url, network_icon, network_id, beeper_bridge_type ] + properties: + displayname: + type: string + description: The displayname of the network. + examples: + - Signal + network_url: + type: string + description: The URL to the website of the network. + examples: + - https://signal.org + network_icon: + type: string + description: The icon of the network as a `mxc://` URI. + format: mxc + pattern: mxc://[a-zA-Z0-9.:-]+/[a-zA-Z0-9_-]+ + examples: + - mxc://maunium.net/wPJgTQbZOtpBFmDNkiNEMDUp + network_id: + type: string + description: An identifier uniquely identifying the network. + examples: + - signal + beeper_bridge_type: + type: string + description: An identifier uniquely identifying the bridge software. + examples: + - com.example.fancysignalbridge + BridgeState: + type: object + description: The connection status of an individual login + required: [ state_event, timestamp ] + properties: + state_event: + type: string + description: The current state of this login. + enum: [ "CONNECTING", "CONNECTED", "TRANSIENT_DISCONNECT", "BAD_CREDENTIALS", "UNKNOWN_ERROR" ] + timestamp: + type: number + description: The time when the state was last updated. + format: unix milliseconds + examples: + - 1723294560531 + error: + type: string + description: An error code defined by the network connector. + message: + type: string + description: A human-readable error message defined by the network connector. + reason: + type: string + description: A reason code for non-error states that aren't exactly successes either. + info: + type: object + description: Additional arbitrary info provided by the network connector. + UserLoginID: + type: string + description: The unique ID of a login. Defined by the network connector. + examples: + - bcc68892-b180-414f-9516-b4aadf7d0496 + RemoteProfile: + type: object + description: The profile info of the logged-in user on the remote network. + properties: + phone: + type: string + format: phone + description: The user's phone number + examples: + - +123456789 + email: + type: string + format: email + description: The user's email address + examples: + - foo@example.com + username: + type: string + description: The user's username + examples: + - foo.123 + name: + type: string + description: The user's displayname + examples: + - Foo Bar + avatar: + type: string + format: mxc + description: The user's avatar + pattern: mxc://[a-zA-Z0-9.:-]+/[a-zA-Z0-9_-]+ + examples: + - mxc://t2bot.io/JYDTofsS6V9aYfUiX7JueA36 + WhoamiLogin: + type: object + description: The info of an individual login + required: [ state, id, name, profile ] + properties: + state: + $ref: '#/components/schemas/BridgeState' + id: + $ref: '#/components/schemas/UserLoginID' + name: + type: string + description: A human-readable name for the login. Defined by the network connector. + examples: + - +123456789 + profile: + $ref: '#/components/schemas/RemoteProfile' + space_room: + type: string + format: matrix_room_id + description: The personal filtering space room ID for this login. + examples: + - "!X9l5njn4Mx1BpdoV8MOkyWU1:t2bot.io" + Whoami: + type: object + description: Info about the bridge and user + required: [ network, login_flows, homeserver, bridge_bot, command_prefix, logins ] + properties: + network: + $ref: '#/components/schemas/BridgeName' + login_flows: + type: array + description: The login flows that the bridge supports. + items: + $ref: '#/components/schemas/LoginFlow' + homeserver: + type: string + description: The server name the bridge is running on. + examples: + - t2bot.io + bridge_bot: + type: string + format: matrix_user_id + description: The Matrix user ID of the bridge bot. + examples: + - "@signalbot:t2bot.io" + command_prefix: + type: string + description: The command prefix used by this bridge. + examples: + - "!signal" + management_room: + type: string + format: matrix_room_id + description: The Matrix management room ID of the user who made the /whoami call. + examples: + - '!OKhS0I5q2fCzdnl2qgeozDQw:t2bot.io' + logins: + type: array + description: The logins of the user who made the /whoami call + items: + $ref: '#/components/schemas/WhoamiLogin' + securitySchemes: + matrix_auth: + type: http + scheme: bearer + description: Either a Matrix access token for users on the local server, or a [Matrix OpenID token](https://spec.matrix.org/v1.11/client-server-api/#openid) for users on other servers. diff --git a/bridgev2/matrix/publicmedia.go b/bridgev2/matrix/publicmedia.go new file mode 100644 index 00000000..82ea8c2b --- /dev/null +++ b/bridgev2/matrix/publicmedia.go @@ -0,0 +1,278 @@ +// Copyright (c) 2025 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package matrix + +import ( + "context" + "crypto/hmac" + "crypto/sha256" + "encoding/base64" + "encoding/binary" + "fmt" + "io" + "mime" + "net/http" + "net/url" + "slices" + "strings" + "time" + + "github.com/rs/zerolog" + + "maunium.net/go/mautrix/bridgev2" + "maunium.net/go/mautrix/bridgev2/database" + "maunium.net/go/mautrix/crypto/attachment" + "maunium.net/go/mautrix/event" + "maunium.net/go/mautrix/id" +) + +var _ bridgev2.MatrixConnectorWithPublicMedia = (*Connector)(nil) + +func (br *Connector) initPublicMedia() error { + if !br.Config.PublicMedia.Enabled { + return nil + } else if br.GetPublicAddress() == "" { + return fmt.Errorf("public media is enabled in config, but no public address is set") + } else if br.Config.PublicMedia.HashLength > 32 { + return fmt.Errorf("public media hash length is too long") + } else if br.Config.PublicMedia.HashLength < 0 { + return fmt.Errorf("public media hash length is negative") + } + br.pubMediaSigKey = []byte(br.Config.PublicMedia.SigningKey) + br.AS.Router.HandleFunc("GET /_mautrix/publicmedia/{customID}", br.serveDatabasePublicMedia) + br.AS.Router.HandleFunc("GET /_mautrix/publicmedia/{customID}/{filename}", br.serveDatabasePublicMedia) + br.AS.Router.HandleFunc("GET /_mautrix/publicmedia/{server}/{mediaID}/{checksum}", br.servePublicMedia) + br.AS.Router.HandleFunc("GET /_mautrix/publicmedia/{server}/{mediaID}/{checksum}/{filename}", br.servePublicMedia) + return nil +} + +func (br *Connector) hashContentURI(uri id.ContentURI, expiry []byte) []byte { + hasher := hmac.New(sha256.New, br.pubMediaSigKey) + hasher.Write([]byte(uri.String())) + hasher.Write(expiry) + return hasher.Sum(expiry)[:br.Config.PublicMedia.HashLength+len(expiry)] +} + +func (br *Connector) hashDBPublicMedia(pm *database.PublicMedia) []byte { + hasher := hmac.New(sha256.New, br.pubMediaSigKey) + hasher.Write([]byte(pm.MXC.String())) + hasher.Write([]byte(pm.MimeType)) + if pm.Keys != nil { + hasher.Write([]byte(pm.Keys.Version)) + hasher.Write([]byte(pm.Keys.Key.Algorithm)) + hasher.Write([]byte(pm.Keys.Key.Key)) + hasher.Write([]byte(pm.Keys.InitVector)) + hasher.Write([]byte(pm.Keys.Hashes.SHA256)) + } + return hasher.Sum(nil)[:br.Config.PublicMedia.HashLength] +} + +func (br *Connector) makePublicMediaChecksum(uri id.ContentURI) []byte { + var expiresAt []byte + if br.Config.PublicMedia.Expiry > 0 { + expiresAtInt := time.Now().Add(time.Duration(br.Config.PublicMedia.Expiry) * time.Second).Unix() + expiresAt = binary.BigEndian.AppendUint64(nil, uint64(expiresAtInt)) + } + return br.hashContentURI(uri, expiresAt) +} + +func (br *Connector) verifyPublicMediaChecksum(uri id.ContentURI, checksum []byte) (valid, expired bool) { + var expiryBytes []byte + if br.Config.PublicMedia.Expiry > 0 { + if len(checksum) < 8 { + return + } + expiryBytes = checksum[:8] + expiresAtInt := binary.BigEndian.Uint64(expiryBytes) + expired = time.Now().Unix() > int64(expiresAtInt) + } + valid = hmac.Equal(checksum, br.hashContentURI(uri, expiryBytes)) + return +} + +var proxyHeadersToCopy = []string{ + "Content-Type", "Content-Disposition", "Content-Length", "Content-Security-Policy", + "Access-Control-Allow-Origin", "Access-Control-Allow-Methods", "Access-Control-Allow-Headers", + "Cache-Control", "Cross-Origin-Resource-Policy", +} + +func (br *Connector) servePublicMedia(w http.ResponseWriter, r *http.Request) { + contentURI := id.ContentURI{ + Homeserver: r.PathValue("server"), + FileID: r.PathValue("mediaID"), + } + if !contentURI.IsValid() { + http.Error(w, "invalid content URI", http.StatusBadRequest) + return + } + checksum, err := base64.RawURLEncoding.DecodeString(r.PathValue("checksum")) + if err != nil || !hmac.Equal(checksum, br.makePublicMediaChecksum(contentURI)) { + http.Error(w, "invalid base64 in checksum", http.StatusBadRequest) + return + } else if valid, expired := br.verifyPublicMediaChecksum(contentURI, checksum); !valid { + http.Error(w, "invalid checksum", http.StatusNotFound) + return + } else if expired { + http.Error(w, "checksum expired", http.StatusGone) + return + } + br.doProxyMedia(w, r, contentURI, nil, "") +} + +func (br *Connector) serveDatabasePublicMedia(w http.ResponseWriter, r *http.Request) { + if !br.Config.PublicMedia.UseDatabase { + http.Error(w, "public media short links are disabled", http.StatusNotFound) + return + } + log := zerolog.Ctx(r.Context()) + media, err := br.Bridge.DB.PublicMedia.Get(r.Context(), r.PathValue("customID")) + if err != nil { + log.Err(err).Msg("Failed to get public media from database") + http.Error(w, "failed to get media metadata", http.StatusInternalServerError) + return + } else if media == nil { + http.Error(w, "media ID not found", http.StatusNotFound) + return + } else if !media.Expiry.IsZero() && media.Expiry.Before(time.Now()) { + // This is not gone as it can still be refreshed in the DB + http.Error(w, "media expired", http.StatusNotFound) + return + } else if media.Keys != nil && media.Keys.PrepareForDecryption() != nil { + http.Error(w, "media keys are malformed", http.StatusInternalServerError) + return + } + br.doProxyMedia(w, r, media.MXC, media.Keys, media.MimeType) +} + +var safeMimes = []string{ + "text/css", "text/plain", "text/csv", + "application/json", "application/ld+json", + "image/jpeg", "image/gif", "image/png", "image/apng", "image/webp", "image/avif", + "video/mp4", "video/webm", "video/ogg", "video/quicktime", + "audio/mp4", "audio/webm", "audio/aac", "audio/mpeg", "audio/ogg", "audio/wave", + "audio/wav", "audio/x-wav", "audio/x-pn-wav", "audio/flac", "audio/x-flac", +} + +func (br *Connector) doProxyMedia(w http.ResponseWriter, r *http.Request, contentURI id.ContentURI, encInfo *attachment.EncryptedFile, mimeType string) { + resp, err := br.Bot.Download(r.Context(), contentURI) + if err != nil { + zerolog.Ctx(r.Context()).Warn().Stringer("uri", contentURI).Err(err).Msg("Failed to download media to proxy") + http.Error(w, "failed to download media", http.StatusInternalServerError) + return + } + defer resp.Body.Close() + for _, hdr := range proxyHeadersToCopy { + w.Header()[hdr] = resp.Header[hdr] + } + stream := resp.Body + if encInfo != nil { + if mimeType == "" { + mimeType = "application/octet-stream" + } + contentDisposition := "attachment" + if slices.Contains(safeMimes, mimeType) { + contentDisposition = "inline" + } + dispositionArgs := map[string]string{} + if filename := r.PathValue("filename"); filename != "" { + dispositionArgs["filename"] = filename + } + w.Header().Set("Content-Type", mimeType) + w.Header().Set("Content-Disposition", mime.FormatMediaType(contentDisposition, dispositionArgs)) + // Note: this won't check the Close result like it should, but it's probably not a big deal here + stream = encInfo.DecryptStream(stream) + } else if filename := r.PathValue("filename"); filename != "" { + contentDisposition, _, _ := mime.ParseMediaType(resp.Header.Get("Content-Disposition")) + if contentDisposition == "" { + contentDisposition = "attachment" + } + w.Header().Set("Content-Disposition", mime.FormatMediaType(contentDisposition, map[string]string{ + "filename": filename, + })) + } + w.WriteHeader(http.StatusOK) + _, _ = io.Copy(w, stream) +} + +func (br *Connector) GetPublicMediaAddress(contentURI id.ContentURIString) string { + return br.getPublicMediaAddressWithFileName(contentURI, "") +} + +func (br *Connector) getPublicMediaAddressWithFileName(contentURI id.ContentURIString, fileName string) string { + if br.pubMediaSigKey == nil { + return "" + } + parsed, err := contentURI.Parse() + if err != nil || !parsed.IsValid() { + return "" + } + fileName = url.PathEscape(strings.ReplaceAll(fileName, "/", "_")) + if fileName == ".." { + fileName = "" + } + parts := []string{ + br.GetPublicAddress(), + strings.Trim(br.Config.PublicMedia.PathPrefix, "/"), + parsed.Homeserver, + parsed.FileID, + base64.RawURLEncoding.EncodeToString(br.makePublicMediaChecksum(parsed)), + fileName, + } + if fileName == "" { + parts = parts[:len(parts)-1] + } + return strings.Join(parts, "/") +} + +func (br *Connector) GetPublicMediaAddressForEvent(ctx context.Context, evt *event.MessageEventContent) (string, error) { + if br.pubMediaSigKey == nil { + return "", bridgev2.ErrPublicMediaDisabled + } + if !br.Config.PublicMedia.UseDatabase { + if evt.File != nil { + return "", fmt.Errorf("can't generate address for encrypted file: %w", bridgev2.ErrPublicMediaDatabaseDisabled) + } + return br.getPublicMediaAddressWithFileName(evt.URL, evt.GetFileName()), nil + } + mxc := evt.URL + var keys *attachment.EncryptedFile + if evt.File != nil { + mxc = evt.File.URL + keys = &evt.File.EncryptedFile + } + parsedMXC, err := mxc.Parse() + if err != nil { + return "", fmt.Errorf("%w: failed to parse MXC: %w", bridgev2.ErrPublicMediaGenerateFailed, err) + } + pm := &database.PublicMedia{ + MXC: parsedMXC, + Keys: keys, + MimeType: evt.GetInfo().MimeType, + } + if br.Config.PublicMedia.Expiry > 0 { + pm.Expiry = time.Now().Add(time.Duration(br.Config.PublicMedia.Expiry) * time.Second) + } + pm.PublicID = base64.RawURLEncoding.EncodeToString(br.hashDBPublicMedia(pm)) + err = br.Bridge.DB.PublicMedia.Put(ctx, pm) + if err != nil { + return "", fmt.Errorf("%w: failed to store public media in database: %w", bridgev2.ErrPublicMediaGenerateFailed, err) + } + fileName := url.PathEscape(strings.ReplaceAll(evt.GetFileName(), "/", "_")) + if fileName == ".." { + fileName = "" + } + parts := []string{ + br.GetPublicAddress(), + strings.Trim(br.Config.PublicMedia.PathPrefix, "/"), + pm.PublicID, + fileName, + } + if fileName == "" { + parts = parts[:len(parts)-1] + } + return strings.Join(parts, "/"), nil +} diff --git a/bridge/websocket.go b/bridgev2/matrix/websocket.go similarity index 64% rename from bridge/websocket.go rename to bridgev2/matrix/websocket.go index 44a3d8d8..b498cacd 100644 --- a/bridge/websocket.go +++ b/bridgev2/matrix/websocket.go @@ -1,14 +1,19 @@ -package bridge +// Copyright (c) 2024 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package matrix import ( "context" "errors" "fmt" + "os" "sync" "time" - "go.mau.fi/util/jsontime" - "maunium.net/go/mautrix/appservice" ) @@ -16,24 +21,21 @@ const defaultReconnectBackoff = 2 * time.Second const maxReconnectBackoff = 2 * time.Minute const reconnectBackoffReset = 5 * time.Minute -func (br *Bridge) startWebsocket(wg *sync.WaitGroup) { - log := br.ZLog.With().Str("action", "appservice websocket").Logger() +func (br *Connector) startWebsocket(wg *sync.WaitGroup) { + log := br.Log.With().Str("action", "appservice websocket").Logger() var wgOnce sync.Once onConnect := func() { - wssBr, ok := br.Child.(WebsocketStartingBridge) - if ok { - wssBr.OnWebsocketConnect() - } - if br.latestState != nil { + if br.hasSentAnyStates { go func() { ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) defer cancel() - br.latestState.Timestamp = jsontime.UnixNow() - err := br.SendBridgeState(ctx, br.latestState) - if err != nil { - log.Err(err).Msg("Failed to resend latest bridge state after websocket reconnect") - } else { - log.Debug().Any("bridge_state", br.latestState).Msg("Resent bridge state after websocket reconnect") + for _, state := range br.Bridge.GetCurrentBridgeStates() { + err := br.SendBridgeStatus(ctx, &state) + if err != nil { + log.Err(err).Msg("Failed to resend latest bridge state after websocket reconnect") + } else { + log.Debug().Any("bridge_state", state).Msg("Resent bridge state after websocket reconnect") + } } }() } @@ -55,17 +57,21 @@ func (br *Bridge) startWebsocket(wg *sync.WaitGroup) { addr = br.Config.Homeserver.Address } for { - err := br.AS.StartWebsocket(addr, onConnect) + err := br.AS.StartWebsocket(br.Bridge.BackgroundCtx, addr, onConnect) if errors.Is(err, appservice.ErrWebsocketManualStop) { return } else if closeCommand := (&appservice.CloseCommand{}); errors.As(err, &closeCommand) && closeCommand.Status == appservice.MeowConnectionReplaced { - log.Info().Msg("Appservice websocket closed by another instance of the bridge, shutting down...") - br.ManualStop(0) + log.Warn().Msg("Appservice websocket closed by another instance of the bridge, shutting down...") + if br.OnWebsocketReplaced != nil { + br.OnWebsocketReplaced() + } else { + os.Exit(1) + } return } else if err != nil { log.Err(err).Msg("Error in appservice websocket") } - if br.Stopping { + if br.stopping { return } now := time.Now().UnixNano() @@ -86,7 +92,7 @@ func (br *Bridge) startWebsocket(wg *sync.WaitGroup) { log.Debug().Msg("Reconnect backoff was short-circuited") case <-time.After(reconnectBackoff): } - if br.Stopping { + if br.stopping { return } } @@ -96,30 +102,30 @@ type wsPingData struct { Timestamp int64 `json:"timestamp"` } -func (br *Bridge) PingServer() (start, serverTs, end time.Time) { +func (br *Connector) PingServer() (start, serverTs, end time.Time) { if !br.Websocket { panic(fmt.Errorf("PingServer called without websocket enabled")) } if !br.AS.HasWebsocket() { - br.ZLog.Debug().Msg("Received server ping request, but no websocket connected. Trying to short-circuit backoff sleep") + br.Log.Debug().Msg("Received server ping request, but no websocket connected. Trying to short-circuit backoff sleep") select { case br.wsShortCircuitReconnectBackoff <- struct{}{}: default: - br.ZLog.Warn().Msg("Failed to ping websocket: not connected and no backoff?") + br.Log.Warn().Msg("Failed to ping websocket: not connected and no backoff?") return } select { case <-br.wsStarted: case <-time.After(15 * time.Second): if !br.AS.HasWebsocket() { - br.ZLog.Warn().Msg("Failed to ping websocket: didn't connect after 15 seconds of waiting") + br.Log.Warn().Msg("Failed to ping websocket: didn't connect after 15 seconds of waiting") return } } } start = time.Now() var resp wsPingData - br.ZLog.Debug().Msg("Pinging appservice websocket") + br.Log.Debug().Msg("Pinging appservice websocket") ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() err := br.AS.RequestWebsocket(ctx, &appservice.WebsocketRequest{ @@ -128,11 +134,11 @@ func (br *Bridge) PingServer() (start, serverTs, end time.Time) { }, &resp) end = time.Now() if err != nil { - br.ZLog.Warn().Err(err).Dur("duration", end.Sub(start)).Msg("Websocket ping returned error") + br.Log.Warn().Err(err).Dur("duration", end.Sub(start)).Msg("Websocket ping returned error") br.AS.StopWebsocket(fmt.Errorf("websocket ping returned error in %s: %w", end.Sub(start), err)) } else { serverTs = time.Unix(0, resp.Timestamp*int64(time.Millisecond)) - br.ZLog.Debug(). + br.Log.Debug(). Dur("duration", end.Sub(start)). Dur("req_duration", serverTs.Sub(start)). Dur("resp_duration", end.Sub(serverTs)). @@ -141,14 +147,14 @@ func (br *Bridge) PingServer() (start, serverTs, end time.Time) { return } -func (br *Bridge) websocketServerPinger() { +func (br *Connector) websocketServerPinger() { interval := time.Duration(br.Config.Homeserver.WSPingInterval) * time.Second clock := time.NewTicker(interval) defer func() { - br.ZLog.Info().Msg("Stopping websocket pinger") + br.Log.Info().Msg("Stopping websocket pinger") clock.Stop() }() - br.ZLog.Info().Dur("interval_duration", interval).Msg("Starting websocket pinger") + br.Log.Info().Dur("interval_duration", interval).Msg("Starting websocket pinger") for { select { case <-clock.C: @@ -156,7 +162,7 @@ func (br *Bridge) websocketServerPinger() { case <-br.wsStopPinger: return } - if br.Stopping { + if br.stopping { return } } diff --git a/bridgev2/matrixinterface.go b/bridgev2/matrixinterface.go new file mode 100644 index 00000000..be26db49 --- /dev/null +++ b/bridgev2/matrixinterface.go @@ -0,0 +1,225 @@ +// Copyright (c) 2025 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package bridgev2 + +import ( + "context" + "fmt" + "io" + "net/http" + "os" + "time" + + "go.mau.fi/util/exhttp" + + "maunium.net/go/mautrix" + "maunium.net/go/mautrix/bridgev2/database" + "maunium.net/go/mautrix/bridgev2/networkid" + "maunium.net/go/mautrix/bridgev2/status" + "maunium.net/go/mautrix/event" + "maunium.net/go/mautrix/id" +) + +type MatrixCapabilities struct { + AutoJoinInvites bool + BatchSending bool + ArbitraryMemberChange bool + ExtraProfileMeta bool +} + +type MatrixConnector interface { + Init(*Bridge) + Start(ctx context.Context) error + PreStop() + Stop() + + GetCapabilities() *MatrixCapabilities + + ParseGhostMXID(userID id.UserID) (networkid.UserID, bool) + GhostIntent(userID networkid.UserID) MatrixAPI + NewUserIntent(ctx context.Context, userID id.UserID, accessToken string) (MatrixAPI, string, error) + BotIntent() MatrixAPI + + SendBridgeStatus(ctx context.Context, state *status.BridgeState) error + SendMessageStatus(ctx context.Context, status *MessageStatus, evt *MessageStatusEventInfo) + + GenerateContentURI(ctx context.Context, mediaID networkid.MediaID) (id.ContentURIString, error) + + GetPowerLevels(ctx context.Context, roomID id.RoomID) (*event.PowerLevelsEventContent, error) + GetMembers(ctx context.Context, roomID id.RoomID) (map[id.UserID]*event.MemberEventContent, error) + GetMemberInfo(ctx context.Context, roomID id.RoomID, userID id.UserID) (*event.MemberEventContent, error) + + BatchSend(ctx context.Context, roomID id.RoomID, req *mautrix.ReqBeeperBatchSend, extras []*MatrixSendExtra) (*mautrix.RespBeeperBatchSend, error) + GenerateDeterministicRoomID(portalKey networkid.PortalKey) id.RoomID + GenerateDeterministicEventID(roomID id.RoomID, portalKey networkid.PortalKey, messageID networkid.MessageID, partID networkid.PartID) id.EventID + GenerateReactionEventID(roomID id.RoomID, targetMessage *database.Message, sender networkid.UserID, emojiID networkid.EmojiID) id.EventID + + ServerName() string +} + +type MatrixConnectorWithArbitraryRoomState interface { + MatrixConnector + GetStateEvent(ctx context.Context, roomID id.RoomID, eventType event.Type, stateKey string) (*event.Event, error) +} + +type MatrixConnectorWithServer interface { + MatrixConnector + GetPublicAddress() string + GetRouter() *http.ServeMux +} + +type IProvisioningAPI interface { + GetRouter() *http.ServeMux + GetUser(r *http.Request) *User +} + +type MatrixConnectorWithProvisioning interface { + MatrixConnector + GetProvisioning() IProvisioningAPI +} + +type MatrixConnectorWithPublicMedia interface { + MatrixConnector + GetPublicMediaAddress(contentURI id.ContentURIString) string + GetPublicMediaAddressForEvent(ctx context.Context, evt *event.MessageEventContent) (string, error) +} + +type MatrixConnectorWithNameDisambiguation interface { + MatrixConnector + IsConfusableName(ctx context.Context, roomID id.RoomID, userID id.UserID, name string) ([]id.UserID, error) +} + +type MatrixConnectorWithBridgeIdentifier interface { + MatrixConnector + GetUniqueBridgeID() string +} + +type MatrixConnectorWithURLPreviews interface { + MatrixConnector + GetURLPreview(ctx context.Context, url string) (*event.LinkPreview, error) +} + +type MatrixConnectorWithPostRoomBridgeHandling interface { + MatrixConnector + HandleNewlyBridgedRoom(ctx context.Context, roomID id.RoomID) error +} + +type MatrixConnectorWithAnalytics interface { + MatrixConnector + TrackAnalytics(userID id.UserID, event string, properties map[string]any) +} + +type DirectNotificationData struct { + Portal *Portal + Sender *Ghost + MessageID networkid.MessageID + Message string + + FormattedNotification string + FormattedTitle string +} + +type MatrixConnectorWithNotifications interface { + MatrixConnector + DisplayNotification(ctx context.Context, data *DirectNotificationData) +} + +type MatrixConnectorWithHTTPSettings interface { + MatrixConnector + GetHTTPClientSettings() exhttp.ClientSettings +} + +type MatrixSendExtra struct { + Timestamp time.Time + MessageMeta *database.Message + ReactionMeta *database.Reaction + StreamOrder int64 + PartIndex int +} + +// FileStreamResult is the result of a FileStreamCallback. +type FileStreamResult struct { + // ReplacementFile is the path to a new file that replaces the original file provided to the callback. + // Providing a replacement file is only allowed if the requireFile flag was set for the UploadMediaStream call. + ReplacementFile string + // FileName is the name of the file to be specified when uploading to the server. + // This should be the same as the file name that will be included in the Matrix event (body or filename field). + // If the file gets encrypted, this field will be ignored. + FileName string + // MimeType is the type of field to be specified when uploading to the server. + // This should be the same as the mime type that will be included in the Matrix event (info -> mimetype field). + // If the file gets encrypted, this field will be replaced with application/octet-stream. + MimeType string +} + +// FileStreamCallback is a callback function for file uploads that roundtrip via disk. +// +// The parameter is either a file or an in-memory buffer depending on the size of the file and whether the requireFile flag was set. +// +// The return value must be non-nil unless there's an error, and should always include FileName and MimeType. +type FileStreamCallback func(file io.Writer) (*FileStreamResult, error) + +type CallbackError struct { + Type string + Wrapped error +} + +func (ce CallbackError) Error() string { + return fmt.Sprintf("%s callback failed: %s", ce.Type, ce.Wrapped.Error()) +} + +func (ce CallbackError) Unwrap() error { + return ce.Wrapped +} + +type EnsureJoinedParams struct { + Via []string +} + +type MatrixAPI interface { + GetMXID() id.UserID + IsDoublePuppet() bool + + SendMessage(ctx context.Context, roomID id.RoomID, eventType event.Type, content *event.Content, extra *MatrixSendExtra) (*mautrix.RespSendEvent, error) + SendState(ctx context.Context, roomID id.RoomID, eventType event.Type, stateKey string, content *event.Content, ts time.Time) (*mautrix.RespSendEvent, error) + MarkRead(ctx context.Context, roomID id.RoomID, eventID id.EventID, ts time.Time) error + MarkUnread(ctx context.Context, roomID id.RoomID, unread bool) error + MarkTyping(ctx context.Context, roomID id.RoomID, typingType TypingType, timeout time.Duration) error + DownloadMedia(ctx context.Context, uri id.ContentURIString, file *event.EncryptedFileInfo) ([]byte, error) + DownloadMediaToFile(ctx context.Context, uri id.ContentURIString, file *event.EncryptedFileInfo, writable bool, callback func(*os.File) error) error + UploadMedia(ctx context.Context, roomID id.RoomID, data []byte, fileName, mimeType string) (url id.ContentURIString, file *event.EncryptedFileInfo, err error) + UploadMediaStream(ctx context.Context, roomID id.RoomID, size int64, requireFile bool, cb FileStreamCallback) (url id.ContentURIString, file *event.EncryptedFileInfo, err error) + + SetDisplayName(ctx context.Context, name string) error + SetAvatarURL(ctx context.Context, avatarURL id.ContentURIString) error + SetExtraProfileMeta(ctx context.Context, data any) error + + CreateRoom(ctx context.Context, req *mautrix.ReqCreateRoom) (id.RoomID, error) + DeleteRoom(ctx context.Context, roomID id.RoomID, puppetsOnly bool) error + EnsureJoined(ctx context.Context, roomID id.RoomID, params ...EnsureJoinedParams) error + EnsureInvited(ctx context.Context, roomID id.RoomID, userID id.UserID) error + + TagRoom(ctx context.Context, roomID id.RoomID, tag event.RoomTag, isTagged bool) error + MuteRoom(ctx context.Context, roomID id.RoomID, until time.Time) error + + GetEvent(ctx context.Context, roomID id.RoomID, eventID id.EventID) (*event.Event, error) +} + +type StreamOrderReadingMatrixAPI interface { + MatrixAPI + MarkStreamOrderRead(ctx context.Context, roomID id.RoomID, streamOrder int64, ts time.Time) error +} + +type MarkAsDMMatrixAPI interface { + MatrixAPI + MarkAsDM(ctx context.Context, roomID id.RoomID, otherUser id.UserID) error +} + +type EphemeralSendingMatrixAPI interface { + MatrixAPI + BeeperSendEphemeralEvent(ctx context.Context, roomID id.RoomID, eventType event.Type, content *event.Content, txnID string) (*mautrix.RespSendEvent, error) +} diff --git a/bridgev2/matrixinvite.go b/bridgev2/matrixinvite.go new file mode 100644 index 00000000..75c00cb0 --- /dev/null +++ b/bridgev2/matrixinvite.go @@ -0,0 +1,294 @@ +// Copyright (c) 2024 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package bridgev2 + +import ( + "context" + "errors" + "fmt" + "time" + + "github.com/rs/zerolog" + + "maunium.net/go/mautrix/event" + "maunium.net/go/mautrix/format" + "maunium.net/go/mautrix/id" +) + +func (br *Bridge) handleBotInvite(ctx context.Context, evt *event.Event, sender *User) EventHandlingResult { + log := zerolog.Ctx(ctx) + // These invites should already be rejected in QueueMatrixEvent + if !sender.Permissions.Commands { + log.Warn().Msg("Received bot invite from user without permission to send commands") + return EventHandlingResultIgnored + } + err := br.Bot.EnsureJoined(ctx, evt.RoomID) + if err != nil { + log.Err(err).Msg("Failed to accept invite to room") + return EventHandlingResultFailed + } + log.Debug().Msg("Accepted invite to room as bot") + members, err := br.Matrix.GetMembers(ctx, evt.RoomID) + if err != nil { + log.Err(err).Msg("Failed to get members of room after accepting invite") + } + if len(members) == 2 { + var message string + if sender.ManagementRoom == "" { + message = fmt.Sprintf("Hello, I'm a %s bridge bot.\n\nUse `help` for help or `login` to log in.\n\nThis room has been marked as your management room.", br.Network.GetName().DisplayName) + sender.ManagementRoom = evt.RoomID + err = br.DB.User.Update(ctx, sender.User) + if err != nil { + log.Err(err).Msg("Failed to update user's management room in database") + } + } else { + message = fmt.Sprintf("Hello, I'm a %s bridge bot.\n\nUse `%s help` for help.", br.Network.GetName().DisplayName, br.Config.CommandPrefix) + } + _, err = br.Bot.SendMessage(ctx, evt.RoomID, event.EventMessage, &event.Content{ + Parsed: format.RenderMarkdown(message, true, false), + }, nil) + if err != nil { + log.Err(err).Msg("Failed to send welcome message to room") + } + } + return EventHandlingResultSuccess +} + +func sendNotice(ctx context.Context, evt *event.Event, intent MatrixAPI, message string, args ...any) { + if len(args) > 0 { + message = fmt.Sprintf(message, args...) + } + content := format.RenderMarkdown(message, true, false) + content.MsgType = event.MsgNotice + resp, err := intent.SendMessage(ctx, evt.RoomID, event.EventMessage, &event.Content{Parsed: content}, nil) + if err != nil { + zerolog.Ctx(ctx).Err(err). + Stringer("room_id", evt.RoomID). + Stringer("inviter_id", evt.Sender). + Stringer("invitee_id", intent.GetMXID()). + Str("notice_text", message). + Msg("Failed to send notice") + } else { + zerolog.Ctx(ctx).Debug(). + Stringer("notice_event_id", resp.EventID). + Stringer("room_id", evt.RoomID). + Stringer("inviter_id", evt.Sender). + Stringer("invitee_id", intent.GetMXID()). + Str("notice_text", message). + Msg("Sent notice") + } +} + +func sendErrorAndLeave(ctx context.Context, evt *event.Event, intent MatrixAPI, message string, args ...any) { + sendNotice(ctx, evt, intent, message, args...) + rejectInvite(ctx, evt, intent, "") +} + +func (portal *Portal) CleanupOrphanedDM(ctx context.Context, userMXID id.UserID) { + if portal.MXID == "" { + return + } + log := zerolog.Ctx(ctx) + existingPortalMembers, err := portal.Bridge.Matrix.GetMembers(ctx, portal.MXID) + if err != nil { + log.Err(err). + Stringer("old_portal_mxid", portal.MXID). + Msg("Failed to check existing portal members, deleting room") + } else if targetUserMember, ok := existingPortalMembers[userMXID]; !ok { + log.Debug(). + Stringer("old_portal_mxid", portal.MXID). + Msg("Inviter has no member event in old portal, deleting room") + } else if targetUserMember.Membership.IsInviteOrJoin() { + return + } else { + log.Debug(). + Stringer("old_portal_mxid", portal.MXID). + Str("membership", string(targetUserMember.Membership)). + Msg("Inviter is not in old portal, deleting room") + } + + if err = portal.RemoveMXID(ctx); err != nil { + log.Err(err).Msg("Failed to delete old portal mxid") + } else if err = portal.Bridge.Bot.DeleteRoom(ctx, portal.MXID, true); err != nil { + log.Err(err).Msg("Failed to clean up old portal room") + } +} + +func (br *Bridge) handleGhostDMInvite(ctx context.Context, evt *event.Event, sender *User) EventHandlingResult { + ghostID, _ := br.Matrix.ParseGhostMXID(id.UserID(evt.GetStateKey())) + validator, ok := br.Network.(IdentifierValidatingNetwork) + if ghostID == "" || (ok && !validator.ValidateUserID(ghostID)) { + rejectInvite(ctx, evt, br.Matrix.GhostIntent(ghostID), "Malformed user ID") + return EventHandlingResultIgnored + } + log := zerolog.Ctx(ctx).With(). + Str("invitee_network_id", string(ghostID)). + Stringer("room_id", evt.RoomID). + Logger() + // TODO sort in preference order + logins := sender.GetUserLogins() + if len(logins) == 0 { + rejectInvite(ctx, evt, br.Matrix.GhostIntent(ghostID), "You're not logged in") + return EventHandlingResultIgnored + } + _, ok = logins[0].Client.(IdentifierResolvingNetworkAPI) + if !ok { + rejectInvite(ctx, evt, br.Matrix.GhostIntent(ghostID), "This bridge does not support starting chats") + return EventHandlingResultIgnored + } + invitedGhost, err := br.GetGhostByID(ctx, ghostID) + if err != nil { + log.Err(err).Msg("Failed to get invited ghost") + return EventHandlingResultFailed + } + err = invitedGhost.Intent.EnsureJoined(ctx, evt.RoomID) + if err != nil { + log.Err(err).Msg("Failed to accept invite to room") + return EventHandlingResultFailed + } + var resp *CreateChatResponse + var sourceLogin *UserLogin + // TODO this should somehow lock incoming event processing to avoid race conditions where a new portal room is created + // between ResolveIdentifier returning and the portal MXID being updated. + for _, login := range logins { + api, ok := login.Client.(IdentifierResolvingNetworkAPI) + if !ok { + continue + } + var resolveResp *ResolveIdentifierResponse + ghostAPI, ok := login.Client.(GhostDMCreatingNetworkAPI) + if ok { + resp, err = ghostAPI.CreateChatWithGhost(ctx, invitedGhost) + } else { + resolveResp, err = api.ResolveIdentifier(ctx, string(ghostID), true) + if resolveResp != nil { + resp = resolveResp.Chat + } + } + if errors.Is(err, ErrResolveIdentifierTryNext) { + log.Debug().Err(err).Str("login_id", string(login.ID)).Msg("Failed to resolve identifier, trying next login") + continue + } else if err != nil { + log.Err(err).Msg("Failed to resolve identifier") + sendErrorAndLeave(ctx, evt, invitedGhost.Intent, "Failed to create chat") + return EventHandlingResultFailed + } else { + sourceLogin = login + break + } + } + if resp == nil { + log.Warn().Msg("No login could resolve the identifier") + sendErrorAndLeave(ctx, evt, br.Matrix.GhostIntent(ghostID), "Failed to create chat via any login") + return EventHandlingResultFailed + } + portal := resp.Portal + if portal == nil { + portal, err = br.GetPortalByKey(ctx, resp.PortalKey) + if err != nil { + log.Err(err).Msg("Failed to get portal by key") + sendErrorAndLeave(ctx, evt, br.Matrix.GhostIntent(ghostID), "Failed to create portal entry") + return EventHandlingResultFailed + } + } + portal.CleanupOrphanedDM(ctx, sender.MXID) + err = invitedGhost.Intent.EnsureInvited(ctx, evt.RoomID, br.Bot.GetMXID()) + if err != nil { + log.Err(err).Msg("Failed to ensure bot is invited to room") + sendErrorAndLeave(ctx, evt, invitedGhost.Intent, "Failed to invite bridge bot") + return EventHandlingResultFailed + } + err = br.Bot.EnsureJoined(ctx, evt.RoomID) + if err != nil { + log.Err(err).Msg("Failed to ensure bot is joined to room") + sendErrorAndLeave(ctx, evt, invitedGhost.Intent, "Failed to join with bridge bot") + return EventHandlingResultFailed + } + + portal.roomCreateLock.Lock() + defer portal.roomCreateLock.Unlock() + portalMXID := portal.MXID + if portalMXID != "" { + sendErrorAndLeave(ctx, evt, invitedGhost.Intent, "You already have a direct chat with me at [%s](%s)", portalMXID, portalMXID.URI(br.Matrix.ServerName()).MatrixToURL()) + rejectInvite(ctx, evt, br.Bot, "") + return EventHandlingResultSuccess + } + err = br.givePowerToBot(ctx, evt.RoomID, invitedGhost.Intent) + if err != nil { + log.Err(err).Msg("Failed to give permissions to bridge bot") + sendErrorAndLeave(ctx, evt, invitedGhost.Intent, "Failed to give permissions to bridge bot") + rejectInvite(ctx, evt, br.Bot, "") + return EventHandlingResultSuccess + } + overrideIntent := invitedGhost.Intent + if resp.DMRedirectedTo != "" && resp.DMRedirectedTo != invitedGhost.ID { + log.Debug(). + Str("dm_redirected_to_id", string(resp.DMRedirectedTo)). + Msg("Created DM was redirected to another user ID") + _, err = invitedGhost.Intent.SendState(ctx, evt.RoomID, event.StateMember, invitedGhost.Intent.GetMXID().String(), &event.Content{ + Parsed: &event.MemberEventContent{ + Membership: event.MembershipLeave, + Reason: "Direct chat redirected to another internal user ID", + }, + }, time.Time{}) + if err != nil { + log.Err(err).Msg("Failed to make incorrect ghost leave new DM room") + } + if resp.DMRedirectedTo == SpecialValueDMRedirectedToBot { + overrideIntent = br.Bot + } else if otherUserGhost, err := br.GetGhostByID(ctx, resp.DMRedirectedTo); err != nil { + log.Err(err).Msg("Failed to get ghost of real portal other user ID") + } else { + invitedGhost = otherUserGhost + overrideIntent = otherUserGhost.Intent + } + } + err = portal.UpdateMatrixRoomID(ctx, evt.RoomID, UpdateMatrixRoomIDParams{ + // We locked it before checking the mxid + RoomCreateAlreadyLocked: true, + + FailIfMXIDSet: true, + ChatInfo: resp.PortalInfo, + ChatInfoSource: sourceLogin, + }) + if err != nil { + log.Err(err).Msg("Failed to update Matrix room ID for new DM portal") + sendNotice(ctx, evt, overrideIntent, "Failed to finish configuring portal. The chat may or may not work") + return EventHandlingResultSuccess + } + message := "Private chat portal created" + mx, ok := br.Matrix.(MatrixConnectorWithPostRoomBridgeHandling) + if ok { + err = mx.HandleNewlyBridgedRoom(ctx, evt.RoomID) + if err != nil { + log.Err(err).Msg("Error in connector newly bridged room handler") + message += fmt.Sprintf("\n\nWarning: %s", err.Error()) + } + } + sendNotice(ctx, evt, overrideIntent, message) + return EventHandlingResultSuccess +} + +func (br *Bridge) givePowerToBot(ctx context.Context, roomID id.RoomID, userWithPower MatrixAPI) error { + powers, err := br.Matrix.GetPowerLevels(ctx, roomID) + if err != nil { + return fmt.Errorf("failed to get power levels: %w", err) + } + userLevel := powers.GetUserLevel(userWithPower.GetMXID()) + if powers.EnsureUserLevelAs(userWithPower.GetMXID(), br.Bot.GetMXID(), userLevel) { + if userLevel > powers.UsersDefault { + powers.SetUserLevel(userWithPower.GetMXID(), userLevel-1) + } + _, err = userWithPower.SendState(ctx, roomID, event.StatePowerLevels, "", &event.Content{ + Parsed: powers, + }, time.Time{}) + if err != nil { + return fmt.Errorf("failed to give power to bot: %w", err) + } + } + return nil +} diff --git a/bridgev2/messagestatus.go b/bridgev2/messagestatus.go new file mode 100644 index 00000000..df0c9e4d --- /dev/null +++ b/bridgev2/messagestatus.go @@ -0,0 +1,239 @@ +// Copyright (c) 2024 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package bridgev2 + +import ( + "errors" + "fmt" + + "go.mau.fi/util/jsontime" + + "maunium.net/go/mautrix/appservice" + "maunium.net/go/mautrix/bridgev2/status" + "maunium.net/go/mautrix/event" + "maunium.net/go/mautrix/id" +) + +type MessageStatusEventInfo struct { + RoomID id.RoomID + TransactionID string + SourceEventID id.EventID + NewEventID id.EventID + EventType event.Type + MessageType event.MessageType + Sender id.UserID + ThreadRoot id.EventID + StreamOrder int64 + + IsSourceEventDoublePuppeted bool +} + +func StatusEventInfoFromEvent(evt *event.Event) *MessageStatusEventInfo { + var threadRoot id.EventID + if relatable, ok := evt.Content.Parsed.(event.Relatable); ok { + threadRoot = relatable.OptionalGetRelatesTo().GetThreadParent() + } + + _, isDoublePuppeted := evt.Content.Raw[appservice.DoublePuppetKey] + + return &MessageStatusEventInfo{ + RoomID: evt.RoomID, + TransactionID: evt.Unsigned.TransactionID, + SourceEventID: evt.ID, + EventType: evt.Type, + MessageType: evt.Content.AsMessage().MsgType, + Sender: evt.Sender, + ThreadRoot: threadRoot, + + IsSourceEventDoublePuppeted: isDoublePuppeted, + } +} + +// MessageStatus represents the status of a message. It also implements the error interface to allow network connectors +// to return errors which get translated into user-friendly error messages and/or status events. +type MessageStatus struct { + Step status.MessageCheckpointStep + RetryNum int + + Status event.MessageStatus + ErrorReason event.MessageStatusReason + DeliveredTo []id.UserID + InternalError error // Internal error to be tracked in message checkpoints + Message string // Human-readable message shown to users + + ErrorAsMessage bool + IsCertain bool + SendNotice bool + DisableMSS bool +} + +func WrapErrorInStatus(err error) MessageStatus { + var alreadyWrapped MessageStatus + var ok bool + if alreadyWrapped, ok = err.(MessageStatus); ok { + return alreadyWrapped + } else if errors.As(err, &alreadyWrapped) { + alreadyWrapped.InternalError = err + return alreadyWrapped + } + return MessageStatus{ + Status: event.MessageStatusRetriable, + ErrorReason: event.MessageStatusGenericError, + InternalError: err, + } +} + +func (ms MessageStatus) WithSendNotice(send bool) MessageStatus { + ms.SendNotice = send + return ms +} + +func (ms MessageStatus) WithIsCertain(certain bool) MessageStatus { + ms.IsCertain = certain + return ms +} + +func (ms MessageStatus) WithMessage(msg string) MessageStatus { + ms.Message = msg + ms.ErrorAsMessage = false + return ms +} + +func (ms MessageStatus) WithStep(step status.MessageCheckpointStep) MessageStatus { + ms.Step = step + return ms +} + +func (ms MessageStatus) WithStatus(status event.MessageStatus) MessageStatus { + ms.Status = status + return ms +} + +func (ms MessageStatus) WithErrorReason(reason event.MessageStatusReason) MessageStatus { + ms.ErrorReason = reason + return ms +} + +func (ms MessageStatus) WithErrorAsMessage() MessageStatus { + ms.ErrorAsMessage = true + return ms +} + +func (ms MessageStatus) Error() string { + return ms.InternalError.Error() +} + +func (ms MessageStatus) Unwrap() error { + return ms.InternalError +} + +func (ms *MessageStatus) checkpointStatus() status.MessageCheckpointStatus { + switch ms.Status { + case event.MessageStatusSuccess: + if len(ms.DeliveredTo) > 0 { + return status.MsgStatusDelivered + } + return status.MsgStatusSuccess + case event.MessageStatusPending: + return status.MsgStatusWillRetry + case event.MessageStatusRetriable, event.MessageStatusFail: + switch ms.ErrorReason { + case event.MessageStatusTooOld: + return status.MsgStatusTimeout + case event.MessageStatusUnsupported: + return status.MsgStatusUnsupported + default: + return status.MsgStatusPermFailure + } + default: + return "UNKNOWN" + } +} + +func (ms *MessageStatus) ToCheckpoint(evt *MessageStatusEventInfo) *status.MessageCheckpoint { + step := status.MsgStepRemote + if ms.Step != "" { + step = ms.Step + } + checkpoint := &status.MessageCheckpoint{ + RoomID: evt.RoomID, + EventID: evt.SourceEventID, + Step: step, + Timestamp: jsontime.UnixMilliNow(), + Status: ms.checkpointStatus(), + RetryNum: ms.RetryNum, + ReportedBy: status.MsgReportedByBridge, + EventType: evt.EventType, + MessageType: evt.MessageType, + } + if ms.InternalError != nil { + checkpoint.Info = ms.InternalError.Error() + } else if ms.Message != "" { + checkpoint.Info = ms.Message + } + return checkpoint +} + +func (ms *MessageStatus) ToMSSEvent(evt *MessageStatusEventInfo) *event.BeeperMessageStatusEventContent { + content := &event.BeeperMessageStatusEventContent{ + RelatesTo: event.RelatesTo{ + Type: event.RelReference, + EventID: evt.SourceEventID, + }, + TargetTxnID: evt.TransactionID, + Status: ms.Status, + Reason: ms.ErrorReason, + Message: ms.Message, + } + if ms.InternalError != nil { + content.InternalError = ms.InternalError.Error() + if ms.ErrorAsMessage { + content.Message = content.InternalError + } + } + if ms.DeliveredTo != nil { + content.DeliveredToUsers = &ms.DeliveredTo + } + return content +} + +func (ms *MessageStatus) ToNoticeEvent(evt *MessageStatusEventInfo) *event.MessageEventContent { + certainty := "may not have been" + if ms.IsCertain { + certainty = "was not" + } + evtType := "message" + switch evt.EventType { + case event.EventReaction: + evtType = "reaction" + case event.EventRedaction: + evtType = "redaction" + } + msg := ms.Message + if ms.ErrorAsMessage || msg == "" { + msg = ms.InternalError.Error() + } + messagePrefix := fmt.Sprintf("Your %s %s bridged", evtType, certainty) + if ms.Step == status.MsgStepCommand { + messagePrefix = "Handling your command panicked" + } + content := &event.MessageEventContent{ + MsgType: event.MsgNotice, + Body: fmt.Sprintf("\u26a0\ufe0f %s: %s", messagePrefix, msg), + RelatesTo: &event.RelatesTo{}, + Mentions: &event.Mentions{}, + } + if evt.ThreadRoot != "" { + content.RelatesTo.SetThread(evt.ThreadRoot, evt.SourceEventID) + } else { + content.RelatesTo.SetReplyTo(evt.SourceEventID) + } + if evt.Sender != "" { + content.Mentions.UserIDs = []id.UserID{evt.Sender} + } + return content +} diff --git a/bridgev2/networkid/bridgeid.go b/bridgev2/networkid/bridgeid.go new file mode 100644 index 00000000..e3a6df70 --- /dev/null +++ b/bridgev2/networkid/bridgeid.go @@ -0,0 +1,146 @@ +// Copyright (c) 2024 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +// Package networkid contains string types used to represent different kinds of identifiers on remote networks. +// +// Except for [BridgeID], all types in this package are only generated by network connectors. +// Network connectors may generate and parse these types any way they want, all other components +// will treat them as opaque identifiers and will not parse them nor assume anything about them. +// However, identifiers are stored in the bridge database, so backwards-compatibility must be +// considered when changing the format. +// +// All IDs are scoped to a bridge, i.e. they don't need to be unique across different bridges. +// However, most IDs need to be globally unique within the bridge, i.e. the same ID must refer +// to the same entity even from another user's point of view. If the remote network does not +// directly provide such globally unique identifiers, the network connector should prefix them +// with a user ID or other identifier to make them unique. +package networkid + +import ( + "fmt" + + "github.com/rs/zerolog" +) + +// BridgeID is an opaque identifier for a bridge +type BridgeID string + +// PortalID is the ID of a room on the remote network. A portal ID alone should identify group chats +// uniquely, and also DMs when scoped to a user login ID (see [PortalKey]). +type PortalID string + +// PortalKey is the unique key of a room on the remote network. It combines a portal ID and a receiver ID. +// +// The Receiver field is generally only used for DMs, and should be empty for group chats. +// The purpose is to segregate DMs by receiver, so that the same DM has separate rooms even +// if both sides are logged into the bridge. Also, for networks that use user IDs as DM chat IDs, +// the receiver is necessary to have separate rooms for separate users who have a DM with the same +// remote user. +// +// It is also permitted to use a non-empty receiver for group chats if there is a good reason to +// segregate them. For example, Telegram's non-supergroups have user-scoped message IDs instead +// of chat-scoped IDs, which is easier to manage with segregated rooms. +// +// As a special case, Receiver MUST be set if the Bridge.Config.SplitPortals flag is set to true. +// The flag is intended for puppeting-only bridges which want multiple logins to create separate portals for each user. +type PortalKey struct { + ID PortalID `json:"portal_id"` + Receiver UserLoginID `json:"portal_receiver,omitempty"` +} + +func (pk PortalKey) IsEmpty() bool { + return pk.ID == "" && pk.Receiver == "" +} + +func (pk PortalKey) String() string { + if pk.Receiver == "" { + return string(pk.ID) + } + return fmt.Sprintf("%s/%s", pk.ID, pk.Receiver) +} + +func (pk PortalKey) MarshalZerologObject(evt *zerolog.Event) { + evt.Str("portal_id", string(pk.ID)) + if pk.Receiver != "" { + evt.Str("portal_receiver", string(pk.Receiver)) + } +} + +// UserID is the ID of a user on the remote network. +// +// User IDs must be globally unique within the bridge for identifying a specific remote user. +type UserID string + +// UserLoginID is the ID of the user being controlled on the remote network. +// +// It may be the same shape as [UserID]. However, being the same shape is not required, and the +// central bridge module and Matrix connectors will never assume it is. Instead, the bridge will +// use methods like [maunium.net/go/mautrix/bridgev2.NetworkAPI.IsThisUser] to check if a user ID +// is associated with a given UserLogin. +// The network connector is of course allowed to assume a UserLoginID is equivalent to a UserID, +// because it is the one defining both types. +type UserLoginID string + +// MessageID is the ID of a message on the remote network. +// +// Message IDs must be unique across rooms and consistent across users (i.e. globally unique within the bridge). +type MessageID string + +// TransactionID is a client-generated identifier for a message send operation on the remote network. +// +// Transaction IDs must be unique across users in a room, but don't need to be unique across different rooms. +type TransactionID string + +// RawTransactionID is a client-generated identifier for a message send operation on the remote network. +// +// Unlike TransactionID, RawTransactionID's are only used for sending and don't have any uniqueness requirements. +type RawTransactionID string + +// PartID is the ID of a message part on the remote network (e.g. index of image in album). +// +// Part IDs are only unique within a message, not globally. +// To refer to a specific message part globally, use the MessagePartID tuple struct. +type PartID string + +// MessagePartID refers to a specific part of a message by combining a message ID and a part ID. +type MessagePartID struct { + MessageID MessageID + PartID PartID +} + +// MessageOptionalPartID refers to a specific part of a message by combining a message ID and an optional part ID. +// If the part ID is not set, this should refer to the first part ID sorted alphabetically. +type MessageOptionalPartID struct { + MessageID MessageID + PartID *PartID +} + +// PaginationCursor is a cursor used for paginating message history. +type PaginationCursor string + +// AvatarID is the ID of a user or room avatar on the remote network. +// +// It may be a real URL, an opaque identifier, or anything in between. It should be an identifier that +// can be acquired from the remote network without downloading the entire avatar. +// +// In general, it is preferred to use a stable identifier which only changes when the avatar changes. +// However, the bridge will also hash the avatar data to check for changes before sending an avatar +// update to Matrix, so the avatar ID being slightly unstable won't be the end of the world. +type AvatarID string + +// EmojiID is the ID of a reaction emoji on the remote network. +// +// On networks that only allow one reaction per message, an empty string should be used +// to apply the unique constraints in the database appropriately. +// On networks that allow multiple emojis, this is the unicode emoji or a network-specific shortcode. +type EmojiID string + +// MediaID represents a media identifier that can be downloaded from the remote network at any point in the future. +// +// This is used to implement on-demand media downloads. The network connector can ask the Matrix connector +// to generate a content URI from a media ID. Then, when the Matrix connector wants to download the media, +// it will parse the content URI and ask the network connector for the data using the media ID. +type MediaID []byte diff --git a/bridgev2/networkinterface.go b/bridgev2/networkinterface.go new file mode 100644 index 00000000..efc5f100 --- /dev/null +++ b/bridgev2/networkinterface.go @@ -0,0 +1,1450 @@ +// Copyright (c) 2024 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package bridgev2 + +import ( + "context" + "encoding/json" + "fmt" + "strings" + "time" + + "github.com/rs/zerolog" + "go.mau.fi/util/configupgrade" + "go.mau.fi/util/ptr" + "go.mau.fi/util/random" + + "maunium.net/go/mautrix" + "maunium.net/go/mautrix/bridgev2/database" + "maunium.net/go/mautrix/bridgev2/networkid" + "maunium.net/go/mautrix/event" + "maunium.net/go/mautrix/id" + "maunium.net/go/mautrix/mediaproxy" +) + +type ConvertedMessagePart struct { + ID networkid.PartID + Type event.Type + Content *event.MessageEventContent + Extra map[string]any + DBMetadata any + DontBridge bool +} + +func (cmp *ConvertedMessagePart) ToEditPart(part *database.Message) *ConvertedEditPart { + if cmp == nil { + return nil + } + if cmp.DBMetadata != nil { + merger, ok := part.Metadata.(database.MetaMerger) + if ok { + merger.CopyFrom(cmp.DBMetadata) + } else { + part.Metadata = cmp.DBMetadata + } + } + return &ConvertedEditPart{ + Part: part, + Type: cmp.Type, + Content: cmp.Content, + Extra: cmp.Extra, + DontBridge: cmp.DontBridge, + } +} + +// EventSender represents a specific user in a chat. +type EventSender struct { + // If IsFromMe is true, the UserLogin who the event was received through is used as the sender. + // Double puppeting will be used if available. + IsFromMe bool + // SenderLogin is the ID of the UserLogin who sent the event. This may be different from the + // login the event was received through. It is used to ensure double puppeting can still be + // used even if the event is received through another login. + SenderLogin networkid.UserLoginID + // Sender is the remote user ID of the user who sent the event. + // For new events, this will not be used for double puppeting. + // + // However, in the member list, [ChatMemberList.CheckAllLogins] can be specified to go through every login + // and call [NetworkAPI.IsThisUser] to check if this ID belongs to that login. This method is not recommended, + // it is better to fill the IsFromMe and SenderLogin fields appropriately. + Sender networkid.UserID + + // ForceDMUser can be set if the event should be sent as the DM user even if the Sender is different. + // This only applies in DM rooms where [database.Portal.OtherUserID] is set and is ignored if IsFromMe is true. + // A warning will be logged if the sender is overridden due to this flag. + ForceDMUser bool +} + +func (es EventSender) MarshalZerologObject(evt *zerolog.Event) { + evt.Str("user_id", string(es.Sender)) + if string(es.SenderLogin) != string(es.Sender) { + evt.Str("sender_login", string(es.SenderLogin)) + } + if es.IsFromMe { + evt.Bool("is_from_me", true) + } + if es.ForceDMUser { + evt.Bool("force_dm_user", true) + } +} + +type ConvertedMessage struct { + ReplyTo *networkid.MessageOptionalPartID + // Optional additional info about the reply. This is only used when backfilling messages + // on Beeper, where replies may target messages that haven't been bridged yet. + // Standard Matrix servers can't backwards backfill, so these are never used. + ReplyToRoom networkid.PortalKey + ReplyToUser networkid.UserID + ReplyToLogin networkid.UserLoginID + + ThreadRoot *networkid.MessageID + Parts []*ConvertedMessagePart + Disappear database.DisappearingSetting +} + +func MergeCaption(textPart, mediaPart *ConvertedMessagePart) *ConvertedMessagePart { + if textPart == nil { + return mediaPart + } else if mediaPart == nil { + return textPart + } + mediaPart = ptr.Clone(mediaPart) + if mediaPart.Content.MsgType == event.MsgNotice || (mediaPart.Content.Body != "" && mediaPart.Content.FileName != "" && mediaPart.Content.Body != mediaPart.Content.FileName) { + textPart = ptr.Clone(textPart) + textPart.Content.EnsureHasHTML() + mediaPart.Content.EnsureHasHTML() + mediaPart.Content.Body += "\n\n" + textPart.Content.Body + mediaPart.Content.FormattedBody += "

" + textPart.Content.FormattedBody + mediaPart.Content.Mentions = mediaPart.Content.Mentions.Merge(textPart.Content.Mentions) + mediaPart.Content.BeeperLinkPreviews = append(mediaPart.Content.BeeperLinkPreviews, textPart.Content.BeeperLinkPreviews...) + } else { + mediaPart.Content.FileName = mediaPart.Content.Body + mediaPart.Content.Body = textPart.Content.Body + mediaPart.Content.Format = textPart.Content.Format + mediaPart.Content.FormattedBody = textPart.Content.FormattedBody + mediaPart.Content.Mentions = textPart.Content.Mentions + mediaPart.Content.BeeperLinkPreviews = textPart.Content.BeeperLinkPreviews + } + if metaMerger, ok := mediaPart.DBMetadata.(database.MetaMerger); ok { + metaMerger.CopyFrom(textPart.DBMetadata) + } + mediaPart.ID = textPart.ID + return mediaPart +} + +func (cm *ConvertedMessage) MergeCaption() bool { + if len(cm.Parts) != 2 { + return false + } + textPart, mediaPart := cm.Parts[1], cm.Parts[0] + if textPart.Content.MsgType != event.MsgText { + textPart, mediaPart = mediaPart, textPart + } + if (!mediaPart.Content.MsgType.IsMedia() && mediaPart.Content.MsgType != event.MsgNotice) || textPart.Content.MsgType != event.MsgText { + return false + } + merged := MergeCaption(textPart, mediaPart) + if merged != nil { + cm.Parts = []*ConvertedMessagePart{merged} + return true + } + return false +} + +type ConvertedEditPart struct { + Part *database.Message + + Type event.Type + // The Content and Extra fields will be put inside `m.new_content` automatically. + // SetEdit must NOT be called by the network connector. + Content *event.MessageEventContent + Extra map[string]any + // TopLevelExtra can be used to specify custom fields at the top level of the content rather than inside `m.new_content`. + TopLevelExtra map[string]any + // NewMentions can be used to specify new mentions that should ping the users again. + // Mentions inside the edited content will not ping. + NewMentions *event.Mentions + + DontBridge bool +} + +type ConvertedEdit struct { + ModifiedParts []*ConvertedEditPart + DeletedParts []*database.Message + // Warning: added parts will be sent at the end of the room. + // If other messages have been sent after the message being edited, + // these new parts will not be next to the existing parts. + AddedParts *ConvertedMessage +} + +// BridgeName contains information about the network that a connector bridges to. +type BridgeName struct { + // The displayname of the network, e.g. `Discord` + DisplayName string `json:"displayname"` + // The URL to the website of the network, e.g. `https://discord.com` + NetworkURL string `json:"network_url"` + // The icon of the network as a mxc:// URI + NetworkIcon id.ContentURIString `json:"network_icon"` + // An identifier uniquely identifying the network, e.g. `discord` + NetworkID string `json:"network_id"` + // An identifier uniquely identifying the bridge software. + // The Go import path is a good choice here (e.g. github.com/octocat/discordbridge) + BeeperBridgeType string `json:"beeper_bridge_type"` + // The default appservice port to use in the example config, defaults to 8080 if unset + // Official mautrix bridges will use ports defined in https://mau.fi/ports + DefaultPort uint16 `json:"default_port,omitempty"` + // The default command prefix to use in the example config, defaults to NetworkID if unset. Must include the ! prefix. + DefaultCommandPrefix string `json:"default_command_prefix,omitempty"` +} + +func (bn BridgeName) AsBridgeInfoSection() event.BridgeInfoSection { + return event.BridgeInfoSection{ + ID: bn.BeeperBridgeType, + DisplayName: bn.DisplayName, + AvatarURL: bn.NetworkIcon, + ExternalURL: bn.NetworkURL, + } +} + +// NetworkConnector is the main interface that a network connector must implement. +type NetworkConnector interface { + // Init is called when the bridge is initialized. The connector should store the bridge instance for later use. + // This should not do any network calls or other blocking operations. + Init(*Bridge) + // Start is called when the bridge is starting. + // The connector should do any non-user-specific startup actions necessary. + // User logins will be loaded separately, so the connector should not load them here. + Start(context.Context) error + + // GetName returns the name of the bridge and some additional metadata, + // which is used to fill `m.bridge` events among other things. + // + // The first call happens *before* the config is loaded, because the data here is also used to + // fill parts of the example config (like the default username template and bot localpart). + // The output can still be adjusted based on config variables, but the function must have + // default values when called without a config. + GetName() BridgeName + // GetDBMetaTypes returns struct types that are used to store connector-specific metadata in various tables. + // All fields are optional. If a field isn't provided, then the corresponding table will have no custom metadata. + // This will be called before Init, it should have a hardcoded response. + GetDBMetaTypes() database.MetaTypes + // GetCapabilities returns the general capabilities of the network connector. + // Note that most capabilities are scoped to rooms and are returned by [NetworkAPI.GetCapabilities] instead. + GetCapabilities() *NetworkGeneralCapabilities + // GetConfig returns all the parts of the network connector's config file. Specifically: + // - example: a string containing an example config file + // - data: an interface to unmarshal the actual config into + // - upgrader: a config upgrader to ensure all fields are present and to do any migrations from old configs + GetConfig() (example string, data any, upgrader configupgrade.Upgrader) + + // LoadUserLogin is called when a UserLogin is loaded from the database in order to fill the [UserLogin.Client] field. + // + // This is called within the bridge's global cache lock, so it must not do any slow operations, + // such as connecting to the network. Instead, connecting should happen when [NetworkAPI.Connect] is called later. + LoadUserLogin(ctx context.Context, login *UserLogin) error + + // GetLoginFlows returns a list of login flows that the network supports. + GetLoginFlows() []LoginFlow + // CreateLogin is called when a user wants to log in to the network. + // + // This should generally not do any work, it should just return a LoginProcess that remembers + // the user and will execute the requested flow. The actual work should start when [LoginProcess.Start] is called. + CreateLogin(ctx context.Context, user *User, flowID string) (LoginProcess, error) + + // GetBridgeInfoVersion returns version numbers for bridge info and room capabilities respectively. + // When the versions change, the bridge will automatically resend bridge info to all rooms. + GetBridgeInfoVersion() (info, capabilities int) +} + +type StoppableNetwork interface { + NetworkConnector + // Stop is called when the bridge is stopping, after all network clients have been disconnected. + Stop() +} + +// DirectMediableNetwork is an optional interface that network connectors can implement to support direct media access. +// +// If the Matrix connector has direct media enabled, SetUseDirectMedia will be called +// before the Start method of the network connector. Download will then be called +// whenever someone wants to download a direct media `mxc://` URI which was generated +// by calling GenerateContentURI on the Matrix connector. +type DirectMediableNetwork interface { + NetworkConnector + SetUseDirectMedia() + Download(ctx context.Context, mediaID networkid.MediaID, params map[string]string) (mediaproxy.GetMediaResponse, error) +} + +// IdentifierValidatingNetwork is an optional interface that network connectors can implement to validate the shape of user IDs. +// +// This should not perform any checks to see if the user ID actually exists on the network, just that the user ID looks valid. +type IdentifierValidatingNetwork interface { + NetworkConnector + ValidateUserID(id networkid.UserID) bool +} + +type TransactionIDGeneratingNetwork interface { + NetworkConnector + GenerateTransactionID(userID id.UserID, roomID id.RoomID, eventType event.Type) networkid.RawTransactionID +} + +type PortalBridgeInfoFillingNetwork interface { + NetworkConnector + FillPortalBridgeInfo(portal *Portal, content *event.BridgeEventContent) +} + +// ConfigValidatingNetwork is an optional interface that network connectors can implement to validate config fields +// before the bridge is started. +// +// When the ValidateConfig method is called, the config data will already be unmarshaled into the +// object returned by [NetworkConnector.GetConfig]. +// +// This mechanism is usually used to refuse bridge startup if a mandatory field has an invalid value. +type ConfigValidatingNetwork interface { + NetworkConnector + ValidateConfig() error +} + +// MaxFileSizeingNetwork is an optional interface that network connectors can implement +// to find out the maximum file size that can be uploaded to Matrix. +// +// The SetMaxFileSize will be called asynchronously soon after startup. +// Before the function is called, the connector may assume a default limit of 50 MiB. +type MaxFileSizeingNetwork interface { + NetworkConnector + SetMaxFileSize(maxSize int64) +} + +type NetworkResettingNetwork interface { + NetworkConnector + // ResetHTTPTransport should recreate the HTTP client used by the bridge. + // It should refetch settings from the Matrix connector using GetHTTPClientSettings if applicable. + ResetHTTPTransport() + // ResetNetworkConnections should forcefully disconnect and restart any persistent network connections. + // ResetHTTPTransport will usually be called before this, so resetting the transport is not necessary here. + ResetNetworkConnections() +} + +type RemoteEchoHandler func(RemoteMessage, *database.Message) (bool, error) + +type MatrixMessageResponse struct { + DB *database.Message + StreamOrder int64 + // If Pending is set, the bridge will not save the provided message to the database. + // This should only be used if AddPendingToSave has been called. + Pending bool + // If RemovePending is set, the bridge will remove the provided transaction ID from pending messages + // after saving the provided message to the database. This should be used with AddPendingToIgnore. + RemovePending networkid.TransactionID + // An optional function that is called after the message is saved to the database. + // Will not be called if the message is not saved for some reason. + PostSave func(context.Context, *database.Message) +} + +type OutgoingTimeoutConfig struct { + CheckInterval time.Duration + NoEchoTimeout time.Duration + NoEchoMessage string + NoAckTimeout time.Duration + NoAckMessage string +} + +type NetworkGeneralCapabilities struct { + // Does the network connector support disappearing messages? + // This flag enables the message disappearing loop in the bridge. + DisappearingMessages bool + // Should the bridge re-request user info on incoming messages even if the ghost already has info? + // By default, info is only requested for ghosts with no name, and other updating is left to events. + AggressiveUpdateInfo bool + // Should the bridge call HandleMatrixReadReceipt with fake data when receiving a new message? + // This should be enabled if the network requires each message to be marked as read independently, + // and doesn't automatically do it when sending a message. + ImplicitReadReceipts bool + // If the bridge uses the pending message mechanism ([MatrixMessage.AddPendingToSave]) + // to handle asynchronous message responses, this field can be set to enable + // automatic timeout errors in case the asynchronous response never arrives. + OutgoingMessageTimeouts *OutgoingTimeoutConfig + // Capabilities related to the provisioning API. + Provisioning ProvisioningCapabilities +} + +// NetworkAPI is an interface representing a remote network client for a single user login. +// +// Implementations of this interface are stored in [UserLogin.Client]. +// The [NetworkConnector.LoadUserLogin] method is responsible for filling the Client field with a NetworkAPI. +type NetworkAPI interface { + // Connect is called to actually connect to the remote network. + // If there's no persistent connection, this may just check access token validity, or even do nothing at all. + // This method isn't allowed to return errors, because any connection errors should be sent + // using the bridge state mechanism (UserLogin.BridgeState.Send) + Connect(ctx context.Context) + // Disconnect should disconnect from the remote network. + // A clean disconnection is preferred, but it should not take too long. + Disconnect() + // IsLoggedIn should return whether the access tokens in this NetworkAPI are valid. + // This should not do any IO operations, it should only return cached data which is updated elsewhere. + IsLoggedIn() bool + // LogoutRemote should invalidate the access tokens in this NetworkAPI if possible + // and disconnect from the remote network. + LogoutRemote(ctx context.Context) + + // IsThisUser should return whether the given remote network user ID is the same as this login. + // This is used when the bridge wants to convert a user login ID to a user ID. + IsThisUser(ctx context.Context, userID networkid.UserID) bool + // GetChatInfo returns info for a given chat. Any fields that are nil will be ignored and not processed at all, + // while empty strings will change the relevant value in the room to be an empty string. + // For example, a nil name will mean the room name is not changed, while an empty string name will remove the name. + GetChatInfo(ctx context.Context, portal *Portal) (*ChatInfo, error) + // GetUserInfo returns info for a given user. Like chat info, fields can be nil to skip them. + GetUserInfo(ctx context.Context, ghost *Ghost) (*UserInfo, error) + // GetCapabilities returns the bridging capabilities in a given room. + // This can simply return a static list if the remote network has no per-chat capability differences, + // but all calls will include the portal, because some networks do have per-chat differences. + GetCapabilities(ctx context.Context, portal *Portal) *event.RoomFeatures + + // HandleMatrixMessage is called when a message is sent from Matrix in an existing portal room. + // This function should convert the message as appropriate, send it over to the remote network, + // and return the info so the central bridge can store it in the database. + // + // This is only called for normal non-edit messages. For other types of events, see the optional extra interfaces (`XHandlingNetworkAPI`). + HandleMatrixMessage(ctx context.Context, msg *MatrixMessage) (message *MatrixMessageResponse, err error) +} + +type ConnectBackgroundParams struct { + // RawData is the raw data in the push that triggered the background connection. + RawData json.RawMessage + // ExtraData is the data returned by [PushParsingNetwork.ParsePushNotification]. + // It's only present for native pushes. Relayed pushes will only have the raw data. + ExtraData any +} + +// BackgroundSyncingNetworkAPI is an optional interface that network connectors can implement to support background resyncs. +type BackgroundSyncingNetworkAPI interface { + NetworkAPI + // ConnectBackground is called in place of Connect for background resyncs. + // The client should connect to the remote network, handle pending messages, and then disconnect. + // This call should block until the entire sync is complete and the client is disconnected. + ConnectBackground(ctx context.Context, params *ConnectBackgroundParams) error +} + +// CredentialExportingNetworkAPI is an optional interface that networks connectors can implement to support export of +// the credentials associated with that login. Credential type is bridge specific. +type CredentialExportingNetworkAPI interface { + NetworkAPI + ExportCredentials(ctx context.Context) any +} + +// FetchMessagesParams contains the parameters for a message history pagination request. +type FetchMessagesParams struct { + // The portal to fetch messages in. Always present. + Portal *Portal + // When fetching messages inside a thread, the ID of the thread. + ThreadRoot networkid.MessageID + // Whether to fetch new messages instead of old ones. + Forward bool + // The oldest known message in the thread or the portal. If Forward is true, this is the newest known message instead. + // If the portal doesn't have any bridged messages, this will be nil. + AnchorMessage *database.Message + // The cursor returned by the previous call to FetchMessages with the same portal and thread root. + // This will not be present in Forward calls. + Cursor networkid.PaginationCursor + // The preferred number of messages to return. The returned batch can be bigger or smaller + // without any side effects, but the network connector should aim for this number. + Count int + + // When a forward backfill is triggered by a [RemoteChatResyncBackfillBundle], this will contain + // the bundled data returned by the event. It can be used as an optimization to avoid fetching + // messages that were already provided by the remote network, while still supporting fetching + // more messages if the limit is higher. + BundledData any + + // When the messages are being fetched for a queued backfill, this is the task object. + Task *database.BackfillTask +} + +// BackfillReaction is an individual reaction to a message in a history pagination request. +// +// The target message is always the BackfillMessage that contains this item. +// Optionally, the reaction can target a specific part by specifying TargetPart. +// If not specified, the first part (sorted lexicographically) is targeted. +type BackfillReaction struct { + // Optional part of the message that the reaction targets. + // If nil, the reaction targets the first part of the message. + TargetPart *networkid.PartID + // Optional timestamp for the reaction. + // If unset, the reaction will have a fake timestamp that is slightly after the message timestamp. + Timestamp time.Time + + Sender EventSender + EmojiID networkid.EmojiID + Emoji string + ExtraContent map[string]any + DBMetadata any +} + +// BackfillMessage is an individual message in a history pagination request. +type BackfillMessage struct { + *ConvertedMessage + Sender EventSender + ID networkid.MessageID + TxnID networkid.TransactionID + Timestamp time.Time + StreamOrder int64 + Reactions []*BackfillReaction + + ShouldBackfillThread bool + LastThreadMessage networkid.MessageID +} + +var ( + _ RemoteMessageWithTransactionID = (*BackfillMessage)(nil) + _ RemoteEventWithTimestamp = (*BackfillMessage)(nil) +) + +func (b *BackfillMessage) GetType() RemoteEventType { + return RemoteEventMessage +} + +func (b *BackfillMessage) GetPortalKey() networkid.PortalKey { + panic("GetPortalKey called for BackfillMessage") +} + +func (b *BackfillMessage) AddLogContext(c zerolog.Context) zerolog.Context { + return c +} + +func (b *BackfillMessage) GetSender() EventSender { + return b.Sender +} + +func (b *BackfillMessage) GetID() networkid.MessageID { + return b.ID +} + +func (b *BackfillMessage) GetTransactionID() networkid.TransactionID { + return b.TxnID +} + +func (b *BackfillMessage) GetTimestamp() time.Time { + return b.Timestamp +} + +func (b *BackfillMessage) ConvertMessage(ctx context.Context, portal *Portal, intent MatrixAPI) (*ConvertedMessage, error) { + return b.ConvertedMessage, nil +} + +// FetchMessagesResponse contains the response for a message history pagination request. +type FetchMessagesResponse struct { + // The messages to backfill. Messages should always be sorted in chronological order (oldest to newest). + Messages []*BackfillMessage + // The next cursor to use for fetching more messages. + Cursor networkid.PaginationCursor + // Whether there are more messages that can be backfilled. + // This field is required. If it is false, FetchMessages will not be called again. + HasMore bool + // Whether the batch contains new messages rather than old ones. + // Cursor, HasMore and the progress fields will be ignored when this is present. + Forward bool + // When sending forward backfill (or the first batch in a room), this field can be set + // to mark the messages as read immediately after backfilling. + MarkRead bool + + // Should the bridge check each message against the database to ensure it's not a duplicate before bridging? + // By default, the bridge will only drop messages that are older than the last bridged message for forward backfills, + // or newer than the first for backward. + AggressiveDeduplication bool + + // When HasMore is true, one of the following fields can be set to report backfill progress: + + // Approximate backfill progress as a number between 0 and 1. + ApproxProgress float64 + // Approximate number of messages remaining that can be backfilled. + ApproxRemainingCount int + // Approximate total number of messages in the chat. + ApproxTotalCount int + + // An optional function that is called after the backfill batch has been sent. + CompleteCallback func() +} + +// BackfillingNetworkAPI is an optional interface that network connectors can implement to support backfilling message history. +type BackfillingNetworkAPI interface { + NetworkAPI + // FetchMessages returns a batch of messages to backfill in a portal room. + // For details on the input and output, see the documentation of [FetchMessagesParams] and [FetchMessagesResponse]. + FetchMessages(ctx context.Context, fetchParams FetchMessagesParams) (*FetchMessagesResponse, error) +} + +// BackfillingNetworkAPIWithLimits is an optional interface that network connectors can implement to customize +// the limit for backwards backfilling tasks. It is recommended to implement this by reading the MaxBatchesOverride +// config field with network-specific keys for different room types. +type BackfillingNetworkAPIWithLimits interface { + BackfillingNetworkAPI + // GetBackfillMaxBatchCount is called before a backfill task is executed to determine the maximum number of batches + // that should be backfilled. Return values less than 0 are treated as unlimited. + GetBackfillMaxBatchCount(ctx context.Context, portal *Portal, task *database.BackfillTask) int +} + +// EditHandlingNetworkAPI is an optional interface that network connectors can implement to handle message edits. +type EditHandlingNetworkAPI interface { + NetworkAPI + // HandleMatrixEdit is called when a previously bridged message is edited in a portal room. + // The central bridge module will save the [*database.Message] after this function returns, + // so the network connector is allowed to mutate the provided object. + HandleMatrixEdit(ctx context.Context, msg *MatrixEdit) error +} + +type PollHandlingNetworkAPI interface { + NetworkAPI + HandleMatrixPollStart(ctx context.Context, msg *MatrixPollStart) (*MatrixMessageResponse, error) + HandleMatrixPollVote(ctx context.Context, msg *MatrixPollVote) (*MatrixMessageResponse, error) +} + +// ReactionHandlingNetworkAPI is an optional interface that network connectors can implement to handle message reactions. +type ReactionHandlingNetworkAPI interface { + NetworkAPI + // PreHandleMatrixReaction is called as the first step of handling a reaction. It returns the emoji ID, + // sender user ID and max reaction count to allow the central bridge module to de-duplicate the reaction + // if appropriate. + PreHandleMatrixReaction(ctx context.Context, msg *MatrixReaction) (MatrixReactionPreResponse, error) + // HandleMatrixReaction is called after confirming that the reaction is not a duplicate. + // This is the method that should actually send the reaction to the remote network. + // The returned [database.Reaction] object may be empty: the central bridge module already has + // all the required fields and will fill them automatically if they're empty. However, network + // connectors are allowed to set fields themselves if any extra fields are necessary. + HandleMatrixReaction(ctx context.Context, msg *MatrixReaction) (reaction *database.Reaction, err error) + // HandleMatrixReactionRemove is called when a redaction event is received pointing at a previously + // bridged reaction. The network connector should remove the reaction from the remote network. + HandleMatrixReactionRemove(ctx context.Context, msg *MatrixReactionRemove) error +} + +// RedactionHandlingNetworkAPI is an optional interface that network connectors can implement to handle message deletions. +type RedactionHandlingNetworkAPI interface { + NetworkAPI + // HandleMatrixMessageRemove is called when a previously bridged message is deleted in a portal room. + HandleMatrixMessageRemove(ctx context.Context, msg *MatrixMessageRemove) error +} + +// ReadReceiptHandlingNetworkAPI is an optional interface that network connectors can implement to handle read receipts. +type ReadReceiptHandlingNetworkAPI interface { + NetworkAPI + // HandleMatrixReadReceipt is called when a read receipt is sent in a portal room. + // This will be called even if the target message is not a bridged message. + // Network connectors must gracefully handle [MatrixReadReceipt.ExactMessage] being nil. + // The exact handling is up to the network connector. + HandleMatrixReadReceipt(ctx context.Context, msg *MatrixReadReceipt) error +} + +// ChatViewingNetworkAPI is an optional interface that network connectors can implement to handle viewing chat status. +type ChatViewingNetworkAPI interface { + NetworkAPI + // HandleMatrixViewingChat is called when the user opens a portal room. + // This will never be called by the standard appservice connector, + // as Matrix doesn't have any standard way of signaling chat open status. + // Clients are expected to call this every 5 seconds. There is no signal for closing a chat. + HandleMatrixViewingChat(ctx context.Context, msg *MatrixViewingChat) error +} + +// TypingHandlingNetworkAPI is an optional interface that network connectors can implement to handle typing events. +type TypingHandlingNetworkAPI interface { + NetworkAPI + // HandleMatrixTyping is called when a user starts typing in a portal room. + // In the future, the central bridge module will likely get a loop to automatically repeat + // calls to this function until the user stops typing. + HandleMatrixTyping(ctx context.Context, msg *MatrixTyping) error +} + +type MarkedUnreadHandlingNetworkAPI interface { + NetworkAPI + HandleMarkedUnread(ctx context.Context, msg *MatrixMarkedUnread) error +} + +type MuteHandlingNetworkAPI interface { + NetworkAPI + HandleMute(ctx context.Context, msg *MatrixMute) error +} + +type TagHandlingNetworkAPI interface { + NetworkAPI + HandleRoomTag(ctx context.Context, msg *MatrixRoomTag) error +} + +// RoomNameHandlingNetworkAPI is an optional interface that network connectors can implement to handle room name changes. +type RoomNameHandlingNetworkAPI interface { + NetworkAPI + // HandleMatrixRoomName is called when the name of a portal room is changed. + // This method should update the Name and NameSet fields of the Portal with + // the new name and return true if the change was successful. + // If the change is not successful, then the fields should not be updated. + HandleMatrixRoomName(ctx context.Context, msg *MatrixRoomName) (bool, error) +} + +// RoomAvatarHandlingNetworkAPI is an optional interface that network connectors can implement to handle room avatar changes. +type RoomAvatarHandlingNetworkAPI interface { + NetworkAPI + // HandleMatrixRoomAvatar is called when the avatar of a portal room is changed. + // This method should update the AvatarID, AvatarHash and AvatarMXC fields + // with the new avatar details and return true if the change was successful. + // If the change is not successful, then the fields should not be updated. + HandleMatrixRoomAvatar(ctx context.Context, msg *MatrixRoomAvatar) (bool, error) +} + +// RoomTopicHandlingNetworkAPI is an optional interface that network connectors can implement to handle room topic changes. +type RoomTopicHandlingNetworkAPI interface { + NetworkAPI + // HandleMatrixRoomTopic is called when the topic of a portal room is changed. + // This method should update the Topic and TopicSet fields of the Portal with + // the new topic and return true if the change was successful. + // If the change is not successful, then the fields should not be updated. + HandleMatrixRoomTopic(ctx context.Context, msg *MatrixRoomTopic) (bool, error) +} + +type DisappearTimerChangingNetworkAPI interface { + NetworkAPI + // HandleMatrixDisappearingTimer is called when the disappearing timer of a portal room is changed. + // This method should update the Disappear field of the Portal with the new timer and return true + // if the change was successful. If the change is not successful, then the field should not be updated. + HandleMatrixDisappearingTimer(ctx context.Context, msg *MatrixDisappearingTimer) (bool, error) +} + +// DeleteChatHandlingNetworkAPI is an optional interface that network connectors +// can implement to delete a chat from the remote network. +type DeleteChatHandlingNetworkAPI interface { + NetworkAPI + // HandleMatrixDeleteChat is called when the user explicitly deletes a chat. + HandleMatrixDeleteChat(ctx context.Context, msg *MatrixDeleteChat) error +} + +// MessageRequestAcceptingNetworkAPI is an optional interface that network connectors +// can implement to accept message requests from the remote network. +type MessageRequestAcceptingNetworkAPI interface { + NetworkAPI + // HandleMatrixAcceptMessageRequest is called when the user accepts a message request. + HandleMatrixAcceptMessageRequest(ctx context.Context, msg *MatrixAcceptMessageRequest) error +} + +type BeeperAIStreamHandlingNetworkAPI interface { + NetworkAPI + HandleMatrixBeeperAIStream(ctx context.Context, msg *MatrixBeeperAIStream) error +} + +type ResolveIdentifierResponse struct { + // Ghost is the ghost of the user that the identifier resolves to. + // This field should be set whenever possible. However, it is not required, + // and the central bridge module will not try to create a ghost if it is not set. + Ghost *Ghost + + // UserID is the user ID of the user that the identifier resolves to. + UserID networkid.UserID + // UserInfo contains the info of the user that the identifier resolves to. + // If both this and the Ghost field are set, the central bridge module will + // automatically update the ghost's info with the data here. + UserInfo *UserInfo + + // Chat contains info about the direct chat with the resolved user. + // This field is required when createChat is true in the ResolveIdentifier call, + // and optional otherwise. + Chat *CreateChatResponse +} + +var SpecialValueDMRedirectedToBot = networkid.UserID("__fi.mau.bridgev2.dm_redirected_to_bot::" + random.String(10)) + +type CreateChatResponse struct { + PortalKey networkid.PortalKey + // Portal and PortalInfo are not required, the caller will fetch them automatically based on PortalKey if necessary. + Portal *Portal + PortalInfo *ChatInfo + // If a start DM request (CreateChatWithGhost or ResolveIdentifier) returns the DM to a different user, + // this field should have the user ID of said different user. + DMRedirectedTo networkid.UserID + + FailedParticipants map[networkid.UserID]*CreateChatFailedParticipant +} + +type CreateChatFailedParticipant struct { + Reason string `json:"reason"` + InviteEventType string `json:"invite_event_type,omitempty"` + InviteContent *event.Content `json:"invite_content,omitempty"` + + UserMXID id.UserID `json:"user_mxid,omitempty"` + DMRoomMXID id.RoomID `json:"dm_room_mxid,omitempty"` +} + +// IdentifierResolvingNetworkAPI is an optional interface that network connectors can implement to support starting new direct chats. +type IdentifierResolvingNetworkAPI interface { + NetworkAPI + // ResolveIdentifier is called when the user wants to start a new chat. + // This can happen via the `resolve-identifier` or `start-chat` bridge bot commands, + // or the corresponding provisioning API endpoints. + ResolveIdentifier(ctx context.Context, identifier string, createChat bool) (*ResolveIdentifierResponse, error) +} + +// GhostDMCreatingNetworkAPI is an optional extension to IdentifierResolvingNetworkAPI for starting chats with pre-validated user IDs. +type GhostDMCreatingNetworkAPI interface { + IdentifierResolvingNetworkAPI + // CreateChatWithGhost may be called instead of [IdentifierResolvingNetworkAPI.ResolveIdentifier] + // when starting a chat with an internal user identifier that has been pre-validated using + // [IdentifierValidatingNetwork.ValidateUserID]. If this is not implemented, ResolveIdentifier + // will be used instead (by stringifying the ghost ID). + CreateChatWithGhost(ctx context.Context, ghost *Ghost) (*CreateChatResponse, error) +} + +// ContactListingNetworkAPI is an optional interface that network connectors can implement to provide the user's contact list. +type ContactListingNetworkAPI interface { + NetworkAPI + GetContactList(ctx context.Context) ([]*ResolveIdentifierResponse, error) +} + +type UserSearchingNetworkAPI interface { + IdentifierResolvingNetworkAPI + SearchUsers(ctx context.Context, query string) ([]*ResolveIdentifierResponse, error) +} + +type GroupCreatingNetworkAPI interface { + IdentifierResolvingNetworkAPI + CreateGroup(ctx context.Context, params *GroupCreateParams) (*CreateChatResponse, error) +} + +type PersonalFilteringCustomizingNetworkAPI interface { + NetworkAPI + CustomizePersonalFilteringSpace(req *mautrix.ReqCreateRoom) +} + +type ProvisioningCapabilities struct { + ResolveIdentifier ResolveIdentifierCapabilities `json:"resolve_identifier"` + GroupCreation map[string]GroupTypeCapabilities `json:"group_creation"` +} + +type ResolveIdentifierCapabilities struct { + // Can DMs be created after resolving an identifier? + CreateDM bool `json:"create_dm"` + // Can users be looked up by phone number? + LookupPhone bool `json:"lookup_phone"` + // Can users be looked up by email address? + LookupEmail bool `json:"lookup_email"` + // Can users be looked up by network-specific username? + LookupUsername bool `json:"lookup_username"` + // Can any phone number be contacted without having to validate it via lookup first? + AnyPhone bool `json:"any_phone"` + // Can a contact list be retrieved from the bridge? + ContactList bool `json:"contact_list"` + // Can users be searched by name on the remote network? + Search bool `json:"search"` +} + +type GroupTypeCapabilities struct { + TypeDescription string `json:"type_description"` + + Name GroupFieldCapability `json:"name"` + Username GroupFieldCapability `json:"username"` + Avatar GroupFieldCapability `json:"avatar"` + Topic GroupFieldCapability `json:"topic"` + Disappear GroupFieldCapability `json:"disappear"` + Participants GroupFieldCapability `json:"participants"` + Parent GroupFieldCapability `json:"parent"` +} + +type GroupFieldCapability struct { + // Is setting this field allowed at all in the create request? + // Even if false, the network connector should attempt to set the metadata after group creation, + // as the allowed flag can't be enforced properly when creating a group for an existing Matrix room. + Allowed bool `json:"allowed"` + // Is setting this field mandatory for the creation to succeed? + Required bool `json:"required,omitempty"` + // The minimum/maximum length of the field, if applicable. + // For members, length means the number of members excluding the creator. + MinLength int `json:"min_length,omitempty"` + MaxLength int `json:"max_length,omitempty"` + + // Only for the disappear field: allowed disappearing settings + DisappearSettings *event.DisappearingTimerCapability `json:"settings,omitempty"` + + // This can be used to tell provisionutil not to call ValidateUserID on each participant. + // It only meant to allow hacks where ResolveIdentifier returns a fake ID that isn't actually valid for MXIDs. + SkipIdentifierValidation bool `json:"-"` +} + +type GroupCreateParams struct { + Type string `json:"type,omitempty"` + + Username string `json:"username,omitempty"` + // Clients may also provide MXIDs here, but provisionutil will normalize them, so bridges only need to handle network IDs + Participants []networkid.UserID `json:"participants,omitempty"` + Parent *networkid.PortalKey `json:"parent,omitempty"` + + Name *event.RoomNameEventContent `json:"name,omitempty"` + Avatar *event.RoomAvatarEventContent `json:"avatar,omitempty"` + Topic *event.TopicEventContent `json:"topic,omitempty"` + Disappear *event.BeeperDisappearingTimer `json:"disappear,omitempty"` + + // An existing room ID to bridge to. If unset, a new room will be created. + RoomID id.RoomID `json:"room_id,omitempty"` +} + +type MembershipChangeType struct { + From event.Membership + To event.Membership + IsSelf bool +} + +var ( + AcceptInvite = MembershipChangeType{From: event.MembershipInvite, To: event.MembershipJoin, IsSelf: true} + RevokeInvite = MembershipChangeType{From: event.MembershipInvite, To: event.MembershipLeave} + RejectInvite = MembershipChangeType{From: event.MembershipInvite, To: event.MembershipLeave, IsSelf: true} + BanInvited = MembershipChangeType{From: event.MembershipInvite, To: event.MembershipBan} + ProfileChange = MembershipChangeType{From: event.MembershipJoin, To: event.MembershipJoin, IsSelf: true} + Leave = MembershipChangeType{From: event.MembershipJoin, To: event.MembershipLeave, IsSelf: true} + Kick = MembershipChangeType{From: event.MembershipJoin, To: event.MembershipLeave} + BanJoined = MembershipChangeType{From: event.MembershipJoin, To: event.MembershipBan} + Invite = MembershipChangeType{From: event.MembershipLeave, To: event.MembershipInvite} + Join = MembershipChangeType{From: event.MembershipLeave, To: event.MembershipJoin} + BanLeft = MembershipChangeType{From: event.MembershipLeave, To: event.MembershipBan} + Knock = MembershipChangeType{From: event.MembershipLeave, To: event.MembershipKnock, IsSelf: true} + AcceptKnock = MembershipChangeType{From: event.MembershipKnock, To: event.MembershipInvite} + RejectKnock = MembershipChangeType{From: event.MembershipKnock, To: event.MembershipLeave} + RetractKnock = MembershipChangeType{From: event.MembershipKnock, To: event.MembershipLeave, IsSelf: true} + BanKnocked = MembershipChangeType{From: event.MembershipKnock, To: event.MembershipBan} + Unban = MembershipChangeType{From: event.MembershipBan, To: event.MembershipLeave} +) + +type GhostOrUserLogin interface { + isGhostOrUserLogin() +} + +func (*Ghost) isGhostOrUserLogin() {} +func (*UserLogin) isGhostOrUserLogin() {} + +type MatrixMembershipChange struct { + MatrixRoomMeta[*event.MemberEventContent] + Target GhostOrUserLogin + Type MembershipChangeType +} + +type MatrixMembershipResult struct { + RedirectTo networkid.UserID +} + +type MembershipHandlingNetworkAPI interface { + NetworkAPI + HandleMatrixMembership(ctx context.Context, msg *MatrixMembershipChange) (*MatrixMembershipResult, error) +} + +type SinglePowerLevelChange struct { + OrigLevel int + NewLevel int + NewIsSet bool +} + +type UserPowerLevelChange struct { + Target GhostOrUserLogin + SinglePowerLevelChange +} + +type MatrixPowerLevelChange struct { + MatrixRoomMeta[*event.PowerLevelsEventContent] + Users map[id.UserID]*UserPowerLevelChange + Events map[string]*SinglePowerLevelChange + UsersDefault *SinglePowerLevelChange + EventsDefault *SinglePowerLevelChange + StateDefault *SinglePowerLevelChange + Invite *SinglePowerLevelChange + Kick *SinglePowerLevelChange + Ban *SinglePowerLevelChange + Redact *SinglePowerLevelChange +} + +type PowerLevelHandlingNetworkAPI interface { + NetworkAPI + HandleMatrixPowerLevels(ctx context.Context, msg *MatrixPowerLevelChange) (bool, error) +} + +type PushType int + +func (pt PushType) String() string { + return pt.GoString() +} + +func PushTypeFromString(str string) PushType { + switch strings.TrimPrefix(strings.ToLower(str), "pushtype") { + case "web": + return PushTypeWeb + case "apns": + return PushTypeAPNs + case "fcm": + return PushTypeFCM + default: + return PushTypeUnknown + } +} + +func (pt PushType) GoString() string { + switch pt { + case PushTypeUnknown: + return "PushTypeUnknown" + case PushTypeWeb: + return "PushTypeWeb" + case PushTypeAPNs: + return "PushTypeAPNs" + case PushTypeFCM: + return "PushTypeFCM" + default: + return fmt.Sprintf("PushType(%d)", int(pt)) + } +} + +const ( + PushTypeUnknown PushType = iota + PushTypeWeb + PushTypeAPNs + PushTypeFCM +) + +type WebPushConfig struct { + VapidKey string `json:"vapid_key"` +} + +type FCMPushConfig struct { + SenderID string `json:"sender_id"` +} + +type APNsPushConfig struct { + BundleID string `json:"bundle_id"` +} + +type PushConfig struct { + Web *WebPushConfig `json:"web,omitempty"` + FCM *FCMPushConfig `json:"fcm,omitempty"` + APNs *APNsPushConfig `json:"apns,omitempty"` + // If Native is true, it means the network supports registering for pushes + // that are delivered directly to the app without the use of a push relay. + Native bool `json:"native,omitempty"` +} + +// PushableNetworkAPI is an optional interface that network connectors can implement +// to support waking up the wrapper app using push notifications. +type PushableNetworkAPI interface { + NetworkAPI + + // RegisterPushNotifications is called when the wrapper app wants to register a push token with the remote network. + RegisterPushNotifications(ctx context.Context, pushType PushType, token string) error + // GetPushConfigs is used to find which types of push notifications the remote network can provide. + GetPushConfigs() *PushConfig +} + +// PushParsingNetwork is an optional interface that network connectors can implement +// to support parsing native push notifications from networks. +type PushParsingNetwork interface { + NetworkConnector + + // ParsePushNotification is called when a native push is received. + // It must return the corresponding user login ID to wake up, plus optionally data to pass to the wakeup call. + ParsePushNotification(ctx context.Context, data json.RawMessage) (networkid.UserLoginID, any, error) +} + +type RemoteEventType int + +func (ret RemoteEventType) String() string { + switch ret { + case RemoteEventUnknown: + return "RemoteEventUnknown" + case RemoteEventMessage: + return "RemoteEventMessage" + case RemoteEventMessageUpsert: + return "RemoteEventMessageUpsert" + case RemoteEventEdit: + return "RemoteEventEdit" + case RemoteEventReaction: + return "RemoteEventReaction" + case RemoteEventReactionRemove: + return "RemoteEventReactionRemove" + case RemoteEventReactionSync: + return "RemoteEventReactionSync" + case RemoteEventMessageRemove: + return "RemoteEventMessageRemove" + case RemoteEventReadReceipt: + return "RemoteEventReadReceipt" + case RemoteEventDeliveryReceipt: + return "RemoteEventDeliveryReceipt" + case RemoteEventMarkUnread: + return "RemoteEventMarkUnread" + case RemoteEventTyping: + return "RemoteEventTyping" + case RemoteEventChatInfoChange: + return "RemoteEventChatInfoChange" + case RemoteEventChatResync: + return "RemoteEventChatResync" + case RemoteEventChatDelete: + return "RemoteEventChatDelete" + case RemoteEventBackfill: + return "RemoteEventBackfill" + default: + return fmt.Sprintf("RemoteEventType(%d)", int(ret)) + } +} + +const ( + RemoteEventUnknown RemoteEventType = iota + RemoteEventMessage + RemoteEventMessageUpsert + RemoteEventEdit + RemoteEventReaction + RemoteEventReactionRemove + RemoteEventReactionSync + RemoteEventMessageRemove + RemoteEventReadReceipt + RemoteEventDeliveryReceipt + RemoteEventMarkUnread + RemoteEventTyping + RemoteEventChatInfoChange + RemoteEventChatResync + RemoteEventChatDelete + RemoteEventBackfill +) + +// RemoteEvent represents a single event from the remote network, such as a message or a reaction. +// +// When a [NetworkAPI] receives an event from the remote network, it should convert it into a [RemoteEvent] +// and pass it to the bridge for processing using [Bridge.QueueRemoteEvent]. +type RemoteEvent interface { + GetType() RemoteEventType + GetPortalKey() networkid.PortalKey + AddLogContext(c zerolog.Context) zerolog.Context + GetSender() EventSender +} + +type RemoteEventWithUncertainPortalReceiver interface { + RemoteEvent + PortalReceiverIsUncertain() bool +} + +type RemotePreHandler interface { + RemoteEvent + PreHandle(ctx context.Context, portal *Portal) +} + +type RemotePostHandler interface { + RemoteEvent + PostHandle(ctx context.Context, portal *Portal) +} + +type RemoteChatInfoChange interface { + RemoteEvent + GetChatInfoChange(ctx context.Context) (*ChatInfoChange, error) +} + +type RemoteChatResync interface { + RemoteEvent +} + +type RemoteChatResyncWithInfo interface { + RemoteChatResync + GetChatInfo(ctx context.Context, portal *Portal) (*ChatInfo, error) +} + +type RemoteChatResyncBackfill interface { + RemoteChatResync + CheckNeedsBackfill(ctx context.Context, latestMessage *database.Message) (bool, error) +} + +type RemoteChatResyncBackfillBundle interface { + RemoteChatResyncBackfill + GetBundledBackfillData() any +} + +type RemoteBackfill interface { + RemoteEvent + GetBackfillData(ctx context.Context, portal *Portal) (*FetchMessagesResponse, error) +} + +type RemoteDeleteOnlyForMe interface { + RemoteEvent + DeleteOnlyForMe() bool +} + +type RemoteChatDelete interface { + RemoteDeleteOnlyForMe +} + +type RemoteChatDeleteWithChildren interface { + RemoteChatDelete + DeleteChildren() bool +} + +type RemoteEventThatMayCreatePortal interface { + RemoteEvent + ShouldCreatePortal() bool +} + +type RemoteEventWithTargetMessage interface { + RemoteEvent + GetTargetMessage() networkid.MessageID +} + +type RemoteEventWithBundledParts interface { + RemoteEventWithTargetMessage + GetTargetDBMessage() []*database.Message +} + +type RemoteEventWithTargetPart interface { + RemoteEventWithTargetMessage + GetTargetMessagePart() networkid.PartID +} + +type RemoteEventWithTimestamp interface { + RemoteEvent + GetTimestamp() time.Time +} + +type RemoteEventWithStreamOrder interface { + RemoteEvent + GetStreamOrder() int64 +} + +type RemoteMessage interface { + RemoteEvent + GetID() networkid.MessageID + ConvertMessage(ctx context.Context, portal *Portal, intent MatrixAPI) (*ConvertedMessage, error) +} + +type UpsertResult struct { + SubEvents []RemoteEvent + SaveParts bool + ContinueMessageHandling bool +} + +type RemoteMessageUpsert interface { + RemoteMessage + HandleExisting(ctx context.Context, portal *Portal, intent MatrixAPI, existing []*database.Message) (UpsertResult, error) +} + +type RemoteMessageWithTransactionID interface { + RemoteMessage + GetTransactionID() networkid.TransactionID +} + +type RemoteEdit interface { + RemoteEventWithTargetMessage + ConvertEdit(ctx context.Context, portal *Portal, intent MatrixAPI, existing []*database.Message) (*ConvertedEdit, error) +} + +type RemoteReaction interface { + RemoteEventWithTargetMessage + GetReactionEmoji() (string, networkid.EmojiID) +} + +type ReactionSyncUser struct { + Reactions []*BackfillReaction + // Whether the list contains all reactions the user has sent + HasAllReactions bool + // If the list doesn't contain all reactions from the user, + // then this field can be set to remove old reactions if there are more than a certain number. + MaxCount int +} + +type ReactionSyncData struct { + Users map[networkid.UserID]*ReactionSyncUser + // Whether the map contains all users who have reacted to the message + HasAllUsers bool +} + +func (rsd *ReactionSyncData) ToBackfill() []*BackfillReaction { + var reactions []*BackfillReaction + for _, user := range rsd.Users { + reactions = append(reactions, user.Reactions...) + } + return reactions +} + +type RemoteReactionSync interface { + RemoteEventWithTargetMessage + GetReactions() *ReactionSyncData +} + +type RemoteReactionWithExtraContent interface { + RemoteReaction + GetReactionExtraContent() map[string]any +} + +type RemoteReactionWithMeta interface { + RemoteReaction + GetReactionDBMetadata() any +} + +type RemoteReactionRemove interface { + RemoteEventWithTargetMessage + GetRemovedEmojiID() networkid.EmojiID +} + +type RemoteMessageRemove interface { + RemoteEventWithTargetMessage +} + +// Deprecated: Renamed to RemoteReadReceipt. +type RemoteReceipt = RemoteReadReceipt + +type RemoteReadReceipt interface { + RemoteEvent + GetLastReceiptTarget() networkid.MessageID + GetReceiptTargets() []networkid.MessageID + GetReadUpTo() time.Time +} + +type RemoteReadReceiptWithStreamOrder interface { + RemoteReadReceipt + GetReadUpToStreamOrder() int64 +} + +type RemoteDeliveryReceipt interface { + RemoteEvent + GetReceiptTargets() []networkid.MessageID +} + +type RemoteMarkUnread interface { + RemoteEvent + GetUnread() bool +} + +type RemoteTyping interface { + RemoteEvent + GetTimeout() time.Duration +} + +type TypingType int + +const ( + TypingTypeText TypingType = iota + TypingTypeUploadingMedia + TypingTypeRecordingMedia +) + +type RemoteTypingWithType interface { + RemoteTyping + GetTypingType() TypingType +} + +type OrigSender struct { + User *User + UserID id.UserID + + RequiresDisambiguation bool + DisambiguatedName string + FormattedName string + PerMessageProfile event.BeeperPerMessageProfile + + event.MemberEventContent +} + +type MatrixEventBase[ContentType any] struct { + // The raw event being bridged. + Event *event.Event + // The parsed content struct of the event. Custom fields can be found in Event.Content.Raw. + Content ContentType + // The room where the event happened. + Portal *Portal + + // The original sender user ID. Only present in case the event is being relayed (and Sender is not the same user). + OrigSender *OrigSender + + InputTransactionID networkid.RawTransactionID +} + +type MatrixMessage struct { + MatrixEventBase[*event.MessageEventContent] + ThreadRoot *database.Message + ReplyTo *database.Message + + pendingSaves []*outgoingMessage +} + +type MatrixEdit struct { + MatrixEventBase[*event.MessageEventContent] + EditTarget *database.Message +} + +type MatrixPollStart struct { + MatrixMessage + Content *event.PollStartEventContent +} + +type MatrixPollVote struct { + MatrixMessage + VoteTo *database.Message + Content *event.PollResponseEventContent +} + +type MatrixReaction struct { + MatrixEventBase[*event.ReactionEventContent] + TargetMessage *database.Message + PreHandleResp *MatrixReactionPreResponse + + // When EmojiID is blank and there's already an existing reaction, this is the old reaction that is being overridden. + ReactionToOverride *database.Reaction + // When MaxReactions is >0 in the pre-response, this is the list of previous reactions that should be preserved. + ExistingReactionsToKeep []*database.Reaction +} + +type MatrixReactionPreResponse struct { + SenderID networkid.UserID + EmojiID networkid.EmojiID + Emoji string + MaxReactions int +} + +type MatrixReactionRemove struct { + MatrixEventBase[*event.RedactionEventContent] + TargetReaction *database.Reaction +} + +type MatrixMessageRemove struct { + MatrixEventBase[*event.RedactionEventContent] + TargetMessage *database.Message +} + +type MatrixRoomMeta[ContentType any] struct { + MatrixEventBase[ContentType] + PrevContent ContentType + IsStateRequest bool +} + +type MatrixRoomName = MatrixRoomMeta[*event.RoomNameEventContent] +type MatrixRoomAvatar = MatrixRoomMeta[*event.RoomAvatarEventContent] +type MatrixRoomTopic = MatrixRoomMeta[*event.TopicEventContent] +type MatrixDisappearingTimer = MatrixRoomMeta[*event.BeeperDisappearingTimer] + +type MatrixReadReceipt struct { + Portal *Portal + // The event ID that the receipt is targeting + EventID id.EventID + // The exact message that was read. This may be nil if the event ID isn't a message. + ExactMessage *database.Message + // The timestamp that the user has read up to. This is either the timestamp of the message + // (if one is present) or the timestamp of the receipt. + ReadUpTo time.Time + // The ReadUpTo timestamp of the previous message + LastRead time.Time + // The receipt metadata. + Receipt event.ReadReceipt + // Whether the receipt is implicit, i.e. triggered by an incoming timeline event rather than an explicit receipt. + Implicit bool +} + +type MatrixTyping struct { + Portal *Portal + IsTyping bool + Type TypingType +} + +type MatrixViewingChat struct { + // The portal that the user is viewing. This will be nil when the user switches to a chat from a different bridge. + Portal *Portal +} + +type MatrixDeleteChat = MatrixEventBase[*event.BeeperChatDeleteEventContent] +type MatrixAcceptMessageRequest = MatrixEventBase[*event.BeeperAcceptMessageRequestEventContent] +type MatrixBeeperAIStream = MatrixEventBase[*event.BeeperAIStreamEventContent] +type MatrixMarkedUnread = MatrixRoomMeta[*event.MarkedUnreadEventContent] +type MatrixMute = MatrixRoomMeta[*event.BeeperMuteEventContent] +type MatrixRoomTag = MatrixRoomMeta[*event.TagEventContent] diff --git a/bridgev2/portal.go b/bridgev2/portal.go new file mode 100644 index 00000000..16aa703b --- /dev/null +++ b/bridgev2/portal.go @@ -0,0 +1,5436 @@ +// Copyright (c) 2024 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package bridgev2 + +import ( + "cmp" + "context" + "errors" + "fmt" + "runtime/debug" + "strings" + "sync" + "sync/atomic" + "time" + + "github.com/rs/zerolog" + "go.mau.fi/util/exfmt" + "go.mau.fi/util/exmaps" + "go.mau.fi/util/exslices" + "go.mau.fi/util/exsync" + "go.mau.fi/util/ptr" + "go.mau.fi/util/variationselector" + "golang.org/x/exp/maps" + "golang.org/x/exp/slices" + + "maunium.net/go/mautrix" + "maunium.net/go/mautrix/bridgev2/database" + "maunium.net/go/mautrix/bridgev2/networkid" + "maunium.net/go/mautrix/event" + "maunium.net/go/mautrix/id" +) + +type portalMatrixEvent struct { + evt *event.Event + sender *User +} + +type portalRemoteEvent struct { + evt RemoteEvent + source *UserLogin + evtType RemoteEventType +} + +type portalCreateEvent struct { + ctx context.Context + source *UserLogin + info *ChatInfo + cb func(error) +} + +func (pme *portalMatrixEvent) isPortalEvent() {} +func (pre *portalRemoteEvent) isPortalEvent() {} +func (pre *portalCreateEvent) isPortalEvent() {} + +type portalEvent interface { + isPortalEvent() +} + +type outgoingMessage struct { + db *database.Message + evt *event.Event + ignore bool + handle func(RemoteMessage, *database.Message) (bool, error) + ackedAt time.Time + timeouted bool +} + +type Portal struct { + *database.Portal + Bridge *Bridge + Log zerolog.Logger + Parent *Portal + Relay *UserLogin + + currentlyTyping []id.UserID + currentlyTypingLogins map[id.UserID]*UserLogin + currentlyTypingLock sync.Mutex + currentlyTypingGhosts *exsync.Set[id.UserID] + + outgoingMessages map[networkid.TransactionID]*outgoingMessage + outgoingMessagesLock sync.Mutex + + lastCapUpdate time.Time + + roomCreateLock sync.Mutex + cancelRoomCreate atomic.Pointer[context.CancelFunc] + RoomCreated *exsync.Event + + functionalMembersLock sync.Mutex + functionalMembersCache *event.ElementFunctionalMembersContent + + events chan portalEvent + deleted *exsync.Event + + eventsLock sync.Mutex + eventIdx int +} + +var PortalEventBuffer = 64 + +func (br *Bridge) loadPortal(ctx context.Context, dbPortal *database.Portal, queryErr error, key *networkid.PortalKey) (*Portal, error) { + if queryErr != nil { + return nil, fmt.Errorf("failed to query db: %w", queryErr) + } + if dbPortal == nil { + if key == nil { + return nil, nil + } + dbPortal = &database.Portal{ + BridgeID: br.ID, + PortalKey: *key, + } + err := br.DB.Portal.Insert(ctx, dbPortal) + if err != nil { + return nil, fmt.Errorf("failed to insert new portal: %w", err) + } + } + portal := &Portal{ + Portal: dbPortal, + Bridge: br, + + currentlyTypingLogins: make(map[id.UserID]*UserLogin), + currentlyTypingGhosts: exsync.NewSet[id.UserID](), + outgoingMessages: make(map[networkid.TransactionID]*outgoingMessage), + + RoomCreated: exsync.NewEvent(), + deleted: exsync.NewEvent(), + } + if portal.MXID != "" { + portal.RoomCreated.Set() + } + // Putting the portal in the cache before it's fully initialized is mildly dangerous, + // but loading the relay user login may depend on it. + br.portalsByKey[portal.PortalKey] = portal + if portal.MXID != "" { + br.portalsByMXID[portal.MXID] = portal + } + var err error + if portal.ParentKey.ID != "" { + portal.Parent, err = br.UnlockedGetPortalByKey(ctx, portal.ParentKey, false) + if err != nil { + delete(br.portalsByKey, portal.PortalKey) + if portal.MXID != "" { + delete(br.portalsByMXID, portal.MXID) + } + return nil, fmt.Errorf("failed to load parent portal (%s): %w", portal.ParentKey, err) + } + } + if portal.RelayLoginID != "" { + portal.Relay, err = br.unlockedGetExistingUserLoginByID(ctx, portal.RelayLoginID) + if err != nil { + delete(br.portalsByKey, portal.PortalKey) + if portal.MXID != "" { + delete(br.portalsByMXID, portal.MXID) + } + return nil, fmt.Errorf("failed to load relay login (%s): %w", portal.RelayLoginID, err) + } + } + portal.updateLogger() + if PortalEventBuffer != 0 { + portal.events = make(chan portalEvent, PortalEventBuffer) + go portal.eventLoop() + } + return portal, nil +} + +func (portal *Portal) updateLogger() { + logWith := portal.Bridge.Log.With(). + Str("portal_id", string(portal.ID)). + Str("portal_receiver", string(portal.Receiver)) + if portal.MXID != "" { + logWith = logWith.Stringer("portal_mxid", portal.MXID) + } + portal.Log = logWith.Logger() +} + +func (br *Bridge) loadManyPortals(ctx context.Context, portals []*database.Portal) ([]*Portal, error) { + output := make([]*Portal, 0, len(portals)) + for _, dbPortal := range portals { + if cached, ok := br.portalsByKey[dbPortal.PortalKey]; ok { + output = append(output, cached) + } else { + loaded, err := br.loadPortal(ctx, dbPortal, nil, nil) + if err != nil { + return nil, err + } else if loaded != nil { + output = append(output, loaded) + } + } + } + return output, nil +} + +func (br *Bridge) loadPortalWithCacheCheck(ctx context.Context, dbPortal *database.Portal) (*Portal, error) { + if dbPortal == nil { + return nil, nil + } else if cached, ok := br.portalsByKey[dbPortal.PortalKey]; ok { + return cached, nil + } else { + return br.loadPortal(ctx, dbPortal, nil, nil) + } +} + +func (br *Bridge) UnlockedGetPortalByKey(ctx context.Context, key networkid.PortalKey, onlyIfExists bool) (*Portal, error) { + if br.Config.SplitPortals && key.Receiver == "" { + return nil, fmt.Errorf("receiver must always be set when split portals is enabled") + } + cached, ok := br.portalsByKey[key] + if ok { + return cached, nil + } + keyPtr := &key + if onlyIfExists { + keyPtr = nil + } + db, err := br.DB.Portal.GetByKey(ctx, key) + return br.loadPortal(ctx, db, err, keyPtr) +} + +func (br *Bridge) FindPortalReceiver(ctx context.Context, id networkid.PortalID, maybeReceiver networkid.UserLoginID) (networkid.PortalKey, error) { + key := br.FindCachedPortalReceiver(id, maybeReceiver) + if !key.IsEmpty() { + return key, nil + } + key, err := br.DB.Portal.FindReceiver(ctx, id, maybeReceiver) + if err != nil { + return networkid.PortalKey{}, err + } + return key, nil +} + +func (br *Bridge) FindCachedPortalReceiver(id networkid.PortalID, maybeReceiver networkid.UserLoginID) networkid.PortalKey { + if br.Config.SplitPortals { + return networkid.PortalKey{ID: id, Receiver: maybeReceiver} + } + br.cacheLock.Lock() + defer br.cacheLock.Unlock() + portal, ok := br.portalsByKey[networkid.PortalKey{ + ID: id, + Receiver: maybeReceiver, + }] + if ok { + return portal.PortalKey + } + portal, ok = br.portalsByKey[networkid.PortalKey{ID: id}] + if ok { + return portal.PortalKey + } + return networkid.PortalKey{} +} + +func (br *Bridge) GetPortalByMXID(ctx context.Context, mxid id.RoomID) (*Portal, error) { + br.cacheLock.Lock() + defer br.cacheLock.Unlock() + cached, ok := br.portalsByMXID[mxid] + if ok { + return cached, nil + } + db, err := br.DB.Portal.GetByMXID(ctx, mxid) + return br.loadPortal(ctx, db, err, nil) +} + +func (br *Bridge) GetAllPortalsWithMXID(ctx context.Context) ([]*Portal, error) { + br.cacheLock.Lock() + defer br.cacheLock.Unlock() + rows, err := br.DB.Portal.GetAllWithMXID(ctx) + if err != nil { + return nil, err + } + return br.loadManyPortals(ctx, rows) +} + +func (br *Bridge) GetAllPortals(ctx context.Context) ([]*Portal, error) { + br.cacheLock.Lock() + defer br.cacheLock.Unlock() + rows, err := br.DB.Portal.GetAll(ctx) + if err != nil { + return nil, err + } + return br.loadManyPortals(ctx, rows) +} + +func (br *Bridge) GetDMPortalsWith(ctx context.Context, otherUserID networkid.UserID) ([]*Portal, error) { + br.cacheLock.Lock() + defer br.cacheLock.Unlock() + rows, err := br.DB.Portal.GetAllDMsWith(ctx, otherUserID) + if err != nil { + return nil, err + } + return br.loadManyPortals(ctx, rows) +} + +func (br *Bridge) GetChildPortals(ctx context.Context, parent networkid.PortalKey) ([]*Portal, error) { + br.cacheLock.Lock() + defer br.cacheLock.Unlock() + rows, err := br.DB.Portal.GetChildren(ctx, parent) + if err != nil { + return nil, err + } + return br.loadManyPortals(ctx, rows) +} + +func (br *Bridge) GetDMPortal(ctx context.Context, receiver networkid.UserLoginID, otherUserID networkid.UserID) (*Portal, error) { + br.cacheLock.Lock() + defer br.cacheLock.Unlock() + dbPortal, err := br.DB.Portal.GetDM(ctx, receiver, otherUserID) + if err != nil { + return nil, err + } + return br.loadPortalWithCacheCheck(ctx, dbPortal) +} + +func (br *Bridge) GetPortalByKey(ctx context.Context, key networkid.PortalKey) (*Portal, error) { + br.cacheLock.Lock() + defer br.cacheLock.Unlock() + return br.UnlockedGetPortalByKey(ctx, key, false) +} + +func (br *Bridge) GetExistingPortalByKey(ctx context.Context, key networkid.PortalKey) (*Portal, error) { + br.cacheLock.Lock() + defer br.cacheLock.Unlock() + if key.Receiver == "" || br.Config.SplitPortals { + return br.UnlockedGetPortalByKey(ctx, key, true) + } + cached, ok := br.portalsByKey[key] + if ok { + return cached, nil + } + cached, ok = br.portalsByKey[networkid.PortalKey{ID: key.ID}] + if ok { + return cached, nil + } + db, err := br.DB.Portal.GetByIDWithUncertainReceiver(ctx, key) + return br.loadPortal(ctx, db, err, nil) +} + +func (portal *Portal) queueEvent(ctx context.Context, evt portalEvent) EventHandlingResult { + if portal.deleted.IsSet() { + return EventHandlingResultIgnored + } + if PortalEventBuffer == 0 { + portal.eventsLock.Lock() + defer portal.eventsLock.Unlock() + portal.eventIdx++ + return portal.handleSingleEventWithDelayLogging(portal.eventIdx, evt) + } else { + if portal.events == nil { + panic(fmt.Errorf("queueEvent into uninitialized portal %s", portal.PortalKey)) + } + select { + case portal.events <- evt: + return EventHandlingResultQueued + case <-portal.deleted.GetChan(): + return EventHandlingResultIgnored + default: + zerolog.Ctx(ctx).Error(). + Str("portal_id", string(portal.ID)). + Msg("Portal event channel is full, queue will block") + for { + select { + case portal.events <- evt: + return EventHandlingResultQueued + case <-time.After(5 * time.Second): + zerolog.Ctx(ctx).Error(). + Str("portal_id", string(portal.ID)). + Msg("Portal event channel is still full") + } + } + } + } +} + +func (portal *Portal) eventLoop() { + if cfg := portal.Bridge.Network.GetCapabilities().OutgoingMessageTimeouts; cfg != nil { + ctx, cancel := context.WithCancel(portal.Log.WithContext(portal.Bridge.BackgroundCtx)) + go portal.pendingMessageTimeoutLoop(ctx, cfg) + defer cancel() + } + deleteCh := portal.deleted.GetChan() + for i := 0; ; i++ { + select { + case rawEvt := <-portal.events: + if rawEvt == nil { + return + } + if portal.Bridge.Config.AsyncEvents { + go portal.handleSingleEventWithDelayLogging(i, rawEvt) + } else { + portal.handleSingleEventWithDelayLogging(i, rawEvt) + } + case <-deleteCh: + return + } + } +} + +func (portal *Portal) handleSingleEventWithDelayLogging(idx int, rawEvt any) (outerRes EventHandlingResult) { + ctx := portal.getEventCtxWithLog(rawEvt, idx) + log := zerolog.Ctx(ctx) + doneCh := make(chan struct{}) + var backgrounded atomic.Bool + start := time.Now() + var handleDuration time.Duration + // Note: this will not set the success flag if the handler times out + outerRes = EventHandlingResult{Queued: true} + go portal.handleSingleEvent(ctx, rawEvt, func(res EventHandlingResult) { + outerRes = res + handleDuration = time.Since(start) + close(doneCh) + if backgrounded.Load() { + log.Debug(). + Time("started_at", start). + Stringer("duration", handleDuration). + Msg("Event that took too long finally finished handling") + } + }) + tick := time.NewTicker(30 * time.Second) + _, isCreate := rawEvt.(*portalCreateEvent) + defer tick.Stop() + for i := 0; i < 10; i++ { + select { + case <-doneCh: + if i > 0 { + log.Debug(). + Time("started_at", start). + Stringer("duration", handleDuration). + Msg("Event that took long finished handling") + } + return + case <-tick.C: + log.Warn(). + Time("started_at", start). + Msg("Event handling is taking long") + if isCreate { + // Never background portal creation events + i = 1 + } + } + } + log.Warn(). + Time("started_at", start). + Msg("Event handling is taking too long, continuing in background") + backgrounded.Store(true) + return +} + +func (portal *Portal) getEventCtxWithLog(rawEvt any, idx int) context.Context { + var logWith zerolog.Context + switch evt := rawEvt.(type) { + case *portalMatrixEvent: + logWith = portal.Log.With().Int("event_loop_index", idx). + Str("action", "handle matrix event"). + Stringer("event_id", evt.evt.ID). + Str("event_type", evt.evt.Type.Type) + if evt.evt.Type.Class != event.EphemeralEventType { + logWith = logWith. + Stringer("event_id", evt.evt.ID). + Stringer("sender", evt.sender.MXID) + } + case *portalRemoteEvent: + evt.evtType = evt.evt.GetType() + logWith = portal.Log.With().Int("event_loop_index", idx). + Str("action", "handle remote event"). + Str("source_id", string(evt.source.ID)). + Stringer("bridge_evt_type", evt.evtType) + logWith = evt.evt.AddLogContext(logWith) + if remoteSender := evt.evt.GetSender(); remoteSender.Sender != "" || remoteSender.IsFromMe { + logWith = logWith.Object("remote_sender", remoteSender) + } + if remoteMsg, ok := evt.evt.(RemoteMessage); ok { + if remoteMsgID := remoteMsg.GetID(); remoteMsgID != "" { + logWith = logWith.Str("remote_message_id", string(remoteMsgID)) + } + } + if remoteMsg, ok := evt.evt.(RemoteEventWithTargetMessage); ok { + if targetMsgID := remoteMsg.GetTargetMessage(); targetMsgID != "" { + logWith = logWith.Str("remote_target_message_id", string(targetMsgID)) + } + } + if remoteMsg, ok := evt.evt.(RemoteEventWithStreamOrder); ok { + if remoteStreamOrder := remoteMsg.GetStreamOrder(); remoteStreamOrder != 0 { + logWith = logWith.Int64("remote_stream_order", remoteStreamOrder) + } + } + if remoteMsg, ok := evt.evt.(RemoteEventWithTimestamp); ok { + if remoteTimestamp := remoteMsg.GetTimestamp(); !remoteTimestamp.IsZero() { + logWith = logWith.Time("remote_timestamp", remoteTimestamp) + } + } + case *portalCreateEvent: + return evt.ctx + } + return logWith.Logger().WithContext(portal.Bridge.BackgroundCtx) +} + +func (portal *Portal) handleSingleEvent(ctx context.Context, rawEvt any, doneCallback func(res EventHandlingResult)) { + log := zerolog.Ctx(ctx) + var res EventHandlingResult + defer func() { + doneCallback(res) + if err := recover(); err != nil { + logEvt := log.Error() + var errorString string + if realErr, ok := err.(error); ok { + logEvt = logEvt.Err(realErr) + errorString = realErr.Error() + } else { + logEvt = logEvt.Any(zerolog.ErrorFieldName, err) + errorString = fmt.Sprintf("%v", err) + } + logEvt. + Bytes("stack", debug.Stack()). + Msg("Event handling panicked") + switch evt := rawEvt.(type) { + case *portalMatrixEvent: + if evt.evt.ID != "" { + go portal.sendErrorStatus(ctx, evt.evt, ErrPanicInEventHandler) + } + case *portalCreateEvent: + evt.cb(fmt.Errorf("portal creation panicked")) + } + portal.Bridge.TrackAnalytics("", "Bridge Event Handler Panic", map[string]any{ + "error": errorString, + }) + } + }() + switch evt := rawEvt.(type) { + case *portalMatrixEvent: + isStateRequest := evt.evt.Type == event.BeeperSendState + if isStateRequest { + if err := portal.unwrapBeeperSendState(ctx, evt.evt); err != nil { + portal.sendErrorStatus(ctx, evt.evt, err) + return + } + } + res = portal.handleMatrixEvent(ctx, evt.sender, evt.evt, isStateRequest) + if res.SendMSS { + if res.Error != nil { + portal.sendErrorStatus(ctx, evt.evt, res.Error) + } else { + portal.sendSuccessStatus(ctx, evt.evt, 0, "") + } + } + if !isStateRequest && res.Error != nil && evt.evt.StateKey != nil { + portal.revertRoomMeta(ctx, evt.evt) + } + if isStateRequest && res.Success && !res.SkipStateEcho { + portal.sendRoomMeta( + ctx, + evt.sender.DoublePuppet(ctx), + time.UnixMilli(evt.evt.Timestamp), + evt.evt.Type, + evt.evt.GetStateKey(), + evt.evt.Content.Parsed, + false, + evt.evt.Content.Raw, + ) + } + case *portalRemoteEvent: + res = portal.handleRemoteEvent(ctx, evt.source, evt.evtType, evt.evt) + case *portalCreateEvent: + err := portal.createMatrixRoomInLoop(evt.ctx, evt.source, evt.info, nil) + res.Success = err == nil + evt.cb(err) + default: + panic(fmt.Errorf("illegal type %T in eventLoop", evt)) + } +} + +func (portal *Portal) unwrapBeeperSendState(ctx context.Context, evt *event.Event) error { + content, ok := evt.Content.Parsed.(*event.BeeperSendStateEventContent) + if !ok { + return fmt.Errorf("%w: %T", ErrUnexpectedParsedContentType, evt.Content.Parsed) + } + evt.Content = content.Content + evt.StateKey = &content.StateKey + evt.Type = event.Type{Type: content.Type, Class: event.StateEventType} + _ = evt.Content.ParseRaw(evt.Type) + mx, ok := portal.Bridge.Matrix.(MatrixConnectorWithArbitraryRoomState) + if !ok { + return fmt.Errorf("matrix connector doesn't support fetching state") + } + prevEvt, err := mx.GetStateEvent(ctx, portal.MXID, evt.Type, evt.GetStateKey()) + if err != nil && !errors.Is(err, mautrix.MNotFound) { + return fmt.Errorf("failed to get prev event: %w", err) + } else if prevEvt != nil { + evt.Unsigned.PrevContent = &prevEvt.Content + evt.Unsigned.PrevSender = prevEvt.Sender + } + return nil +} + +func (portal *Portal) FindPreferredLogin(ctx context.Context, user *User, allowRelay bool) (*UserLogin, *database.UserPortal, error) { + if portal.Receiver != "" { + login, err := portal.Bridge.GetExistingUserLoginByID(ctx, portal.Receiver) + if err != nil { + return nil, nil, err + } + if login == nil { + return nil, nil, fmt.Errorf("%w (receiver login is nil)", ErrNotLoggedIn) + } else if !login.Client.IsLoggedIn() { + return nil, nil, fmt.Errorf("%w (receiver login is not logged in)", ErrNotLoggedIn) + } else if login.UserMXID != user.MXID { + if allowRelay && portal.Relay != nil { + return nil, nil, nil + } + return nil, nil, fmt.Errorf("%w (relay not set and receiver login is owned by %s, not %s)", ErrNotLoggedIn, login.UserMXID, user.MXID) + } + up, err := portal.Bridge.DB.UserPortal.Get(ctx, login.UserLogin, portal.PortalKey) + return login, up, err + } + logins, err := portal.Bridge.DB.UserPortal.GetAllForUserInPortal(ctx, user.MXID, portal.PortalKey) + if err != nil { + return nil, nil, err + } + portal.Bridge.cacheLock.Lock() + defer portal.Bridge.cacheLock.Unlock() + for _, up := range logins { + login, ok := user.logins[up.LoginID] + if ok && login.Client != nil && login.Client.IsLoggedIn() { + return login, up, nil + } + } + if !allowRelay { + return nil, nil, ErrNotLoggedIn + } + // Portal has relay, use it + if portal.Relay != nil { + return nil, nil, nil + } + var firstLogin *UserLogin + for _, login := range user.logins { + firstLogin = login + break + } + if firstLogin != nil && firstLogin.Client.IsLoggedIn() { + zerolog.Ctx(ctx).Warn(). + Str("chosen_login_id", string(firstLogin.ID)). + Msg("No usable user portal rows found, returning random login") + return firstLogin, nil, nil + } else { + return nil, nil, ErrNotLoggedIn + } +} + +func (portal *Portal) sendSuccessStatus(ctx context.Context, evt *event.Event, streamOrder int64, newEventID id.EventID) { + info := StatusEventInfoFromEvent(evt) + info.StreamOrder = streamOrder + if newEventID != evt.ID { + info.NewEventID = newEventID + } + portal.Bridge.Matrix.SendMessageStatus(ctx, &MessageStatus{Status: event.MessageStatusSuccess}, info) +} + +func (portal *Portal) sendErrorStatus(ctx context.Context, evt *event.Event, err error) { + status := WrapErrorInStatus(err) + if status.Status == "" { + status.Status = event.MessageStatusRetriable + } + if status.ErrorReason == "" { + status.ErrorReason = event.MessageStatusGenericError + } + if status.InternalError == nil { + status.InternalError = err + } + portal.Bridge.Matrix.SendMessageStatus(ctx, &status, StatusEventInfoFromEvent(evt)) +} + +func (portal *Portal) checkConfusableName(ctx context.Context, userID id.UserID, name string) bool { + conn, ok := portal.Bridge.Matrix.(MatrixConnectorWithNameDisambiguation) + if !ok { + return false + } + confusableWith, err := conn.IsConfusableName(ctx, portal.MXID, userID, name) + if err != nil { + zerolog.Ctx(ctx).Err(err).Msg("Failed to check if name is confusable") + return true + } + for _, confusable := range confusableWith { + // Don't disambiguate names that only conflict with ghosts of this bridge + if !portal.Bridge.IsGhostMXID(confusable) { + return true + } + } + return false +} + +var fakePerMessageProfileEventType = event.Type{Class: event.StateEventType, Type: "m.per_message_profile"} + +func (portal *Portal) handleMatrixEvent(ctx context.Context, sender *User, evt *event.Event, isStateRequest bool) EventHandlingResult { + log := zerolog.Ctx(ctx) + if evt.Mautrix.EventSource&event.SourceEphemeral != 0 { + switch evt.Type { + case event.EphemeralEventReceipt: + return portal.handleMatrixReceipts(ctx, evt) + case event.EphemeralEventTyping: + return portal.handleMatrixTyping(ctx, evt) + case event.BeeperEphemeralEventAIStream: + return portal.handleMatrixAIStream(ctx, sender, evt) + default: + return EventHandlingResultIgnored + } + } + if evt.Type == event.StateTombstone { + // Tombstones aren't bridged so they don't need a login + return portal.handleMatrixTombstone(ctx, evt) + } + login, userPortal, err := portal.FindPreferredLogin(ctx, sender, true) + if err != nil { + log.Err(err).Msg("Failed to get user login to handle Matrix event") + if errors.Is(err, ErrNotLoggedIn) { + shouldSendNotice := evt.Content.AsMessage().MsgType != event.MsgNotice + return EventHandlingResultFailed.WithMSSError( + WrapErrorInStatus(err).WithMessage("You're not logged in").WithIsCertain(true).WithSendNotice(shouldSendNotice), + ) + } else { + return EventHandlingResultFailed.WithMSSError( + WrapErrorInStatus(err).WithMessage("Failed to get login to handle event").WithIsCertain(true).WithSendNotice(true), + ) + } + } + var origSender *OrigSender + if login == nil { + if isStateRequest { + return EventHandlingResultFailed.WithMSSError(ErrCantRelayStateRequest) + } + login = portal.Relay + origSender = &OrigSender{ + User: sender, + UserID: sender.MXID, + } + + memberInfo, err := portal.Bridge.Matrix.GetMemberInfo(ctx, portal.MXID, sender.MXID) + if err != nil { + log.Warn().Err(err).Msg("Failed to get member info for user being relayed") + } else if memberInfo != nil { + origSender.MemberEventContent = *memberInfo + if memberInfo.Displayname == "" { + origSender.DisambiguatedName = sender.MXID.String() + } else if origSender.RequiresDisambiguation = portal.checkConfusableName(ctx, sender.MXID, memberInfo.Displayname); origSender.RequiresDisambiguation { + origSender.DisambiguatedName = fmt.Sprintf("%s (%s)", memberInfo.Displayname, sender.MXID) + } else { + origSender.DisambiguatedName = memberInfo.Displayname + } + } else { + origSender.DisambiguatedName = sender.MXID.String() + } + msg := evt.Content.AsMessage() + if msg != nil && msg.BeeperPerMessageProfile != nil && msg.BeeperPerMessageProfile.Displayname != "" { + pmp := msg.BeeperPerMessageProfile + origSender.PerMessageProfile = *pmp + roomPLs, err := portal.Bridge.Matrix.GetPowerLevels(ctx, portal.MXID) + if err != nil { + log.Warn().Err(err).Msg("Failed to get power levels to check relay profile") + } + if roomPLs != nil && + roomPLs.GetUserLevel(sender.MXID) >= roomPLs.GetEventLevel(fakePerMessageProfileEventType) && + !portal.checkConfusableName(ctx, sender.MXID, pmp.Displayname) { + origSender.DisambiguatedName = pmp.Displayname + origSender.RequiresDisambiguation = false + } else { + origSender.DisambiguatedName = fmt.Sprintf("%s via %s", pmp.Displayname, origSender.DisambiguatedName) + } + } + + origSender.FormattedName = portal.Bridge.Config.Relay.FormatName(origSender) + } + // Copy logger because many of the handlers will use UpdateContext + ctx = log.With().Str("login_id", string(login.ID)).Logger().WithContext(ctx) + + if origSender == nil && portal.Bridge.Network.GetCapabilities().ImplicitReadReceipts && !evt.Type.IsAccountData() { + rrLog := log.With().Str("subaction", "implicit read receipt").Logger() + rrCtx := rrLog.WithContext(ctx) + rrLog.Debug().Msg("Sending implicit read receipt for event") + evtTS := time.UnixMilli(evt.Timestamp) + portal.callReadReceiptHandler(rrCtx, login, nil, &MatrixReadReceipt{ + Portal: portal, + EventID: evt.ID, + Implicit: true, + ReadUpTo: evtTS, + Receipt: event.ReadReceipt{Timestamp: evtTS}, + }, userPortal) + } + + switch evt.Type { + case event.EventMessage, event.EventSticker, event.EventUnstablePollStart, event.EventUnstablePollResponse: + return portal.handleMatrixMessage(ctx, login, origSender, evt) + case event.EventReaction: + if origSender != nil { + log.Debug().Msg("Ignoring reaction event from relayed user") + return EventHandlingResultIgnored.WithMSSError(ErrIgnoringReactionFromRelayedUser) + } + return portal.handleMatrixReaction(ctx, login, evt) + case event.EventRedaction: + return portal.handleMatrixRedaction(ctx, login, origSender, evt) + case event.StateRoomName: + return handleMatrixRoomMeta(portal, ctx, login, origSender, evt, isStateRequest, RoomNameHandlingNetworkAPI.HandleMatrixRoomName) + case event.StateTopic: + return handleMatrixRoomMeta(portal, ctx, login, origSender, evt, isStateRequest, RoomTopicHandlingNetworkAPI.HandleMatrixRoomTopic) + case event.StateRoomAvatar: + return handleMatrixRoomMeta(portal, ctx, login, origSender, evt, isStateRequest, RoomAvatarHandlingNetworkAPI.HandleMatrixRoomAvatar) + case event.StateBeeperDisappearingTimer: + return handleMatrixRoomMeta(portal, ctx, login, origSender, evt, isStateRequest, DisappearTimerChangingNetworkAPI.HandleMatrixDisappearingTimer) + case event.StateEncryption: + // TODO? + return EventHandlingResultIgnored + case event.AccountDataMarkedUnread: + return handleMatrixAccountData(portal, ctx, login, evt, MarkedUnreadHandlingNetworkAPI.HandleMarkedUnread) + case event.AccountDataRoomTags: + return handleMatrixAccountData(portal, ctx, login, evt, TagHandlingNetworkAPI.HandleRoomTag) + case event.AccountDataBeeperMute: + return handleMatrixAccountData(portal, ctx, login, evt, MuteHandlingNetworkAPI.HandleMute) + case event.StateMember: + return portal.handleMatrixMembership(ctx, login, origSender, evt, isStateRequest) + case event.StatePowerLevels: + return portal.handleMatrixPowerLevels(ctx, login, origSender, evt, isStateRequest) + case event.BeeperDeleteChat: + return portal.handleMatrixDeleteChat(ctx, login, origSender, evt) + case event.BeeperAcceptMessageRequest: + return portal.handleMatrixAcceptMessageRequest(ctx, login, origSender, evt) + default: + return EventHandlingResultIgnored + } +} + +func (portal *Portal) handleMatrixReceipts(ctx context.Context, evt *event.Event) EventHandlingResult { + content, ok := evt.Content.Parsed.(*event.ReceiptEventContent) + if !ok { + return EventHandlingResultFailed + } + for evtID, receipts := range *content { + readReceipts, ok := receipts[event.ReceiptTypeRead] + if !ok { + continue + } + for userID, receipt := range readReceipts { + sender, err := portal.Bridge.GetUserByMXID(ctx, userID) + if err != nil { + zerolog.Ctx(ctx).Err(err).Msg("Failed to get user to handle read receipt") + return EventHandlingResultFailed.WithError(err) + } + portal.handleMatrixReadReceipt(ctx, sender, evtID, receipt) + } + } + // TODO actual status + return EventHandlingResultSuccess +} + +func (portal *Portal) handleMatrixReadReceipt(ctx context.Context, user *User, eventID id.EventID, receipt event.ReadReceipt) { + log := zerolog.Ctx(ctx) + log.UpdateContext(func(c zerolog.Context) zerolog.Context { + return c. + Stringer("event_id", eventID). + Stringer("user_id", user.MXID). + Stringer("receipt_ts", receipt.Timestamp) + }) + login, userPortal, err := portal.FindPreferredLogin(ctx, user, false) + if err != nil { + if !errors.Is(err, ErrNotLoggedIn) { + log.Err(err).Msg("Failed to get preferred login for user") + } + return + } else if login == nil { + return + } + rrClient, ok := login.Client.(ReadReceiptHandlingNetworkAPI) + if !ok { + return + } + log.UpdateContext(func(c zerolog.Context) zerolog.Context { + return c.Str("user_login_id", string(login.ID)) + }) + evt := &MatrixReadReceipt{ + Portal: portal, + EventID: eventID, + Receipt: receipt, + } + evt.ExactMessage, err = portal.Bridge.DB.Message.GetPartByMXID(ctx, eventID) + if err != nil { + log.Err(err).Msg("Failed to get exact message from database") + evt.ReadUpTo = receipt.Timestamp + } else if evt.ExactMessage != nil { + log.UpdateContext(func(c zerolog.Context) zerolog.Context { + return c.Str("exact_message_id", string(evt.ExactMessage.ID)).Time("exact_message_ts", evt.ExactMessage.Timestamp) + }) + evt.ReadUpTo = evt.ExactMessage.Timestamp + } else { + evt.ReadUpTo = receipt.Timestamp + } + portal.callReadReceiptHandler(ctx, login, rrClient, evt, userPortal) +} + +func (portal *Portal) callReadReceiptHandler( + ctx context.Context, + login *UserLogin, + rrClient ReadReceiptHandlingNetworkAPI, + evt *MatrixReadReceipt, + userPortal *database.UserPortal, +) { + if rrClient == nil { + var ok bool + rrClient, ok = login.Client.(ReadReceiptHandlingNetworkAPI) + if !ok { + return + } + } + if userPortal == nil { + userPortal = database.UserPortalFor(login.UserLogin, portal.PortalKey) + } else { + evt.LastRead = userPortal.LastRead + userPortal = userPortal.CopyWithoutValues() + } + err := rrClient.HandleMatrixReadReceipt(ctx, evt) + if err != nil { + zerolog.Ctx(ctx).Err(err).Msg("Failed to handle read receipt") + return + } + userPortal.LastRead = evt.ReadUpTo + err = portal.Bridge.DB.UserPortal.Put(ctx, userPortal) + if err != nil { + zerolog.Ctx(ctx).Err(err).Msg("Failed to save user portal metadata") + } + portal.Bridge.DisappearLoop.StartAllBefore(ctx, portal.MXID, evt.ReadUpTo) +} + +func (portal *Portal) handleMatrixTyping(ctx context.Context, evt *event.Event) EventHandlingResult { + content, ok := evt.Content.Parsed.(*event.TypingEventContent) + if !ok { + return EventHandlingResultFailed + } + portal.currentlyTypingLock.Lock() + defer portal.currentlyTypingLock.Unlock() + slices.Sort(content.UserIDs) + stoppedTyping, startedTyping := exslices.SortedDiff(portal.currentlyTyping, content.UserIDs, func(a, b id.UserID) int { + return strings.Compare(string(a), string(b)) + }) + portal.sendTypings(ctx, stoppedTyping, false) + portal.sendTypings(ctx, startedTyping, true) + portal.currentlyTyping = content.UserIDs + // TODO actual status + return EventHandlingResultSuccess +} + +func (portal *Portal) handleMatrixAIStream(ctx context.Context, sender *User, evt *event.Event) EventHandlingResult { + log := zerolog.Ctx(ctx) + if sender == nil { + log.Error().Msg("Missing sender for Matrix AI stream event") + return EventHandlingResultIgnored + } + login, _, err := portal.FindPreferredLogin(ctx, sender, true) + if err != nil { + log.Err(err).Msg("Failed to get user login to handle Matrix AI stream event") + return EventHandlingResultFailed.WithMSSError(err) + } + var origSender *OrigSender + if login == nil { + if portal.Relay == nil { + return EventHandlingResultIgnored + } + login = portal.Relay + origSender = &OrigSender{ + User: sender, + UserID: sender.MXID, + } + } + content, ok := evt.Content.Parsed.(*event.BeeperAIStreamEventContent) + if !ok { + log.Error().Type("content_type", evt.Content.Parsed).Msg("Unexpected parsed content type") + return EventHandlingResultFailed.WithMSSError(fmt.Errorf("%w: %T", ErrUnexpectedParsedContentType, evt.Content.Parsed)) + } + api, ok := login.Client.(BeeperAIStreamHandlingNetworkAPI) + if !ok { + return EventHandlingResultIgnored.WithMSSError(ErrBeeperAIStreamNotSupported) + } + err = api.HandleMatrixBeeperAIStream(ctx, &MatrixBeeperAIStream{ + Event: evt, + Content: content, + Portal: portal, + OrigSender: origSender, + }) + if err != nil { + log.Err(err).Msg("Failed to handle Matrix AI stream event") + return EventHandlingResultFailed.WithMSSError(err) + } + return EventHandlingResultSuccess.WithMSS() +} + +func (portal *Portal) sendTypings(ctx context.Context, userIDs []id.UserID, typing bool) { + for _, userID := range userIDs { + login, ok := portal.currentlyTypingLogins[userID] + if !ok && !typing { + continue + } else if !ok { + user, err := portal.Bridge.GetUserByMXID(ctx, userID) + if err != nil { + zerolog.Ctx(ctx).Err(err).Stringer("user_id", userID).Msg("Failed to get user to send typing event") + continue + } else if user == nil { + continue + } + login, _, err = portal.FindPreferredLogin(ctx, user, false) + if err != nil { + if !errors.Is(err, ErrNotLoggedIn) { + zerolog.Ctx(ctx).Err(err).Stringer("user_id", userID).Msg("Failed to get user login to send typing event") + } + continue + } else if login == nil { + continue + } else if _, ok = login.Client.(TypingHandlingNetworkAPI); !ok { + continue + } + portal.currentlyTypingLogins[userID] = login + } + if !typing { + delete(portal.currentlyTypingLogins, userID) + } + typingAPI, ok := login.Client.(TypingHandlingNetworkAPI) + if !ok { + continue + } + err := typingAPI.HandleMatrixTyping(ctx, &MatrixTyping{ + Portal: portal, + IsTyping: typing, + Type: TypingTypeText, + }) + if err != nil { + zerolog.Ctx(ctx).Err(err).Stringer("user_id", userID).Msg("Failed to bridge Matrix typing event") + } else { + zerolog.Ctx(ctx).Debug(). + Stringer("user_id", userID). + Bool("typing", typing). + Msg("Sent typing event") + } + } +} + +func (portal *Portal) periodicTypingUpdater() { + // TODO actually call this function + log := portal.Log.With().Str("component", "typing updater").Logger() + ctx := log.WithContext(context.Background()) + for { + // TODO make delay configurable by network connector + time.Sleep(5 * time.Second) + portal.currentlyTypingLock.Lock() + if len(portal.currentlyTyping) == 0 { + portal.currentlyTypingLock.Unlock() + continue + } + for _, userID := range portal.currentlyTyping { + login, ok := portal.currentlyTypingLogins[userID] + if !ok { + continue + } + typingAPI, ok := login.Client.(TypingHandlingNetworkAPI) + if !ok { + continue + } + err := typingAPI.HandleMatrixTyping(ctx, &MatrixTyping{ + Portal: portal, + IsTyping: true, + Type: TypingTypeText, + }) + if err != nil { + log.Err(err).Stringer("user_id", userID).Msg("Failed to repeat Matrix typing event") + } else { + log.Debug(). + Stringer("user_id", userID). + Bool("typing", true). + Msg("Sent repeated typing event") + } + } + portal.currentlyTypingLock.Unlock() + } +} + +func (portal *Portal) checkMessageContentCaps(caps *event.RoomFeatures, content *event.MessageEventContent) error { + switch content.MsgType { + case event.MsgText, event.MsgNotice, event.MsgEmote: + // No checks for now, message length is safer to check after conversion inside connector + case event.MsgLocation: + if caps.LocationMessage.Reject() { + return ErrLocationMessagesNotAllowed + } + case event.MsgImage, event.MsgAudio, event.MsgVideo, event.MsgFile, event.CapMsgSticker: + capMsgType := content.GetCapMsgType() + feat, ok := caps.File[capMsgType] + if !ok { + return ErrUnsupportedMessageType + } + if content.MsgType != event.CapMsgSticker && + content.FileName != "" && + content.Body != content.FileName && + feat.Caption.Reject() { + return ErrCaptionsNotAllowed + } + if content.Info != nil { + dur := time.Duration(content.Info.Duration) * time.Millisecond + if feat.MaxDuration != nil && dur > feat.MaxDuration.Duration { + if capMsgType == event.CapMsgVoice { + return fmt.Errorf("%w: %s supports voice messages up to %s long", ErrVoiceMessageDurationTooLong, portal.Bridge.Network.GetName().DisplayName, exfmt.Duration(feat.MaxDuration.Duration)) + } + return fmt.Errorf("%w: %s is longer than the maximum of %s", ErrMediaDurationTooLong, exfmt.Duration(dur), exfmt.Duration(feat.MaxDuration.Duration)) + } + if feat.MaxSize != 0 && int64(content.Info.Size) > feat.MaxSize { + return fmt.Errorf("%w: %.1f MiB is larger than the maximum of %.1f MiB", ErrMediaTooLarge, float64(content.Info.Size)/1024/1024, float64(feat.MaxSize)/1024/1024) + } + if content.Info.MimeType != "" && feat.GetMimeSupport(content.Info.MimeType).Reject() { + return fmt.Errorf("%w (%s in %s)", ErrUnsupportedMediaType, content.Info.MimeType, capMsgType) + } + } + fallthrough + default: + } + return nil +} + +func (portal *Portal) parseInputTransactionID(origSender *OrigSender, evt *event.Event) networkid.RawTransactionID { + if origSender != nil || !strings.HasPrefix(evt.ID.String(), database.NetworkTxnMXIDPrefix) { + return "" + } + return networkid.RawTransactionID(strings.TrimPrefix(evt.ID.String(), database.NetworkTxnMXIDPrefix)) +} + +func (portal *Portal) handleMatrixMessage(ctx context.Context, sender *UserLogin, origSender *OrigSender, evt *event.Event) EventHandlingResult { + log := zerolog.Ctx(ctx) + var relatesTo *event.RelatesTo + var msgContent *event.MessageEventContent + var pollContent *event.PollStartEventContent + var pollResponseContent *event.PollResponseEventContent + var ok bool + if evt.Type == event.EventUnstablePollStart { + pollContent, ok = evt.Content.Parsed.(*event.PollStartEventContent) + relatesTo = pollContent.RelatesTo + } else if evt.Type == event.EventUnstablePollResponse { + pollResponseContent, ok = evt.Content.Parsed.(*event.PollResponseEventContent) + relatesTo = &pollResponseContent.RelatesTo + } else { + msgContent, ok = evt.Content.Parsed.(*event.MessageEventContent) + relatesTo = msgContent.RelatesTo + if evt.Type == event.EventSticker { + msgContent.MsgType = event.CapMsgSticker + } + if msgContent.MsgType == event.MsgNotice && !portal.Bridge.Config.BridgeNotices { + return EventHandlingResultIgnored.WithMSSError(ErrIgnoringMNotice) + } + } + if !ok { + log.Error().Type("content_type", evt.Content.Parsed).Msg("Unexpected parsed content type") + return EventHandlingResultFailed. + WithMSSError(fmt.Errorf("%w: %T", ErrUnexpectedParsedContentType, evt.Content.Parsed)) + } + caps := sender.Client.GetCapabilities(ctx, portal) + + if relatesTo.GetReplaceID() != "" { + if msgContent == nil { + log.Warn().Msg("Ignoring edit of poll") + return EventHandlingResultFailed.WithMSSError(fmt.Errorf("%w of polls", ErrEditsNotSupported)) + } + return portal.handleMatrixEdit(ctx, sender, origSender, evt, msgContent, caps) + } + var err error + if origSender != nil { + if msgContent == nil { + log.Debug().Msg("Ignoring poll event from relayed user") + return EventHandlingResultIgnored.WithMSSError(ErrIgnoringPollFromRelayedUser) + } + if !caps.PerMessageProfileRelay { + msgContent, err = portal.Bridge.Config.Relay.FormatMessage(msgContent, origSender) + if err != nil { + log.Err(err).Msg("Failed to format message for relaying") + return EventHandlingResultFailed.WithMSSError(err) + } + } + } + if msgContent != nil { + if err = portal.checkMessageContentCaps(caps, msgContent); err != nil { + return EventHandlingResultFailed.WithMSSError(err) + } + } else if pollResponseContent != nil || pollContent != nil { + if _, ok = sender.Client.(PollHandlingNetworkAPI); !ok { + log.Debug().Msg("Ignoring poll event as network connector doesn't implement PollHandlingNetworkAPI") + return EventHandlingResultIgnored.WithMSSError(ErrPollsNotSupported) + } + } + + var threadRoot, replyTo, voteTo *database.Message + if evt.Type == event.EventUnstablePollResponse { + voteTo, err = portal.Bridge.DB.Message.GetPartByMXID(ctx, relatesTo.GetReferenceID()) + if err != nil { + log.Err(err).Msg("Failed to get poll target message from database") + // TODO send status + return EventHandlingResultFailed + } else if voteTo == nil { + log.Warn().Stringer("vote_to_id", relatesTo.GetReferenceID()).Msg("Poll target message not found") + // TODO send status + return EventHandlingResultFailed + } + } + var replyToID id.EventID + threadRootID := relatesTo.GetThreadParent() + if caps.Thread.Partial() { + replyToID = relatesTo.GetNonFallbackReplyTo() + if threadRootID != "" { + threadRoot, err = portal.Bridge.DB.Message.GetPartByMXID(ctx, threadRootID) + if err != nil { + log.Err(err).Msg("Failed to get thread root message from database") + } else if threadRoot == nil { + log.Warn().Stringer("thread_root_id", threadRootID).Msg("Thread root message not found") + } + } + } else { + replyToID = relatesTo.GetReplyTo() + } + if replyToID != "" && (caps.Reply.Partial() || caps.Thread.Partial()) { + replyTo, err = portal.Bridge.DB.Message.GetPartByMXID(ctx, replyToID) + if err != nil { + log.Err(err).Msg("Failed to get reply target message from database") + } else if replyTo == nil { + log.Warn().Stringer("reply_to_id", replyToID).Msg("Reply target message not found") + } else { + // Support replying to threads from non-thread-capable clients. + // The fallback happens if the message is not a Matrix thread and either + // * the replied-to message is in a thread, or + // * the network only supports threads (assume the user wants to start a new thread) + if caps.Thread.Partial() && threadRoot == nil && (replyTo.ThreadRoot != "" || !caps.Reply.Partial()) { + threadRootRemoteID := replyTo.ThreadRoot + if threadRootRemoteID == "" { + threadRootRemoteID = replyTo.ID + } + threadRoot, err = portal.Bridge.DB.Message.GetFirstThreadMessage(ctx, portal.PortalKey, threadRootRemoteID) + if err != nil { + log.Err(err).Msg("Failed to get thread root message from database (via reply fallback)") + } + } + if !caps.Reply.Partial() { + replyTo = nil + } + } + } + var messageTimer *event.BeeperDisappearingTimer + if msgContent != nil { + messageTimer = msgContent.BeeperDisappearingTimer + } + if messageTimer != nil && *portal.Disappear.ToEventContent() != *messageTimer { + log.Warn(). + Any("event_timer", messageTimer). + Any("portal_timer", portal.Disappear.ToEventContent()). + Msg("Mismatching disappearing timer in event") + } + + wrappedMsgEvt := &MatrixMessage{ + MatrixEventBase: MatrixEventBase[*event.MessageEventContent]{ + Event: evt, + Content: msgContent, + OrigSender: origSender, + Portal: portal, + + InputTransactionID: portal.parseInputTransactionID(origSender, evt), + }, + ThreadRoot: threadRoot, + ReplyTo: replyTo, + } + if portal.Bridge.Config.DeduplicateMatrixMessages { + if part, err := portal.Bridge.DB.Message.GetPartByTxnID(ctx, portal.Receiver, evt.ID, wrappedMsgEvt.InputTransactionID); err != nil { + log.Err(err).Msg("Failed to check db if message is already sent") + } else if part != nil { + log.Debug(). + Stringer("message_mxid", part.MXID). + Stringer("input_event_id", evt.ID). + Msg("Message already sent, ignoring") + return EventHandlingResultIgnored + } + } + + err = portal.autoAcceptMessageRequest(ctx, evt, sender, origSender, caps) + if err != nil { + log.Warn().Err(err).Msg("Failed to auto-accept message request on message") + // TODO stop processing? + } + + var resp *MatrixMessageResponse + if msgContent != nil { + resp, err = sender.Client.HandleMatrixMessage(ctx, wrappedMsgEvt) + } else if pollContent != nil { + resp, err = sender.Client.(PollHandlingNetworkAPI).HandleMatrixPollStart(ctx, &MatrixPollStart{ + MatrixMessage: *wrappedMsgEvt, + Content: pollContent, + }) + } else if pollResponseContent != nil { + resp, err = sender.Client.(PollHandlingNetworkAPI).HandleMatrixPollVote(ctx, &MatrixPollVote{ + MatrixMessage: *wrappedMsgEvt, + VoteTo: voteTo, + Content: pollResponseContent, + }) + } else { + log.Error().Msg("Failed to handle Matrix message: all contents are nil?") + return EventHandlingResultFailed.WithMSSError(fmt.Errorf("all contents are nil")) + } + if err != nil { + log.Err(err).Msg("Failed to handle Matrix message") + return EventHandlingResultFailed.WithMSSError(err) + } + message := wrappedMsgEvt.fillDBMessage(resp.DB) + if resp.Pending { + for _, save := range wrappedMsgEvt.pendingSaves { + save.ackedAt = time.Now() + } + } else { + if resp.DB == nil { + log.Error().Msg("Network connector didn't return a message to save") + } else { + if portal.Bridge.Config.OutgoingMessageReID { + message.MXID = portal.Bridge.Matrix.GenerateDeterministicEventID(portal.MXID, portal.PortalKey, message.ID, message.PartID) + } + // Hack to ensure the ghost row exists + // TODO move to better place (like login) + portal.Bridge.GetGhostByID(ctx, message.SenderID) + err = portal.Bridge.DB.Message.Insert(ctx, message) + if err != nil { + log.Err(err).Msg("Failed to save message to database") + } else if resp.PostSave != nil { + resp.PostSave(ctx, message) + } + if resp.RemovePending != "" { + portal.outgoingMessagesLock.Lock() + delete(portal.outgoingMessages, resp.RemovePending) + portal.outgoingMessagesLock.Unlock() + } + } + portal.sendSuccessStatus(ctx, evt, resp.StreamOrder, message.MXID) + } + ds := portal.Disappear + if messageTimer != nil { + ds = database.DisappearingSettingFromEvent(messageTimer) + } + if ds.Type != event.DisappearingTypeNone { + go portal.Bridge.DisappearLoop.Add(ctx, &database.DisappearingMessage{ + RoomID: portal.MXID, + EventID: message.MXID, + Timestamp: message.Timestamp, + DisappearingSetting: ds.StartingAt(message.Timestamp), + }) + } + if resp.Pending { + // Not exactly queued, but not finished either + return EventHandlingResultQueued + } + return EventHandlingResultSuccess.WithEventID(message.MXID).WithStreamOrder(resp.StreamOrder) +} + +// AddPendingToIgnore adds a transaction ID that should be ignored if encountered as a new message. +// +// This should be used when the network connector will return the real message ID from HandleMatrixMessage. +// The [MatrixMessageResponse] should include RemovePending with the transaction ID sto remove it from the lit +// after saving to database. +// +// See also: [MatrixMessage.AddPendingToSave] +func (evt *MatrixMessage) AddPendingToIgnore(txnID networkid.TransactionID) { + evt.Portal.outgoingMessagesLock.Lock() + evt.Portal.outgoingMessages[txnID] = &outgoingMessage{ + ignore: true, + } + evt.Portal.outgoingMessagesLock.Unlock() +} + +// AddPendingToSave adds a transaction ID that should be processed and pointed at the existing event if encountered. +// +// This should be used when the network connector returns `Pending: true` from HandleMatrixMessage, +// i.e. when the network connector does not know the message ID at the end of the handler. +// The [MatrixMessageResponse] should set Pending to true to prevent saving the returned message to the database. +// +// The provided function will be called when the message is encountered. +func (evt *MatrixMessage) AddPendingToSave(message *database.Message, txnID networkid.TransactionID, handleEcho RemoteEchoHandler) { + pending := &outgoingMessage{ + db: evt.fillDBMessage(message), + evt: evt.Event, + handle: handleEcho, + } + evt.Portal.outgoingMessagesLock.Lock() + evt.Portal.outgoingMessages[txnID] = pending + evt.pendingSaves = append(evt.pendingSaves, pending) + evt.Portal.outgoingMessagesLock.Unlock() +} + +// RemovePending removes a transaction ID from the list of pending messages. +// This should only be called if sending the message fails. +func (evt *MatrixMessage) RemovePending(txnID networkid.TransactionID) { + evt.Portal.outgoingMessagesLock.Lock() + pendingSave := evt.Portal.outgoingMessages[txnID] + if pendingSave != nil { + evt.pendingSaves = slices.DeleteFunc(evt.pendingSaves, func(save *outgoingMessage) bool { + return save == pendingSave + }) + } + delete(evt.Portal.outgoingMessages, txnID) + evt.Portal.outgoingMessagesLock.Unlock() +} + +func (evt *MatrixMessage) fillDBMessage(message *database.Message) *database.Message { + if message == nil { + message = &database.Message{} + } + if message.MXID == "" { + message.MXID = evt.Event.ID + } + if message.Room.ID == "" { + message.Room = evt.Portal.PortalKey + } + if message.Timestamp.IsZero() { + message.Timestamp = time.UnixMilli(evt.Event.Timestamp) + } + if message.ReplyTo.MessageID == "" && evt.ReplyTo != nil { + message.ReplyTo.MessageID = evt.ReplyTo.ID + message.ReplyTo.PartID = &evt.ReplyTo.PartID + } + if message.ThreadRoot == "" && evt.ThreadRoot != nil { + message.ThreadRoot = evt.ThreadRoot.ID + if evt.ThreadRoot.ThreadRoot != "" { + message.ThreadRoot = evt.ThreadRoot.ThreadRoot + } + } + if message.SenderMXID == "" { + message.SenderMXID = evt.Event.Sender + } + if message.SendTxnID != "" { + message.SendTxnID = evt.InputTransactionID + } + return message +} + +func (portal *Portal) pendingMessageTimeoutLoop(ctx context.Context, cfg *OutgoingTimeoutConfig) { + ticker := time.NewTicker(cfg.CheckInterval) + defer ticker.Stop() + for { + select { + case <-ticker.C: + portal.checkPendingMessages(ctx, cfg) + case <-ctx.Done(): + return + } + } +} + +func (portal *Portal) checkPendingMessages(ctx context.Context, cfg *OutgoingTimeoutConfig) { + portal.outgoingMessagesLock.Lock() + defer portal.outgoingMessagesLock.Unlock() + for _, msg := range portal.outgoingMessages { + if msg.evt != nil && !msg.timeouted { + if cfg.NoEchoTimeout > 0 && !msg.ackedAt.IsZero() && time.Since(msg.ackedAt) > cfg.NoEchoTimeout { + msg.timeouted = true + portal.sendErrorStatus(ctx, msg.evt, ErrRemoteEchoTimeout.WithMessage(cfg.NoEchoMessage)) + } else if cfg.NoAckTimeout > 0 && time.Since(msg.db.Timestamp) > cfg.NoAckTimeout { + msg.timeouted = true + portal.sendErrorStatus(ctx, msg.evt, ErrRemoteAckTimeout.WithMessage(cfg.NoAckMessage)) + } + } + } +} + +func (portal *Portal) handleMatrixEdit( + ctx context.Context, + sender *UserLogin, + origSender *OrigSender, + evt *event.Event, + content *event.MessageEventContent, + caps *event.RoomFeatures, +) EventHandlingResult { + log := zerolog.Ctx(ctx) + editTargetID := content.RelatesTo.GetReplaceID() + log.UpdateContext(func(c zerolog.Context) zerolog.Context { + return c.Stringer("edit_target_mxid", editTargetID) + }) + if content.NewContent != nil { + content = content.NewContent + if evt.Type == event.EventSticker { + content.MsgType = event.CapMsgSticker + } + } + if origSender != nil { + var err error + content, err = portal.Bridge.Config.Relay.FormatMessage(content, origSender) + if err != nil { + log.Err(err).Msg("Failed to format message for relaying") + return EventHandlingResultFailed.WithMSSError(err) + } + } + + editingAPI, ok := sender.Client.(EditHandlingNetworkAPI) + if !ok { + log.Debug().Msg("Ignoring edit as network connector doesn't implement EditHandlingNetworkAPI") + return EventHandlingResultIgnored.WithMSSError(ErrEditsNotSupported) + } else if !caps.Edit.Partial() { + log.Debug().Msg("Ignoring edit as room doesn't support edits") + return EventHandlingResultIgnored.WithMSSError(ErrEditsNotSupportedInPortal) + } else if err := portal.checkMessageContentCaps(caps, content); err != nil { + return EventHandlingResultFailed.WithMSSError(err) + } + editTarget, err := portal.Bridge.DB.Message.GetPartByMXID(ctx, editTargetID) + if err != nil { + log.Err(err).Msg("Failed to get edit target message from database") + return EventHandlingResultFailed.WithMSSError(fmt.Errorf("%w: failed to get edit target: %w", ErrDatabaseError, err)) + } else if editTarget == nil { + log.Warn().Msg("Edit target message not found in database") + return EventHandlingResultFailed.WithMSSError(fmt.Errorf("edit %w", ErrTargetMessageNotFound)) + } else if caps.EditMaxAge != nil && caps.EditMaxAge.Duration > 0 && time.Since(editTarget.Timestamp) > caps.EditMaxAge.Duration { + return EventHandlingResultFailed.WithMSSError(ErrEditTargetTooOld) + } else if caps.EditMaxCount > 0 && editTarget.EditCount >= caps.EditMaxCount { + return EventHandlingResultFailed.WithMSSError(ErrEditTargetTooManyEdits) + } + log.UpdateContext(func(c zerolog.Context) zerolog.Context { + return c.Str("edit_target_remote_id", string(editTarget.ID)) + }) + err = editingAPI.HandleMatrixEdit(ctx, &MatrixEdit{ + MatrixEventBase: MatrixEventBase[*event.MessageEventContent]{ + Event: evt, + Content: content, + OrigSender: origSender, + Portal: portal, + + InputTransactionID: portal.parseInputTransactionID(origSender, evt), + }, + EditTarget: editTarget, + }) + if err != nil { + log.Err(err).Msg("Failed to handle Matrix edit") + return EventHandlingResultFailed.WithMSSError(err) + } + err = portal.Bridge.DB.Message.Update(ctx, editTarget) + if err != nil { + log.Err(err).Msg("Failed to save message to database after editing") + } + // TODO allow returning stream order from HandleMatrixEdit + portal.sendSuccessStatus(ctx, evt, 0, "") + return EventHandlingResultSuccess +} + +func (portal *Portal) handleMatrixReaction(ctx context.Context, sender *UserLogin, evt *event.Event) (handleRes EventHandlingResult) { + log := zerolog.Ctx(ctx) + reactingAPI, ok := sender.Client.(ReactionHandlingNetworkAPI) + if !ok { + log.Debug().Msg("Ignoring reaction as network connector doesn't implement ReactionHandlingNetworkAPI") + return EventHandlingResultIgnored.WithMSSError(ErrReactionsNotSupported) + } + content, ok := evt.Content.Parsed.(*event.ReactionEventContent) + if !ok { + log.Error().Type("content_type", evt.Content.Parsed).Msg("Unexpected parsed content type") + return EventHandlingResultFailed.WithMSSError(fmt.Errorf("%w: %T", ErrUnexpectedParsedContentType, evt.Content.Parsed)) + } + log.UpdateContext(func(c zerolog.Context) zerolog.Context { + return c.Stringer("reaction_target_mxid", content.RelatesTo.EventID) + }) + reactionTarget, err := portal.Bridge.DB.Message.GetPartByMXID(ctx, content.RelatesTo.EventID) + if err != nil { + log.Err(err).Msg("Failed to get reaction target message from database") + return EventHandlingResultFailed.WithMSSError(fmt.Errorf("%w: failed to get reaction target: %w", ErrDatabaseError, err)) + } else if reactionTarget == nil { + log.Warn().Msg("Reaction target message not found in database") + return EventHandlingResultFailed.WithMSSError(fmt.Errorf("reaction %w", ErrTargetMessageNotFound)) + } + caps := sender.Client.GetCapabilities(ctx, portal) + err = portal.autoAcceptMessageRequest(ctx, evt, sender, nil, caps) + if err != nil { + log.Warn().Err(err).Msg("Failed to auto-accept message request on reaction") + // TODO stop processing? + } + log.UpdateContext(func(c zerolog.Context) zerolog.Context { + return c.Str("reaction_target_remote_id", string(reactionTarget.ID)) + }) + react := &MatrixReaction{ + MatrixEventBase: MatrixEventBase[*event.ReactionEventContent]{ + Event: evt, + Content: content, + Portal: portal, + + InputTransactionID: portal.parseInputTransactionID(nil, evt), + }, + TargetMessage: reactionTarget, + } + preResp, err := reactingAPI.PreHandleMatrixReaction(ctx, react) + if err != nil { + log.Err(err).Msg("Failed to pre-handle Matrix reaction") + return EventHandlingResultFailed.WithMSSError(err) + } + var deterministicID id.EventID + if portal.Bridge.Config.OutgoingMessageReID { + deterministicID = portal.Bridge.Matrix.GenerateReactionEventID(portal.MXID, reactionTarget, preResp.SenderID, preResp.EmojiID) + } + defer func() { + // Do this in a defer so that it happens after any potential defer calls to removeOutdatedReaction + if handleRes.Success { + portal.sendSuccessStatus(ctx, evt, 0, deterministicID) + } + }() + removeOutdatedReaction := func(oldReact *database.Reaction, deleteDB bool) { + if !handleRes.Success { + return + } + _, err := portal.Bridge.Bot.SendMessage(ctx, portal.MXID, event.EventRedaction, &event.Content{ + Parsed: &event.RedactionEventContent{ + Redacts: oldReact.MXID, + }, + }, nil) + if err != nil { + log.Err(err).Msg("Failed to remove old reaction") + } + if deleteDB { + err = portal.Bridge.DB.Reaction.Delete(ctx, oldReact) + if err != nil { + log.Err(err).Msg("Failed to delete old reaction from database") + } + } + } + existing, err := portal.Bridge.DB.Reaction.GetByID(ctx, portal.Receiver, reactionTarget.ID, reactionTarget.PartID, preResp.SenderID, preResp.EmojiID) + if err != nil { + log.Err(err).Msg("Failed to check if reaction is a duplicate") + return EventHandlingResultFailed.WithMSSError(fmt.Errorf("%w: failed to check for existing reaction: %w", ErrDatabaseError, err)) + } else if existing != nil { + if existing.EmojiID != "" || existing.Emoji == preResp.Emoji { + log.Debug().Msg("Ignoring duplicate reaction") + portal.sendSuccessStatus(ctx, evt, 0, deterministicID) + return EventHandlingResultIgnored.WithEventID(deterministicID) + } + react.ReactionToOverride = existing + defer removeOutdatedReaction(existing, false) + } + react.PreHandleResp = &preResp + if preResp.MaxReactions > 0 { + allReactions, err := portal.Bridge.DB.Reaction.GetAllToMessageBySender(ctx, portal.Receiver, reactionTarget.ID, preResp.SenderID) + if err != nil { + log.Err(err).Msg("Failed to get all reactions to message by sender") + return EventHandlingResultFailed.WithMSSError(fmt.Errorf("%w: failed to get previous reactions: %w", ErrDatabaseError, err)) + } + if len(allReactions) < preResp.MaxReactions { + react.ExistingReactionsToKeep = allReactions + } else { + // Keep n-1 previous reactions and remove the rest + react.ExistingReactionsToKeep = allReactions[:preResp.MaxReactions-1] + for _, oldReaction := range allReactions[preResp.MaxReactions-1:] { + if existing != nil && oldReaction.EmojiID == existing.EmojiID { + // Don't double-delete on networks that only allow one emoji + continue + } + // Intentionally defer in a loop, there won't be that many items, + // and we want all of them to be done after this function completes successfully + //goland:noinspection GoDeferInLoop + defer removeOutdatedReaction(oldReaction, true) + } + } + } + dbReaction, err := reactingAPI.HandleMatrixReaction(ctx, react) + if err != nil { + log.Err(err).Msg("Failed to handle Matrix reaction") + return EventHandlingResultFailed.WithMSSError(err) + } + if dbReaction == nil { + dbReaction = &database.Reaction{} + } + // Fill all fields that are known to allow omitting them in connector code + if dbReaction.Room.ID == "" { + dbReaction.Room = portal.PortalKey + } + if dbReaction.MessageID == "" { + dbReaction.MessageID = reactionTarget.ID + dbReaction.MessagePartID = reactionTarget.PartID + } + if deterministicID != "" { + dbReaction.MXID = deterministicID + } else if dbReaction.MXID == "" { + dbReaction.MXID = evt.ID + } + if dbReaction.Timestamp.IsZero() { + dbReaction.Timestamp = time.UnixMilli(evt.Timestamp) + } + if preResp.EmojiID == "" && dbReaction.EmojiID == "" { + if dbReaction.Emoji == "" { + dbReaction.Emoji = preResp.Emoji + } + } else if dbReaction.EmojiID == "" { + dbReaction.EmojiID = preResp.EmojiID + } + if dbReaction.SenderID == "" { + dbReaction.SenderID = preResp.SenderID + } + if dbReaction.SenderMXID == "" { + dbReaction.SenderMXID = evt.Sender + } + err = portal.Bridge.DB.Reaction.Upsert(ctx, dbReaction) + if err != nil { + log.Err(err).Msg("Failed to save reaction to database") + } + return EventHandlingResultSuccess.WithEventID(deterministicID) +} + +func handleMatrixRoomMeta[APIType any, ContentType any]( + portal *Portal, + ctx context.Context, + sender *UserLogin, + origSender *OrigSender, + evt *event.Event, + isStateRequest bool, + fn func(APIType, context.Context, *MatrixRoomMeta[ContentType]) (bool, error), +) EventHandlingResult { + if evt.StateKey == nil || *evt.StateKey != "" { + return EventHandlingResultFailed.WithMSSError(ErrInvalidStateKey) + } + //caps := sender.Client.GetCapabilities(ctx, portal) + //if stateCap, ok := caps.State[evt.Type.Type]; !ok || stateCap.Level <= event.CapLevelUnsupported { + // return EventHandlingResultIgnored.WithMSSError(fmt.Errorf("%s %w", evt.Type.Type, ErrRoomMetadataNotAllowed)) + //} + api, ok := sender.Client.(APIType) + if !ok { + return EventHandlingResultIgnored.WithMSSError(fmt.Errorf("%w of type %s", ErrRoomMetadataNotSupported, evt.Type)) + } + log := zerolog.Ctx(ctx) + content, ok := evt.Content.Parsed.(ContentType) + if !ok { + log.Error().Type("content_type", evt.Content.Parsed).Msg("Unexpected parsed content type") + return EventHandlingResultFailed.WithMSSError(fmt.Errorf("%w: %T", ErrUnexpectedParsedContentType, evt.Content.Parsed)) + } + switch typedContent := evt.Content.Parsed.(type) { + case *event.RoomNameEventContent: + if typedContent.Name == portal.Name { + portal.sendSuccessStatus(ctx, evt, 0, "") + return EventHandlingResultIgnored + } + case *event.TopicEventContent: + if typedContent.Topic == portal.Topic { + portal.sendSuccessStatus(ctx, evt, 0, "") + return EventHandlingResultIgnored + } + case *event.RoomAvatarEventContent: + if typedContent.URL == portal.AvatarMXC { + portal.sendSuccessStatus(ctx, evt, 0, "") + return EventHandlingResultIgnored + } + case *event.BeeperDisappearingTimer: + if typedContent.Type == event.DisappearingTypeNone || typedContent.Timer.Duration <= 0 { + typedContent.Type = event.DisappearingTypeNone + typedContent.Timer.Duration = 0 + } + if typedContent.Type == portal.Disappear.Type && typedContent.Timer.Duration == portal.Disappear.Timer { + portal.sendSuccessStatus(ctx, evt, 0, "") + return EventHandlingResultIgnored + } + if !sender.Client.GetCapabilities(ctx, portal).DisappearingTimer.Supports(typedContent) { + return EventHandlingResultFailed.WithMSSError(ErrDisappearingTimerUnsupported) + } + } + var prevContent ContentType + if evt.Unsigned.PrevContent != nil { + _ = evt.Unsigned.PrevContent.ParseRaw(evt.Type) + prevContent, _ = evt.Unsigned.PrevContent.Parsed.(ContentType) + } + + changed, err := fn(api, ctx, &MatrixRoomMeta[ContentType]{ + MatrixEventBase: MatrixEventBase[ContentType]{ + Event: evt, + Content: content, + Portal: portal, + OrigSender: origSender, + + InputTransactionID: portal.parseInputTransactionID(origSender, evt), + }, + IsStateRequest: isStateRequest, + PrevContent: prevContent, + }) + if err != nil { + log.Err(err).Msg("Failed to handle Matrix room metadata") + return EventHandlingResultFailed.WithMSSError(err) + } + if changed { + if evt.Type != event.StateBeeperDisappearingTimer { + portal.UpdateBridgeInfo(ctx) + } + err = portal.Save(ctx) + if err != nil { + log.Err(err).Msg("Failed to save portal after updating room metadata") + } + } + return EventHandlingResultSuccess.WithMSS() +} + +func handleMatrixAccountData[APIType any, ContentType any]( + portal *Portal, ctx context.Context, sender *UserLogin, evt *event.Event, + fn func(APIType, context.Context, *MatrixRoomMeta[ContentType]) error, +) EventHandlingResult { + api, ok := sender.Client.(APIType) + if !ok { + return EventHandlingResultIgnored + } + log := zerolog.Ctx(ctx) + content, ok := evt.Content.Parsed.(ContentType) + if !ok { + log.Error().Type("content_type", evt.Content.Parsed).Msg("Unexpected parsed content type") + return EventHandlingResultFailed.WithError(fmt.Errorf("%w: %T", ErrUnexpectedParsedContentType, evt.Content.Parsed)) + } + var prevContent ContentType + if evt.Unsigned.PrevContent != nil { + _ = evt.Unsigned.PrevContent.ParseRaw(evt.Type) + prevContent, _ = evt.Unsigned.PrevContent.Parsed.(ContentType) + } + + err := fn(api, ctx, &MatrixRoomMeta[ContentType]{ + MatrixEventBase: MatrixEventBase[ContentType]{ + Event: evt, + Content: content, + Portal: portal, + }, + PrevContent: prevContent, + }) + if err != nil { + log.Err(err).Msg("Failed to handle Matrix room account data") + return EventHandlingResultFailed.WithError(err) + } + return EventHandlingResultSuccess +} + +func (portal *Portal) getTargetUser(ctx context.Context, userID id.UserID) (GhostOrUserLogin, error) { + if targetGhost, err := portal.Bridge.GetGhostByMXID(ctx, userID); err != nil { + return nil, fmt.Errorf("failed to get ghost: %w", err) + } else if targetGhost != nil { + return targetGhost, nil + } else if targetUser, err := portal.Bridge.GetUserByMXID(ctx, userID); err != nil { + return nil, fmt.Errorf("failed to get user: %w", err) + } else if targetUserLogin, _, err := portal.FindPreferredLogin(ctx, targetUser, false); err != nil { + return nil, fmt.Errorf("failed to find preferred login: %w", err) + } else if targetUserLogin != nil { + return targetUserLogin, nil + } else { + // Return raw nil as a separate case to ensure a typed nil isn't returned + return nil, nil + } +} + +func (portal *Portal) handleMatrixAcceptMessageRequest( + ctx context.Context, + sender *UserLogin, + origSender *OrigSender, + evt *event.Event, +) EventHandlingResult { + if origSender != nil { + return EventHandlingResultFailed.WithMSSError(ErrIgnoringAcceptRequestRelayedUser) + } + log := zerolog.Ctx(ctx) + content, ok := evt.Content.Parsed.(*event.BeeperAcceptMessageRequestEventContent) + if !ok { + log.Error().Type("content_type", evt.Content.Parsed).Msg("Unexpected parsed content type") + return EventHandlingResultFailed.WithMSSError(fmt.Errorf("%w: %T", ErrUnexpectedParsedContentType, evt.Content.Parsed)) + } + api, ok := sender.Client.(MessageRequestAcceptingNetworkAPI) + if !ok { + return EventHandlingResultIgnored.WithMSSError(ErrDeleteChatNotSupported) + } + err := api.HandleMatrixAcceptMessageRequest(ctx, &MatrixAcceptMessageRequest{ + Event: evt, + Content: content, + Portal: portal, + }) + if err != nil { + log.Err(err).Msg("Failed to handle Matrix accept message request") + return EventHandlingResultFailed.WithMSSError(err) + } + if portal.MessageRequest { + portal.MessageRequest = false + portal.UpdateBridgeInfo(ctx) + err = portal.Save(ctx) + if err != nil { + log.Err(err).Msg("Failed to save portal after accepting message request") + } + } + return EventHandlingResultSuccess.WithMSS() +} + +func (portal *Portal) autoAcceptMessageRequest( + ctx context.Context, evt *event.Event, sender *UserLogin, origSender *OrigSender, caps *event.RoomFeatures, +) error { + if !portal.MessageRequest || caps.MessageRequest == nil || caps.MessageRequest.AcceptWithMessage == event.CapLevelFullySupported { + return nil + } + mran, ok := sender.Client.(MessageRequestAcceptingNetworkAPI) + if !ok { + return nil + } + err := mran.HandleMatrixAcceptMessageRequest(ctx, &MatrixAcceptMessageRequest{ + Event: evt, + Content: &event.BeeperAcceptMessageRequestEventContent{ + IsImplicit: true, + }, + Portal: portal, + OrigSender: origSender, + }) + if err != nil { + return err + } + if portal.MessageRequest { + portal.MessageRequest = false + portal.UpdateBridgeInfo(ctx) + err = portal.Save(ctx) + if err != nil { + zerolog.Ctx(ctx).Err(err).Msg("Failed to save portal after accepting message request") + } + } + return nil +} + +func (portal *Portal) handleMatrixDeleteChat( + ctx context.Context, + sender *UserLogin, + origSender *OrigSender, + evt *event.Event, +) EventHandlingResult { + if origSender != nil { + return EventHandlingResultFailed.WithMSSError(ErrIgnoringDeleteChatRelayedUser) + } + log := zerolog.Ctx(ctx) + content, ok := evt.Content.Parsed.(*event.BeeperChatDeleteEventContent) + if !ok { + log.Error().Type("content_type", evt.Content.Parsed).Msg("Unexpected parsed content type") + return EventHandlingResultFailed.WithMSSError(fmt.Errorf("%w: %T", ErrUnexpectedParsedContentType, evt.Content.Parsed)) + } + api, ok := sender.Client.(DeleteChatHandlingNetworkAPI) + if !ok { + return EventHandlingResultIgnored.WithMSSError(ErrDeleteChatNotSupported) + } + err := api.HandleMatrixDeleteChat(ctx, &MatrixDeleteChat{ + Event: evt, + Content: content, + Portal: portal, + }) + if err != nil { + log.Err(err).Msg("Failed to handle Matrix chat delete") + return EventHandlingResultFailed.WithMSSError(err) + } + if portal.Receiver == "" { + _, others, err := portal.findOtherLogins(ctx, sender) + if err != nil { + log.Err(err).Msg("Failed to check if portal has other logins") + return EventHandlingResultFailed.WithError(err) + } else if len(others) > 0 { + log.Debug().Msg("Not deleting portal after chat delete as other logins are present") + return EventHandlingResultSuccess + } + } + err = portal.Delete(ctx) + if err != nil { + log.Err(err).Msg("Failed to delete portal from database") + return EventHandlingResultFailed.WithMSSError(err) + } + err = portal.Bridge.Bot.DeleteRoom(ctx, portal.MXID, false) + if err != nil { + log.Err(err).Msg("Failed to delete Matrix room") + return EventHandlingResultFailed.WithMSSError(err) + } + // No MSS here as the portal was deleted + return EventHandlingResultSuccess +} + +func (portal *Portal) handleMatrixMembership( + ctx context.Context, + sender *UserLogin, + origSender *OrigSender, + evt *event.Event, + isStateRequest bool, +) EventHandlingResult { + if evt.StateKey == nil { + return EventHandlingResultFailed.WithMSSError(ErrInvalidStateKey) + } + log := zerolog.Ctx(ctx) + content, ok := evt.Content.Parsed.(*event.MemberEventContent) + if !ok { + log.Error().Type("content_type", evt.Content.Parsed).Msg("Unexpected parsed content type") + return EventHandlingResultFailed.WithMSSError(fmt.Errorf("%w: %T", ErrUnexpectedParsedContentType, evt.Content.Parsed)) + } + prevContent := &event.MemberEventContent{Membership: event.MembershipLeave} + if evt.Unsigned.PrevContent != nil { + _ = evt.Unsigned.PrevContent.ParseRaw(evt.Type) + prevContent, _ = evt.Unsigned.PrevContent.Parsed.(*event.MemberEventContent) + } + log.UpdateContext(func(c zerolog.Context) zerolog.Context { + return c. + Str("membership", string(content.Membership)). + Str("prev_membership", string(prevContent.Membership)). + Str("target_user_id", evt.GetStateKey()) + }) + api, ok := sender.Client.(MembershipHandlingNetworkAPI) + if !ok { + return EventHandlingResultIgnored.WithMSSError(ErrMembershipNotSupported) + } + targetMXID := id.UserID(*evt.StateKey) + isSelf := sender.User.MXID == targetMXID + target, err := portal.getTargetUser(ctx, targetMXID) + if err != nil { + log.Err(err).Msg("Failed to get member event target") + return EventHandlingResultFailed.WithMSSError(err) + } + + membershipChangeType := MembershipChangeType{From: prevContent.Membership, To: content.Membership, IsSelf: isSelf} + if !portal.Bridge.Config.BridgeMatrixLeave && membershipChangeType == Leave { + log.Debug().Msg("Dropping leave event") + return EventHandlingResultIgnored //.WithMSSError(ErrIgnoringLeaveEvent) + } + targetGhost, _ := target.(*Ghost) + membershipChange := &MatrixMembershipChange{ + MatrixRoomMeta: MatrixRoomMeta[*event.MemberEventContent]{ + MatrixEventBase: MatrixEventBase[*event.MemberEventContent]{ + Event: evt, + Content: content, + Portal: portal, + OrigSender: origSender, + + InputTransactionID: portal.parseInputTransactionID(origSender, evt), + }, + IsStateRequest: isStateRequest, + PrevContent: prevContent, + }, + Target: target, + Type: membershipChangeType, + } + res, err := api.HandleMatrixMembership(ctx, membershipChange) + if err != nil { + log.Err(err).Msg("Failed to handle Matrix membership change") + return EventHandlingResultFailed.WithMSSError(err) + } + didRedirectInvite := membershipChangeType == Invite && + targetGhost != nil && + res != nil && + res.RedirectTo != "" && + res.RedirectTo != targetGhost.ID + if didRedirectInvite { + log.Debug(). + Str("orig_id", string(targetGhost.ID)). + Str("redirect_id", string(res.RedirectTo)). + Msg("Invite was redirected to different ghost") + var redirectGhost *Ghost + redirectGhost, err = portal.Bridge.GetGhostByID(ctx, res.RedirectTo) + if err != nil { + log.Err(err).Msg("Failed to get redirect target ghost") + return EventHandlingResultFailed.WithError(err) + } + if !isStateRequest { + portal.sendRoomMeta( + ctx, + sender.User.DoublePuppet(ctx), + time.UnixMilli(evt.Timestamp), + event.StateMember, + evt.GetStateKey(), + &event.MemberEventContent{ + Membership: event.MembershipLeave, + Reason: fmt.Sprintf("Invite redirected to %s", res.RedirectTo), + }, + true, + nil, + ) + } + portal.sendRoomMeta( + ctx, + sender.User.DoublePuppet(ctx), + time.UnixMilli(evt.Timestamp), + event.StateMember, + redirectGhost.Intent.GetMXID().String(), + content, + false, + nil, + ) + } + return EventHandlingResultSuccess.WithMSS().WithSkipStateEcho(didRedirectInvite) +} + +func makePLChange(old, new int, newIsSet bool) *SinglePowerLevelChange { + if old == new { + return nil + } + return &SinglePowerLevelChange{OrigLevel: old, NewLevel: new, NewIsSet: newIsSet} +} + +func getUniqueKeys[Key comparable, Value any](maps ...map[Key]Value) map[Key]struct{} { + unique := make(map[Key]struct{}) + for _, m := range maps { + for k := range m { + unique[k] = struct{}{} + } + } + return unique +} + +func (portal *Portal) handleMatrixPowerLevels( + ctx context.Context, + sender *UserLogin, + origSender *OrigSender, + evt *event.Event, + isStateRequest bool, +) EventHandlingResult { + if evt.StateKey == nil || *evt.StateKey != "" { + return EventHandlingResultFailed.WithMSSError(ErrInvalidStateKey) + } + log := zerolog.Ctx(ctx) + content, ok := evt.Content.Parsed.(*event.PowerLevelsEventContent) + if !ok { + log.Error().Type("content_type", evt.Content.Parsed).Msg("Unexpected parsed content type") + return EventHandlingResultFailed.WithMSSError(fmt.Errorf("%w: %T", ErrUnexpectedParsedContentType, evt.Content.Parsed)) + } + if content.CreateEvent == nil { + ars, ok := portal.Bridge.Matrix.(MatrixConnectorWithArbitraryRoomState) + if ok { + var err error + content.CreateEvent, err = ars.GetStateEvent(ctx, portal.MXID, event.StateCreate, "") + if err != nil { + return EventHandlingResultFailed.WithMSSError(fmt.Errorf("failed to get create event for power levels: %w", err)) + } + } + } + api, ok := sender.Client.(PowerLevelHandlingNetworkAPI) + if !ok { + return EventHandlingResultIgnored.WithMSSError(ErrPowerLevelsNotSupported) + } + prevContent := &event.PowerLevelsEventContent{} + if evt.Unsigned.PrevContent != nil { + _ = evt.Unsigned.PrevContent.ParseRaw(evt.Type) + prevContent, _ = evt.Unsigned.PrevContent.Parsed.(*event.PowerLevelsEventContent) + prevContent.CreateEvent = content.CreateEvent + } + + plChange := &MatrixPowerLevelChange{ + MatrixRoomMeta: MatrixRoomMeta[*event.PowerLevelsEventContent]{ + MatrixEventBase: MatrixEventBase[*event.PowerLevelsEventContent]{ + Event: evt, + Content: content, + Portal: portal, + OrigSender: origSender, + + InputTransactionID: portal.parseInputTransactionID(origSender, evt), + }, + IsStateRequest: isStateRequest, + PrevContent: prevContent, + }, + Users: make(map[id.UserID]*UserPowerLevelChange), + Events: make(map[string]*SinglePowerLevelChange), + UsersDefault: makePLChange(prevContent.UsersDefault, content.UsersDefault, true), + EventsDefault: makePLChange(prevContent.EventsDefault, content.EventsDefault, true), + StateDefault: makePLChange(prevContent.StateDefault(), content.StateDefault(), content.StateDefaultPtr != nil), + Invite: makePLChange(prevContent.Invite(), content.Invite(), content.InvitePtr != nil), + Kick: makePLChange(prevContent.Kick(), content.Kick(), content.KickPtr != nil), + Ban: makePLChange(prevContent.Ban(), content.Ban(), content.BanPtr != nil), + Redact: makePLChange(prevContent.Redact(), content.Redact(), content.RedactPtr != nil), + } + for eventType := range getUniqueKeys(content.Events, prevContent.Events) { + newLevel, hasNewLevel := content.Events[eventType] + if !hasNewLevel { + // TODO this doesn't handle state events properly + newLevel = content.EventsDefault + } + if change := makePLChange(prevContent.Events[eventType], newLevel, hasNewLevel); change != nil { + plChange.Events[eventType] = change + } + } + for user := range getUniqueKeys(content.Users, prevContent.Users) { + _, hasNewLevel := content.Users[user] + change := makePLChange(prevContent.GetUserLevel(user), content.GetUserLevel(user), hasNewLevel) + if change == nil { + continue + } + target, err := portal.getTargetUser(ctx, user) + if err != nil { + log.Err(err).Stringer("target_user_id", user).Msg("Failed to get user for power level change") + } else { + plChange.Users[user] = &UserPowerLevelChange{ + Target: target, + SinglePowerLevelChange: *change, + } + } + } + _, err := api.HandleMatrixPowerLevels(ctx, plChange) + if err != nil { + log.Err(err).Msg("Failed to handle Matrix power level change") + return EventHandlingResultFailed.WithMSSError(err) + } + return EventHandlingResultSuccess.WithMSS() +} + +func (portal *Portal) handleMatrixTombstone(ctx context.Context, evt *event.Event) EventHandlingResult { + if evt.StateKey == nil || *evt.StateKey != "" || portal.MXID != evt.RoomID { + return EventHandlingResultIgnored + } + log := *zerolog.Ctx(ctx) + sentByBridge := evt.Sender == portal.Bridge.Bot.GetMXID() || portal.Bridge.IsGhostMXID(evt.Sender) + var senderUser *User + var err error + if !sentByBridge { + senderUser, err = portal.Bridge.GetUserByMXID(ctx, evt.Sender) + if err != nil { + log.Err(err).Msg("Failed to get tombstone sender user") + return EventHandlingResultFailed.WithError(err) + } + } + content, ok := evt.Content.Parsed.(*event.TombstoneEventContent) + if !ok { + log.Error().Type("content_type", evt.Content.Parsed).Msg("Unexpected parsed content type") + return EventHandlingResultFailed.WithMSSError(fmt.Errorf("%w: %T", ErrUnexpectedParsedContentType, evt.Content.Parsed)) + } + log = log.With(). + Stringer("replacement_room", content.ReplacementRoom). + Logger() + if content.ReplacementRoom == "" { + log.Info().Msg("Received tombstone with no replacement room, cleaning up portal") + err := portal.RemoveMXID(ctx) + if err != nil { + log.Err(err).Msg("Failed to remove portal MXID") + return EventHandlingResultFailed.WithMSSError(err) + } + err = portal.Bridge.Bot.DeleteRoom(ctx, portal.MXID, true) + if err != nil { + log.Err(err).Msg("Failed to clean up Matrix room") + return EventHandlingResultFailed.WithError(err) + } + return EventHandlingResultSuccess + } + existingMemberEvt, err := portal.Bridge.Matrix.GetMemberInfo(ctx, content.ReplacementRoom, portal.Bridge.Bot.GetMXID()) + if err != nil { + log.Err(err).Msg("Failed to get member info of bot in replacement room") + return EventHandlingResultFailed.WithError(err) + } + leaveOnError := func() { + if existingMemberEvt != nil && existingMemberEvt.Membership == event.MembershipJoin { + return + } + log.Debug().Msg("Leaving replacement room with bot after tombstone validation failed") + _, err = portal.Bridge.Bot.SendState( + ctx, + content.ReplacementRoom, + event.StateMember, + portal.Bridge.Bot.GetMXID().String(), + &event.Content{ + Parsed: &event.MemberEventContent{ + Membership: event.MembershipLeave, + Reason: fmt.Sprintf("Failed to validate tombstone sent by %s from %s", evt.Sender, evt.RoomID), + }, + }, + time.Time{}, + ) + if err != nil { + log.Err(err).Msg("Failed to leave replacement room after tombstone validation failed") + } + } + var via []string + if senderHS := evt.Sender.Homeserver(); senderHS != "" { + via = []string{senderHS} + } + err = portal.Bridge.Bot.EnsureJoined(ctx, content.ReplacementRoom, EnsureJoinedParams{Via: via}) + if err != nil { + log.Err(err).Msg("Failed to join replacement room from tombstone") + return EventHandlingResultFailed.WithError(err) + } + if !sentByBridge && !senderUser.Permissions.Admin { + powers, err := portal.Bridge.Matrix.GetPowerLevels(ctx, content.ReplacementRoom) + if err != nil { + log.Err(err).Msg("Failed to get power levels in replacement room") + leaveOnError() + return EventHandlingResultFailed.WithError(err) + } + if powers.GetUserLevel(evt.Sender) < powers.Invite() { + log.Warn().Msg("Tombstone sender doesn't have enough power to invite the bot to the replacement room") + leaveOnError() + return EventHandlingResultIgnored + } + } + err = portal.UpdateMatrixRoomID(ctx, content.ReplacementRoom, UpdateMatrixRoomIDParams{ + DeleteOldRoom: true, + FetchInfoVia: senderUser, + }) + if errors.Is(err, ErrTargetRoomIsPortal) { + return EventHandlingResultIgnored + } else if err != nil { + return EventHandlingResultFailed.WithError(err) + } + return EventHandlingResultSuccess +} + +var ErrTargetRoomIsPortal = errors.New("target room is already a portal") +var ErrRoomAlreadyExists = errors.New("this portal already has a room") + +type UpdateMatrixRoomIDParams struct { + SyncDBMetadata func() + FailIfMXIDSet bool + OverwriteOldPortal bool + TombstoneOldRoom bool + DeleteOldRoom bool + + RoomCreateAlreadyLocked bool + + FetchInfoVia *User + ChatInfo *ChatInfo + ChatInfoSource *UserLogin +} + +func (portal *Portal) UpdateMatrixRoomID( + ctx context.Context, + newRoomID id.RoomID, + params UpdateMatrixRoomIDParams, +) error { + if !params.RoomCreateAlreadyLocked { + portal.roomCreateLock.Lock() + defer portal.roomCreateLock.Unlock() + } + oldRoom := portal.MXID + if oldRoom == newRoomID { + return nil + } else if oldRoom != "" && params.FailIfMXIDSet { + return ErrRoomAlreadyExists + } + log := zerolog.Ctx(ctx) + portal.Bridge.cacheLock.Lock() + // Wrap unlock in a sync.OnceFunc because we want to both defer it to catch early returns + // and unlock it before return if nothing goes wrong. + unlockCacheLock := sync.OnceFunc(portal.Bridge.cacheLock.Unlock) + defer unlockCacheLock() + if existingPortal, alreadyExists := portal.Bridge.portalsByMXID[newRoomID]; alreadyExists && !params.OverwriteOldPortal { + log.Warn().Msg("Replacement room is already a portal, ignoring") + return ErrTargetRoomIsPortal + } else if alreadyExists { + log.Debug().Msg("Replacement room is already a portal, overwriting") + existingPortal.MXID = "" + existingPortal.RoomCreated.Clear() + err := existingPortal.Save(ctx) + if err != nil { + return fmt.Errorf("failed to clear mxid of existing portal: %w", err) + } + delete(portal.Bridge.portalsByMXID, portal.MXID) + } + portal.MXID = newRoomID + portal.RoomCreated.Set() + portal.Bridge.portalsByMXID[portal.MXID] = portal + portal.NameSet = false + portal.AvatarSet = false + portal.TopicSet = false + portal.InSpace = false + portal.CapState = database.CapabilityState{} + portal.lastCapUpdate = time.Time{} + if params.SyncDBMetadata != nil { + params.SyncDBMetadata() + } + unlockCacheLock() + portal.updateLogger() + + err := portal.Save(ctx) + if err != nil { + log.Err(err).Msg("Failed to save portal in UpdateMatrixRoomID") + return err + } + log.Info().Msg("Successfully followed tombstone and updated portal MXID") + err = portal.Bridge.DB.UserPortal.MarkAllNotInSpace(ctx, portal.PortalKey) + if err != nil { + log.Err(err).Msg("Failed to update in_space flag for user portals after updating portal MXID") + } + go portal.addToUserSpaces(ctx) + if params.FetchInfoVia != nil { + go portal.updateInfoAfterTombstone(ctx, params.FetchInfoVia) + } else if params.ChatInfo != nil { + go portal.UpdateInfo(ctx, params.ChatInfo, params.ChatInfoSource, nil, time.Time{}) + } else if params.ChatInfoSource != nil { + portal.UpdateCapabilities(ctx, params.ChatInfoSource, true) + portal.UpdateBridgeInfo(ctx) + } + go func() { + // TODO this might become unnecessary if UpdateInfo starts taking care of it + _, err = portal.Bridge.Bot.SendState(ctx, portal.MXID, event.StateElementFunctionalMembers, "", &event.Content{ + Parsed: &event.ElementFunctionalMembersContent{ + ServiceMembers: []id.UserID{portal.Bridge.Bot.GetMXID()}, + }, + }, time.Time{}) + if err != nil { + if err != nil { + log.Warn().Err(err).Msg("Failed to set service members in new room") + } + } + }() + if params.TombstoneOldRoom && oldRoom != "" { + _, err = portal.Bridge.Bot.SendState(ctx, portal.MXID, event.StateTombstone, "", &event.Content{ + Parsed: &event.TombstoneEventContent{ + Body: "Room has been replaced.", + ReplacementRoom: newRoomID, + }, + }, time.Now()) + if err != nil { + log.Err(err).Msg("Failed to send tombstone event to old room") + } + } + if params.DeleteOldRoom && oldRoom != "" { + go func() { + err = portal.Bridge.Bot.DeleteRoom(ctx, oldRoom, true) + if err != nil { + log.Err(err).Msg("Failed to clean up old Matrix room after updating portal MXID") + } + }() + } + return nil +} + +func (portal *Portal) updateInfoAfterTombstone(ctx context.Context, senderUser *User) { + log := zerolog.Ctx(ctx) + logins, err := portal.Bridge.GetUserLoginsInPortal(ctx, portal.PortalKey) + if err != nil { + log.Err(err).Msg("Failed to get user logins in portal to sync info") + return + } + var preferredLogin *UserLogin + for _, login := range logins { + if !login.Client.IsLoggedIn() { + continue + } else if preferredLogin == nil { + preferredLogin = login + } else if senderUser != nil && login.User == senderUser { + preferredLogin = login + } + } + if preferredLogin == nil { + log.Warn().Msg("No logins found to sync info") + return + } + info, err := preferredLogin.Client.GetChatInfo(ctx, portal) + if err != nil { + log.Err(err).Msg("Failed to get chat info") + return + } + log.Info(). + Str("info_source_login", string(preferredLogin.ID)). + Msg("Fetched info to update portal after tombstone") + portal.UpdateInfo(ctx, info, preferredLogin, nil, time.Time{}) +} + +func (portal *Portal) handleMatrixRedaction( + ctx context.Context, sender *UserLogin, origSender *OrigSender, evt *event.Event, +) EventHandlingResult { + log := zerolog.Ctx(ctx) + content, ok := evt.Content.Parsed.(*event.RedactionEventContent) + if !ok { + log.Error().Type("content_type", evt.Content.Parsed).Msg("Unexpected parsed content type") + return EventHandlingResultFailed.WithMSSError(fmt.Errorf("%w: %T", ErrUnexpectedParsedContentType, evt.Content.Parsed)) + } + if evt.Redacts != "" && content.Redacts != evt.Redacts { + content.Redacts = evt.Redacts + } + log.UpdateContext(func(c zerolog.Context) zerolog.Context { + return c.Stringer("redaction_target_mxid", content.Redacts) + }) + deletingAPI, deleteOK := sender.Client.(RedactionHandlingNetworkAPI) + reactingAPI, reactOK := sender.Client.(ReactionHandlingNetworkAPI) + if !deleteOK && !reactOK { + log.Debug().Msg("Ignoring redaction without checking target as network connector doesn't implement RedactionHandlingNetworkAPI nor ReactionHandlingNetworkAPI") + return EventHandlingResultIgnored.WithMSSError(ErrRedactionsNotSupported) + } + var redactionTargetReaction *database.Reaction + redactionTargetMsg, err := portal.Bridge.DB.Message.GetPartByMXID(ctx, content.Redacts) + if err != nil { + log.Err(err).Msg("Failed to get redaction target message from database") + return EventHandlingResultFailed.WithMSSError(fmt.Errorf("%w: failed to get redaction target message: %w", ErrDatabaseError, err)) + } else if redactionTargetMsg != nil { + if !deleteOK { + log.Debug().Msg("Ignoring message redaction event as network connector doesn't implement RedactionHandlingNetworkAPI") + return EventHandlingResultIgnored.WithMSSError(ErrRedactionsNotSupported) + } + err = deletingAPI.HandleMatrixMessageRemove(ctx, &MatrixMessageRemove{ + MatrixEventBase: MatrixEventBase[*event.RedactionEventContent]{ + Event: evt, + Content: content, + Portal: portal, + OrigSender: origSender, + + InputTransactionID: portal.parseInputTransactionID(origSender, evt), + }, + TargetMessage: redactionTargetMsg, + }) + } else if redactionTargetReaction, err = portal.Bridge.DB.Reaction.GetByMXID(ctx, content.Redacts); err != nil { + log.Err(err).Msg("Failed to get redaction target reaction from database") + return EventHandlingResultFailed.WithMSSError(fmt.Errorf("%w: failed to get redaction target message reaction: %w", ErrDatabaseError, err)) + } else if redactionTargetReaction != nil { + if !reactOK { + log.Debug().Msg("Ignoring reaction redaction event as network connector doesn't implement ReactionHandlingNetworkAPI") + return EventHandlingResultIgnored.WithMSSError(ErrReactionsNotSupported) + } + // TODO ignore if sender doesn't match? + err = reactingAPI.HandleMatrixReactionRemove(ctx, &MatrixReactionRemove{ + MatrixEventBase: MatrixEventBase[*event.RedactionEventContent]{ + Event: evt, + Content: content, + Portal: portal, + OrigSender: origSender, + + InputTransactionID: portal.parseInputTransactionID(origSender, evt), + }, + TargetReaction: redactionTargetReaction, + }) + } else { + log.Debug().Msg("Redaction target message not found in database") + return EventHandlingResultIgnored.WithMSSError(fmt.Errorf("redaction %w", ErrTargetMessageNotFound)) + } + if err != nil { + log.Err(err).Msg("Failed to handle Matrix redaction") + return EventHandlingResultFailed.WithMSSError(err) + } + // TODO delete msg/reaction db row + return EventHandlingResultSuccess.WithMSS() +} + +func (portal *Portal) handleRemoteEvent(ctx context.Context, source *UserLogin, evtType RemoteEventType, evt RemoteEvent) (res EventHandlingResult) { + log := zerolog.Ctx(ctx) + if portal.MXID == "" { + mcp, ok := evt.(RemoteEventThatMayCreatePortal) + if !ok || !mcp.ShouldCreatePortal() { + log.Debug().Msg("Dropping event as portal doesn't exist") + return EventHandlingResultIgnored + } + infoProvider, ok := mcp.(RemoteChatResyncWithInfo) + var info *ChatInfo + var err error + if ok { + info, err = infoProvider.GetChatInfo(ctx, portal) + if err != nil { + log.Err(err).Msg("Failed to get chat info for portal creation from chat resync event") + } + } + bundleProvider, ok := evt.(RemoteChatResyncBackfillBundle) + var bundle any + if ok { + bundle = bundleProvider.GetBundledBackfillData() + } + err = portal.createMatrixRoomInLoop(ctx, source, info, bundle) + if err != nil { + log.Err(err).Msg("Failed to create portal to handle event") + return EventHandlingResultFailed.WithError(err) + } + if evtType == RemoteEventChatResync { + log.Debug().Msg("Not handling chat resync event further as portal was created by it") + postHandler, ok := evt.(RemotePostHandler) + if ok { + postHandler.PostHandle(ctx, portal) + } + return EventHandlingResultSuccess + } + } + preHandler, ok := evt.(RemotePreHandler) + if ok { + preHandler.PreHandle(ctx, portal) + } + log.Debug().Msg("Handling remote event") + switch evtType { + case RemoteEventUnknown: + log.Debug().Msg("Ignoring remote event with type unknown") + res = EventHandlingResultIgnored + case RemoteEventMessage, RemoteEventMessageUpsert: + res = portal.handleRemoteMessage(ctx, source, evt.(RemoteMessage)) + case RemoteEventEdit: + res = portal.handleRemoteEdit(ctx, source, evt.(RemoteEdit)) + case RemoteEventReaction: + res = portal.handleRemoteReaction(ctx, source, evt.(RemoteReaction)) + case RemoteEventReactionRemove: + res = portal.handleRemoteReactionRemove(ctx, source, evt.(RemoteReactionRemove)) + case RemoteEventReactionSync: + res = portal.handleRemoteReactionSync(ctx, source, evt.(RemoteReactionSync)) + case RemoteEventMessageRemove: + res = portal.handleRemoteMessageRemove(ctx, source, evt.(RemoteMessageRemove)) + case RemoteEventReadReceipt: + res = portal.handleRemoteReadReceipt(ctx, source, evt.(RemoteReadReceipt)) + case RemoteEventMarkUnread: + res = portal.handleRemoteMarkUnread(ctx, source, evt.(RemoteMarkUnread)) + case RemoteEventDeliveryReceipt: + res = portal.handleRemoteDeliveryReceipt(ctx, source, evt.(RemoteDeliveryReceipt)) + case RemoteEventTyping: + res = portal.handleRemoteTyping(ctx, source, evt.(RemoteTyping)) + case RemoteEventChatInfoChange: + res = portal.handleRemoteChatInfoChange(ctx, source, evt.(RemoteChatInfoChange)) + case RemoteEventChatResync: + res = portal.handleRemoteChatResync(ctx, source, evt.(RemoteChatResync)) + case RemoteEventChatDelete: + res = portal.handleRemoteChatDelete(ctx, source, evt.(RemoteChatDelete)) + case RemoteEventBackfill: + res = portal.handleRemoteBackfill(ctx, source, evt.(RemoteBackfill)) + default: + log.Warn().Msg("Got remote event with unknown type") + } + postHandler, ok := evt.(RemotePostHandler) + if ok { + postHandler.PostHandle(ctx, portal) + } + return +} + +func (portal *Portal) ensureFunctionalMember(ctx context.Context, ghost *Ghost) { + if !ghost.IsBot || portal.RoomType != database.RoomTypeDM || portal.OtherUserID == ghost.ID || portal.MXID == "" { + return + } + ars, ok := portal.Bridge.Matrix.(MatrixConnectorWithArbitraryRoomState) + if !ok { + return + } + portal.functionalMembersLock.Lock() + defer portal.functionalMembersLock.Unlock() + var functionalMembers *event.ElementFunctionalMembersContent + if portal.functionalMembersCache != nil { + functionalMembers = portal.functionalMembersCache + } else { + evt, err := ars.GetStateEvent(ctx, portal.MXID, event.StateElementFunctionalMembers, "") + if err != nil && !errors.Is(err, mautrix.MNotFound) { + zerolog.Ctx(ctx).Err(err).Msg("Failed to get functional members state event") + return + } + functionalMembers = &event.ElementFunctionalMembersContent{} + if evt != nil { + evtContent, ok := evt.Content.Parsed.(*event.ElementFunctionalMembersContent) + if ok && evtContent != nil { + functionalMembers = evtContent + } + } + } + // TODO what about non-double-puppeted user ghosts? + functionalMembers.Add(portal.Bridge.Bot.GetMXID()) + if functionalMembers.Add(ghost.Intent.GetMXID()) { + _, err := portal.Bridge.Bot.SendState(ctx, portal.MXID, event.StateElementFunctionalMembers, "", &event.Content{ + Parsed: functionalMembers, + }, time.Time{}) + if err != nil { + zerolog.Ctx(ctx).Err(err).Msg("Failed to update functional members state event") + return + } + } +} + +func (portal *Portal) getIntentAndUserMXIDFor(ctx context.Context, sender EventSender, source *UserLogin, otherLogins []*UserLogin, evtType RemoteEventType) (intent MatrixAPI, extraUserID id.UserID, err error) { + var ghost *Ghost + if !sender.IsFromMe && sender.ForceDMUser && portal.OtherUserID != "" && sender.Sender != portal.OtherUserID { + zerolog.Ctx(ctx).Warn(). + Str("original_id", string(sender.Sender)). + Str("default_other_user", string(portal.OtherUserID)). + Msg("Overriding event sender with primary other user in DM portal") + // Ensure the ghost row exists anyway to prevent foreign key errors when saving messages + // TODO it'd probably be better to override the sender in the saved message, but that's more effort + _, err = portal.Bridge.GetGhostByID(ctx, sender.Sender) + if err != nil { + zerolog.Ctx(ctx).Warn().Err(err).Msg("Failed to get ghost with original user ID") + return + } + sender.Sender = portal.OtherUserID + } + if sender.Sender != "" { + ghost, err = portal.Bridge.GetGhostByID(ctx, sender.Sender) + if err != nil { + zerolog.Ctx(ctx).Err(err).Msg("Failed to get ghost for message sender") + return + } + ghost.UpdateInfoIfNecessary(ctx, source, evtType) + portal.ensureFunctionalMember(ctx, ghost) + } + if sender.IsFromMe { + intent = source.User.DoublePuppet(ctx) + if intent != nil { + return + } + extraUserID = source.UserMXID + } else if sender.SenderLogin != "" && portal.Receiver == "" { + senderLogin := portal.Bridge.GetCachedUserLoginByID(sender.SenderLogin) + if senderLogin != nil { + intent = senderLogin.User.DoublePuppet(ctx) + if intent != nil { + return + } + extraUserID = senderLogin.UserMXID + } + } + if sender.Sender != "" && portal.Receiver == "" && otherLogins != nil { + for _, login := range otherLogins { + if login.Client.IsThisUser(ctx, sender.Sender) { + intent = login.User.DoublePuppet(ctx) + if intent != nil { + return + } + extraUserID = login.UserMXID + } + } + } + if ghost != nil { + intent = ghost.Intent + } + return +} + +func (portal *Portal) GetIntentFor(ctx context.Context, sender EventSender, source *UserLogin, evtType RemoteEventType) (MatrixAPI, bool) { + intent, _, err := portal.getIntentAndUserMXIDFor(ctx, sender, source, nil, evtType) + if err != nil { + return nil, false + } + if intent == nil { + // TODO this is very hacky - we should either insert an empty ghost row automatically + // (and not fetch it at runtime) or make the message sender column nullable. + portal.Bridge.GetGhostByID(ctx, "") + intent = portal.Bridge.Bot + if intent == nil { + panic(fmt.Errorf("bridge bot is nil")) + } + } + return intent, true +} + +func (portal *Portal) getRelationMeta( + ctx context.Context, + currentMsgID networkid.MessageID, + currentMsg *ConvertedMessage, + isBatchSend bool, +) (replyTo, threadRoot, prevThreadEvent *database.Message) { + log := zerolog.Ctx(ctx) + var err error + if currentMsg.ReplyTo != nil { + replyTo, err = portal.Bridge.DB.Message.GetFirstOrSpecificPartByID(ctx, portal.Receiver, *currentMsg.ReplyTo) + if err != nil { + log.Err(err).Msg("Failed to get reply target message from database") + } else if replyTo == nil { + if isBatchSend || portal.Bridge.Config.OutgoingMessageReID { + // This is somewhat evil + replyTo = &database.Message{ + MXID: portal.Bridge.Matrix.GenerateDeterministicEventID(portal.MXID, portal.PortalKey, currentMsg.ReplyTo.MessageID, ptr.Val(currentMsg.ReplyTo.PartID)), + Room: currentMsg.ReplyToRoom, + SenderID: currentMsg.ReplyToUser, + } + if currentMsg.ReplyToLogin != "" && (portal.Receiver == "" || portal.Receiver == currentMsg.ReplyToLogin) { + userLogin, err := portal.Bridge.GetExistingUserLoginByID(ctx, currentMsg.ReplyToLogin) + if err != nil { + log.Err(err). + Str("reply_to_login", string(currentMsg.ReplyToLogin)). + Msg("Failed to get reply target user login") + } else if userLogin != nil { + replyTo.SenderMXID = userLogin.UserMXID + } + } else { + ghost, err := portal.Bridge.GetGhostByID(ctx, currentMsg.ReplyToUser) + if err != nil { + log.Err(err). + Str("reply_to_user_id", string(currentMsg.ReplyToUser)). + Msg("Failed to get reply target ghost") + } else { + replyTo.SenderMXID = ghost.Intent.GetMXID() + } + } + } else { + log.Warn().Any("reply_to", *currentMsg.ReplyTo).Msg("Reply target message not found in database") + } + } + } + if currentMsg.ThreadRoot != nil && *currentMsg.ThreadRoot != currentMsgID { + threadRoot, err = portal.Bridge.DB.Message.GetFirstThreadMessage(ctx, portal.PortalKey, *currentMsg.ThreadRoot) + if err != nil { + log.Err(err).Msg("Failed to get thread root message from database") + } else if threadRoot == nil { + if isBatchSend || portal.Bridge.Config.OutgoingMessageReID { + threadRoot = &database.Message{ + MXID: portal.Bridge.Matrix.GenerateDeterministicEventID(portal.MXID, portal.PortalKey, *currentMsg.ThreadRoot, ""), + } + } else { + log.Warn().Str("thread_root", string(*currentMsg.ThreadRoot)).Msg("Thread root message not found in database") + } + } else if prevThreadEvent, err = portal.Bridge.DB.Message.GetLastThreadMessage(ctx, portal.PortalKey, *currentMsg.ThreadRoot); err != nil { + log.Err(err).Msg("Failed to get last thread message from database") + } + if prevThreadEvent == nil { + prevThreadEvent = ptr.Clone(threadRoot) + } + } + return +} + +func (portal *Portal) applyRelationMeta(ctx context.Context, content *event.MessageEventContent, replyTo, threadRoot, prevThreadEvent *database.Message) { + if content.Mentions == nil { + content.Mentions = &event.Mentions{} + } + if threadRoot != nil && prevThreadEvent != nil { + content.GetRelatesTo().SetThread(threadRoot.MXID, prevThreadEvent.MXID) + } + if replyTo != nil { + crossRoom := !replyTo.Room.IsEmpty() && replyTo.Room != portal.PortalKey + if !crossRoom || portal.Bridge.Config.CrossRoomReplies { + content.GetRelatesTo().SetReplyTo(replyTo.MXID) + } + if crossRoom && portal.Bridge.Config.CrossRoomReplies { + targetPortal, err := portal.Bridge.GetExistingPortalByKey(ctx, replyTo.Room) + if err != nil { + zerolog.Ctx(ctx).Err(err). + Object("target_portal_key", replyTo.Room). + Msg("Failed to get cross-room reply portal") + } else if targetPortal == nil || targetPortal.MXID == "" { + zerolog.Ctx(ctx).Warn(). + Object("target_portal_key", replyTo.Room). + Msg("Cross-room reply portal not found") + } else { + content.RelatesTo.InReplyTo.UnstableRoomID = targetPortal.MXID + } + } + content.Mentions.Add(replyTo.SenderMXID) + } +} + +func (portal *Portal) sendConvertedMessage( + ctx context.Context, + id networkid.MessageID, + intent MatrixAPI, + senderID networkid.UserID, + converted *ConvertedMessage, + ts time.Time, + streamOrder int64, + logContext func(*zerolog.Event) *zerolog.Event, +) ([]*database.Message, EventHandlingResult) { + if logContext == nil { + logContext = func(e *zerolog.Event) *zerolog.Event { + return e + } + } + log := zerolog.Ctx(ctx) + replyTo, threadRoot, prevThreadEvent := portal.getRelationMeta( + ctx, id, converted, false, + ) + output := make([]*database.Message, 0, len(converted.Parts)) + allSuccess := true + for i, part := range converted.Parts { + portal.applyRelationMeta(ctx, part.Content, replyTo, threadRoot, prevThreadEvent) + part.Content.BeeperDisappearingTimer = converted.Disappear.ToEventContent() + dbMessage := &database.Message{ + ID: id, + PartID: part.ID, + Room: portal.PortalKey, + SenderID: senderID, + SenderMXID: intent.GetMXID(), + Timestamp: ts, + ThreadRoot: ptr.Val(converted.ThreadRoot), + ReplyTo: ptr.Val(converted.ReplyTo), + Metadata: part.DBMetadata, + IsDoublePuppeted: intent.IsDoublePuppet(), + } + if part.DontBridge { + dbMessage.SetFakeMXID() + logContext(log.Debug()). + Stringer("event_id", dbMessage.MXID). + Str("part_id", string(part.ID)). + Msg("Not bridging message part with DontBridge flag to Matrix") + } else { + resp, err := intent.SendMessage(ctx, portal.MXID, part.Type, &event.Content{ + Parsed: part.Content, + Raw: part.Extra, + }, &MatrixSendExtra{ + Timestamp: ts, + MessageMeta: dbMessage, + StreamOrder: streamOrder, + PartIndex: i, + }) + if err != nil { + logContext(log.Err(err)).Str("part_id", string(part.ID)).Msg("Failed to send message part to Matrix") + allSuccess = false + continue + } + logContext(log.Debug()). + Stringer("event_id", resp.EventID). + Str("part_id", string(part.ID)). + Msg("Sent message part to Matrix") + dbMessage.MXID = resp.EventID + } + err := portal.Bridge.DB.Message.Insert(ctx, dbMessage) + if err != nil { + logContext(log.Err(err)).Str("part_id", string(part.ID)).Msg("Failed to save message part to database") + allSuccess = false + } + if converted.Disappear.Type != event.DisappearingTypeNone && !dbMessage.HasFakeMXID() { + if converted.Disappear.Type == event.DisappearingTypeAfterSend && converted.Disappear.DisappearAt.IsZero() { + converted.Disappear.DisappearAt = dbMessage.Timestamp.Add(converted.Disappear.Timer) + } + portal.Bridge.DisappearLoop.Add(ctx, &database.DisappearingMessage{ + RoomID: portal.MXID, + EventID: dbMessage.MXID, + Timestamp: dbMessage.Timestamp, + DisappearingSetting: converted.Disappear, + }) + } + if prevThreadEvent != nil && !dbMessage.HasFakeMXID() { + prevThreadEvent = dbMessage + } + output = append(output, dbMessage) + } + if !allSuccess { + return output, EventHandlingResultFailed + } + return output, EventHandlingResultSuccess +} + +func (portal *Portal) checkPendingMessage(ctx context.Context, evt RemoteMessage) (bool, *database.Message) { + evtWithTxn, ok := evt.(RemoteMessageWithTransactionID) + if !ok { + return false, nil + } + txnID := evtWithTxn.GetTransactionID() + if txnID == "" { + return false, nil + } + portal.outgoingMessagesLock.Lock() + defer portal.outgoingMessagesLock.Unlock() + pending, ok := portal.outgoingMessages[txnID] + if !ok { + return false, nil + } else if pending.ignore { + return true, nil + } + delete(portal.outgoingMessages, txnID) + pending.db.ID = evt.GetID() + if pending.db.SenderID == "" { + pending.db.SenderID = evt.GetSender().Sender + } + evtWithTimestamp, ok := evt.(RemoteEventWithTimestamp) + if ok { + ts := evtWithTimestamp.GetTimestamp() + if !ts.IsZero() { + pending.db.Timestamp = ts + } + } + var statusErr error + saveMessage := true + if pending.handle != nil { + saveMessage, statusErr = pending.handle(evt, pending.db) + } + if saveMessage { + if portal.Bridge.Config.OutgoingMessageReID { + pending.db.MXID = portal.Bridge.Matrix.GenerateDeterministicEventID(portal.MXID, portal.PortalKey, pending.db.ID, pending.db.PartID) + } + // Hack to ensure the ghost row exists + // TODO move to better place (like login) + portal.Bridge.GetGhostByID(ctx, pending.db.SenderID) + err := portal.Bridge.DB.Message.Insert(ctx, pending.db) + if err != nil { + zerolog.Ctx(ctx).Err(err).Msg("Failed to save message to database after receiving remote echo") + } + } + if !errors.Is(statusErr, ErrNoStatus) { + if statusErr != nil { + portal.sendErrorStatus(ctx, pending.evt, statusErr) + } else { + portal.sendSuccessStatus(ctx, pending.evt, getStreamOrder(evt), pending.evt.ID) + } + } + zerolog.Ctx(ctx).Debug().Stringer("event_id", pending.evt.ID).Msg("Received remote echo for message") + return true, pending.db +} + +func (portal *Portal) handleRemoteUpsert(ctx context.Context, source *UserLogin, evt RemoteMessageUpsert, existing []*database.Message) (handleRes EventHandlingResult, continueHandling bool) { + log := zerolog.Ctx(ctx) + intent, ok := portal.GetIntentFor(ctx, evt.GetSender(), source, RemoteEventMessageUpsert) + if !ok { + return + } + res, err := evt.HandleExisting(ctx, portal, intent, existing) + if err != nil { + log.Err(err).Msg("Failed to handle existing message in upsert event after receiving remote echo") + } else { + handleRes = EventHandlingResultSuccess + } + if res.SaveParts { + for _, part := range existing { + err = portal.Bridge.DB.Message.Update(ctx, part) + if err != nil { + log.Err(err).Str("part_id", string(part.PartID)).Msg("Failed to update message part in database") + handleRes = EventHandlingResultFailed.WithError(err) + } + } + } + if len(res.SubEvents) > 0 { + for _, subEvt := range res.SubEvents { + subType := subEvt.GetType() + log := portal.Log.With(). + Str("source_id", string(source.ID)). + Str("action", "handle remote subevent"). + Stringer("bridge_evt_type", subType). + Logger() + subRes := portal.handleRemoteEvent(log.WithContext(ctx), source, subType, subEvt) + if !subRes.Success { + handleRes.Success = false + } + } + } + continueHandling = res.ContinueMessageHandling + return +} + +func (portal *Portal) handleRemoteMessage(ctx context.Context, source *UserLogin, evt RemoteMessage) (res EventHandlingResult) { + log := zerolog.Ctx(ctx) + upsertEvt, isUpsert := evt.(RemoteMessageUpsert) + isUpsert = isUpsert && evt.GetType() == RemoteEventMessageUpsert + if wasPending, dbMessage := portal.checkPendingMessage(ctx, evt); wasPending { + if isUpsert && dbMessage != nil { + res, _ = portal.handleRemoteUpsert(ctx, source, upsertEvt, []*database.Message{dbMessage}) + } else { + res = EventHandlingResultIgnored + } + return + } + existing, err := portal.Bridge.DB.Message.GetAllPartsByID(ctx, portal.Receiver, evt.GetID()) + if err != nil { + log.Err(err).Msg("Failed to check if message is a duplicate") + } else if len(existing) > 0 { + if isUpsert { + var continueHandling bool + res, continueHandling = portal.handleRemoteUpsert(ctx, source, upsertEvt, existing) + if continueHandling { + log.Debug().Msg("Upsert handler said to continue message handling normally") + } else { + return res + } + } else { + log.Debug().Stringer("existing_mxid", existing[0].MXID).Msg("Ignoring duplicate message") + return EventHandlingResultIgnored + } + } + intent, ok := portal.GetIntentFor(ctx, evt.GetSender(), source, RemoteEventMessage) + if !ok { + return EventHandlingResultFailed + } + ts := getEventTS(evt) + converted, err := evt.ConvertMessage(ctx, portal, intent) + if err != nil { + if errors.Is(err, ErrIgnoringRemoteEvent) { + log.Debug().Err(err).Msg("Remote message handling was cancelled by convert function") + return EventHandlingResultIgnored + } else { + log.Err(err).Msg("Failed to convert remote message") + portal.sendRemoteErrorNotice(ctx, intent, err, ts, "message") + return EventHandlingResultFailed.WithError(err) + } + } + _, res = portal.sendConvertedMessage(ctx, evt.GetID(), intent, evt.GetSender().Sender, converted, ts, getStreamOrder(evt), nil) + if portal.currentlyTypingGhosts.Pop(intent.GetMXID()) { + err = intent.MarkTyping(ctx, portal.MXID, TypingTypeText, 0) + if err != nil { + log.Warn().Err(err).Msg("Failed to send stop typing event after bridging message") + } + } + return +} + +func (portal *Portal) sendRemoteErrorNotice(ctx context.Context, intent MatrixAPI, err error, ts time.Time, evtTypeName string) { + resp, sendErr := intent.SendMessage(ctx, portal.MXID, event.EventMessage, &event.Content{ + Parsed: &event.MessageEventContent{ + MsgType: event.MsgNotice, + Body: fmt.Sprintf("An error occurred while processing an incoming %s", evtTypeName), + Mentions: &event.Mentions{}, + }, + Raw: map[string]any{ + "fi.mau.bridge.internal_error": err.Error(), + }, + }, &MatrixSendExtra{ + Timestamp: ts, + }) + if sendErr != nil { + zerolog.Ctx(ctx).Err(sendErr).Msg("Failed to send error notice after remote event handling failed") + } else { + zerolog.Ctx(ctx).Debug().Stringer("event_id", resp.EventID).Msg("Sent error notice after remote event handling failed") + } +} + +func (portal *Portal) handleRemoteEdit(ctx context.Context, source *UserLogin, evt RemoteEdit) EventHandlingResult { + log := zerolog.Ctx(ctx) + var existing []*database.Message + if bundledEvt, ok := evt.(RemoteEventWithBundledParts); ok { + existing = bundledEvt.GetTargetDBMessage() + } + if existing == nil { + targetID := evt.GetTargetMessage() + var err error + existing, err = portal.Bridge.DB.Message.GetAllPartsByID(ctx, portal.Receiver, targetID) + if err != nil { + log.Err(err).Msg("Failed to get edit target message") + return EventHandlingResultFailed.WithError(err) + } + } + if existing == nil { + log.Warn().Msg("Edit target message not found") + return EventHandlingResultIgnored + } + intent, ok := portal.GetIntentFor(ctx, evt.GetSender(), source, RemoteEventEdit) + if !ok { + return EventHandlingResultFailed + } else if intent.GetMXID() != existing[0].SenderMXID { + log.Warn(). + Stringer("edit_sender_mxid", intent.GetMXID()). + Stringer("original_sender_mxid", existing[0].SenderMXID). + Msg("Not bridging edit: sender doesn't match original message sender") + return EventHandlingResultIgnored + } + ts := getEventTS(evt) + converted, err := evt.ConvertEdit(ctx, portal, intent, existing) + if errors.Is(err, ErrIgnoringRemoteEvent) { + log.Debug().Err(err).Msg("Remote edit handling was cancelled by convert function") + return EventHandlingResultIgnored + } else if err != nil { + log.Err(err).Msg("Failed to convert remote edit") + portal.sendRemoteErrorNotice(ctx, intent, err, ts, "edit") + return EventHandlingResultFailed.WithError(err) + } + res := portal.sendConvertedEdit(ctx, existing[0].ID, evt.GetSender().Sender, converted, intent, ts, getStreamOrder(evt)) + if portal.currentlyTypingGhosts.Pop(intent.GetMXID()) { + err = intent.MarkTyping(ctx, portal.MXID, TypingTypeText, 0) + if err != nil { + log.Warn().Err(err).Msg("Failed to send stop typing event after bridging edit") + } + } + return res +} + +func (portal *Portal) sendConvertedEdit( + ctx context.Context, + targetID networkid.MessageID, + senderID networkid.UserID, + converted *ConvertedEdit, + intent MatrixAPI, + ts time.Time, + streamOrder int64, +) EventHandlingResult { + log := zerolog.Ctx(ctx) + allSuccess := true + for i, part := range converted.ModifiedParts { + if part.Content.Mentions == nil { + part.Content.Mentions = &event.Mentions{} + } + overrideMXID := true + if part.Part.Room != portal.PortalKey { + part.Part.Room = portal.PortalKey + } else if !part.Part.HasFakeMXID() { + part.Content.SetEdit(part.Part.MXID) + overrideMXID = false + if part.NewMentions != nil { + part.Content.Mentions = part.NewMentions + } else { + part.Content.Mentions = &event.Mentions{} + } + } + if part.TopLevelExtra == nil { + part.TopLevelExtra = make(map[string]any) + } + if part.Extra != nil { + part.TopLevelExtra["m.new_content"] = part.Extra + } + wrappedContent := &event.Content{ + Parsed: part.Content, + Raw: part.TopLevelExtra, + } + if !part.DontBridge { + resp, err := intent.SendMessage(ctx, portal.MXID, part.Type, wrappedContent, &MatrixSendExtra{ + Timestamp: ts, + MessageMeta: part.Part, + StreamOrder: streamOrder, + PartIndex: i, + }) + if err != nil { + log.Err(err).Stringer("part_mxid", part.Part.MXID).Msg("Failed to edit message part") + allSuccess = false + continue + } else { + log.Debug(). + Stringer("event_id", resp.EventID). + Str("part_id", string(part.Part.ID)). + Msg("Sent message part edit to Matrix") + if overrideMXID { + part.Part.MXID = resp.EventID + } + } + } + err := portal.Bridge.DB.Message.Update(ctx, part.Part) + if err != nil { + log.Err(err).Int64("part_rowid", part.Part.RowID).Msg("Failed to update message part in database") + allSuccess = false + } + } + for _, part := range converted.DeletedParts { + redactContent := &event.Content{ + Parsed: &event.RedactionEventContent{ + Redacts: part.MXID, + }, + } + resp, err := intent.SendMessage(ctx, portal.MXID, event.EventRedaction, redactContent, &MatrixSendExtra{ + Timestamp: ts, + }) + if err != nil { + log.Err(err).Stringer("part_mxid", part.MXID).Msg("Failed to redact message part deleted in edit") + allSuccess = false + } else { + log.Debug(). + Stringer("redaction_event_id", resp.EventID). + Stringer("redacted_event_id", part.MXID). + Str("part_id", string(part.ID)). + Msg("Sent redaction of message part to Matrix") + } + err = portal.Bridge.DB.Message.Delete(ctx, part.RowID) + if err != nil { + log.Err(err).Int64("part_rowid", part.RowID).Msg("Failed to delete message part from database") + allSuccess = false + } + } + if converted.AddedParts != nil { + _, res := portal.sendConvertedMessage(ctx, targetID, intent, senderID, converted.AddedParts, ts, streamOrder, nil) + if !res.Success { + allSuccess = false + } + } + if !allSuccess { + return EventHandlingResultFailed + } + return EventHandlingResultSuccess +} + +func (portal *Portal) getTargetMessagePart(ctx context.Context, evt RemoteEventWithTargetMessage) (*database.Message, error) { + if partTargeter, ok := evt.(RemoteEventWithTargetPart); ok { + return portal.Bridge.DB.Message.GetPartByID(ctx, portal.Receiver, evt.GetTargetMessage(), partTargeter.GetTargetMessagePart()) + } else { + return portal.Bridge.DB.Message.GetFirstPartByID(ctx, portal.Receiver, evt.GetTargetMessage()) + } +} + +func (portal *Portal) getTargetReaction(ctx context.Context, evt RemoteReactionRemove) (*database.Reaction, error) { + if partTargeter, ok := evt.(RemoteEventWithTargetPart); ok { + return portal.Bridge.DB.Reaction.GetByID(ctx, portal.Receiver, evt.GetTargetMessage(), partTargeter.GetTargetMessagePart(), evt.GetSender().Sender, evt.GetRemovedEmojiID()) + } else { + return portal.Bridge.DB.Reaction.GetByIDWithoutMessagePart(ctx, portal.Receiver, evt.GetTargetMessage(), evt.GetSender().Sender, evt.GetRemovedEmojiID()) + } +} + +func getEventTS(evt RemoteEvent) time.Time { + if tsProvider, ok := evt.(RemoteEventWithTimestamp); ok { + return tsProvider.GetTimestamp() + } + return time.Now() +} + +func getStreamOrder(evt RemoteEvent) int64 { + if streamProvider, ok := evt.(RemoteEventWithStreamOrder); ok { + return streamProvider.GetStreamOrder() + } + return 0 +} + +func (portal *Portal) handleRemoteReactionSync(ctx context.Context, source *UserLogin, evt RemoteReactionSync) EventHandlingResult { + log := zerolog.Ctx(ctx) + eventTS := getEventTS(evt) + targetMessage, err := portal.getTargetMessagePart(ctx, evt) + if err != nil { + log.Err(err).Msg("Failed to get target message for reaction") + return EventHandlingResultFailed.WithError(err) + } else if targetMessage == nil { + // TODO use deterministic event ID as target if applicable? + log.Warn().Msg("Target message for reaction not found") + return EventHandlingResultIgnored + } + var existingReactions []*database.Reaction + if partTargeter, ok := evt.(RemoteEventWithTargetPart); ok { + existingReactions, err = portal.Bridge.DB.Reaction.GetAllToMessagePart(ctx, portal.Receiver, evt.GetTargetMessage(), partTargeter.GetTargetMessagePart()) + } else { + existingReactions, err = portal.Bridge.DB.Reaction.GetAllToMessage(ctx, portal.Receiver, evt.GetTargetMessage()) + } + if err != nil { + log.Err(err).Msg("Failed to get existing reactions for reaction sync") + return EventHandlingResultFailed.WithError(err) + } + existing := make(map[networkid.UserID]map[networkid.EmojiID]*database.Reaction) + for _, existingReaction := range existingReactions { + if existing[existingReaction.SenderID] == nil { + existing[existingReaction.SenderID] = make(map[networkid.EmojiID]*database.Reaction) + } + existing[existingReaction.SenderID][existingReaction.EmojiID] = existingReaction + } + + doAddReaction := func(new *BackfillReaction, intent MatrixAPI) { + if intent == nil { + var ok bool + intent, ok = portal.GetIntentFor(ctx, new.Sender, source, RemoteEventReactionSync) + if !ok { + return + } + } + portal.sendConvertedReaction( + ctx, new.Sender.Sender, intent, targetMessage, new.EmojiID, new.Emoji, + new.Timestamp, new.DBMetadata, new.ExtraContent, + func(z *zerolog.Event) *zerolog.Event { + return z. + Any("reaction_sender_id", new.Sender). + Time("reaction_ts", new.Timestamp) + }, + ) + } + doRemoveReaction := func(old *database.Reaction, intent MatrixAPI, deleteRow bool) { + if intent == nil && old.SenderMXID != "" { + intent, err = portal.getIntentForMXID(ctx, old.SenderMXID) + if err != nil { + log.Err(err). + Stringer("reaction_sender_mxid", old.SenderMXID). + Msg("Failed to get intent for removing reaction") + } + } + if intent == nil { + log.Warn(). + Str("reaction_sender_id", string(old.SenderID)). + Stringer("reaction_sender_mxid", old.SenderMXID). + Msg("Didn't find intent for removing reaction, using bridge bot") + intent = portal.Bridge.Bot + } + _, err = intent.SendMessage(ctx, portal.MXID, event.EventRedaction, &event.Content{ + Parsed: &event.RedactionEventContent{ + Redacts: old.MXID, + }, + }, &MatrixSendExtra{Timestamp: eventTS}) + if err != nil { + log.Err(err).Msg("Failed to redact old reaction") + } + if deleteRow { + err = portal.Bridge.DB.Reaction.Delete(ctx, old) + if err != nil { + log.Err(err).Msg("Failed to delete old reaction row") + } + } + } + doOverwriteReaction := func(new *BackfillReaction, old *database.Reaction) { + intent, ok := portal.GetIntentFor(ctx, new.Sender, source, RemoteEventReactionSync) + if !ok { + return + } + doRemoveReaction(old, intent, false) + doAddReaction(new, intent) + } + + newData := evt.GetReactions() + for userID, reactions := range newData.Users { + existingUserReactions := existing[userID] + delete(existing, userID) + for _, reaction := range reactions.Reactions { + if reaction.Timestamp.IsZero() { + reaction.Timestamp = eventTS + } + existingReaction, ok := existingUserReactions[reaction.EmojiID] + if ok { + delete(existingUserReactions, reaction.EmojiID) + if reaction.EmojiID != "" || reaction.Emoji == existingReaction.Emoji { + continue + } + doOverwriteReaction(reaction, existingReaction) + } else { + doAddReaction(reaction, nil) + } + } + totalReactionCount := len(existingUserReactions) + len(reactions.Reactions) + if reactions.HasAllReactions { + for _, existingReaction := range existingUserReactions { + doRemoveReaction(existingReaction, nil, true) + } + } else if reactions.MaxCount > 0 && totalReactionCount > reactions.MaxCount { + remainingReactionList := maps.Values(existingUserReactions) + slices.SortFunc(remainingReactionList, func(a, b *database.Reaction) int { + diff := a.Timestamp.Compare(b.Timestamp) + if diff == 0 { + return cmp.Compare(a.EmojiID, b.EmojiID) + } + return diff + }) + numberToRemove := totalReactionCount - reactions.MaxCount + for i := 0; i < numberToRemove && i < len(remainingReactionList); i++ { + doRemoveReaction(remainingReactionList[i], nil, true) + } + } + } + if newData.HasAllUsers { + for _, userReactions := range existing { + for _, existingReaction := range userReactions { + doRemoveReaction(existingReaction, nil, true) + } + } + } + return EventHandlingResultSuccess +} + +func (portal *Portal) handleRemoteReaction(ctx context.Context, source *UserLogin, evt RemoteReaction) EventHandlingResult { + log := zerolog.Ctx(ctx) + targetMessage, err := portal.getTargetMessagePart(ctx, evt) + if err != nil { + log.Err(err).Msg("Failed to get target message for reaction") + return EventHandlingResultFailed.WithError(err) + } else if targetMessage == nil { + // TODO use deterministic event ID as target if applicable? + log.Warn().Msg("Target message for reaction not found") + return EventHandlingResultIgnored + } + emoji, emojiID := evt.GetReactionEmoji() + existingReaction, err := portal.Bridge.DB.Reaction.GetByID(ctx, portal.Receiver, targetMessage.ID, targetMessage.PartID, evt.GetSender().Sender, emojiID) + if err != nil { + log.Err(err).Msg("Failed to check if reaction is a duplicate") + return EventHandlingResultFailed.WithError(err) + } else if existingReaction != nil && (emojiID != "" || existingReaction.Emoji == emoji) { + log.Debug().Msg("Ignoring duplicate reaction") + return EventHandlingResultIgnored + } + ts := getEventTS(evt) + intent, ok := portal.GetIntentFor(ctx, evt.GetSender(), source, RemoteEventReaction) + if !ok { + return EventHandlingResultFailed + } + var extra map[string]any + if extraContentProvider, ok := evt.(RemoteReactionWithExtraContent); ok { + extra = extraContentProvider.GetReactionExtraContent() + } + var dbMetadata any + if metaProvider, ok := evt.(RemoteReactionWithMeta); ok { + dbMetadata = metaProvider.GetReactionDBMetadata() + } + if existingReaction != nil { + _, err = intent.SendMessage(ctx, portal.MXID, event.EventRedaction, &event.Content{ + Parsed: &event.RedactionEventContent{ + Redacts: existingReaction.MXID, + }, + }, &MatrixSendExtra{Timestamp: ts}) + if err != nil { + log.Err(err).Msg("Failed to redact old reaction") + } + } + return portal.sendConvertedReaction(ctx, evt.GetSender().Sender, intent, targetMessage, emojiID, emoji, ts, dbMetadata, extra, nil) +} + +func (portal *Portal) sendConvertedReaction( + ctx context.Context, senderID networkid.UserID, intent MatrixAPI, targetMessage *database.Message, + emojiID networkid.EmojiID, emoji string, ts time.Time, dbMetadata any, extraContent map[string]any, + logContext func(*zerolog.Event) *zerolog.Event, +) EventHandlingResult { + if logContext == nil { + logContext = func(e *zerolog.Event) *zerolog.Event { + return e + } + } + log := zerolog.Ctx(ctx) + dbReaction := &database.Reaction{ + Room: portal.PortalKey, + MessageID: targetMessage.ID, + MessagePartID: targetMessage.PartID, + SenderID: senderID, + SenderMXID: intent.GetMXID(), + EmojiID: emojiID, + Timestamp: ts, + Metadata: dbMetadata, + } + if emojiID == "" { + dbReaction.Emoji = emoji + } + resp, err := intent.SendMessage(ctx, portal.MXID, event.EventReaction, &event.Content{ + Parsed: &event.ReactionEventContent{ + RelatesTo: event.RelatesTo{ + Type: event.RelAnnotation, + EventID: targetMessage.MXID, + Key: variationselector.Add(emoji), + }, + }, + Raw: extraContent, + }, &MatrixSendExtra{ + Timestamp: ts, + ReactionMeta: dbReaction, + }) + if err != nil { + logContext(log.Err(err)).Msg("Failed to send reaction to Matrix") + return EventHandlingResultFailed.WithError(err) + } + logContext(log.Debug()). + Stringer("event_id", resp.EventID). + Msg("Sent reaction to Matrix") + dbReaction.MXID = resp.EventID + err = portal.Bridge.DB.Reaction.Upsert(ctx, dbReaction) + if err != nil { + logContext(log.Err(err)).Msg("Failed to save reaction to database") + return EventHandlingResultFailed.WithError(err) + } + return EventHandlingResultSuccess +} + +func (portal *Portal) getIntentForMXID(ctx context.Context, userID id.UserID) (MatrixAPI, error) { + if userID == "" { + return nil, nil + } else if ghost, err := portal.Bridge.GetGhostByMXID(ctx, userID); err != nil { + return nil, fmt.Errorf("failed to get ghost: %w", err) + } else if ghost != nil { + return ghost.Intent, nil + } else if user, err := portal.Bridge.GetExistingUserByMXID(ctx, userID); err != nil { + return nil, fmt.Errorf("failed to get user: %w", err) + } else if user != nil { + return user.DoublePuppet(ctx), nil + } else { + return nil, nil + } +} + +func (portal *Portal) handleRemoteReactionRemove(ctx context.Context, source *UserLogin, evt RemoteReactionRemove) EventHandlingResult { + log := zerolog.Ctx(ctx) + targetReaction, err := portal.getTargetReaction(ctx, evt) + if err != nil { + log.Err(err).Msg("Failed to get target reaction for removal") + return EventHandlingResultFailed.WithError(err) + } else if targetReaction == nil { + log.Warn().Msg("Target reaction not found") + return EventHandlingResultIgnored + } + intent, err := portal.getIntentForMXID(ctx, targetReaction.SenderMXID) + if err != nil { + log.Err(err).Stringer("sender_mxid", targetReaction.SenderMXID).Msg("Failed to get intent for removing reaction") + } + if intent == nil { + var ok bool + intent, ok = portal.GetIntentFor(ctx, evt.GetSender(), source, RemoteEventReactionRemove) + if !ok { + return EventHandlingResultFailed + } + } + ts := getEventTS(evt) + _, err = intent.SendMessage(ctx, portal.MXID, event.EventRedaction, &event.Content{ + Parsed: &event.RedactionEventContent{ + Redacts: targetReaction.MXID, + }, + }, &MatrixSendExtra{Timestamp: ts, ReactionMeta: targetReaction}) + if err != nil { + log.Err(err).Stringer("reaction_mxid", targetReaction.MXID).Msg("Failed to redact reaction") + return EventHandlingResultFailed.WithError(err) + } + err = portal.Bridge.DB.Reaction.Delete(ctx, targetReaction) + if err != nil { + log.Err(err).Msg("Failed to delete target reaction from database") + } + return EventHandlingResultSuccess +} + +func (portal *Portal) handleRemoteMessageRemove(ctx context.Context, source *UserLogin, evt RemoteMessageRemove) EventHandlingResult { + log := zerolog.Ctx(ctx) + targetParts, err := portal.Bridge.DB.Message.GetAllPartsByID(ctx, portal.Receiver, evt.GetTargetMessage()) + if err != nil { + log.Err(err).Msg("Failed to get target message for removal") + return EventHandlingResultFailed.WithError(err) + } else if len(targetParts) == 0 { + log.Debug().Msg("Target message not found") + return EventHandlingResultIgnored + } + onlyForMeProvider, ok := evt.(RemoteDeleteOnlyForMe) + onlyForMe := ok && onlyForMeProvider.DeleteOnlyForMe() + if onlyForMe && portal.Receiver == "" { + _, others, err := portal.findOtherLogins(ctx, source) + if err != nil { + log.Err(err).Msg("Failed to check if portal has other logins") + return EventHandlingResultFailed.WithError(err) + } else if len(others) > 0 { + log.Debug().Msg("Ignoring delete for me event in portal with multiple logins") + return EventHandlingResultIgnored + } + } + + intent, ok := portal.GetIntentFor(ctx, evt.GetSender(), source, RemoteEventMessageRemove) + if !ok { + return EventHandlingResultFailed + } + if intent == portal.Bridge.Bot && len(targetParts) > 0 { + senderIntent, err := portal.getIntentForMXID(ctx, targetParts[0].SenderMXID) + if err != nil { + log.Err(err).Stringer("sender_mxid", targetParts[0].SenderMXID).Msg("Failed to get intent for removing message") + } else if senderIntent != nil { + intent = senderIntent + } + } + res := portal.redactMessageParts(ctx, targetParts, intent, getEventTS(evt)) + err = portal.Bridge.DB.Message.DeleteAllParts(ctx, portal.Receiver, evt.GetTargetMessage()) + if err != nil { + log.Err(err).Msg("Failed to delete target message from database") + } + return res +} + +func (portal *Portal) redactMessageParts(ctx context.Context, parts []*database.Message, intent MatrixAPI, ts time.Time) EventHandlingResult { + log := zerolog.Ctx(ctx) + var anyFailed bool + for _, part := range parts { + if part.HasFakeMXID() { + continue + } + resp, err := intent.SendMessage(ctx, portal.MXID, event.EventRedaction, &event.Content{ + Parsed: &event.RedactionEventContent{ + Redacts: part.MXID, + }, + }, &MatrixSendExtra{Timestamp: ts, MessageMeta: part}) + if err != nil { + log.Err(err).Stringer("part_mxid", part.MXID).Msg("Failed to redact message part") + anyFailed = true + } else { + log.Debug(). + Stringer("redaction_event_id", resp.EventID). + Stringer("redacted_event_id", part.MXID). + Str("part_id", string(part.ID)). + Msg("Sent redaction of message part to Matrix") + } + } + if anyFailed { + return EventHandlingResultFailed + } + return EventHandlingResultSuccess +} + +func (portal *Portal) handleRemoteReadReceipt(ctx context.Context, source *UserLogin, evt RemoteReadReceipt) EventHandlingResult { + log := zerolog.Ctx(ctx) + var err error + var lastTarget *database.Message + readUpTo := evt.GetReadUpTo() + if lastTargetID := evt.GetLastReceiptTarget(); lastTargetID != "" { + lastTarget, err = portal.Bridge.DB.Message.GetLastPartByID(ctx, portal.Receiver, lastTargetID) + if err != nil { + log.Err(err).Str("last_target_id", string(lastTargetID)). + Msg("Failed to get last target message for read receipt") + return EventHandlingResultFailed.WithError(err) + } else if lastTarget == nil { + log.Debug().Str("last_target_id", string(lastTargetID)). + Msg("Last target message not found") + } else if lastTarget.HasFakeMXID() { + log.Debug().Str("last_target_id", string(lastTargetID)). + Msg("Last target message is fake") + if readUpTo.IsZero() { + readUpTo = lastTarget.Timestamp + } + lastTarget = nil + } + } + if lastTarget == nil { + for _, targetID := range evt.GetReceiptTargets() { + target, err := portal.Bridge.DB.Message.GetLastPartByID(ctx, portal.Receiver, targetID) + if err != nil { + log.Err(err).Str("target_id", string(targetID)). + Msg("Failed to get target message for read receipt") + return EventHandlingResultFailed.WithError(err) + } else if target != nil && !target.HasFakeMXID() && (lastTarget == nil || target.Timestamp.After(lastTarget.Timestamp)) { + lastTarget = target + } + } + } + if lastTarget == nil && !readUpTo.IsZero() { + lastTarget, err = portal.Bridge.DB.Message.GetLastNonFakePartAtOrBeforeTime(ctx, portal.PortalKey, readUpTo) + if err != nil { + log.Err(err).Time("read_up_to", readUpTo).Msg("Failed to get target message for read receipt") + } + } + sender := evt.GetSender() + intent, ok := portal.GetIntentFor(ctx, sender, source, RemoteEventReadReceipt) + if !ok { + return EventHandlingResultFailed + } + var addTargetLog func(evt *zerolog.Event) *zerolog.Event + if lastTarget == nil { + sevt, evtOK := evt.(RemoteReadReceiptWithStreamOrder) + soIntent, soIntentOK := intent.(StreamOrderReadingMatrixAPI) + if !evtOK || !soIntentOK || sevt.GetReadUpToStreamOrder() == 0 { + log.Warn().Msg("No target message found for read receipt") + return EventHandlingResultIgnored + } + targetStreamOrder := sevt.GetReadUpToStreamOrder() + addTargetLog = func(evt *zerolog.Event) *zerolog.Event { + return evt.Int64("target_stream_order", targetStreamOrder) + } + err = soIntent.MarkStreamOrderRead(ctx, portal.MXID, targetStreamOrder, getEventTS(evt)) + if readUpTo.IsZero() { + readUpTo = getEventTS(evt) + } + } else { + addTargetLog = func(evt *zerolog.Event) *zerolog.Event { + return evt.Stringer("target_mxid", lastTarget.MXID) + } + err = intent.MarkRead(ctx, portal.MXID, lastTarget.MXID, getEventTS(evt)) + readUpTo = lastTarget.Timestamp + } + if err != nil { + addTargetLog(log.Err(err)).Msg("Failed to bridge read receipt") + return EventHandlingResultFailed.WithError(err) + } else { + addTargetLog(log.Debug()).Msg("Bridged read receipt") + } + if sender.IsFromMe { + portal.Bridge.DisappearLoop.StartAllBefore(ctx, portal.MXID, readUpTo) + } + return EventHandlingResultSuccess +} + +func (portal *Portal) handleRemoteMarkUnread(ctx context.Context, source *UserLogin, evt RemoteMarkUnread) EventHandlingResult { + if !evt.GetSender().IsFromMe { + zerolog.Ctx(ctx).Warn().Msg("Ignoring mark unread event from non-self user") + return EventHandlingResultIgnored + } + dp := source.User.DoublePuppet(ctx) + if dp == nil { + return EventHandlingResultIgnored + } + err := dp.MarkUnread(ctx, portal.MXID, evt.GetUnread()) + if err != nil { + zerolog.Ctx(ctx).Err(err).Msg("Failed to bridge mark unread event") + return EventHandlingResultFailed.WithError(err) + } + return EventHandlingResultSuccess +} + +func (portal *Portal) handleRemoteDeliveryReceipt(ctx context.Context, source *UserLogin, evt RemoteDeliveryReceipt) EventHandlingResult { + if portal.RoomType != database.RoomTypeDM || (evt.GetSender().Sender != portal.OtherUserID && portal.OtherUserID != "") { + return EventHandlingResultIgnored + } + intent, ok := portal.GetIntentFor(ctx, evt.GetSender(), source, RemoteEventDeliveryReceipt) + if !ok { + return EventHandlingResultFailed + } + log := zerolog.Ctx(ctx) + for _, target := range evt.GetReceiptTargets() { + targetParts, err := portal.Bridge.DB.Message.GetAllPartsByID(ctx, portal.Receiver, target) + if err != nil { + log.Err(err).Str("target_id", string(target)).Msg("Failed to get target message for delivery receipt") + return EventHandlingResultFailed.WithError(err) + } else if len(targetParts) == 0 { + continue + } else if _, sentByGhost := portal.Bridge.Matrix.ParseGhostMXID(targetParts[0].SenderMXID); sentByGhost { + continue + } + for _, part := range targetParts { + portal.Bridge.Matrix.SendMessageStatus(ctx, &MessageStatus{ + Status: event.MessageStatusSuccess, + DeliveredTo: []id.UserID{intent.GetMXID()}, + }, &MessageStatusEventInfo{ + RoomID: portal.MXID, + SourceEventID: part.MXID, + Sender: part.SenderMXID, + + IsSourceEventDoublePuppeted: part.IsDoublePuppeted, + }) + } + } + return EventHandlingResultSuccess +} + +func (portal *Portal) handleRemoteTyping(ctx context.Context, source *UserLogin, evt RemoteTyping) EventHandlingResult { + var typingType TypingType + if typedEvt, ok := evt.(RemoteTypingWithType); ok { + typingType = typedEvt.GetTypingType() + } + intent, ok := portal.GetIntentFor(ctx, evt.GetSender(), source, RemoteEventTyping) + if !ok { + return EventHandlingResultFailed + } + timeout := evt.GetTimeout() + err := intent.MarkTyping(ctx, portal.MXID, typingType, timeout) + if err != nil { + zerolog.Ctx(ctx).Err(err).Msg("Failed to bridge typing event") + return EventHandlingResultFailed.WithError(err) + } + if timeout == 0 { + portal.currentlyTypingGhosts.Remove(intent.GetMXID()) + } else { + portal.currentlyTypingGhosts.Add(intent.GetMXID()) + } + return EventHandlingResultSuccess +} + +func (portal *Portal) handleRemoteChatInfoChange(ctx context.Context, source *UserLogin, evt RemoteChatInfoChange) EventHandlingResult { + info, err := evt.GetChatInfoChange(ctx) + if err != nil { + zerolog.Ctx(ctx).Err(err).Msg("Failed to get chat info change") + return EventHandlingResultFailed.WithError(err) + } + portal.ProcessChatInfoChange(ctx, evt.GetSender(), source, info, getEventTS(evt)) + return EventHandlingResultSuccess +} + +func (portal *Portal) handleRemoteChatResync(ctx context.Context, source *UserLogin, evt RemoteChatResync) EventHandlingResult { + log := zerolog.Ctx(ctx) + infoProvider, ok := evt.(RemoteChatResyncWithInfo) + if ok { + info, err := infoProvider.GetChatInfo(ctx, portal) + if err != nil { + log.Err(err).Msg("Failed to get chat info from resync event") + } else if info != nil { + portal.UpdateInfo(ctx, info, source, nil, time.Time{}) + } else { + log.Debug().Msg("No chat info provided in resync event") + } + } + backfillChecker, ok := evt.(RemoteChatResyncBackfill) + if portal.Bridge.Config.Backfill.Enabled && ok && portal.RoomType != database.RoomTypeSpace { + latestMessage, err := portal.Bridge.DB.Message.GetLastPartAtOrBeforeTime(ctx, portal.PortalKey, time.Now().Add(10*time.Second)) + if err != nil { + log.Err(err).Msg("Failed to get last message in portal to check if backfill is necessary") + } else if needsBackfill, err := backfillChecker.CheckNeedsBackfill(ctx, latestMessage); err != nil { + log.Err(err).Msg("Failed to check if backfill is needed") + } else if needsBackfill { + bundleProvider, ok := evt.(RemoteChatResyncBackfillBundle) + var bundle any + if ok { + bundle = bundleProvider.GetBundledBackfillData() + } + portal.doForwardBackfill(ctx, source, latestMessage, bundle) + } + } + return EventHandlingResultSuccess +} + +func (portal *Portal) findOtherLogins(ctx context.Context, source *UserLogin) (ownUP *database.UserPortal, others []*database.UserPortal, err error) { + others, err = portal.Bridge.DB.UserPortal.GetAllInPortal(ctx, portal.PortalKey) + if err != nil { + return + } + others = slices.DeleteFunc(others, func(up *database.UserPortal) bool { + if up.LoginID == source.ID { + ownUP = up + return true + } + return false + }) + return +} + +type childDeleteProxy struct { + RemoteChatDeleteWithChildren + child networkid.PortalKey + done func() +} + +func (cdp *childDeleteProxy) AddLogContext(c zerolog.Context) zerolog.Context { + return cdp.RemoteChatDeleteWithChildren.AddLogContext(c).Str("subaction", "delete children") +} +func (cdp *childDeleteProxy) GetPortalKey() networkid.PortalKey { return cdp.child } +func (cdp *childDeleteProxy) ShouldCreatePortal() bool { return false } +func (cdp *childDeleteProxy) PreHandle(ctx context.Context, portal *Portal) {} +func (cdp *childDeleteProxy) PostHandle(ctx context.Context, portal *Portal) { cdp.done() } + +func (portal *Portal) handleRemoteChatDelete(ctx context.Context, source *UserLogin, evt RemoteChatDelete) EventHandlingResult { + log := zerolog.Ctx(ctx) + if portal.Receiver == "" && evt.DeleteOnlyForMe() { + ownUP, logins, err := portal.findOtherLogins(ctx, source) + if err != nil { + log.Err(err).Msg("Failed to check if portal has other logins") + return EventHandlingResultFailed.WithError(err) + } + if len(logins) > 0 { + log.Debug().Msg("Not deleting portal with other logins in remote chat delete event") + if ownUP != nil { + err = portal.Bridge.DB.UserPortal.Delete(ctx, ownUP) + if err != nil { + log.Err(err).Msg("Failed to delete own user portal row from database") + } else { + log.Debug().Msg("Deleted own user portal row from database") + } + } + _, err = portal.sendStateWithIntentOrBot( + ctx, + source.User.DoublePuppet(ctx), + event.StateMember, + source.UserMXID.String(), + &event.Content{Parsed: &event.MemberEventContent{Membership: event.MembershipLeave}}, + getEventTS(evt), + ) + if err != nil { + log.Err(err).Msg("Failed to send leave state event for user after remote chat delete") + return EventHandlingResultFailed.WithError(err) + } else { + log.Debug().Msg("Sent leave state event for user after remote chat delete") + return EventHandlingResultSuccess + } + } + } + if childDeleter, ok := evt.(RemoteChatDeleteWithChildren); ok && childDeleter.DeleteChildren() && portal.RoomType == database.RoomTypeSpace { + children, err := portal.Bridge.GetChildPortals(ctx, portal.PortalKey) + if err != nil { + log.Err(err).Msg("Failed to fetch children to delete") + return EventHandlingResultFailed.WithError(err) + } + log.Debug(). + Int("portal_count", len(children)). + Msg("Deleting child portals before remote chat delete") + var wg sync.WaitGroup + wg.Add(len(children)) + for _, child := range children { + child.queueEvent(ctx, &portalRemoteEvent{ + evt: &childDeleteProxy{ + RemoteChatDeleteWithChildren: childDeleter, + child: child.PortalKey, + done: wg.Done, + }, + source: source, + evtType: RemoteEventChatDelete, + }) + } + wg.Wait() + log.Debug().Msg("Finished deleting child portals") + } + err := portal.Delete(ctx) + if err != nil { + log.Err(err).Msg("Failed to delete portal from database") + return EventHandlingResultFailed.WithError(err) + } + err = portal.Bridge.Bot.DeleteRoom(ctx, portal.MXID, false) + if err != nil { + log.Err(err).Msg("Failed to delete Matrix room") + return EventHandlingResultFailed.WithError(err) + } else { + log.Info().Msg("Deleted room after remote chat delete event") + return EventHandlingResultSuccess + } +} + +func (portal *Portal) handleRemoteBackfill(ctx context.Context, source *UserLogin, backfill RemoteBackfill) (res EventHandlingResult) { + //data, err := backfill.GetBackfillData(ctx, portal) + //if err != nil { + // zerolog.Ctx(ctx).Err(err).Msg("Failed to get backfill data") + // return + //} + return +} + +type ChatInfoChange struct { + // The chat info that changed. Any fields that did not change can be left as nil. + ChatInfo *ChatInfo + // A list of member changes. + // This list should only include changes, not the whole member list. + // To resync the whole list, use the field inside ChatInfo. + MemberChanges *ChatMemberList +} + +func (portal *Portal) ProcessChatInfoChange(ctx context.Context, sender EventSender, source *UserLogin, change *ChatInfoChange, ts time.Time) { + intent, ok := portal.GetIntentFor(ctx, sender, source, RemoteEventChatInfoChange) + if !ok { + return + } + if change.ChatInfo != nil { + portal.UpdateInfo(ctx, change.ChatInfo, source, intent, ts) + } + if change.MemberChanges != nil { + err := portal.syncParticipants(ctx, change.MemberChanges, source, intent, ts) + if err != nil { + zerolog.Ctx(ctx).Err(err).Msg("Failed to sync room members") + } + } +} + +// Deprecated: Renamed to ChatInfo +type PortalInfo = ChatInfo + +type ChatMember struct { + EventSender + Membership event.Membership + // Per-room nickname for the user. Not yet used. + Nickname *string + // The power level to set for the user when syncing power levels. + PowerLevel *int + // Optional user info to sync the ghost user while updating membership. + UserInfo *UserInfo + // The user who sent the membership change (user who invited/kicked/banned this user). + // Not yet used. Not applicable if Membership is join or knock. + MemberSender EventSender + // Extra fields to include in the member event. + MemberEventExtra map[string]any + // The expected previous membership. If this doesn't match, the change is ignored. + PrevMembership event.Membership +} + +type ChatMemberMap map[networkid.UserID]ChatMember + +// Set adds the given entry to this map, overwriting any existing entry with the same Sender field. +func (cmm ChatMemberMap) Set(member ChatMember) ChatMemberMap { + if member.Sender == "" && member.SenderLogin == "" && !member.IsFromMe { + return cmm + } + cmm[member.Sender] = member + return cmm +} + +// Add adds the given entry to this map, but will ignore it if an entry with the same Sender field already exists. +// It returns true if the entry was added, false otherwise. +func (cmm ChatMemberMap) Add(member ChatMember) bool { + if member.Sender == "" && member.SenderLogin == "" && !member.IsFromMe { + return false + } + if _, exists := cmm[member.Sender]; exists { + return false + } + cmm[member.Sender] = member + return true +} + +type ChatMemberList struct { + // Whether this is the full member list. + // If true, any extra members not listed here will be removed from the portal. + IsFull bool + // Should the bridge call IsThisUser for every member in the list? + // This should be used when SenderLogin can't be filled accurately. + CheckAllLogins bool + // Should any changes have the `com.beeper.exclude_from_timeline` flag set by default? + // This is recommended for syncs with non-real-time changes. + // Real-time changes (e.g. a user joining) should not set this flag set. + ExcludeChangesFromTimeline bool + + // The total number of members in the chat, regardless of how many of those members are included in MemberMap. + TotalMemberCount int + + // For DM portals, the ID of the recipient user. + // This field is optional and will be automatically filled from MemberMap if there are only 2 entries in the map. + OtherUserID networkid.UserID + + // Deprecated: Use MemberMap instead to avoid duplicate entries + Members []ChatMember + MemberMap ChatMemberMap + PowerLevels *PowerLevelOverrides +} + +func (cml *ChatMemberList) memberListToMap(ctx context.Context) { + if cml.Members == nil || cml.MemberMap != nil { + return + } + cml.MemberMap = make(map[networkid.UserID]ChatMember, len(cml.Members)) + for _, member := range cml.Members { + if _, alreadyExists := cml.MemberMap[member.Sender]; alreadyExists { + zerolog.Ctx(ctx).Warn().Str("member_id", string(member.Sender)).Msg("Duplicate member in list") + } + cml.MemberMap[member.Sender] = member + } +} + +type PowerLevelOverrides struct { + Events map[event.Type]int + UsersDefault *int + EventsDefault *int + StateDefault *int + Invite *int + Kick *int + Ban *int + Redact *int + + Custom func(*event.PowerLevelsEventContent) bool +} + +// Deprecated: renamed to PowerLevelOverrides +type PowerLevelChanges = PowerLevelOverrides + +func allowChange(newLevel *int, oldLevel, actorLevel int) bool { + return newLevel != nil && + *newLevel <= actorLevel && oldLevel <= actorLevel && + oldLevel != *newLevel +} + +func (plc *PowerLevelOverrides) Apply(actor id.UserID, content *event.PowerLevelsEventContent) (changed bool) { + if plc == nil || content == nil { + return + } + for evtType, level := range plc.Events { + changed = content.EnsureEventLevelAs(actor, evtType, level) || changed + } + var actorLevel int + if actor != "" { + actorLevel = content.GetUserLevel(actor) + } else { + actorLevel = (1 << 31) - 1 + } + if allowChange(plc.UsersDefault, content.UsersDefault, actorLevel) { + changed = true + content.UsersDefault = *plc.UsersDefault + } + if allowChange(plc.EventsDefault, content.EventsDefault, actorLevel) { + changed = true + content.EventsDefault = *plc.EventsDefault + } + if allowChange(plc.StateDefault, content.StateDefault(), actorLevel) { + changed = true + content.StateDefaultPtr = plc.StateDefault + } + if allowChange(plc.Invite, content.Invite(), actorLevel) { + changed = true + content.InvitePtr = plc.Invite + } + if allowChange(plc.Kick, content.Kick(), actorLevel) { + changed = true + content.KickPtr = plc.Kick + } + if allowChange(plc.Ban, content.Ban(), actorLevel) { + changed = true + content.BanPtr = plc.Ban + } + if allowChange(plc.Redact, content.Redact(), actorLevel) { + changed = true + content.RedactPtr = plc.Redact + } + if plc.Custom != nil { + changed = plc.Custom(content) || changed + } + return changed +} + +// DefaultChatName can be used to explicitly clear the name of a room +// and reset it to the default one based on members. +var DefaultChatName = ptr.Ptr("") + +type ChatInfo struct { + Name *string + Topic *string + Avatar *Avatar + + Members *ChatMemberList + JoinRule *event.JoinRulesEventContent + + Type *database.RoomType + Disappear *database.DisappearingSetting + ParentID *networkid.PortalID + + UserLocal *UserLocalPortalInfo + MessageRequest *bool + CanBackfill bool + + ExcludeChangesFromTimeline bool + + ExtraUpdates ExtraUpdater[*Portal] +} + +type ExtraUpdater[T any] func(context.Context, T) bool + +func MergeExtraUpdaters[T any](funcs ...ExtraUpdater[T]) ExtraUpdater[T] { + funcs = slices.DeleteFunc(funcs, func(f ExtraUpdater[T]) bool { + return f == nil + }) + if len(funcs) == 0 { + return nil + } else if len(funcs) == 1 { + return funcs[0] + } + return func(ctx context.Context, p T) bool { + changed := false + for _, f := range funcs { + changed = f(ctx, p) || changed + } + return changed + } +} + +var Unmuted = time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC) + +type UserLocalPortalInfo struct { + // To signal an indefinite mute, use [event.MutedForever] as the value here. + // To unmute, set any time before now, e.g. [bridgev2.Unmuted]. + MutedUntil *time.Time + Tag *event.RoomTag +} + +func (portal *Portal) updateName( + ctx context.Context, name string, sender MatrixAPI, ts time.Time, excludeFromTimeline bool, +) bool { + if portal.Name == name && (portal.NameSet || portal.MXID == "") { + return false + } + portal.Name = name + portal.NameSet = portal.sendRoomMeta( + ctx, sender, ts, event.StateRoomName, "", &event.RoomNameEventContent{Name: name}, excludeFromTimeline, nil, + ) + return true +} + +func (portal *Portal) updateTopic( + ctx context.Context, topic string, sender MatrixAPI, ts time.Time, excludeFromTimeline bool, +) bool { + if portal.Topic == topic && (portal.TopicSet || portal.MXID == "") { + return false + } + portal.Topic = topic + portal.TopicSet = portal.sendRoomMeta( + ctx, sender, ts, event.StateTopic, "", &event.TopicEventContent{Topic: topic}, excludeFromTimeline, nil, + ) + return true +} + +func (portal *Portal) updateAvatar( + ctx context.Context, avatar *Avatar, sender MatrixAPI, ts time.Time, excludeFromTimeline bool, +) bool { + if portal.AvatarID == avatar.ID && (avatar.Remove || portal.AvatarMXC != "") && (portal.AvatarSet || portal.MXID == "") { + return false + } + portal.AvatarID = avatar.ID + if sender == nil { + sender = portal.Bridge.Bot + } + if avatar.Remove { + portal.AvatarMXC = "" + portal.AvatarHash = [32]byte{} + } else { + newMXC, newHash, err := avatar.Reupload(ctx, sender, portal.AvatarHash, portal.AvatarMXC) + if err != nil { + portal.AvatarSet = false + zerolog.Ctx(ctx).Err(err).Msg("Failed to reupload room avatar") + return true + } else if newHash == portal.AvatarHash && portal.AvatarMXC != "" && portal.AvatarSet { + return true + } + portal.AvatarMXC = newMXC + portal.AvatarHash = newHash + } + portal.AvatarSet = portal.sendRoomMeta( + ctx, sender, ts, event.StateRoomAvatar, "", &event.RoomAvatarEventContent{URL: portal.AvatarMXC}, excludeFromTimeline, nil, + ) + return true +} + +func (portal *Portal) GetTopLevelParent() *Portal { + if portal.Parent == nil { + if portal.RoomType != database.RoomTypeSpace { + return nil + } + return portal + } + return portal.Parent.GetTopLevelParent() +} + +func (portal *Portal) getBridgeInfoStateKey() string { + if portal.Bridge.Config.NoBridgeInfoStateKey { + return "" + } + idProvider, ok := portal.Bridge.Matrix.(MatrixConnectorWithBridgeIdentifier) + if ok { + return idProvider.GetUniqueBridgeID() + } + return string(portal.BridgeID) +} + +func (portal *Portal) getBridgeInfo() (string, event.BridgeEventContent) { + bridgeInfo := event.BridgeEventContent{ + BridgeBot: portal.Bridge.Bot.GetMXID(), + Creator: portal.Bridge.Bot.GetMXID(), + Protocol: portal.Bridge.Network.GetName().AsBridgeInfoSection(), + Channel: event.BridgeInfoSection{ + ID: string(portal.ID), + DisplayName: portal.Name, + AvatarURL: portal.AvatarMXC, + Receiver: string(portal.Receiver), + MessageRequest: portal.MessageRequest, + // TODO external URL? + }, + BeeperRoomTypeV2: string(portal.RoomType), + } + if portal.RoomType == database.RoomTypeDM || portal.RoomType == database.RoomTypeGroupDM { + bridgeInfo.BeeperRoomType = "dm" + } + if bridgeInfo.Protocol.ID == "slackgo" { + bridgeInfo.TempSlackRemoteIDMigratedFlag = true + bridgeInfo.TempSlackRemoteIDMigratedFlag2 = true + } + parent := portal.GetTopLevelParent() + if parent != nil { + bridgeInfo.Network = &event.BridgeInfoSection{ + ID: string(parent.ID), + DisplayName: parent.Name, + AvatarURL: parent.AvatarMXC, + // TODO external URL? + } + } + filler, ok := portal.Bridge.Network.(PortalBridgeInfoFillingNetwork) + if ok { + filler.FillPortalBridgeInfo(portal, &bridgeInfo) + } + return portal.getBridgeInfoStateKey(), bridgeInfo +} + +func (portal *Portal) UpdateBridgeInfo(ctx context.Context) { + if portal.MXID == "" { + return + } + stateKey, bridgeInfo := portal.getBridgeInfo() + portal.sendRoomMeta(ctx, nil, time.Now(), event.StateBridge, stateKey, &bridgeInfo, false, nil) + portal.sendRoomMeta(ctx, nil, time.Now(), event.StateHalfShotBridge, stateKey, &bridgeInfo, false, nil) +} + +func (portal *Portal) UpdateCapabilities(ctx context.Context, source *UserLogin, implicit bool) bool { + if portal.MXID == "" { + return false + } else if !implicit && time.Since(portal.lastCapUpdate) < 24*time.Hour { + return false + } else if portal.CapState.ID != "" && source.ID != portal.CapState.Source && source.ID != portal.Receiver { + // TODO allow capability state source to change if the old user login is removed from the portal + return false + } + caps := source.Client.GetCapabilities(ctx, portal) + capID := caps.GetID() + if capID == portal.CapState.ID { + return false + } + zerolog.Ctx(ctx).Debug(). + Str("user_login_id", string(source.ID)). + Str("old_id", portal.CapState.ID). + Str("new_id", capID). + Msg("Sending new room capability event") + success := portal.sendRoomMeta(ctx, nil, time.Now(), event.StateBeeperRoomFeatures, portal.getBridgeInfoStateKey(), caps, false, nil) + if !success { + return false + } + portal.CapState = database.CapabilityState{ + Source: source.ID, + ID: capID, + Flags: portal.CapState.Flags, + } + if caps.DisappearingTimer != nil && !portal.CapState.Flags.Has(database.CapStateFlagDisappearingTimerSet) { + zerolog.Ctx(ctx).Debug().Msg("Disappearing timer capability was added, sending disappearing timer state event") + success = portal.sendRoomMeta(ctx, nil, time.Now(), event.StateBeeperDisappearingTimer, "", portal.Disappear.ToEventContent(), true, nil) + if !success { + return false + } + portal.CapState.Flags |= database.CapStateFlagDisappearingTimerSet + } + portal.lastCapUpdate = time.Now() + if implicit { + err := portal.Save(ctx) + if err != nil { + zerolog.Ctx(ctx).Err(err).Msg("Failed to save portal capability state after sending state event") + } + } + return true +} + +func (portal *Portal) sendStateWithIntentOrBot(ctx context.Context, sender MatrixAPI, eventType event.Type, stateKey string, content *event.Content, ts time.Time) (resp *mautrix.RespSendEvent, err error) { + if sender == nil { + sender = portal.Bridge.Bot + } + resp, err = sender.SendState(ctx, portal.MXID, eventType, stateKey, content, ts) + if errors.Is(err, mautrix.MForbidden) && sender != portal.Bridge.Bot { + if content.Raw == nil { + content.Raw = make(map[string]any) + } + content.Raw["fi.mau.bridge.set_by"] = sender.GetMXID() + resp, err = portal.Bridge.Bot.SendState(ctx, portal.MXID, eventType, stateKey, content, ts) + } + return +} + +func (portal *Portal) sendRoomMeta( + ctx context.Context, + sender MatrixAPI, + ts time.Time, + eventType event.Type, + stateKey string, + content any, + excludeFromTimeline bool, + extra map[string]any, +) bool { + if portal.MXID == "" { + return false + } + if extra == nil { + extra = make(map[string]any) + } + if excludeFromTimeline { + extra["com.beeper.exclude_from_timeline"] = true + } + if !portal.NameIsCustom && (eventType == event.StateRoomName || eventType == event.StateRoomAvatar) { + extra["fi.mau.implicit_name"] = true + } + _, err := portal.sendStateWithIntentOrBot(ctx, sender, eventType, stateKey, &event.Content{ + Parsed: content, + Raw: extra, + }, ts) + if err != nil { + zerolog.Ctx(ctx).Err(err). + Str("event_type", eventType.Type). + Msg("Failed to set room metadata") + return false + } + if eventType == event.StateBeeperDisappearingTimer { + // TODO remove this debug log at some point + zerolog.Ctx(ctx).Debug(). + Any("content", content). + Msg("Sent new disappearing timer event") + } + return true +} + +func (portal *Portal) revertRoomMeta(ctx context.Context, evt *event.Event) { + if !portal.Bridge.Config.RevertFailedStateChanges { + return + } + if evt.GetStateKey() != "" && evt.Type != event.StateMember { + return + } + switch evt.Type { + case event.StateRoomName: + portal.sendRoomMeta(ctx, nil, time.Time{}, event.StateRoomName, "", &event.RoomNameEventContent{Name: portal.Name}, true, nil) + case event.StateRoomAvatar: + portal.sendRoomMeta(ctx, nil, time.Time{}, event.StateRoomAvatar, "", &event.RoomAvatarEventContent{URL: portal.AvatarMXC}, true, nil) + case event.StateTopic: + portal.sendRoomMeta(ctx, nil, time.Time{}, event.StateTopic, "", &event.TopicEventContent{Topic: portal.Topic}, true, nil) + case event.StateBeeperDisappearingTimer: + portal.sendRoomMeta(ctx, nil, time.Time{}, event.StateBeeperDisappearingTimer, "", portal.Disappear.ToEventContent(), true, nil) + case event.StateMember: + var prevContent *event.MemberEventContent + var extra map[string]any + if evt.Unsigned.PrevContent != nil { + _ = evt.Unsigned.PrevContent.ParseRaw(evt.Type) + prevContent = evt.Unsigned.PrevContent.AsMember() + newContent := evt.Content.AsMember() + if prevContent.Membership == newContent.Membership { + return + } + extra = evt.Unsigned.PrevContent.Raw + } else { + prevContent = &event.MemberEventContent{Membership: event.MembershipLeave} + } + if portal.Bridge.Matrix.GetCapabilities().ArbitraryMemberChange { + if extra == nil { + extra = make(map[string]any) + } + extra["com.beeper.member_rollback"] = true + portal.sendRoomMeta(ctx, nil, time.Time{}, event.StateMember, evt.GetStateKey(), prevContent, true, extra) + } + } +} + +func (portal *Portal) getInitialMemberList(ctx context.Context, members *ChatMemberList, source *UserLogin, pl *event.PowerLevelsEventContent) (invite, functional []id.UserID, err error) { + if members == nil { + invite = []id.UserID{source.UserMXID} + return + } + var loginsInPortal []*UserLogin + if members.CheckAllLogins && !portal.Bridge.Config.SplitPortals { + loginsInPortal, err = portal.Bridge.GetUserLoginsInPortal(ctx, portal.PortalKey) + if err != nil { + err = fmt.Errorf("failed to get user logins in portal: %w", err) + return + } + } + members.PowerLevels.Apply("", pl) + members.memberListToMap(ctx) + for _, member := range members.MemberMap { + if ctx.Err() != nil { + err = ctx.Err() + return + } + if member.Membership != event.MembershipJoin && member.Membership != "" { + continue + } + if member.Sender != "" && member.UserInfo != nil { + ghost, err := portal.Bridge.GetGhostByID(ctx, member.Sender) + if err != nil { + zerolog.Ctx(ctx).Err(err).Str("ghost_id", string(member.Sender)).Msg("Failed to get ghost from member list to update info") + } else { + ghost.UpdateInfo(ctx, member.UserInfo) + } + } + intent, extraUserID, err := portal.getIntentAndUserMXIDFor(ctx, member.EventSender, source, loginsInPortal, 0) + if err != nil { + return nil, nil, err + } + if extraUserID != "" { + invite = append(invite, extraUserID) + if member.PowerLevel != nil { + pl.EnsureUserLevel(extraUserID, *member.PowerLevel) + } + if intent != nil { + // If intent is present along with a user ID, it's the ghost of a logged-in user, + // so add it to the functional members list + functional = append(functional, intent.GetMXID()) + } + } + if intent != nil { + invite = append(invite, intent.GetMXID()) + if member.PowerLevel != nil { + pl.EnsureUserLevel(intent.GetMXID(), *member.PowerLevel) + } + } + } + portal.updateOtherUser(ctx, members) + return +} + +func (portal *Portal) updateOtherUser(ctx context.Context, members *ChatMemberList) (changed bool) { + members.memberListToMap(ctx) + var expectedUserID networkid.UserID + if portal.RoomType != database.RoomTypeDM { + // expected user ID is empty + } else if members.OtherUserID != "" { + expectedUserID = members.OtherUserID + } else if len(members.MemberMap) == 2 && members.IsFull { + vals := maps.Values(members.MemberMap) + if vals[0].IsFromMe && !vals[1].IsFromMe { + expectedUserID = vals[1].Sender + } else if vals[1].IsFromMe && !vals[0].IsFromMe { + expectedUserID = vals[0].Sender + } + } + if portal.OtherUserID != expectedUserID { + zerolog.Ctx(ctx).Debug(). + Str("old_other_user_id", string(portal.OtherUserID)). + Str("new_other_user_id", string(expectedUserID)). + Msg("Updating other user ID in DM portal") + portal.OtherUserID = expectedUserID + return true + } + return false +} + +func looksDirectlyJoinable(rule *event.JoinRulesEventContent) bool { + switch rule.JoinRule { + case event.JoinRulePublic: + return true + case event.JoinRuleKnockRestricted, event.JoinRuleRestricted: + for _, allow := range rule.Allow { + if allow.Type == "fi.mau.spam_checker" { + return true + } + } + } + return false +} + +func (portal *Portal) roomIsPublic(ctx context.Context) bool { + mx, ok := portal.Bridge.Matrix.(MatrixConnectorWithArbitraryRoomState) + if !ok { + return false + } + evt, err := mx.GetStateEvent(ctx, portal.MXID, event.StateJoinRules, "") + if err != nil { + zerolog.Ctx(ctx).Warn().Err(err).Msg("Failed to get join rules to check if room is public") + return false + } else if evt == nil { + return false + } + content, ok := evt.Content.Parsed.(*event.JoinRulesEventContent) + if !ok { + return false + } + return looksDirectlyJoinable(content) +} + +func (portal *Portal) syncParticipants( + ctx context.Context, + members *ChatMemberList, + source *UserLogin, + sender MatrixAPI, + ts time.Time, +) error { + members.memberListToMap(ctx) + var loginsInPortal []*UserLogin + var err error + if members.CheckAllLogins && !portal.Bridge.Config.SplitPortals { + loginsInPortal, err = portal.Bridge.GetUserLoginsInPortal(ctx, portal.PortalKey) + if err != nil { + return fmt.Errorf("failed to get user logins in portal: %w", err) + } + } + if sender == nil { + sender = portal.Bridge.Bot + } + log := zerolog.Ctx(ctx) + currentPower, err := portal.Bridge.Matrix.GetPowerLevels(ctx, portal.MXID) + if err != nil { + return fmt.Errorf("failed to get current power levels: %w", err) + } + currentMembers, err := portal.Bridge.Matrix.GetMembers(ctx, portal.MXID) + if err != nil { + return fmt.Errorf("failed to get current members: %w", err) + } + delete(currentMembers, portal.Bridge.Bot.GetMXID()) + powerChanged := members.PowerLevels.Apply(portal.Bridge.Bot.GetMXID(), currentPower) + addExcludeFromTimeline := func(raw map[string]any) { + _, hasKey := raw["com.beeper.exclude_from_timeline"] + if !hasKey && members.ExcludeChangesFromTimeline { + raw["com.beeper.exclude_from_timeline"] = true + } + } + syncUser := func(extraUserID id.UserID, member ChatMember, intent MatrixAPI) bool { + if member.Membership == "" { + member.Membership = event.MembershipJoin + } + if member.PowerLevel != nil { + powerChanged = currentPower.EnsureUserLevelAs(portal.Bridge.Bot.GetMXID(), extraUserID, *member.PowerLevel) || powerChanged + } + currentMember, ok := currentMembers[extraUserID] + delete(currentMembers, extraUserID) + if ok && currentMember.Membership == member.Membership { + return false + } + if currentMember == nil { + currentMember = &event.MemberEventContent{Membership: event.MembershipLeave} + } + if member.PrevMembership != "" && member.PrevMembership != currentMember.Membership { + log.Trace(). + Stringer("user_id", extraUserID). + Str("expected_prev_membership", string(member.PrevMembership)). + Str("actual_prev_membership", string(currentMember.Membership)). + Str("target_membership", string(member.Membership)). + Msg("Not updating membership: prev membership mismatch") + return false + } + content := &event.MemberEventContent{ + Membership: member.Membership, + Displayname: currentMember.Displayname, + AvatarURL: currentMember.AvatarURL, + } + wrappedContent := &event.Content{Parsed: content, Raw: exmaps.NonNilClone(member.MemberEventExtra)} + addExcludeFromTimeline(wrappedContent.Raw) + thisEvtSender := sender + if member.Membership == event.MembershipJoin && (intent == nil || !portal.roomIsPublic(ctx)) { + content.Membership = event.MembershipInvite + if intent != nil { + wrappedContent.Raw["fi.mau.will_auto_accept"] = true + } + if thisEvtSender.GetMXID() == extraUserID { + thisEvtSender = portal.Bridge.Bot + } + } + addLogContext := func(e *zerolog.Event) *zerolog.Event { + return e.Stringer("target_user_id", extraUserID). + Stringer("sender_user_id", thisEvtSender.GetMXID()). + Str("prev_membership", string(currentMember.Membership)) + } + if currentMember != nil && currentMember.Membership == event.MembershipBan && member.Membership != event.MembershipLeave { + unbanContent := *content + unbanContent.Membership = event.MembershipLeave + wrappedUnbanContent := &event.Content{Parsed: &unbanContent} + _, err = portal.sendStateWithIntentOrBot(ctx, thisEvtSender, event.StateMember, extraUserID.String(), wrappedUnbanContent, ts) + if err != nil { + addLogContext(log.Err(err)). + Str("new_membership", string(unbanContent.Membership)). + Msg("Failed to unban user to update membership") + } else { + addLogContext(log.Trace()). + Str("new_membership", string(unbanContent.Membership)). + Msg("Unbanned user to update membership") + currentMember.Membership = event.MembershipLeave + } + } + if content.Membership == event.MembershipJoin && intent != nil && intent.GetMXID() == extraUserID { + _, err = intent.SendState(ctx, portal.MXID, event.StateMember, extraUserID.String(), wrappedContent, ts) + } else { + _, err = portal.sendStateWithIntentOrBot(ctx, thisEvtSender, event.StateMember, extraUserID.String(), wrappedContent, ts) + } + if err != nil { + addLogContext(log.Err(err)). + Str("new_membership", string(content.Membership)). + Msg("Failed to update user membership") + } else { + addLogContext(log.Trace()). + Str("new_membership", string(content.Membership)). + Msg("Updated membership in room") + currentMember.Membership = content.Membership + + if intent != nil && content.Membership == event.MembershipInvite && member.Membership == event.MembershipJoin { + content.Membership = event.MembershipJoin + wrappedJoinContent := &event.Content{Parsed: content, Raw: exmaps.NonNilClone(member.MemberEventExtra)} + addExcludeFromTimeline(wrappedContent.Raw) + _, err = intent.SendState(ctx, portal.MXID, event.StateMember, intent.GetMXID().String(), wrappedJoinContent, ts) + if err != nil { + addLogContext(log.Err(err)). + Str("new_membership", string(content.Membership)). + Msg("Failed to join with intent") + } else { + addLogContext(log.Trace()). + Str("new_membership", string(content.Membership)). + Msg("Joined room with intent") + } + } + } + return true + } + syncIntent := func(intent MatrixAPI, member ChatMember) { + if !syncUser(intent.GetMXID(), member, intent) { + return + } + if member.Membership == event.MembershipJoin || member.Membership == "" { + err = intent.EnsureJoined(ctx, portal.MXID) + if err != nil { + log.Err(err). + Stringer("user_id", intent.GetMXID()). + Msg("Failed to ensure user is joined to room") + } + } + } + for _, member := range members.MemberMap { + if ctx.Err() != nil { + return ctx.Err() + } + if member.Sender != "" && member.UserInfo != nil { + ghost, err := portal.Bridge.GetGhostByID(ctx, member.Sender) + if err != nil { + zerolog.Ctx(ctx).Err(err).Str("ghost_id", string(member.Sender)).Msg("Failed to get ghost from member list to update info") + } else { + ghost.UpdateInfo(ctx, member.UserInfo) + } + } + intent, extraUserID, err := portal.getIntentAndUserMXIDFor(ctx, member.EventSender, source, loginsInPortal, 0) + if err != nil { + return err + } + if intent != nil { + syncIntent(intent, member) + } + if extraUserID != "" { + syncUser(extraUserID, member, nil) + } + } + if powerChanged { + _, err = portal.sendStateWithIntentOrBot(ctx, sender, event.StatePowerLevels, "", &event.Content{Parsed: currentPower}, ts) + if err != nil { + log.Err(err).Msg("Failed to update power levels") + } + } + portal.updateOtherUser(ctx, members) + if members.IsFull { + for extraMember, memberEvt := range currentMembers { + if memberEvt.Membership == event.MembershipLeave || memberEvt.Membership == event.MembershipBan { + continue + } + if !portal.Bridge.IsGhostMXID(extraMember) && (portal.Relay != nil || !portal.Bridge.Config.KickMatrixUsers) { + continue + } + _, err = portal.Bridge.Bot.SendState(ctx, portal.MXID, event.StateMember, extraMember.String(), &event.Content{ + Parsed: &event.MemberEventContent{ + Membership: event.MembershipLeave, + AvatarURL: memberEvt.AvatarURL, + Displayname: memberEvt.Displayname, + Reason: "User is not in remote chat", + }, + Raw: map[string]any{ + "com.beeper.exclude_from_timeline": members.ExcludeChangesFromTimeline, + }, + }, time.Now()) + if err != nil { + zerolog.Ctx(ctx).Err(err). + Stringer("user_id", extraMember). + Msg("Failed to remove user from room") + } + } + } + return nil +} + +func (portal *Portal) updateUserLocalInfo(ctx context.Context, info *UserLocalPortalInfo, source *UserLogin, didJustCreate bool) { + if portal.MXID == "" { + return + } + dp := source.User.DoublePuppet(ctx) + if dp == nil { + return + } + dmMarkingMatrixAPI, canMarkDM := dp.(MarkAsDMMatrixAPI) + if canMarkDM && portal.OtherUserID != "" && portal.RoomType == database.RoomTypeDM { + dmGhost, err := portal.Bridge.GetGhostByID(ctx, portal.OtherUserID) + if err != nil { + zerolog.Ctx(ctx).Err(err).Msg("Failed to get DM ghost to mark room as DM") + } else if err = dmMarkingMatrixAPI.MarkAsDM(ctx, portal.MXID, dmGhost.Intent.GetMXID()); err != nil { + zerolog.Ctx(ctx).Err(err).Msg("Failed to mark room as DM") + } + } + if info == nil { + return + } + if info.MutedUntil != nil && (didJustCreate || !portal.Bridge.Config.MuteOnlyOnCreate) && (!didJustCreate || info.MutedUntil.After(time.Now())) { + err := dp.MuteRoom(ctx, portal.MXID, *info.MutedUntil) + if err != nil { + zerolog.Ctx(ctx).Err(err).Msg("Failed to mute room") + } + } + if info.Tag != nil && + len(portal.Bridge.Config.OnlyBridgeTags) > 0 && + (*info.Tag == "" || slices.Contains(portal.Bridge.Config.OnlyBridgeTags, *info.Tag)) && + (didJustCreate || !portal.Bridge.Config.TagOnlyOnCreate) && + (!didJustCreate || *info.Tag != "") { + err := dp.TagRoom(ctx, portal.MXID, *info.Tag, *info.Tag != "") + if err != nil { + zerolog.Ctx(ctx).Err(err).Msg("Failed to tag room") + } + } +} + +func DisappearingMessageNotice(expiration time.Duration, implicit bool) *event.MessageEventContent { + formattedDuration := exfmt.DurationCustom(expiration, nil, exfmt.Day, time.Hour, time.Minute, time.Second) + content := &event.MessageEventContent{ + MsgType: event.MsgNotice, + Body: fmt.Sprintf("Set the disappearing message timer to %s", formattedDuration), + Mentions: &event.Mentions{}, + } + if expiration == 0 { + if implicit { + content.Body = "Automatically turned off disappearing messages because incoming message is not disappearing" + } else { + content.Body = "Turned off disappearing messages" + } + } else if implicit { + content.Body = fmt.Sprintf("Automatically enabled disappearing message timer (%s) because incoming message is disappearing", formattedDuration) + } + return content +} + +type UpdateDisappearingSettingOpts struct { + Sender MatrixAPI + Timestamp time.Time + Implicit bool + Save bool + SendNotice bool + + ExcludeFromTimeline bool +} + +func (portal *Portal) UpdateDisappearingSetting( + ctx context.Context, + setting database.DisappearingSetting, + opts UpdateDisappearingSettingOpts, +) bool { + setting = setting.Normalize() + if portal.Disappear.Timer == setting.Timer && portal.Disappear.Type == setting.Type { + return false + } + portal.Disappear.Type = setting.Type + portal.Disappear.Timer = setting.Timer + if opts.Save { + err := portal.Save(ctx) + if err != nil { + zerolog.Ctx(ctx).Err(err).Msg("Failed to save portal to database after updating disappearing setting") + } + } + if portal.MXID == "" { + return true + } + + if opts.Sender == nil { + opts.Sender = portal.Bridge.Bot + } + if opts.Timestamp.IsZero() { + opts.Timestamp = time.Now() + } + portal.sendRoomMeta( + ctx, + opts.Sender, + opts.Timestamp, + event.StateBeeperDisappearingTimer, + "", + setting.ToEventContent(), + opts.ExcludeFromTimeline, + nil, + ) + + if !opts.SendNotice { + return true + } + content := DisappearingMessageNotice(setting.Timer, opts.Implicit) + _, err := opts.Sender.SendMessage(ctx, portal.MXID, event.EventMessage, &event.Content{ + Parsed: content, + Raw: map[string]any{ + "com.beeper.action_message": map[string]any{ + "type": "disappearing_timer", + "timer": setting.Timer.Milliseconds(), + "timer_type": setting.Type, + "implicit": opts.Implicit, + }, + }, + }, &MatrixSendExtra{Timestamp: opts.Timestamp}) + if err != nil { + zerolog.Ctx(ctx).Err(err).Msg("Failed to send disappearing messages notice") + } else { + zerolog.Ctx(ctx).Debug(). + Dur("new_timer", portal.Disappear.Timer). + Bool("implicit", opts.Implicit). + Msg("Sent disappearing messages notice") + } + return true +} + +func (portal *Portal) updateParent(ctx context.Context, newParentID networkid.PortalID, source *UserLogin) bool { + newParent := networkid.PortalKey{ID: newParentID} + if portal.Bridge.Config.SplitPortals { + newParent.Receiver = portal.Receiver + } + if portal.ParentKey == newParent { + return false + } + var err error + if portal.MXID != "" && portal.InSpace && portal.Parent != nil && portal.Parent.MXID != "" { + err = portal.toggleSpace(ctx, portal.Parent.MXID, false, true) + if err != nil { + zerolog.Ctx(ctx).Err(err).Stringer("old_space_mxid", portal.Parent.MXID).Msg("Failed to remove portal from old space") + } + } + portal.ParentKey = newParent + portal.InSpace = false + if newParent.ID != "" { + portal.Parent, err = portal.Bridge.GetPortalByKey(ctx, newParent) + if err != nil { + zerolog.Ctx(ctx).Err(err).Msg("Failed to get new parent portal") + } + } + if portal.MXID != "" && portal.Parent != nil && (source != nil || portal.Parent.MXID != "") { + if portal.Parent.MXID == "" { + zerolog.Ctx(ctx).Info().Msg("Parent portal doesn't exist, creating") + err = portal.Parent.CreateMatrixRoom(ctx, source, nil) + if err != nil { + zerolog.Ctx(ctx).Err(err).Msg("Failed to create parent portal") + } + } + if portal.Parent.MXID != "" { + portal.addToParentSpaceAndSave(ctx, false) + } + } + return true +} + +func (portal *Portal) lockedUpdateInfoFromGhost(ctx context.Context, ghost *Ghost) { + portal.roomCreateLock.Lock() + defer portal.roomCreateLock.Unlock() + portal.UpdateInfoFromGhost(ctx, ghost) +} + +func (portal *Portal) UpdateInfoFromGhost(ctx context.Context, ghost *Ghost) (changed bool) { + if portal.NameIsCustom || !portal.Bridge.Config.PrivateChatPortalMeta || (portal.OtherUserID == "" && ghost == nil) || portal.RoomType != database.RoomTypeDM { + return + } + var err error + if ghost == nil { + ghost, err = portal.Bridge.GetGhostByID(ctx, portal.OtherUserID) + if err != nil { + zerolog.Ctx(ctx).Err(err).Msg("Failed to get ghost to update info from") + return + } + } + changed = portal.updateName(ctx, ghost.Name, nil, time.Time{}, false) || changed + changed = portal.updateAvatar(ctx, &Avatar{ + ID: ghost.AvatarID, + MXC: ghost.AvatarMXC, + Hash: ghost.AvatarHash, + Remove: ghost.AvatarID == "", + }, nil, time.Time{}, false) || changed + return +} + +func (portal *Portal) UpdateInfo(ctx context.Context, info *ChatInfo, source *UserLogin, sender MatrixAPI, ts time.Time) { + changed := false + if info.Name == DefaultChatName { + if portal.NameIsCustom { + portal.NameIsCustom = false + changed = portal.updateName(ctx, "", sender, ts, info.ExcludeChangesFromTimeline) || changed + } + } else if info.Name != nil { + portal.NameIsCustom = true + changed = portal.updateName(ctx, *info.Name, sender, ts, info.ExcludeChangesFromTimeline) || changed + } + if info.Topic != nil { + changed = portal.updateTopic(ctx, *info.Topic, sender, ts, info.ExcludeChangesFromTimeline) || changed + } + if info.Avatar != nil { + portal.NameIsCustom = true + changed = portal.updateAvatar(ctx, info.Avatar, sender, ts, info.ExcludeChangesFromTimeline) || changed + } + if info.Disappear != nil { + changed = portal.UpdateDisappearingSetting(ctx, *info.Disappear, UpdateDisappearingSettingOpts{ + Sender: sender, + Timestamp: ts, + Implicit: false, + Save: false, + + SendNotice: !info.ExcludeChangesFromTimeline, + ExcludeFromTimeline: info.ExcludeChangesFromTimeline, + }) || changed + } + if info.ParentID != nil { + changed = portal.updateParent(ctx, *info.ParentID, source) || changed + } + if info.JoinRule != nil { + // TODO change detection instead of spamming this every time? + portal.sendRoomMeta(ctx, sender, ts, event.StateJoinRules, "", info.JoinRule, info.ExcludeChangesFromTimeline, nil) + } + if info.Type != nil && portal.RoomType != *info.Type { + if portal.MXID != "" && (*info.Type == database.RoomTypeSpace || portal.RoomType == database.RoomTypeSpace) { + zerolog.Ctx(ctx).Warn(). + Str("current_type", string(portal.RoomType)). + Str("target_type", string(*info.Type)). + Msg("Tried to change existing room type from/to space") + } else { + changed = true + portal.RoomType = *info.Type + } + } + if info.MessageRequest != nil && *info.MessageRequest != portal.MessageRequest { + changed = true + portal.MessageRequest = *info.MessageRequest + } + if info.Members != nil && portal.MXID != "" && source != nil { + err := portal.syncParticipants(ctx, info.Members, source, nil, time.Time{}) + if err != nil { + zerolog.Ctx(ctx).Err(err).Msg("Failed to sync room members") + } + // TODO detect changes to functional members list? + } else if info.Members != nil { + portal.updateOtherUser(ctx, info.Members) + } + changed = portal.UpdateInfoFromGhost(ctx, nil) || changed + if source != nil { + source.MarkInPortal(ctx, portal) + portal.updateUserLocalInfo(ctx, info.UserLocal, source, false) + changed = portal.UpdateCapabilities(ctx, source, false) || changed + } + if info.CanBackfill && source != nil && portal.MXID != "" { + err := portal.Bridge.DB.BackfillTask.EnsureExists(ctx, portal.PortalKey, source.ID) + if err != nil { + zerolog.Ctx(ctx).Err(err).Msg("Failed to ensure backfill queue task exists") + } + // TODO wake up backfill queue if task was just created + } + if info.ExtraUpdates != nil { + changed = info.ExtraUpdates(ctx, portal) || changed + } + if changed { + portal.UpdateBridgeInfo(ctx) + err := portal.Save(ctx) + if err != nil { + zerolog.Ctx(ctx).Err(err).Msg("Failed to save portal to database after updating info") + } + } +} + +func (portal *Portal) CreateMatrixRoom(ctx context.Context, source *UserLogin, info *ChatInfo) (retErr error) { + if portal.MXID != "" { + if source != nil { + source.MarkInPortal(ctx, portal) + } + return nil + } + if portal.deleted.IsSet() { + return ErrPortalIsDeleted + } + waiter := make(chan struct{}) + closed := false + evt := &portalCreateEvent{ + ctx: ctx, + source: source, + info: info, + cb: func(err error) { + retErr = err + if !closed { + closed = true + close(waiter) + } + }, + } + if PortalEventBuffer == 0 { + go portal.queueEvent(ctx, evt) + } else { + select { + case portal.events <- evt: + case <-portal.deleted.GetChan(): + return ErrPortalIsDeleted + } + } + select { + case <-ctx.Done(): + return ctx.Err() + case <-waiter: + return + } +} + +func (portal *Portal) createMatrixRoomInLoop(ctx context.Context, source *UserLogin, info *ChatInfo, backfillBundle any) error { + cancellableCtx, cancel := context.WithCancel(ctx) + defer cancel() + portal.cancelRoomCreate.CompareAndSwap(nil, &cancel) + portal.roomCreateLock.Lock() + portal.cancelRoomCreate.Store(&cancel) + defer portal.roomCreateLock.Unlock() + if portal.MXID != "" { + if source != nil { + source.MarkInPortal(ctx, portal) + } + return nil + } + log := zerolog.Ctx(ctx).With(). + Str("action", "create matrix room"). + Logger() + cancellableCtx = log.WithContext(cancellableCtx) + ctx = log.WithContext(ctx) + log.Info().Msg("Creating Matrix room") + + var err error + if info == nil || info.Members == nil { + if info != nil { + log.Warn().Msg("CreateMatrixRoom got info without members. Refetching info") + } + info, err = source.Client.GetChatInfo(cancellableCtx, portal) + if err != nil { + log.Err(err).Msg("Failed to update portal info for creation") + return err + } + } + + portal.UpdateInfo(cancellableCtx, info, source, nil, time.Time{}) + if cancellableCtx.Err() != nil { + return cancellableCtx.Err() + } + + powerLevels := &event.PowerLevelsEventContent{ + Events: map[string]int{ + event.StateTombstone.Type: 100, + event.StateServerACL.Type: 100, + event.StateEncryption.Type: 100, + }, + Users: map[id.UserID]int{ + portal.Bridge.Bot.GetMXID(): 9001, + }, + } + initialMembers, extraFunctionalMembers, err := portal.getInitialMemberList(cancellableCtx, info.Members, source, powerLevels) + if err != nil { + log.Err(err).Msg("Failed to process participant list for portal creation") + return err + } + powerLevels.EnsureUserLevel(portal.Bridge.Bot.GetMXID(), 9001) + + req := mautrix.ReqCreateRoom{ + Visibility: "private", + CreationContent: make(map[string]any), + InitialState: make([]*event.Event, 0, 6), + Preset: "private_chat", + IsDirect: portal.RoomType == database.RoomTypeDM, + PowerLevelOverride: powerLevels, + BeeperLocalRoomID: portal.Bridge.Matrix.GenerateDeterministicRoomID(portal.PortalKey), + } + autoJoinInvites := portal.Bridge.Matrix.GetCapabilities().AutoJoinInvites + if autoJoinInvites { + req.BeeperInitialMembers = initialMembers + // TODO remove this after initial_members is supported in hungryserv + req.BeeperAutoJoinInvites = true + req.Invite = initialMembers + } + if portal.RoomType == database.RoomTypeSpace { + req.CreationContent["type"] = event.RoomTypeSpace + } + bridgeInfoStateKey, bridgeInfo := portal.getBridgeInfo() + roomFeatures := source.Client.GetCapabilities(cancellableCtx, portal) + portal.CapState = database.CapabilityState{ + Source: source.ID, + ID: roomFeatures.GetID(), + } + + req.InitialState = append(req.InitialState, &event.Event{ + Type: event.StateElementFunctionalMembers, + Content: event.Content{Parsed: &event.ElementFunctionalMembersContent{ + ServiceMembers: append(extraFunctionalMembers, portal.Bridge.Bot.GetMXID()), + }}, + }, &event.Event{ + StateKey: &bridgeInfoStateKey, + Type: event.StateHalfShotBridge, + Content: event.Content{Parsed: &bridgeInfo}, + }, &event.Event{ + StateKey: &bridgeInfoStateKey, + Type: event.StateBridge, + Content: event.Content{Parsed: &bridgeInfo}, + }, &event.Event{ + StateKey: &bridgeInfoStateKey, + Type: event.StateBeeperRoomFeatures, + Content: event.Content{Parsed: roomFeatures}, + }, &event.Event{ + Type: event.StateTopic, + Content: event.Content{ + Parsed: &event.TopicEventContent{Topic: portal.Topic}, + Raw: map[string]any{ + "com.beeper.exclude_from_timeline": true, + }, + }, + }) + if roomFeatures.DisappearingTimer != nil { + req.InitialState = append(req.InitialState, &event.Event{ + Type: event.StateBeeperDisappearingTimer, + Content: event.Content{ + Parsed: portal.Disappear.ToEventContent(), + Raw: map[string]any{ + "com.beeper.exclude_from_timeline": true, + }, + }, + }) + portal.CapState.Flags |= database.CapStateFlagDisappearingTimerSet + } + if portal.Name != "" { + req.InitialState = append(req.InitialState, &event.Event{ + Type: event.StateRoomName, + Content: event.Content{ + Parsed: &event.RoomNameEventContent{Name: portal.Name}, + Raw: map[string]any{ + "com.beeper.exclude_from_timeline": true, + }, + }, + }) + } + if portal.AvatarMXC != "" { + req.InitialState = append(req.InitialState, &event.Event{ + Type: event.StateRoomAvatar, + Content: event.Content{ + Parsed: &event.RoomAvatarEventContent{URL: portal.AvatarMXC}, + Raw: map[string]any{ + "com.beeper.exclude_from_timeline": true, + }, + }, + }) + } + if portal.Parent != nil && portal.Parent.MXID != "" { + req.InitialState = append(req.InitialState, &event.Event{ + StateKey: ptr.Ptr(portal.Parent.MXID.String()), + Type: event.StateSpaceParent, + Content: event.Content{Parsed: &event.SpaceParentEventContent{ + Via: []string{portal.Bridge.Matrix.ServerName()}, + Canonical: true, + }}, + }) + } + if info.JoinRule != nil { + req.InitialState = append(req.InitialState, &event.Event{ + Type: event.StateJoinRules, + Content: event.Content{Parsed: info.JoinRule}, + }) + } + if cancellableCtx.Err() != nil { + return cancellableCtx.Err() + } + roomID, err := portal.Bridge.Bot.CreateRoom(ctx, &req) + if err != nil { + log.Err(err).Msg("Failed to create Matrix room") + return err + } + log.Info().Stringer("room_id", roomID).Msg("Matrix room created") + portal.AvatarSet = true + portal.TopicSet = true + portal.NameSet = true + portal.MXID = roomID + portal.RoomCreated.Set() + portal.Bridge.cacheLock.Lock() + portal.Bridge.portalsByMXID[roomID] = portal + portal.Bridge.cacheLock.Unlock() + portal.updateLogger() + err = portal.Save(ctx) + if err != nil { + log.Err(err).Msg("Failed to save portal to database after creating Matrix room") + return err + } + if info.CanBackfill && portal.RoomType != database.RoomTypeSpace { + err = portal.Bridge.DB.BackfillTask.Upsert(ctx, &database.BackfillTask{ + PortalKey: portal.PortalKey, + UserLoginID: source.ID, + NextDispatchMinTS: time.Now().Add(BackfillMinBackoffAfterRoomCreate), + }) + if err != nil { + log.Err(err).Msg("Failed to create backfill queue task after creating room") + } + portal.Bridge.WakeupBackfillQueue() + } + withoutCancelCtx := zerolog.Ctx(ctx).WithContext(portal.Bridge.BackgroundCtx) + if portal.Parent != nil { + if portal.Parent.MXID != "" { + portal.addToParentSpaceAndSave(ctx, true) + } else { + log.Info().Msg("Parent portal doesn't exist, creating in background") + go portal.createParentAndAddToSpace(withoutCancelCtx, source) + } + } + portal.updateUserLocalInfo(ctx, info.UserLocal, source, true) + if !autoJoinInvites { + if info.Members == nil { + dp := source.User.DoublePuppet(ctx) + if dp != nil { + err = dp.EnsureJoined(ctx, portal.MXID) + if err != nil { + log.Err(err).Msg("Failed to ensure user is joined to room after creation") + } + } + } else { + err = portal.syncParticipants(ctx, info.Members, source, nil, time.Time{}) + if err != nil { + log.Err(err).Msg("Failed to sync participants after room creation") + } + } + } + portal.addToUserSpaces(ctx) + if info.CanBackfill && + portal.Bridge.Config.Backfill.Enabled && + portal.RoomType != database.RoomTypeSpace && + !portal.Bridge.Background { + portal.doForwardBackfill(ctx, source, nil, backfillBundle) + } + return nil +} + +func (portal *Portal) addToUserSpaces(ctx context.Context) { + if portal.Parent != nil { + return + } + log := zerolog.Ctx(ctx) + withoutCancelCtx := log.WithContext(portal.Bridge.BackgroundCtx) + if portal.Receiver != "" { + login := portal.Bridge.GetCachedUserLoginByID(portal.Receiver) + if login != nil { + up, err := portal.Bridge.DB.UserPortal.GetOrCreate(ctx, login.UserLogin, portal.PortalKey) + if err != nil { + log.Err(err).Msg("Failed to get user portal to add portal to spaces") + } else { + login.inPortalCache.Remove(portal.PortalKey) + go login.tryAddPortalToSpace(withoutCancelCtx, portal, up.CopyWithoutValues()) + } + } + } else { + userPortals, err := portal.Bridge.DB.UserPortal.GetAllInPortal(ctx, portal.PortalKey) + if err != nil { + log.Err(err).Msg("Failed to get user logins in portal to add portal to spaces") + } else { + for _, up := range userPortals { + login := portal.Bridge.GetCachedUserLoginByID(up.LoginID) + if login != nil { + login.inPortalCache.Remove(portal.PortalKey) + go login.tryAddPortalToSpace(withoutCancelCtx, portal, up.CopyWithoutValues()) + } + } + } + } +} + +func (portal *Portal) Delete(ctx context.Context) error { + if portal.deleted.IsSet() { + return nil + } + portal.removeInPortalCache(ctx) + err := portal.safeDBDelete(ctx) + if err != nil { + return err + } + portal.Bridge.cacheLock.Lock() + defer portal.Bridge.cacheLock.Unlock() + portal.unlockedDeleteCache() + return nil +} + +func (portal *Portal) safeDBDelete(ctx context.Context) error { + err := portal.Bridge.DB.Message.DeleteInChunks(ctx, portal.PortalKey) + if err != nil { + return fmt.Errorf("failed to delete messages in portal: %w", err) + } + // TODO delete child portals? + return portal.Bridge.DB.Portal.Delete(ctx, portal.PortalKey) +} + +func (portal *Portal) RemoveMXID(ctx context.Context) error { + if portal.MXID == "" { + return nil + } + portal.MXID = "" + portal.RoomCreated.Clear() + err := portal.Save(ctx) + if err != nil { + return err + } + portal.Bridge.cacheLock.Lock() + defer portal.Bridge.cacheLock.Unlock() + delete(portal.Bridge.portalsByMXID, portal.MXID) + return nil +} + +func (portal *Portal) removeInPortalCache(ctx context.Context) { + if portal.Receiver != "" { + login := portal.Bridge.GetCachedUserLoginByID(portal.Receiver) + if login != nil { + login.inPortalCache.Remove(portal.PortalKey) + } + return + } + userPortals, err := portal.Bridge.DB.UserPortal.GetAllInPortal(ctx, portal.PortalKey) + if err != nil { + zerolog.Ctx(ctx).Err(err).Msg("Failed to get user logins in portal to remove user portal cache") + } else { + for _, up := range userPortals { + login := portal.Bridge.GetCachedUserLoginByID(up.LoginID) + if login != nil { + login.inPortalCache.Remove(portal.PortalKey) + } + } + } +} + +func (portal *Portal) unlockedDelete(ctx context.Context) error { + if portal.deleted.IsSet() { + return nil + } + err := portal.safeDBDelete(ctx) + if err != nil { + return err + } + portal.unlockedDeleteCache() + return nil +} + +func (portal *Portal) unlockedDeleteCache() { + if portal.deleted.IsSet() { + return + } + delete(portal.Bridge.portalsByKey, portal.PortalKey) + if portal.MXID != "" { + delete(portal.Bridge.portalsByMXID, portal.MXID) + } + portal.deleted.Set() + if portal.events != nil { + // TODO there's a small risk of this racing with a queueEvent call + close(portal.events) + } +} + +func (portal *Portal) Save(ctx context.Context) error { + return portal.Bridge.DB.Portal.Update(ctx, portal.Portal) +} + +func (portal *Portal) SetRelay(ctx context.Context, relay *UserLogin) error { + if portal.Receiver != "" && relay.ID != portal.Receiver { + return fmt.Errorf("can't set non-receiver login as relay") + } + portal.Relay = relay + if relay == nil { + portal.RelayLoginID = "" + } else { + portal.RelayLoginID = relay.ID + } + err := portal.Save(ctx) + if err != nil { + return err + } + return nil +} + +func (portal *Portal) PerMessageProfileForSender(ctx context.Context, sender networkid.UserID) (profile event.BeeperPerMessageProfile, err error) { + var ghost *Ghost + ghost, err = portal.Bridge.GetGhostByID(ctx, sender) + if err != nil { + return + } + profile.ID = string(ghost.Intent.GetMXID()) + profile.Displayname = ghost.Name + if ghost.AvatarMXC != "" { + profile.AvatarURL = &ghost.AvatarMXC + } + return +} diff --git a/bridgev2/portalbackfill.go b/bridgev2/portalbackfill.go new file mode 100644 index 00000000..879f07ae --- /dev/null +++ b/bridgev2/portalbackfill.go @@ -0,0 +1,584 @@ +// Copyright (c) 2024 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package bridgev2 + +import ( + "context" + "fmt" + "slices" + "time" + + "github.com/rs/zerolog" + "go.mau.fi/util/ptr" + "go.mau.fi/util/variationselector" + + "maunium.net/go/mautrix" + "maunium.net/go/mautrix/bridgev2/database" + "maunium.net/go/mautrix/bridgev2/networkid" + "maunium.net/go/mautrix/event" + "maunium.net/go/mautrix/id" +) + +func (portal *Portal) doForwardBackfill(ctx context.Context, source *UserLogin, lastMessage *database.Message, bundledData any) { + log := zerolog.Ctx(ctx).With().Str("action", "forward backfill").Logger() + ctx = log.WithContext(ctx) + api, ok := source.Client.(BackfillingNetworkAPI) + if !ok { + log.Debug().Msg("Network API does not support backfilling") + return + } + logEvt := log.Info() + var limit int + if lastMessage != nil { + logEvt = logEvt.Str("latest_message_id", string(lastMessage.ID)) + limit = portal.Bridge.Config.Backfill.MaxCatchupMessages + } else { + logEvt = logEvt.Str("latest_message_id", "") + limit = portal.Bridge.Config.Backfill.MaxInitialMessages + } + if limit <= 0 { + logEvt.Discard().Send() + return + } + logEvt.Msg("Fetching messages for forward backfill") + resp, err := api.FetchMessages(ctx, FetchMessagesParams{ + Portal: portal, + ThreadRoot: "", + Forward: true, + AnchorMessage: lastMessage, + Count: limit, + BundledData: bundledData, + }) + if err != nil { + log.Err(err).Msg("Failed to fetch messages for forward backfill") + return + } else if resp == nil { + log.Debug().Msg("Didn't get backfill response") + return + } else if len(resp.Messages) == 0 { + log.Debug().Msg("No messages to backfill") + if resp.CompleteCallback != nil { + resp.CompleteCallback() + } + return + } + log.Debug(). + Int("message_count", len(resp.Messages)). + Bool("mark_read", resp.MarkRead). + Bool("aggressive_deduplication", resp.AggressiveDeduplication). + Msg("Fetched messages for forward backfill, deduplicating before sending") + // TODO mark backfill queue task as done if last message is nil (-> room was empty) and HasMore is false? + resp.Messages = portal.cutoffMessages(ctx, resp.Messages, resp.AggressiveDeduplication, true, lastMessage) + if len(resp.Messages) == 0 { + log.Warn().Msg("No messages left to backfill after cutting off old messages") + if resp.CompleteCallback != nil { + resp.CompleteCallback() + } + return + } + portal.sendBackfill(ctx, source, resp.Messages, true, resp.MarkRead, false, resp.CompleteCallback) +} + +func (portal *Portal) DoBackwardsBackfill(ctx context.Context, source *UserLogin, task *database.BackfillTask) error { + log := zerolog.Ctx(ctx) + api, ok := source.Client.(BackfillingNetworkAPI) + if !ok { + return fmt.Errorf("network API does not support backfilling") + } + firstMessage, err := portal.Bridge.DB.Message.GetFirstPortalMessage(ctx, portal.PortalKey) + if err != nil { + return fmt.Errorf("failed to get first portal message: %w", err) + } + logEvt := log.Info(). + Str("cursor", string(task.Cursor)). + Str("task_oldest_message_id", string(task.OldestMessageID)). + Int("current_batch_count", task.BatchCount) + if firstMessage != nil { + logEvt = logEvt.Str("db_oldest_message_id", string(firstMessage.ID)) + } else { + logEvt = logEvt.Str("db_oldest_message_id", "") + } + logEvt.Msg("Fetching messages for backward backfill") + resp, err := api.FetchMessages(ctx, FetchMessagesParams{ + Portal: portal, + ThreadRoot: "", + Forward: false, + Cursor: task.Cursor, + AnchorMessage: firstMessage, + Count: portal.Bridge.Config.Backfill.Queue.BatchSize, + Task: task, + }) + if err != nil { + return fmt.Errorf("failed to fetch messages for backward backfill: %w", err) + } else if resp == nil { + log.Debug().Msg("Didn't get backfill response, marking task as done") + task.IsDone = true + return nil + } + log.Debug(). + Str("new_cursor", string(resp.Cursor)). + Bool("has_more", resp.HasMore). + Int("message_count", len(resp.Messages)). + Msg("Fetched messages for backward backfill") + task.Cursor = resp.Cursor + if !resp.HasMore { + task.IsDone = true + } + if len(resp.Messages) == 0 { + if !resp.HasMore { + log.Debug().Msg("No messages to backfill, marking backfill task as done") + } else { + log.Warn().Msg("No messages to backfill, but HasMore is true") + } + if resp.CompleteCallback != nil { + resp.CompleteCallback() + } + return nil + } + resp.Messages = portal.cutoffMessages(ctx, resp.Messages, resp.AggressiveDeduplication, false, firstMessage) + if len(resp.Messages) == 0 { + if resp.CompleteCallback != nil { + resp.CompleteCallback() + } + return fmt.Errorf("no messages left to backfill after cutting off too new messages") + } + portal.sendBackfill(ctx, source, resp.Messages, false, resp.MarkRead, false, resp.CompleteCallback) + if len(resp.Messages) > 0 { + task.OldestMessageID = resp.Messages[0].ID + } + return nil +} + +func (portal *Portal) fetchThreadBackfill(ctx context.Context, source *UserLogin, anchor *database.Message) *FetchMessagesResponse { + log := zerolog.Ctx(ctx) + resp, err := source.Client.(BackfillingNetworkAPI).FetchMessages(ctx, FetchMessagesParams{ + Portal: portal, + ThreadRoot: anchor.ID, + Forward: true, + AnchorMessage: anchor, + Count: portal.Bridge.Config.Backfill.Threads.MaxInitialMessages, + }) + if err != nil { + log.Err(err).Msg("Failed to fetch messages for thread backfill") + return nil + } else if resp == nil { + log.Debug().Msg("Didn't get backfill response") + return nil + } else if len(resp.Messages) == 0 { + log.Debug().Msg("No messages to backfill") + return nil + } + resp.Messages = portal.cutoffMessages(ctx, resp.Messages, resp.AggressiveDeduplication, true, anchor) + if len(resp.Messages) == 0 { + log.Warn().Msg("No messages left to backfill after cutting off old messages") + if resp.CompleteCallback != nil { + resp.CompleteCallback() + } + return nil + } + return resp +} + +func (portal *Portal) doThreadBackfill(ctx context.Context, source *UserLogin, threadID networkid.MessageID) { + log := zerolog.Ctx(ctx).With(). + Str("subaction", "thread backfill"). + Str("thread_id", string(threadID)). + Logger() + ctx = log.WithContext(ctx) + log.Info().Msg("Backfilling thread inside other backfill") + anchorMessage, err := portal.Bridge.DB.Message.GetLastThreadMessage(ctx, portal.PortalKey, threadID) + if err != nil { + log.Err(err).Msg("Failed to get last thread message") + return + } else if anchorMessage == nil { + log.Warn().Msg("No messages found in thread?") + return + } + resp := portal.fetchThreadBackfill(ctx, source, anchorMessage) + if resp != nil { + portal.sendBackfill(ctx, source, resp.Messages, true, resp.MarkRead, true, resp.CompleteCallback) + } +} + +func (portal *Portal) cutoffMessages(ctx context.Context, messages []*BackfillMessage, aggressiveDedup, forward bool, lastMessage *database.Message) []*BackfillMessage { + if lastMessage == nil { + return messages + } + if forward { + cutoff := -1 + for i, msg := range messages { + if msg.ID == lastMessage.ID || msg.Timestamp.Before(lastMessage.Timestamp) { + cutoff = i + } else { + break + } + } + if cutoff != -1 { + zerolog.Ctx(ctx).Debug(). + Int("cutoff_count", cutoff+1). + Int("total_count", len(messages)). + Time("last_bridged_ts", lastMessage.Timestamp). + Msg("Cutting off forward backfill messages older than latest bridged message") + messages = messages[cutoff+1:] + } + } else { + cutoff := -1 + for i := len(messages) - 1; i >= 0; i-- { + if messages[i].ID == lastMessage.ID || messages[i].Timestamp.After(lastMessage.Timestamp) { + cutoff = i + } else { + break + } + } + if cutoff != -1 { + zerolog.Ctx(ctx).Debug(). + Int("cutoff_count", len(messages)-cutoff). + Int("total_count", len(messages)). + Time("oldest_bridged_ts", lastMessage.Timestamp). + Msg("Cutting off backward backfill messages newer than oldest bridged message") + messages = messages[:cutoff] + } + } + if aggressiveDedup { + filteredMessages := messages[:0] + for _, msg := range messages { + existingMsg, err := portal.Bridge.DB.Message.GetFirstPartByID(ctx, portal.Receiver, msg.ID) + if err != nil { + zerolog.Ctx(ctx).Err(err).Str("message_id", string(msg.ID)).Msg("Failed to check for existing message") + } else if existingMsg != nil { + zerolog.Ctx(ctx).Err(err). + Str("message_id", string(msg.ID)). + Time("message_ts", msg.Timestamp). + Str("message_sender", string(msg.Sender.Sender)). + Msg("Ignoring duplicate message in backfill") + continue + } + if forward && msg.TxnID != "" { + wasPending, _ := portal.checkPendingMessage(ctx, msg) + if wasPending { + zerolog.Ctx(ctx).Err(err). + Str("transaction_id", string(msg.TxnID)). + Str("message_id", string(msg.ID)). + Time("message_ts", msg.Timestamp). + Msg("Found pending message in backfill") + continue + } + } + filteredMessages = append(filteredMessages, msg) + } + messages = filteredMessages + } + return messages +} + +func (portal *Portal) sendBackfill( + ctx context.Context, + source *UserLogin, + messages []*BackfillMessage, + forceForward, + markRead, + inThread bool, + done func(), +) { + canBatchSend := portal.Bridge.Matrix.GetCapabilities().BatchSending + unreadThreshold := time.Duration(portal.Bridge.Config.Backfill.UnreadHoursThreshold) * time.Hour + forceMarkRead := unreadThreshold > 0 && time.Since(messages[len(messages)-1].Timestamp) > unreadThreshold + zerolog.Ctx(ctx).Info(). + Int("message_count", len(messages)). + Bool("batch_send", canBatchSend). + Bool("mark_read", markRead). + Bool("mark_read_past_threshold", forceMarkRead). + Msg("Sending backfill messages") + if canBatchSend { + portal.sendBatch(ctx, source, messages, forceForward, markRead || forceMarkRead, inThread) + } else { + portal.sendLegacyBackfill(ctx, source, messages, markRead || forceMarkRead) + } + if done != nil { + done() + } + zerolog.Ctx(ctx).Debug().Msg("Backfill finished") + if !canBatchSend && !inThread && portal.Bridge.Config.Backfill.Threads.MaxInitialMessages > 0 { + for _, msg := range messages { + if msg.ShouldBackfillThread { + portal.doThreadBackfill(ctx, source, msg.ID) + } + } + } +} + +type compileBatchOutput struct { + PrevThreadEvents map[networkid.MessageID]id.EventID + + Events []*event.Event + Extras []*MatrixSendExtra + + DBMessages []*database.Message + DBReactions []*database.Reaction + Disappear []*database.DisappearingMessage +} + +func (portal *Portal) compileBatchMessage(ctx context.Context, source *UserLogin, msg *BackfillMessage, out *compileBatchOutput, inThread bool) { + if len(msg.Parts) == 0 { + return + } + intent, ok := portal.GetIntentFor(ctx, msg.Sender, source, RemoteEventMessage) + if !ok { + return + } + replyTo, threadRoot, prevThreadEvent := portal.getRelationMeta( + ctx, msg.ID, msg.ConvertedMessage, true, + ) + if threadRoot != nil && out.PrevThreadEvents[*msg.ThreadRoot] != "" { + prevThreadEvent.MXID = out.PrevThreadEvents[*msg.ThreadRoot] + } + var partIDs []networkid.PartID + partMap := make(map[networkid.PartID]*database.Message, len(msg.Parts)) + var firstPart *database.Message + for i, part := range msg.Parts { + partIDs = append(partIDs, part.ID) + portal.applyRelationMeta(ctx, part.Content, replyTo, threadRoot, prevThreadEvent) + part.Content.BeeperDisappearingTimer = msg.Disappear.ToEventContent() + evtID := portal.Bridge.Matrix.GenerateDeterministicEventID(portal.MXID, portal.PortalKey, msg.ID, part.ID) + dbMessage := &database.Message{ + ID: msg.ID, + PartID: part.ID, + MXID: evtID, + Room: portal.PortalKey, + SenderID: msg.Sender.Sender, + SenderMXID: intent.GetMXID(), + Timestamp: msg.Timestamp, + ThreadRoot: ptr.Val(msg.ThreadRoot), + ReplyTo: ptr.Val(msg.ReplyTo), + Metadata: part.DBMetadata, + IsDoublePuppeted: intent.IsDoublePuppet(), + } + if part.DontBridge { + dbMessage.SetFakeMXID() + out.DBMessages = append(out.DBMessages, dbMessage) + continue + } + out.Events = append(out.Events, &event.Event{ + Sender: intent.GetMXID(), + Type: part.Type, + Timestamp: msg.Timestamp.UnixMilli(), + ID: evtID, + RoomID: portal.MXID, + Content: event.Content{ + Parsed: part.Content, + Raw: part.Extra, + }, + }) + if firstPart == nil { + firstPart = dbMessage + } + partMap[part.ID] = dbMessage + out.Extras = append(out.Extras, &MatrixSendExtra{MessageMeta: dbMessage, StreamOrder: msg.StreamOrder, PartIndex: i}) + out.DBMessages = append(out.DBMessages, dbMessage) + if prevThreadEvent != nil { + prevThreadEvent.MXID = evtID + out.PrevThreadEvents[*msg.ThreadRoot] = evtID + } + if msg.Disappear.Type != event.DisappearingTypeNone { + if msg.Disappear.Type == event.DisappearingTypeAfterSend && msg.Disappear.DisappearAt.IsZero() { + msg.Disappear.DisappearAt = msg.Timestamp.Add(msg.Disappear.Timer) + } + out.Disappear = append(out.Disappear, &database.DisappearingMessage{ + RoomID: portal.MXID, + EventID: evtID, + Timestamp: msg.Timestamp, + DisappearingSetting: msg.Disappear, + }) + } + } + slices.Sort(partIDs) + for _, reaction := range msg.Reactions { + if reaction == nil { + continue + } + reactionIntent, ok := portal.GetIntentFor(ctx, reaction.Sender, source, RemoteEventReactionRemove) + if !ok { + continue + } + if reaction.TargetPart == nil { + reaction.TargetPart = &partIDs[0] + } + if reaction.Timestamp.IsZero() { + reaction.Timestamp = msg.Timestamp.Add(10 * time.Millisecond) + } + //lint:ignore SA4006 it's a todo + targetPart, ok := partMap[*reaction.TargetPart] + if !ok { + // TODO warning log and/or skip reaction? + } + reactionMXID := portal.Bridge.Matrix.GenerateReactionEventID(portal.MXID, targetPart, reaction.Sender.Sender, reaction.EmojiID) + dbReaction := &database.Reaction{ + Room: portal.PortalKey, + MessageID: msg.ID, + MessagePartID: *reaction.TargetPart, + SenderID: reaction.Sender.Sender, + EmojiID: reaction.EmojiID, + MXID: reactionMXID, + Timestamp: reaction.Timestamp, + Emoji: reaction.Emoji, + Metadata: reaction.DBMetadata, + } + out.Events = append(out.Events, &event.Event{ + Sender: reactionIntent.GetMXID(), + Type: event.EventReaction, + Timestamp: reaction.Timestamp.UnixMilli(), + ID: reactionMXID, + RoomID: portal.MXID, + Content: event.Content{ + Parsed: &event.ReactionEventContent{ + RelatesTo: event.RelatesTo{ + Type: event.RelAnnotation, + EventID: portal.Bridge.Matrix.GenerateDeterministicEventID(portal.MXID, portal.PortalKey, msg.ID, *reaction.TargetPart), + Key: variationselector.Add(reaction.Emoji), + }, + }, + Raw: reaction.ExtraContent, + }, + }) + out.DBReactions = append(out.DBReactions, dbReaction) + out.Extras = append(out.Extras, &MatrixSendExtra{ReactionMeta: dbReaction}) + } + if firstPart != nil && !inThread && portal.Bridge.Config.Backfill.Threads.MaxInitialMessages > 0 && msg.ShouldBackfillThread { + portal.fetchThreadInsideBatch(ctx, source, firstPart, out) + } +} + +func (portal *Portal) fetchThreadInsideBatch(ctx context.Context, source *UserLogin, dbMsg *database.Message, out *compileBatchOutput) { + log := zerolog.Ctx(ctx).With(). + Str("subaction", "thread backfill in batch"). + Str("thread_id", string(dbMsg.ID)). + Logger() + ctx = log.WithContext(ctx) + resp := portal.fetchThreadBackfill(ctx, source, dbMsg) + if resp != nil { + for _, msg := range resp.Messages { + portal.compileBatchMessage(ctx, source, msg, out, true) + } + } +} + +func (portal *Portal) sendBatch(ctx context.Context, source *UserLogin, messages []*BackfillMessage, forceForward, markRead, inThread bool) { + out := &compileBatchOutput{ + PrevThreadEvents: make(map[networkid.MessageID]id.EventID), + Events: make([]*event.Event, 0, len(messages)), + Extras: make([]*MatrixSendExtra, 0, len(messages)), + DBMessages: make([]*database.Message, 0, len(messages)), + DBReactions: make([]*database.Reaction, 0), + Disappear: make([]*database.DisappearingMessage, 0), + } + for _, msg := range messages { + portal.compileBatchMessage(ctx, source, msg, out, inThread) + } + req := &mautrix.ReqBeeperBatchSend{ + ForwardIfNoMessages: !forceForward, + Forward: forceForward, + SendNotification: !markRead && forceForward && !inThread, + Events: out.Events, + } + if markRead { + req.MarkReadBy = source.UserMXID + } + _, err := portal.Bridge.Matrix.BatchSend(ctx, portal.MXID, req, out.Extras) + if err != nil { + zerolog.Ctx(ctx).Err(err).Msg("Failed to send backfill messages") + } + if len(out.Disappear) > 0 { + // TODO mass insert disappearing messages + go func() { + for _, msg := range out.Disappear { + portal.Bridge.DisappearLoop.Add(ctx, msg) + } + }() + } + // TODO mass insert db messages + for _, msg := range out.DBMessages { + err = portal.Bridge.DB.Message.Insert(ctx, msg) + if err != nil { + zerolog.Ctx(ctx).Err(err). + Str("message_id", string(msg.ID)). + Str("part_id", string(msg.PartID)). + Str("sender_id", string(msg.SenderID)). + Str("portal_id", string(msg.Room.ID)). + Str("portal_receiver", string(msg.Room.Receiver)). + Msg("Failed to insert backfilled message to database") + } + } + // TODO mass insert db reactions + for _, react := range out.DBReactions { + err = portal.Bridge.DB.Reaction.Upsert(ctx, react) + if err != nil { + zerolog.Ctx(ctx).Err(err). + Str("message_id", string(react.MessageID)). + Str("part_id", string(react.MessagePartID)). + Str("sender_id", string(react.SenderID)). + Str("portal_id", string(react.Room.ID)). + Str("portal_receiver", string(react.Room.Receiver)). + Msg("Failed to insert backfilled reaction to database") + } + } +} + +func (portal *Portal) sendLegacyBackfill(ctx context.Context, source *UserLogin, messages []*BackfillMessage, markRead bool) { + var lastPart id.EventID + for _, msg := range messages { + intent, ok := portal.GetIntentFor(ctx, msg.Sender, source, RemoteEventMessage) + if !ok { + continue + } + dbMessages, _ := portal.sendConvertedMessage(ctx, msg.ID, intent, msg.Sender.Sender, msg.ConvertedMessage, msg.Timestamp, msg.StreamOrder, func(z *zerolog.Event) *zerolog.Event { + return z. + Str("message_id", string(msg.ID)). + Any("sender_id", msg.Sender). + Time("message_ts", msg.Timestamp) + }) + if len(dbMessages) > 0 { + lastPart = dbMessages[len(dbMessages)-1].MXID + for _, reaction := range msg.Reactions { + reactionIntent, ok := portal.GetIntentFor(ctx, reaction.Sender, source, RemoteEventReaction) + if !ok { + continue + } + targetPart := dbMessages[0] + if reaction.TargetPart != nil { + targetPartIdx := slices.IndexFunc(dbMessages, func(dbMsg *database.Message) bool { + return dbMsg.PartID == *reaction.TargetPart + }) + if targetPartIdx != -1 { + targetPart = dbMessages[targetPartIdx] + } else { + // TODO warning log and/or skip reaction? + } + } + portal.sendConvertedReaction( + ctx, reaction.Sender.Sender, reactionIntent, targetPart, reaction.EmojiID, reaction.Emoji, + reaction.Timestamp, reaction.DBMetadata, reaction.ExtraContent, + func(z *zerolog.Event) *zerolog.Event { + return z. + Str("target_message_id", string(msg.ID)). + Str("target_part_id", string(targetPart.PartID)). + Any("reaction_sender_id", reaction.Sender). + Time("reaction_ts", reaction.Timestamp) + }, + ) + } + } + } + if markRead { + dp := source.User.DoublePuppet(ctx) + if dp != nil { + err := dp.MarkRead(ctx, portal.MXID, lastPart, time.Now()) + if err != nil { + zerolog.Ctx(ctx).Err(err).Msg("Failed to mark room as read after backfill") + } + } + } +} diff --git a/bridgev2/portalinternal.go b/bridgev2/portalinternal.go new file mode 100644 index 00000000..4c7e2447 --- /dev/null +++ b/bridgev2/portalinternal.go @@ -0,0 +1,402 @@ +// GENERATED BY portalinternal_generate.go; DO NOT EDIT + +//go:generate go run portalinternal_generate.go +//go:generate goimports -local maunium.net/go/mautrix -w portalinternal.go + +package bridgev2 + +import ( + "context" + "time" + + "github.com/rs/zerolog" + + "maunium.net/go/mautrix" + "maunium.net/go/mautrix/bridgev2/database" + "maunium.net/go/mautrix/bridgev2/networkid" + "maunium.net/go/mautrix/event" + "maunium.net/go/mautrix/id" +) + +type PortalInternals Portal + +// Deprecated: portal internals should be used carefully and only when necessary. +func (portal *Portal) Internal() *PortalInternals { + return (*PortalInternals)(portal) +} + +func (portal *PortalInternals) UpdateLogger() { + (*Portal)(portal).updateLogger() +} + +func (portal *PortalInternals) QueueEvent(ctx context.Context, evt portalEvent) EventHandlingResult { + return (*Portal)(portal).queueEvent(ctx, evt) +} + +func (portal *PortalInternals) EventLoop() { + (*Portal)(portal).eventLoop() +} + +func (portal *PortalInternals) HandleSingleEventWithDelayLogging(idx int, rawEvt any) (outerRes EventHandlingResult) { + return (*Portal)(portal).handleSingleEventWithDelayLogging(idx, rawEvt) +} + +func (portal *PortalInternals) GetEventCtxWithLog(rawEvt any, idx int) context.Context { + return (*Portal)(portal).getEventCtxWithLog(rawEvt, idx) +} + +func (portal *PortalInternals) HandleSingleEvent(ctx context.Context, rawEvt any, doneCallback func(EventHandlingResult)) { + (*Portal)(portal).handleSingleEvent(ctx, rawEvt, doneCallback) +} + +func (portal *PortalInternals) UnwrapBeeperSendState(ctx context.Context, evt *event.Event) error { + return (*Portal)(portal).unwrapBeeperSendState(ctx, evt) +} + +func (portal *PortalInternals) SendSuccessStatus(ctx context.Context, evt *event.Event, streamOrder int64, newEventID id.EventID) { + (*Portal)(portal).sendSuccessStatus(ctx, evt, streamOrder, newEventID) +} + +func (portal *PortalInternals) SendErrorStatus(ctx context.Context, evt *event.Event, err error) { + (*Portal)(portal).sendErrorStatus(ctx, evt, err) +} + +func (portal *PortalInternals) CheckConfusableName(ctx context.Context, userID id.UserID, name string) bool { + return (*Portal)(portal).checkConfusableName(ctx, userID, name) +} + +func (portal *PortalInternals) HandleMatrixEvent(ctx context.Context, sender *User, evt *event.Event, isStateRequest bool) EventHandlingResult { + return (*Portal)(portal).handleMatrixEvent(ctx, sender, evt, isStateRequest) +} + +func (portal *PortalInternals) HandleMatrixReceipts(ctx context.Context, evt *event.Event) EventHandlingResult { + return (*Portal)(portal).handleMatrixReceipts(ctx, evt) +} + +func (portal *PortalInternals) HandleMatrixReadReceipt(ctx context.Context, user *User, eventID id.EventID, receipt event.ReadReceipt) { + (*Portal)(portal).handleMatrixReadReceipt(ctx, user, eventID, receipt) +} + +func (portal *PortalInternals) CallReadReceiptHandler(ctx context.Context, login *UserLogin, rrClient ReadReceiptHandlingNetworkAPI, evt *MatrixReadReceipt, userPortal *database.UserPortal) { + (*Portal)(portal).callReadReceiptHandler(ctx, login, rrClient, evt, userPortal) +} + +func (portal *PortalInternals) HandleMatrixTyping(ctx context.Context, evt *event.Event) EventHandlingResult { + return (*Portal)(portal).handleMatrixTyping(ctx, evt) +} + +func (portal *PortalInternals) SendTypings(ctx context.Context, userIDs []id.UserID, typing bool) { + (*Portal)(portal).sendTypings(ctx, userIDs, typing) +} + +func (portal *PortalInternals) PeriodicTypingUpdater() { + (*Portal)(portal).periodicTypingUpdater() +} + +func (portal *PortalInternals) CheckMessageContentCaps(caps *event.RoomFeatures, content *event.MessageEventContent) error { + return (*Portal)(portal).checkMessageContentCaps(caps, content) +} + +func (portal *PortalInternals) ParseInputTransactionID(origSender *OrigSender, evt *event.Event) networkid.RawTransactionID { + return (*Portal)(portal).parseInputTransactionID(origSender, evt) +} + +func (portal *PortalInternals) HandleMatrixMessage(ctx context.Context, sender *UserLogin, origSender *OrigSender, evt *event.Event) EventHandlingResult { + return (*Portal)(portal).handleMatrixMessage(ctx, sender, origSender, evt) +} + +func (portal *PortalInternals) PendingMessageTimeoutLoop(ctx context.Context, cfg *OutgoingTimeoutConfig) { + (*Portal)(portal).pendingMessageTimeoutLoop(ctx, cfg) +} + +func (portal *PortalInternals) CheckPendingMessages(ctx context.Context, cfg *OutgoingTimeoutConfig) { + (*Portal)(portal).checkPendingMessages(ctx, cfg) +} + +func (portal *PortalInternals) HandleMatrixEdit(ctx context.Context, sender *UserLogin, origSender *OrigSender, evt *event.Event, content *event.MessageEventContent, caps *event.RoomFeatures) EventHandlingResult { + return (*Portal)(portal).handleMatrixEdit(ctx, sender, origSender, evt, content, caps) +} + +func (portal *PortalInternals) HandleMatrixReaction(ctx context.Context, sender *UserLogin, evt *event.Event) EventHandlingResult { + return (*Portal)(portal).handleMatrixReaction(ctx, sender, evt) +} + +func (portal *PortalInternals) GetTargetUser(ctx context.Context, userID id.UserID) (GhostOrUserLogin, error) { + return (*Portal)(portal).getTargetUser(ctx, userID) +} + +func (portal *PortalInternals) HandleMatrixDeleteChat(ctx context.Context, sender *UserLogin, origSender *OrigSender, evt *event.Event) EventHandlingResult { + return (*Portal)(portal).handleMatrixDeleteChat(ctx, sender, origSender, evt) +} + +func (portal *PortalInternals) HandleMatrixMembership(ctx context.Context, sender *UserLogin, origSender *OrigSender, evt *event.Event, isStateRequest bool) EventHandlingResult { + return (*Portal)(portal).handleMatrixMembership(ctx, sender, origSender, evt, isStateRequest) +} + +func (portal *PortalInternals) HandleMatrixPowerLevels(ctx context.Context, sender *UserLogin, origSender *OrigSender, evt *event.Event, isStateRequest bool) EventHandlingResult { + return (*Portal)(portal).handleMatrixPowerLevels(ctx, sender, origSender, evt, isStateRequest) +} + +func (portal *PortalInternals) HandleMatrixTombstone(ctx context.Context, evt *event.Event) EventHandlingResult { + return (*Portal)(portal).handleMatrixTombstone(ctx, evt) +} + +func (portal *PortalInternals) UpdateInfoAfterTombstone(ctx context.Context, senderUser *User) { + (*Portal)(portal).updateInfoAfterTombstone(ctx, senderUser) +} + +func (portal *PortalInternals) HandleMatrixRedaction(ctx context.Context, sender *UserLogin, origSender *OrigSender, evt *event.Event) EventHandlingResult { + return (*Portal)(portal).handleMatrixRedaction(ctx, sender, origSender, evt) +} + +func (portal *PortalInternals) HandleRemoteEvent(ctx context.Context, source *UserLogin, evtType RemoteEventType, evt RemoteEvent) (res EventHandlingResult) { + return (*Portal)(portal).handleRemoteEvent(ctx, source, evtType, evt) +} + +func (portal *PortalInternals) EnsureFunctionalMember(ctx context.Context, ghost *Ghost) { + (*Portal)(portal).ensureFunctionalMember(ctx, ghost) +} + +func (portal *PortalInternals) GetIntentAndUserMXIDFor(ctx context.Context, sender EventSender, source *UserLogin, otherLogins []*UserLogin, evtType RemoteEventType) (intent MatrixAPI, extraUserID id.UserID, err error) { + return (*Portal)(portal).getIntentAndUserMXIDFor(ctx, sender, source, otherLogins, evtType) +} + +func (portal *PortalInternals) GetRelationMeta(ctx context.Context, currentMsgID networkid.MessageID, currentMsg *ConvertedMessage, isBatchSend bool) (replyTo, threadRoot, prevThreadEvent *database.Message) { + return (*Portal)(portal).getRelationMeta(ctx, currentMsgID, currentMsg, isBatchSend) +} + +func (portal *PortalInternals) ApplyRelationMeta(ctx context.Context, content *event.MessageEventContent, replyTo, threadRoot, prevThreadEvent *database.Message) { + (*Portal)(portal).applyRelationMeta(ctx, content, replyTo, threadRoot, prevThreadEvent) +} + +func (portal *PortalInternals) SendConvertedMessage(ctx context.Context, id networkid.MessageID, intent MatrixAPI, senderID networkid.UserID, converted *ConvertedMessage, ts time.Time, streamOrder int64, logContext func(*zerolog.Event) *zerolog.Event) ([]*database.Message, EventHandlingResult) { + return (*Portal)(portal).sendConvertedMessage(ctx, id, intent, senderID, converted, ts, streamOrder, logContext) +} + +func (portal *PortalInternals) CheckPendingMessage(ctx context.Context, evt RemoteMessage) (bool, *database.Message) { + return (*Portal)(portal).checkPendingMessage(ctx, evt) +} + +func (portal *PortalInternals) HandleRemoteUpsert(ctx context.Context, source *UserLogin, evt RemoteMessageUpsert, existing []*database.Message) (handleRes EventHandlingResult, continueHandling bool) { + return (*Portal)(portal).handleRemoteUpsert(ctx, source, evt, existing) +} + +func (portal *PortalInternals) HandleRemoteMessage(ctx context.Context, source *UserLogin, evt RemoteMessage) (res EventHandlingResult) { + return (*Portal)(portal).handleRemoteMessage(ctx, source, evt) +} + +func (portal *PortalInternals) SendRemoteErrorNotice(ctx context.Context, intent MatrixAPI, err error, ts time.Time, evtTypeName string) { + (*Portal)(portal).sendRemoteErrorNotice(ctx, intent, err, ts, evtTypeName) +} + +func (portal *PortalInternals) HandleRemoteEdit(ctx context.Context, source *UserLogin, evt RemoteEdit) EventHandlingResult { + return (*Portal)(portal).handleRemoteEdit(ctx, source, evt) +} + +func (portal *PortalInternals) SendConvertedEdit(ctx context.Context, targetID networkid.MessageID, senderID networkid.UserID, converted *ConvertedEdit, intent MatrixAPI, ts time.Time, streamOrder int64) EventHandlingResult { + return (*Portal)(portal).sendConvertedEdit(ctx, targetID, senderID, converted, intent, ts, streamOrder) +} + +func (portal *PortalInternals) GetTargetMessagePart(ctx context.Context, evt RemoteEventWithTargetMessage) (*database.Message, error) { + return (*Portal)(portal).getTargetMessagePart(ctx, evt) +} + +func (portal *PortalInternals) GetTargetReaction(ctx context.Context, evt RemoteReactionRemove) (*database.Reaction, error) { + return (*Portal)(portal).getTargetReaction(ctx, evt) +} + +func (portal *PortalInternals) HandleRemoteReactionSync(ctx context.Context, source *UserLogin, evt RemoteReactionSync) EventHandlingResult { + return (*Portal)(portal).handleRemoteReactionSync(ctx, source, evt) +} + +func (portal *PortalInternals) HandleRemoteReaction(ctx context.Context, source *UserLogin, evt RemoteReaction) EventHandlingResult { + return (*Portal)(portal).handleRemoteReaction(ctx, source, evt) +} + +func (portal *PortalInternals) SendConvertedReaction(ctx context.Context, senderID networkid.UserID, intent MatrixAPI, targetMessage *database.Message, emojiID networkid.EmojiID, emoji string, ts time.Time, dbMetadata any, extraContent map[string]any, logContext func(*zerolog.Event) *zerolog.Event) EventHandlingResult { + return (*Portal)(portal).sendConvertedReaction(ctx, senderID, intent, targetMessage, emojiID, emoji, ts, dbMetadata, extraContent, logContext) +} + +func (portal *PortalInternals) GetIntentForMXID(ctx context.Context, userID id.UserID) (MatrixAPI, error) { + return (*Portal)(portal).getIntentForMXID(ctx, userID) +} + +func (portal *PortalInternals) HandleRemoteReactionRemove(ctx context.Context, source *UserLogin, evt RemoteReactionRemove) EventHandlingResult { + return (*Portal)(portal).handleRemoteReactionRemove(ctx, source, evt) +} + +func (portal *PortalInternals) HandleRemoteMessageRemove(ctx context.Context, source *UserLogin, evt RemoteMessageRemove) EventHandlingResult { + return (*Portal)(portal).handleRemoteMessageRemove(ctx, source, evt) +} + +func (portal *PortalInternals) RedactMessageParts(ctx context.Context, parts []*database.Message, intent MatrixAPI, ts time.Time) EventHandlingResult { + return (*Portal)(portal).redactMessageParts(ctx, parts, intent, ts) +} + +func (portal *PortalInternals) HandleRemoteReadReceipt(ctx context.Context, source *UserLogin, evt RemoteReadReceipt) EventHandlingResult { + return (*Portal)(portal).handleRemoteReadReceipt(ctx, source, evt) +} + +func (portal *PortalInternals) HandleRemoteMarkUnread(ctx context.Context, source *UserLogin, evt RemoteMarkUnread) EventHandlingResult { + return (*Portal)(portal).handleRemoteMarkUnread(ctx, source, evt) +} + +func (portal *PortalInternals) HandleRemoteDeliveryReceipt(ctx context.Context, source *UserLogin, evt RemoteDeliveryReceipt) EventHandlingResult { + return (*Portal)(portal).handleRemoteDeliveryReceipt(ctx, source, evt) +} + +func (portal *PortalInternals) HandleRemoteTyping(ctx context.Context, source *UserLogin, evt RemoteTyping) EventHandlingResult { + return (*Portal)(portal).handleRemoteTyping(ctx, source, evt) +} + +func (portal *PortalInternals) HandleRemoteChatInfoChange(ctx context.Context, source *UserLogin, evt RemoteChatInfoChange) EventHandlingResult { + return (*Portal)(portal).handleRemoteChatInfoChange(ctx, source, evt) +} + +func (portal *PortalInternals) HandleRemoteChatResync(ctx context.Context, source *UserLogin, evt RemoteChatResync) EventHandlingResult { + return (*Portal)(portal).handleRemoteChatResync(ctx, source, evt) +} + +func (portal *PortalInternals) FindOtherLogins(ctx context.Context, source *UserLogin) (ownUP *database.UserPortal, others []*database.UserPortal, err error) { + return (*Portal)(portal).findOtherLogins(ctx, source) +} + +func (portal *PortalInternals) HandleRemoteChatDelete(ctx context.Context, source *UserLogin, evt RemoteChatDelete) EventHandlingResult { + return (*Portal)(portal).handleRemoteChatDelete(ctx, source, evt) +} + +func (portal *PortalInternals) HandleRemoteBackfill(ctx context.Context, source *UserLogin, backfill RemoteBackfill) (res EventHandlingResult) { + return (*Portal)(portal).handleRemoteBackfill(ctx, source, backfill) +} + +func (portal *PortalInternals) UpdateName(ctx context.Context, name string, sender MatrixAPI, ts time.Time, excludeFromTimeline bool) bool { + return (*Portal)(portal).updateName(ctx, name, sender, ts, excludeFromTimeline) +} + +func (portal *PortalInternals) UpdateTopic(ctx context.Context, topic string, sender MatrixAPI, ts time.Time, excludeFromTimeline bool) bool { + return (*Portal)(portal).updateTopic(ctx, topic, sender, ts, excludeFromTimeline) +} + +func (portal *PortalInternals) UpdateAvatar(ctx context.Context, avatar *Avatar, sender MatrixAPI, ts time.Time, excludeFromTimeline bool) bool { + return (*Portal)(portal).updateAvatar(ctx, avatar, sender, ts, excludeFromTimeline) +} + +func (portal *PortalInternals) GetBridgeInfoStateKey() string { + return (*Portal)(portal).getBridgeInfoStateKey() +} + +func (portal *PortalInternals) GetBridgeInfo() (string, event.BridgeEventContent) { + return (*Portal)(portal).getBridgeInfo() +} + +func (portal *PortalInternals) SendStateWithIntentOrBot(ctx context.Context, sender MatrixAPI, eventType event.Type, stateKey string, content *event.Content, ts time.Time) (resp *mautrix.RespSendEvent, err error) { + return (*Portal)(portal).sendStateWithIntentOrBot(ctx, sender, eventType, stateKey, content, ts) +} + +func (portal *PortalInternals) SendRoomMeta(ctx context.Context, sender MatrixAPI, ts time.Time, eventType event.Type, stateKey string, content any, excludeFromTimeline bool, extra map[string]any) bool { + return (*Portal)(portal).sendRoomMeta(ctx, sender, ts, eventType, stateKey, content, excludeFromTimeline, extra) +} + +func (portal *PortalInternals) RevertRoomMeta(ctx context.Context, evt *event.Event) { + (*Portal)(portal).revertRoomMeta(ctx, evt) +} + +func (portal *PortalInternals) GetInitialMemberList(ctx context.Context, members *ChatMemberList, source *UserLogin, pl *event.PowerLevelsEventContent) (invite, functional []id.UserID, err error) { + return (*Portal)(portal).getInitialMemberList(ctx, members, source, pl) +} + +func (portal *PortalInternals) UpdateOtherUser(ctx context.Context, members *ChatMemberList) (changed bool) { + return (*Portal)(portal).updateOtherUser(ctx, members) +} + +func (portal *PortalInternals) RoomIsPublic(ctx context.Context) bool { + return (*Portal)(portal).roomIsPublic(ctx) +} + +func (portal *PortalInternals) SyncParticipants(ctx context.Context, members *ChatMemberList, source *UserLogin, sender MatrixAPI, ts time.Time) error { + return (*Portal)(portal).syncParticipants(ctx, members, source, sender, ts) +} + +func (portal *PortalInternals) UpdateUserLocalInfo(ctx context.Context, info *UserLocalPortalInfo, source *UserLogin, didJustCreate bool) { + (*Portal)(portal).updateUserLocalInfo(ctx, info, source, didJustCreate) +} + +func (portal *PortalInternals) UpdateParent(ctx context.Context, newParentID networkid.PortalID, source *UserLogin) bool { + return (*Portal)(portal).updateParent(ctx, newParentID, source) +} + +func (portal *PortalInternals) LockedUpdateInfoFromGhost(ctx context.Context, ghost *Ghost) { + (*Portal)(portal).lockedUpdateInfoFromGhost(ctx, ghost) +} + +func (portal *PortalInternals) CreateMatrixRoomInLoop(ctx context.Context, source *UserLogin, info *ChatInfo, backfillBundle any) error { + return (*Portal)(portal).createMatrixRoomInLoop(ctx, source, info, backfillBundle) +} + +func (portal *PortalInternals) AddToUserSpaces(ctx context.Context) { + (*Portal)(portal).addToUserSpaces(ctx) +} + +func (portal *PortalInternals) RemoveInPortalCache(ctx context.Context) { + (*Portal)(portal).removeInPortalCache(ctx) +} + +func (portal *PortalInternals) UnlockedDelete(ctx context.Context) error { + return (*Portal)(portal).unlockedDelete(ctx) +} + +func (portal *PortalInternals) UnlockedDeleteCache() { + (*Portal)(portal).unlockedDeleteCache() +} + +func (portal *PortalInternals) DoForwardBackfill(ctx context.Context, source *UserLogin, lastMessage *database.Message, bundledData any) { + (*Portal)(portal).doForwardBackfill(ctx, source, lastMessage, bundledData) +} + +func (portal *PortalInternals) FetchThreadBackfill(ctx context.Context, source *UserLogin, anchor *database.Message) *FetchMessagesResponse { + return (*Portal)(portal).fetchThreadBackfill(ctx, source, anchor) +} + +func (portal *PortalInternals) DoThreadBackfill(ctx context.Context, source *UserLogin, threadID networkid.MessageID) { + (*Portal)(portal).doThreadBackfill(ctx, source, threadID) +} + +func (portal *PortalInternals) CutoffMessages(ctx context.Context, messages []*BackfillMessage, aggressiveDedup, forward bool, lastMessage *database.Message) []*BackfillMessage { + return (*Portal)(portal).cutoffMessages(ctx, messages, aggressiveDedup, forward, lastMessage) +} + +func (portal *PortalInternals) SendBackfill(ctx context.Context, source *UserLogin, messages []*BackfillMessage, forceForward, markRead, inThread bool, done func()) { + (*Portal)(portal).sendBackfill(ctx, source, messages, forceForward, markRead, inThread, done) +} + +func (portal *PortalInternals) CompileBatchMessage(ctx context.Context, source *UserLogin, msg *BackfillMessage, out *compileBatchOutput, inThread bool) { + (*Portal)(portal).compileBatchMessage(ctx, source, msg, out, inThread) +} + +func (portal *PortalInternals) FetchThreadInsideBatch(ctx context.Context, source *UserLogin, dbMsg *database.Message, out *compileBatchOutput) { + (*Portal)(portal).fetchThreadInsideBatch(ctx, source, dbMsg, out) +} + +func (portal *PortalInternals) SendBatch(ctx context.Context, source *UserLogin, messages []*BackfillMessage, forceForward, markRead, inThread bool) { + (*Portal)(portal).sendBatch(ctx, source, messages, forceForward, markRead, inThread) +} + +func (portal *PortalInternals) SendLegacyBackfill(ctx context.Context, source *UserLogin, messages []*BackfillMessage, markRead bool) { + (*Portal)(portal).sendLegacyBackfill(ctx, source, messages, markRead) +} + +func (portal *PortalInternals) UnlockedReID(ctx context.Context, target networkid.PortalKey) error { + return (*Portal)(portal).unlockedReID(ctx, target) +} + +func (portal *PortalInternals) CreateParentAndAddToSpace(ctx context.Context, source *UserLogin) { + (*Portal)(portal).createParentAndAddToSpace(ctx, source) +} + +func (portal *PortalInternals) AddToParentSpaceAndSave(ctx context.Context, save bool) { + (*Portal)(portal).addToParentSpaceAndSave(ctx, save) +} + +func (portal *PortalInternals) ToggleSpace(ctx context.Context, spaceID id.RoomID, canonical, remove bool) error { + return (*Portal)(portal).toggleSpace(ctx, spaceID, canonical, remove) +} diff --git a/bridgev2/portalinternal_generate.go b/bridgev2/portalinternal_generate.go new file mode 100644 index 00000000..2ac6c898 --- /dev/null +++ b/bridgev2/portalinternal_generate.go @@ -0,0 +1,173 @@ +// Copyright (c) 2024 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +//go:build ignore + +package main + +import ( + "fmt" + "go/ast" + "go/parser" + "go/token" + "os" + "strings" + + "go.mau.fi/util/exerrors" +) + +const header = `// GENERATED BY portalinternal_generate.go; DO NOT EDIT + +//go:generate go run portalinternal_generate.go +//go:generate goimports -local maunium.net/go/mautrix -w portalinternal.go + +package bridgev2 + +` +const postImportHeader = ` +type PortalInternals Portal + +// Deprecated: portal internals should be used carefully and only when necessary. +func (portal *Portal) Internal() *PortalInternals { + return (*PortalInternals)(portal) +} +` + +func getTypeName(expr ast.Expr) string { + switch e := expr.(type) { + case *ast.Ident: + return e.Name + case *ast.StarExpr: + return "*" + getTypeName(e.X) + case *ast.ArrayType: + return "[]" + getTypeName(e.Elt) + case *ast.MapType: + return fmt.Sprintf("map[%s]%s", getTypeName(e.Key), getTypeName(e.Value)) + case *ast.ChanType: + return fmt.Sprintf("chan %s", getTypeName(e.Value)) + case *ast.FuncType: + var params []string + for _, param := range e.Params.List { + params = append(params, getTypeName(param.Type)) + } + var results []string + if e.Results != nil { + for _, result := range e.Results.List { + results = append(results, getTypeName(result.Type)) + } + } + return fmt.Sprintf("func(%s) %s", strings.Join(params, ", "), strings.Join(results, ", ")) + case *ast.SelectorExpr: + return fmt.Sprintf("%s.%s", getTypeName(e.X), e.Sel.Name) + default: + panic(fmt.Errorf("unknown type %T", e)) + } +} + +var write func(str string) +var writef func(format string, args ...any) + +func main() { + fset := token.NewFileSet() + fileNames := []string{"portal.go", "portalbackfill.go", "portalreid.go", "space.go", "matrixinvite.go"} + files := make([]*ast.File, len(fileNames)) + for i, name := range fileNames { + files[i] = exerrors.Must(parser.ParseFile(fset, name, nil, parser.SkipObjectResolution)) + } + file := exerrors.Must(os.OpenFile("portalinternal.go", os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0644)) + write = func(str string) { + exerrors.Must(file.WriteString(str)) + } + writef = func(format string, args ...any) { + exerrors.Must(fmt.Fprintf(file, format, args...)) + } + write(header) + write("import (\n") + for _, i := range files[0].Imports { + write("\t") + if i.Name != nil { + writef("%s ", i.Name.Name) + } + writef("%s\n", i.Path.Value) + } + write(")\n") + write(postImportHeader) + for _, f := range files { + processFile(f) + } + exerrors.PanicIfNotNil(file.Close()) +} + +func processFile(f *ast.File) { + ast.Inspect(f, func(node ast.Node) (retVal bool) { + retVal = true + funcDecl, ok := node.(*ast.FuncDecl) + if !ok || funcDecl.Name.IsExported() { + return + } + if funcDecl.Recv == nil || len(funcDecl.Recv.List) == 0 || len(funcDecl.Recv.List[0].Names) == 0 || + funcDecl.Recv.List[0].Names[0].Name != "portal" { + return + } + writef("\nfunc (portal *PortalInternals) %s%s(", strings.ToUpper(funcDecl.Name.Name[0:1]), funcDecl.Name.Name[1:]) + for i, param := range funcDecl.Type.Params.List { + if i != 0 { + write(", ") + } + for j, name := range param.Names { + if j != 0 { + write(", ") + } + write(name.Name) + } + if len(param.Names) > 0 { + write(" ") + } + write(getTypeName(param.Type)) + } + write(") ") + if funcDecl.Type.Results != nil && len(funcDecl.Type.Results.List) > 0 { + needsParentheses := len(funcDecl.Type.Results.List) > 1 || len(funcDecl.Type.Results.List[0].Names) > 0 + if needsParentheses { + write("(") + } + for i, result := range funcDecl.Type.Results.List { + if i != 0 { + write(", ") + } + for j, name := range result.Names { + if j != 0 { + write(", ") + } + write(name.Name) + } + if len(result.Names) > 0 { + write(" ") + } + write(getTypeName(result.Type)) + } + if needsParentheses { + write(")") + } + write(" ") + } + write("{\n\t") + if funcDecl.Type.Results != nil { + write("return ") + } + writef("(*Portal)(portal).%s(", funcDecl.Name.Name) + for i, param := range funcDecl.Type.Params.List { + for j, name := range param.Names { + if i != 0 || j != 0 { + write(", ") + } + write(name.Name) + } + } + write(")\n}\n") + return + }) +} diff --git a/bridgev2/portalreid.go b/bridgev2/portalreid.go new file mode 100644 index 00000000..c976d97c --- /dev/null +++ b/bridgev2/portalreid.go @@ -0,0 +1,161 @@ +// Copyright (c) 2024 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package bridgev2 + +import ( + "context" + "fmt" + "time" + + "github.com/rs/zerolog" + + "maunium.net/go/mautrix/bridgev2/networkid" + "maunium.net/go/mautrix/event" +) + +type ReIDResult int + +const ( + ReIDResultError ReIDResult = iota + ReIDResultNoOp + ReIDResultSourceDeleted + ReIDResultSourceReIDd + ReIDResultTargetDeletedAndSourceReIDd + ReIDResultSourceTombstonedIntoTarget +) + +func (br *Bridge) ReIDPortal(ctx context.Context, source, target networkid.PortalKey) (ReIDResult, *Portal, error) { + if source == target { + return ReIDResultError, nil, fmt.Errorf("illegal re-ID call: source and target are the same") + } + log := zerolog.Ctx(ctx).With(). + Str("action", "re-id portal"). + Stringer("source_portal_key", source). + Stringer("target_portal_key", target). + Logger() + ctx = log.WithContext(ctx) + defer func() { + log.Debug().Msg("Finished handling portal re-ID") + }() + acquireCacheLock := func() { + if !br.cacheLock.TryLock() { + log.Debug().Msg("Waiting for global cache lock") + br.cacheLock.Lock() + log.Debug().Msg("Acquired global cache lock after waiting") + } else { + log.Trace().Msg("Acquired global cache lock without waiting") + } + } + log.Debug().Msg("Re-ID'ing portal") + sourcePortal, err := br.GetExistingPortalByKey(ctx, source) + if err != nil { + return ReIDResultError, nil, fmt.Errorf("failed to get source portal: %w", err) + } else if sourcePortal == nil { + log.Debug().Msg("Source portal not found, re-ID is no-op") + return ReIDResultNoOp, nil, nil + } + if !sourcePortal.roomCreateLock.TryLock() { + if cancelCreate := sourcePortal.cancelRoomCreate.Swap(nil); cancelCreate != nil { + (*cancelCreate)() + } + log.Debug().Msg("Waiting for source portal room creation lock") + sourcePortal.roomCreateLock.Lock() + log.Debug().Msg("Acquired source portal room creation lock after waiting") + } + defer sourcePortal.roomCreateLock.Unlock() + if sourcePortal.MXID == "" { + log.Info().Msg("Source portal doesn't have Matrix room, deleting row") + err = sourcePortal.unlockedDelete(ctx) + if err != nil { + return ReIDResultError, nil, fmt.Errorf("failed to delete source portal: %w", err) + } + return ReIDResultSourceDeleted, nil, nil + } + log.UpdateContext(func(c zerolog.Context) zerolog.Context { + return c.Stringer("source_portal_mxid", sourcePortal.MXID) + }) + + acquireCacheLock() + targetPortal, err := br.UnlockedGetPortalByKey(ctx, target, true) + if err != nil { + br.cacheLock.Unlock() + return ReIDResultError, nil, fmt.Errorf("failed to get target portal: %w", err) + } + if targetPortal == nil { + log.Info().Msg("Target portal doesn't exist, re-ID'ing source portal") + err = sourcePortal.unlockedReID(ctx, target) + br.cacheLock.Unlock() + if err != nil { + return ReIDResultError, nil, fmt.Errorf("failed to re-ID source portal: %w", err) + } + return ReIDResultSourceReIDd, sourcePortal, nil + } + br.cacheLock.Unlock() + + if !targetPortal.roomCreateLock.TryLock() { + if cancelCreate := targetPortal.cancelRoomCreate.Swap(nil); cancelCreate != nil { + (*cancelCreate)() + } + log.Debug().Msg("Waiting for target portal room creation lock") + targetPortal.roomCreateLock.Lock() + log.Debug().Msg("Acquired target portal room creation lock after waiting") + } + defer targetPortal.roomCreateLock.Unlock() + if targetPortal.MXID == "" { + log.Info().Msg("Target portal row exists, but doesn't have a Matrix room. Deleting target portal row and re-ID'ing source portal") + acquireCacheLock() + defer br.cacheLock.Unlock() + err = targetPortal.unlockedDelete(ctx) + if err != nil { + return ReIDResultError, nil, fmt.Errorf("failed to delete target portal: %w", err) + } + err = sourcePortal.unlockedReID(ctx, target) + if err != nil { + return ReIDResultError, nil, fmt.Errorf("failed to re-ID source portal after deleting target: %w", err) + } + return ReIDResultTargetDeletedAndSourceReIDd, sourcePortal, nil + } else { + log.UpdateContext(func(c zerolog.Context) zerolog.Context { + return c.Stringer("target_portal_mxid", targetPortal.MXID) + }) + log.Info().Msg("Both target and source portals have Matrix rooms, tombstoning source portal") + sourcePortal.removeInPortalCache(ctx) + acquireCacheLock() + defer br.cacheLock.Unlock() + err = sourcePortal.unlockedDelete(ctx) + if err != nil { + return ReIDResultError, nil, fmt.Errorf("failed to delete source portal row: %w", err) + } + go func() { + _, err := br.Bot.SendState(ctx, sourcePortal.MXID, event.StateTombstone, "", &event.Content{ + Parsed: &event.TombstoneEventContent{ + Body: "This room has been merged", + ReplacementRoom: targetPortal.MXID, + }, + }, time.Now()) + if err != nil { + log.Err(err).Msg("Failed to send tombstone to source portal room") + } + err = br.Bot.DeleteRoom(ctx, sourcePortal.MXID, err == nil) + if err != nil { + log.Err(err).Msg("Failed to delete source portal room") + } + }() + return ReIDResultSourceTombstonedIntoTarget, targetPortal, nil + } +} + +func (portal *Portal) unlockedReID(ctx context.Context, target networkid.PortalKey) error { + err := portal.Bridge.DB.Portal.ReID(ctx, portal.PortalKey, target) + if err != nil { + return err + } + delete(portal.Bridge.portalsByKey, portal.PortalKey) + portal.Bridge.portalsByKey[target] = portal + portal.PortalKey = target + return nil +} diff --git a/bridgev2/provisionutil/creategroup.go b/bridgev2/provisionutil/creategroup.go new file mode 100644 index 00000000..72bacaff --- /dev/null +++ b/bridgev2/provisionutil/creategroup.go @@ -0,0 +1,149 @@ +// Copyright (c) 2025 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package provisionutil + +import ( + "context" + + "github.com/rs/zerolog" + "go.mau.fi/util/ptr" + + "maunium.net/go/mautrix" + "maunium.net/go/mautrix/bridgev2" + "maunium.net/go/mautrix/bridgev2/networkid" + "maunium.net/go/mautrix/event" + "maunium.net/go/mautrix/id" +) + +type RespCreateGroup struct { + ID networkid.PortalID `json:"id"` + MXID id.RoomID `json:"mxid"` + Portal *bridgev2.Portal `json:"-"` + + FailedParticipants map[networkid.UserID]*bridgev2.CreateChatFailedParticipant `json:"failed_participants,omitempty"` +} + +func CreateGroup(ctx context.Context, login *bridgev2.UserLogin, params *bridgev2.GroupCreateParams) (*RespCreateGroup, error) { + api, ok := login.Client.(bridgev2.GroupCreatingNetworkAPI) + if !ok { + return nil, bridgev2.RespError(mautrix.MUnrecognized.WithMessage("This bridge does not support creating groups")) + } + zerolog.Ctx(ctx).Debug(). + Any("create_params", params). + Msg("Creating group chat on remote network") + caps := login.Bridge.Network.GetCapabilities() + typeSpec, validType := caps.Provisioning.GroupCreation[params.Type] + if !validType { + return nil, bridgev2.RespError(mautrix.MUnrecognized.WithMessage("Unrecognized group type %s", params.Type)) + } + if len(params.Participants) < typeSpec.Participants.MinLength { + return nil, bridgev2.RespError(mautrix.MInvalidParam.WithMessage("Must have at least %d members", typeSpec.Participants.MinLength)) + } else if typeSpec.Participants.MaxLength > 0 && len(params.Participants) > typeSpec.Participants.MaxLength { + return nil, bridgev2.RespError(mautrix.MInvalidParam.WithMessage("Must have at most %d members", typeSpec.Participants.MaxLength)) + } + userIDValidatingNetwork, uidValOK := login.Bridge.Network.(bridgev2.IdentifierValidatingNetwork) + for i, participant := range params.Participants { + parsedParticipant, ok := login.Bridge.Matrix.ParseGhostMXID(id.UserID(participant)) + if ok { + participant = parsedParticipant + params.Participants[i] = participant + } + if !typeSpec.Participants.SkipIdentifierValidation { + if uidValOK && !userIDValidatingNetwork.ValidateUserID(participant) { + return nil, bridgev2.RespError(mautrix.MInvalidParam.WithMessage("User ID %q is not valid on this network", participant)) + } + } + if api.IsThisUser(ctx, participant) { + return nil, bridgev2.RespError(mautrix.MInvalidParam.WithMessage("You can't include yourself in the participants list", participant)) + } + } + if (params.Name == nil || params.Name.Name == "") && typeSpec.Name.Required { + return nil, bridgev2.RespError(mautrix.MInvalidParam.WithMessage("Name is required")) + } else if nameLen := len(ptr.Val(params.Name).Name); nameLen > 0 && nameLen < typeSpec.Name.MinLength { + return nil, bridgev2.RespError(mautrix.MInvalidParam.WithMessage("Name must be at least %d characters", typeSpec.Name.MinLength)) + } else if typeSpec.Name.MaxLength > 0 && nameLen > typeSpec.Name.MaxLength { + return nil, bridgev2.RespError(mautrix.MInvalidParam.WithMessage("Name must be at most %d characters", typeSpec.Name.MaxLength)) + } + if (params.Avatar == nil || params.Avatar.URL == "") && typeSpec.Avatar.Required { + return nil, bridgev2.RespError(mautrix.MInvalidParam.WithMessage("Avatar is required")) + } + if (params.Topic == nil || params.Topic.Topic == "") && typeSpec.Topic.Required { + return nil, bridgev2.RespError(mautrix.MInvalidParam.WithMessage("Topic is required")) + } else if topicLen := len(ptr.Val(params.Topic).Topic); topicLen > 0 && topicLen < typeSpec.Topic.MinLength { + return nil, bridgev2.RespError(mautrix.MInvalidParam.WithMessage("Topic must be at least %d characters", typeSpec.Topic.MinLength)) + } else if typeSpec.Topic.MaxLength > 0 && topicLen > typeSpec.Topic.MaxLength { + return nil, bridgev2.RespError(mautrix.MInvalidParam.WithMessage("Topic must be at most %d characters", typeSpec.Topic.MaxLength)) + } + if (params.Disappear == nil || params.Disappear.Timer.Duration == 0) && typeSpec.Disappear.Required { + return nil, bridgev2.RespError(mautrix.MInvalidParam.WithMessage("Disappearing timer is required")) + } else if !typeSpec.Disappear.DisappearSettings.Supports(params.Disappear) { + return nil, bridgev2.RespError(mautrix.MInvalidParam.WithMessage("Unsupported value for disappearing timer")) + } + if params.Username == "" && typeSpec.Username.Required { + return nil, bridgev2.RespError(mautrix.MInvalidParam.WithMessage("Username is required")) + } else if len(params.Username) > 0 && len(params.Username) < typeSpec.Username.MinLength { + return nil, bridgev2.RespError(mautrix.MInvalidParam.WithMessage("Username must be at least %d characters", typeSpec.Username.MinLength)) + } else if typeSpec.Username.MaxLength > 0 && len(params.Username) > typeSpec.Username.MaxLength { + return nil, bridgev2.RespError(mautrix.MInvalidParam.WithMessage("Username must be at most %d characters", typeSpec.Username.MaxLength)) + } + if params.Parent == nil && typeSpec.Parent.Required { + return nil, bridgev2.RespError(mautrix.MInvalidParam.WithMessage("Parent is required")) + } + resp, err := api.CreateGroup(ctx, params) + if err != nil { + zerolog.Ctx(ctx).Err(err).Msg("Failed to create group") + return nil, err + } + if resp.PortalKey.IsEmpty() { + return nil, ErrNoPortalKey + } + zerolog.Ctx(ctx).Debug(). + Object("portal_key", resp.PortalKey). + Msg("Successfully created group on remote network") + if resp.Portal == nil { + resp.Portal, err = login.Bridge.GetPortalByKey(ctx, resp.PortalKey) + if err != nil { + zerolog.Ctx(ctx).Err(err).Msg("Failed to get portal") + return nil, bridgev2.RespError(mautrix.MUnknown.WithMessage("Failed to get portal")) + } + } + if resp.Portal.MXID == "" { + err = resp.Portal.CreateMatrixRoom(ctx, login, resp.PortalInfo) + if err != nil { + zerolog.Ctx(ctx).Err(err).Msg("Failed to create portal room") + return nil, bridgev2.RespError(mautrix.MUnknown.WithMessage("Failed to create portal room")) + } + } + for key, fp := range resp.FailedParticipants { + if fp.InviteEventType == "" { + fp.InviteEventType = event.EventMessage.Type + } + if fp.UserMXID == "" { + ghost, err := login.Bridge.GetGhostByID(ctx, key) + if err != nil { + zerolog.Ctx(ctx).Err(err).Msg("Failed to get ghost for failed participant") + } else if ghost != nil { + fp.UserMXID = ghost.Intent.GetMXID() + } + } + if fp.DMRoomMXID == "" { + portal, err := login.Bridge.GetDMPortal(ctx, login.ID, key) + if err != nil { + zerolog.Ctx(ctx).Err(err).Msg("Failed to get DM portal for failed participant") + } else if portal != nil { + fp.DMRoomMXID = portal.MXID + } + } + } + return &RespCreateGroup{ + ID: resp.Portal.ID, + MXID: resp.Portal.MXID, + Portal: resp.Portal, + + FailedParticipants: resp.FailedParticipants, + }, nil +} diff --git a/bridgev2/provisionutil/listcontacts.go b/bridgev2/provisionutil/listcontacts.go new file mode 100644 index 00000000..ce163e67 --- /dev/null +++ b/bridgev2/provisionutil/listcontacts.go @@ -0,0 +1,98 @@ +// Copyright (c) 2025 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package provisionutil + +import ( + "context" + + "github.com/rs/zerolog" + + "maunium.net/go/mautrix" + "maunium.net/go/mautrix/bridgev2" +) + +type RespGetContactList struct { + Contacts []*RespResolveIdentifier `json:"contacts"` +} + +type RespSearchUsers struct { + Results []*RespResolveIdentifier `json:"results"` +} + +func GetContactList(ctx context.Context, login *bridgev2.UserLogin) (*RespGetContactList, error) { + api, ok := login.Client.(bridgev2.ContactListingNetworkAPI) + if !ok { + return nil, bridgev2.RespError(mautrix.MUnrecognized.WithMessage("This bridge does not support listing contacts")) + } + resp, err := api.GetContactList(ctx) + if err != nil { + zerolog.Ctx(ctx).Err(err).Msg("Failed to get contact list") + return nil, err + } + return &RespGetContactList{ + Contacts: processResolveIdentifiers(ctx, login.Bridge, resp, false), + }, nil +} + +func SearchUsers(ctx context.Context, login *bridgev2.UserLogin, query string) (*RespSearchUsers, error) { + api, ok := login.Client.(bridgev2.UserSearchingNetworkAPI) + if !ok { + return nil, bridgev2.RespError(mautrix.MUnrecognized.WithMessage("This bridge does not support searching for users")) + } + resp, err := api.SearchUsers(ctx, query) + if err != nil { + zerolog.Ctx(ctx).Err(err).Msg("Failed to get contact list") + return nil, err + } + return &RespSearchUsers{ + Results: processResolveIdentifiers(ctx, login.Bridge, resp, true), + }, nil +} + +func processResolveIdentifiers(ctx context.Context, br *bridgev2.Bridge, resp []*bridgev2.ResolveIdentifierResponse, syncInfo bool) (apiResp []*RespResolveIdentifier) { + apiResp = make([]*RespResolveIdentifier, len(resp)) + for i, contact := range resp { + apiContact := &RespResolveIdentifier{ + ID: contact.UserID, + } + apiResp[i] = apiContact + if contact.UserInfo != nil { + if contact.UserInfo.Name != nil { + apiContact.Name = *contact.UserInfo.Name + } + if contact.UserInfo.Identifiers != nil { + apiContact.Identifiers = contact.UserInfo.Identifiers + } + } + if contact.Ghost != nil { + if syncInfo && contact.UserInfo != nil { + contact.Ghost.UpdateInfo(ctx, contact.UserInfo) + } + if contact.Ghost.Name != "" { + apiContact.Name = contact.Ghost.Name + } + if len(contact.Ghost.Identifiers) >= len(apiContact.Identifiers) { + apiContact.Identifiers = contact.Ghost.Identifiers + } + apiContact.AvatarURL = contact.Ghost.AvatarMXC + apiContact.MXID = contact.Ghost.Intent.GetMXID() + } + if contact.Chat != nil { + if contact.Chat.Portal == nil { + var err error + contact.Chat.Portal, err = br.GetPortalByKey(ctx, contact.Chat.PortalKey) + if err != nil { + zerolog.Ctx(ctx).Err(err).Msg("Failed to get portal") + } + } + if contact.Chat.Portal != nil { + apiContact.DMRoomID = contact.Chat.Portal.MXID + } + } + } + return +} diff --git a/bridgev2/provisionutil/resolveidentifier.go b/bridgev2/provisionutil/resolveidentifier.go new file mode 100644 index 00000000..cfc388d0 --- /dev/null +++ b/bridgev2/provisionutil/resolveidentifier.go @@ -0,0 +1,125 @@ +// Copyright (c) 2025 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package provisionutil + +import ( + "context" + "errors" + + "github.com/rs/zerolog" + + "maunium.net/go/mautrix" + "maunium.net/go/mautrix/bridgev2" + "maunium.net/go/mautrix/bridgev2/networkid" + "maunium.net/go/mautrix/id" +) + +type RespResolveIdentifier struct { + ID networkid.UserID `json:"id"` + Name string `json:"name,omitempty"` + AvatarURL id.ContentURIString `json:"avatar_url,omitempty"` + Identifiers []string `json:"identifiers,omitempty"` + MXID id.UserID `json:"mxid,omitempty"` + DMRoomID id.RoomID `json:"dm_room_mxid,omitempty"` + + Portal *bridgev2.Portal `json:"-"` + Ghost *bridgev2.Ghost `json:"-"` + JustCreated bool `json:"-"` +} + +var ErrNoPortalKey = errors.New("network API didn't return portal key for createChat request") + +func ResolveIdentifier( + ctx context.Context, + login *bridgev2.UserLogin, + identifier string, + createChat bool, +) (*RespResolveIdentifier, error) { + api, ok := login.Client.(bridgev2.IdentifierResolvingNetworkAPI) + if !ok { + return nil, bridgev2.RespError(mautrix.MUnrecognized.WithMessage("This bridge does not support resolving identifiers")) + } + var resp *bridgev2.ResolveIdentifierResponse + parsedUserID, ok := login.Bridge.Matrix.ParseGhostMXID(id.UserID(identifier)) + validator, vOK := login.Bridge.Network.(bridgev2.IdentifierValidatingNetwork) + if ok && (!vOK || validator.ValidateUserID(parsedUserID)) { + ghost, err := login.Bridge.GetGhostByID(ctx, parsedUserID) + if err != nil { + zerolog.Ctx(ctx).Err(err).Msg("Failed to get ghost by ID") + return nil, err + } + resp = &bridgev2.ResolveIdentifierResponse{ + Ghost: ghost, + UserID: parsedUserID, + } + gdcAPI, ok := api.(bridgev2.GhostDMCreatingNetworkAPI) + if ok && createChat { + resp.Chat, err = gdcAPI.CreateChatWithGhost(ctx, ghost) + if err != nil { + zerolog.Ctx(ctx).Err(err).Msg("Failed to create chat") + return nil, err + } + } else if createChat || ghost.Name == "" { + zerolog.Ctx(ctx).Debug(). + Bool("create_chat", createChat). + Bool("has_name", ghost.Name != ""). + Msg("Falling back to resolving identifier") + resp = nil + identifier = string(parsedUserID) + } + } + if resp == nil { + var err error + resp, err = api.ResolveIdentifier(ctx, identifier, createChat) + if err != nil { + zerolog.Ctx(ctx).Err(err).Msg("Failed to resolve identifier") + return nil, err + } else if resp == nil { + return nil, nil + } + } + apiResp := &RespResolveIdentifier{ + ID: resp.UserID, + Ghost: resp.Ghost, + } + if resp.Ghost != nil { + if resp.UserInfo != nil { + resp.Ghost.UpdateInfo(ctx, resp.UserInfo) + } + apiResp.Name = resp.Ghost.Name + apiResp.AvatarURL = resp.Ghost.AvatarMXC + apiResp.Identifiers = resp.Ghost.Identifiers + apiResp.MXID = resp.Ghost.Intent.GetMXID() + } else if resp.UserInfo != nil && resp.UserInfo.Name != nil { + apiResp.Name = *resp.UserInfo.Name + } + if resp.Chat != nil { + if resp.Chat.PortalKey.IsEmpty() { + return nil, ErrNoPortalKey + } + if resp.Chat.Portal == nil { + var err error + resp.Chat.Portal, err = login.Bridge.GetPortalByKey(ctx, resp.Chat.PortalKey) + if err != nil { + zerolog.Ctx(ctx).Err(err).Msg("Failed to get portal") + return nil, bridgev2.RespError(mautrix.MUnknown.WithMessage("Failed to get portal")) + } + } + resp.Chat.Portal.CleanupOrphanedDM(ctx, login.UserMXID) + if createChat && resp.Chat.Portal.MXID == "" { + apiResp.JustCreated = true + err := resp.Chat.Portal.CreateMatrixRoom(ctx, login, resp.Chat.PortalInfo) + if err != nil { + zerolog.Ctx(ctx).Err(err).Msg("Failed to create portal room") + return nil, bridgev2.RespError(mautrix.MUnknown.WithMessage("Failed to create portal room")) + } + } + apiResp.Portal = resp.Chat.Portal + apiResp.DMRoomID = resp.Chat.Portal.MXID + } + return apiResp, nil +} diff --git a/bridgev2/queue.go b/bridgev2/queue.go new file mode 100644 index 00000000..3775c825 --- /dev/null +++ b/bridgev2/queue.go @@ -0,0 +1,254 @@ +// Copyright (c) 2024 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package bridgev2 + +import ( + "context" + "errors" + "fmt" + "strings" + "time" + + "github.com/rs/zerolog" + + "maunium.net/go/mautrix/event" + "maunium.net/go/mautrix/id" +) + +func rejectInvite(ctx context.Context, evt *event.Event, intent MatrixAPI, reason string) { + resp, err := intent.SendState(ctx, evt.RoomID, event.StateMember, intent.GetMXID().String(), &event.Content{ + Parsed: &event.MemberEventContent{ + Membership: event.MembershipLeave, + Reason: reason, + }, + }, time.Time{}) + if err != nil { + zerolog.Ctx(ctx).Err(err). + Stringer("room_id", evt.RoomID). + Stringer("inviter_id", evt.Sender). + Stringer("invitee_id", intent.GetMXID()). + Str("reason", reason). + Msg("Failed to reject invite") + } else { + zerolog.Ctx(ctx).Debug(). + Stringer("leave_event_id", resp.EventID). + Stringer("room_id", evt.RoomID). + Stringer("inviter_id", evt.Sender). + Stringer("invitee_id", intent.GetMXID()). + Str("reason", reason). + Msg("Rejected invite") + } +} + +func (br *Bridge) rejectInviteOnNoPermission(ctx context.Context, evt *event.Event, permType string) bool { + if evt.Type != event.StateMember || evt.Content.AsMember().Membership != event.MembershipInvite { + return false + } + userID := id.UserID(evt.GetStateKey()) + parsed, isGhost := br.Matrix.ParseGhostMXID(userID) + if userID != br.Bot.GetMXID() && !isGhost { + return false + } + var intent MatrixAPI + if userID == br.Bot.GetMXID() { + intent = br.Bot + } else { + intent = br.Matrix.GhostIntent(parsed) + } + rejectInvite(ctx, evt, intent, "You don't have permission to "+permType+" this bridge") + return true +} + +var ( + ErrEventSenderUserNotFound = WrapErrorInStatus(errors.New("sender not found for event")).WithIsCertain(true).WithErrorAsMessage() + ErrNoPermissionToInteract = WrapErrorInStatus(errors.New("you don't have permission to send messages")).WithIsCertain(true).WithSendNotice(false).WithErrorAsMessage() + ErrNoPermissionForCommands = WrapErrorInStatus(WrapErrorInStatus(errors.New("you don't have permission to use commands")).WithIsCertain(true).WithSendNotice(false).WithErrorAsMessage()) + ErrCantRelayStateRequest = WrapErrorInStatus(errors.New("relayed users can't use beeper state requests")).WithIsCertain(true).WithErrorAsMessage() +) + +func (br *Bridge) QueueMatrixEvent(ctx context.Context, evt *event.Event) EventHandlingResult { + // TODO maybe HandleMatrixEvent would be more appropriate as this also handles bot invites and commands + + log := zerolog.Ctx(ctx) + var sender *User + if evt.Sender != "" { + var err error + sender, err = br.GetUserByMXID(ctx, evt.Sender) + if err != nil { + log.Err(err).Msg("Failed to get sender user for incoming Matrix event") + status := WrapErrorInStatus(fmt.Errorf("%w: failed to get sender user: %w", ErrDatabaseError, err)) + br.Matrix.SendMessageStatus(ctx, &status, StatusEventInfoFromEvent(evt)) + return EventHandlingResultFailed + } else if sender == nil { + log.Error().Msg("Couldn't get sender for incoming non-ephemeral Matrix event") + br.Matrix.SendMessageStatus(ctx, &ErrEventSenderUserNotFound, StatusEventInfoFromEvent(evt)) + return EventHandlingResultFailed + } else if !sender.Permissions.SendEvents { + if !br.rejectInviteOnNoPermission(ctx, evt, "interact with") { + br.Matrix.SendMessageStatus(ctx, &ErrNoPermissionToInteract, StatusEventInfoFromEvent(evt)) + } + return EventHandlingResultIgnored + } else if !sender.Permissions.Commands && br.rejectInviteOnNoPermission(ctx, evt, "send commands to") { + return EventHandlingResultIgnored + } + } else if evt.Type.Class != event.EphemeralEventType { + log.Error().Msg("Missing sender for incoming non-ephemeral Matrix event") + br.Matrix.SendMessageStatus(ctx, &ErrEventSenderUserNotFound, StatusEventInfoFromEvent(evt)) + return EventHandlingResultIgnored + } + if evt.Type == event.EventMessage && sender != nil { + msg := evt.Content.AsMessage() + msg.RemoveReplyFallback() + msg.RemovePerMessageProfileFallback() + if strings.HasPrefix(msg.Body, br.Config.CommandPrefix) || evt.RoomID == sender.ManagementRoom { + if !sender.Permissions.Commands { + br.Matrix.SendMessageStatus(ctx, &ErrNoPermissionForCommands, StatusEventInfoFromEvent(evt)) + return EventHandlingResultIgnored + } + go br.Commands.Handle( + ctx, + evt.RoomID, + evt.ID, + sender, + strings.TrimPrefix(msg.Body, br.Config.CommandPrefix+" "), + msg.RelatesTo.GetReplyTo(), + ) + return EventHandlingResultQueued + } + } + if evt.Type == event.StateMember && evt.GetStateKey() == br.Bot.GetMXID().String() && evt.Content.AsMember().Membership == event.MembershipInvite && sender != nil { + return br.handleBotInvite(ctx, evt, sender) + } else if sender != nil && evt.RoomID == sender.ManagementRoom { + if evt.Type == event.StateMember && evt.Content.AsMember().Membership == event.MembershipLeave && (evt.GetStateKey() == br.Bot.GetMXID().String() || evt.GetStateKey() == sender.MXID.String()) { + sender.ManagementRoom = "" + err := br.DB.User.Update(ctx, sender.User) + if err != nil { + log.Err(err).Msg("Failed to clear user's management room in database") + return EventHandlingResultFailed + } else { + log.Debug().Msg("Cleared user's management room due to leave event") + } + } + return EventHandlingResultSuccess + } + portal, err := br.GetPortalByMXID(ctx, evt.RoomID) + if err != nil { + log.Err(err).Msg("Failed to get portal for incoming Matrix event") + status := WrapErrorInStatus(fmt.Errorf("%w: failed to get portal: %w", ErrDatabaseError, err)) + br.Matrix.SendMessageStatus(ctx, &status, StatusEventInfoFromEvent(evt)) + return EventHandlingResultFailed + } else if portal != nil { + return portal.queueEvent(ctx, &portalMatrixEvent{ + evt: evt, + sender: sender, + }) + } else if evt.Type == event.StateMember && br.IsGhostMXID(id.UserID(evt.GetStateKey())) && evt.Content.AsMember().Membership == event.MembershipInvite && evt.Content.AsMember().IsDirect { + return br.handleGhostDMInvite(ctx, evt, sender) + } else { + status := WrapErrorInStatus(ErrNoPortal) + br.Matrix.SendMessageStatus(ctx, &status, StatusEventInfoFromEvent(evt)) + return EventHandlingResultIgnored + } +} + +type EventHandlingResult struct { + Success bool + Ignored bool + Queued bool + + SkipStateEcho bool + + // Error is an optional reason for failure. It is not required, Success may be false even without a specific error. + Error error + // Whether the Error should be sent as a MSS event. + SendMSS bool + + // EventID from the network + EventID id.EventID + // Stream order from the network + StreamOrder int64 +} + +func (ehr EventHandlingResult) WithEventID(id id.EventID) EventHandlingResult { + ehr.EventID = id + return ehr +} + +func (ehr EventHandlingResult) WithStreamOrder(order int64) EventHandlingResult { + ehr.StreamOrder = order + return ehr +} + +func (ehr EventHandlingResult) WithError(err error) EventHandlingResult { + if err == nil { + return ehr + } + ehr.Error = err + ehr.Success = false + return ehr +} + +func (ehr EventHandlingResult) WithMSS() EventHandlingResult { + ehr.SendMSS = true + return ehr +} + +func (ehr EventHandlingResult) WithSkipStateEcho(skip bool) EventHandlingResult { + ehr.SkipStateEcho = skip + return ehr +} + +func (ehr EventHandlingResult) WithMSSError(err error) EventHandlingResult { + if err == nil { + return ehr + } + return ehr.WithError(err).WithMSS() +} + +var ( + EventHandlingResultFailed = EventHandlingResult{} + EventHandlingResultQueued = EventHandlingResult{Success: true, Queued: true} + EventHandlingResultSuccess = EventHandlingResult{Success: true} + EventHandlingResultIgnored = EventHandlingResult{Success: true, Ignored: true} +) + +func (ul *UserLogin) QueueRemoteEvent(evt RemoteEvent) EventHandlingResult { + return ul.Bridge.QueueRemoteEvent(ul, evt) +} + +func (br *Bridge) QueueRemoteEvent(login *UserLogin, evt RemoteEvent) EventHandlingResult { + log := login.Log + ctx := log.WithContext(br.BackgroundCtx) + maybeUncertain, ok := evt.(RemoteEventWithUncertainPortalReceiver) + isUncertain := ok && maybeUncertain.PortalReceiverIsUncertain() + key := evt.GetPortalKey() + var portal *Portal + var err error + if isUncertain && !br.Config.SplitPortals { + portal, err = br.GetExistingPortalByKey(ctx, key) + } else { + portal, err = br.GetPortalByKey(ctx, key) + } + if err != nil { + log.Err(err).Object("portal_key", key).Bool("uncertain_receiver", isUncertain). + Msg("Failed to get portal to handle remote event") + return EventHandlingResultFailed.WithError(fmt.Errorf("failed to get portal: %w", err)) + } else if portal == nil { + log.Warn(). + Stringer("event_type", evt.GetType()). + Object("portal_key", key). + Bool("uncertain_receiver", isUncertain). + Msg("Portal not found to handle remote event") + return EventHandlingResultFailed.WithError(ErrPortalNotFoundInEventHandler) + } + // TODO put this in a better place, and maybe cache to avoid constant db queries + login.MarkInPortal(ctx, portal) + return portal.queueEvent(ctx, &portalRemoteEvent{ + evt: evt, + source: login, + }) +} diff --git a/bridgev2/simpleremoteevent.go b/bridgev2/simpleremoteevent.go new file mode 100644 index 00000000..66058e3e --- /dev/null +++ b/bridgev2/simpleremoteevent.go @@ -0,0 +1,133 @@ +// Copyright (c) 2024 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package bridgev2 + +import ( + "context" + "time" + + "github.com/rs/zerolog" + + "maunium.net/go/mautrix/bridgev2/database" + "maunium.net/go/mautrix/bridgev2/networkid" +) + +// SimpleRemoteEvent is a simple implementation of RemoteEvent that can be used with struct fields and some callbacks. +// +// Using this type is only recommended for simple bridges. More advanced ones should implement +// the remote event interfaces themselves by wrapping the remote network library event types. +// +// Deprecated: use the types in the simplevent package instead. +type SimpleRemoteEvent[T any] struct { + Type RemoteEventType + LogContext func(c zerolog.Context) zerolog.Context + PortalKey networkid.PortalKey + Data T + CreatePortal bool + + ID networkid.MessageID + Sender EventSender + TargetMessage networkid.MessageID + EmojiID networkid.EmojiID + Emoji string + ReactionDBMeta any + Timestamp time.Time + ChatInfoChange *ChatInfoChange + + ResyncChatInfo *ChatInfo + ResyncBackfillNeeded bool + + BackfillData *FetchMessagesResponse + + ConvertMessageFunc func(ctx context.Context, portal *Portal, intent MatrixAPI, data T) (*ConvertedMessage, error) + ConvertEditFunc func(ctx context.Context, portal *Portal, intent MatrixAPI, existing []*database.Message, data T) (*ConvertedEdit, error) +} + +var ( + _ RemoteMessage = (*SimpleRemoteEvent[any])(nil) + _ RemoteEdit = (*SimpleRemoteEvent[any])(nil) + _ RemoteEventWithTimestamp = (*SimpleRemoteEvent[any])(nil) + _ RemoteReaction = (*SimpleRemoteEvent[any])(nil) + _ RemoteReactionWithMeta = (*SimpleRemoteEvent[any])(nil) + _ RemoteReactionRemove = (*SimpleRemoteEvent[any])(nil) + _ RemoteMessageRemove = (*SimpleRemoteEvent[any])(nil) + _ RemoteChatInfoChange = (*SimpleRemoteEvent[any])(nil) + _ RemoteChatResyncWithInfo = (*SimpleRemoteEvent[any])(nil) + _ RemoteChatResyncBackfill = (*SimpleRemoteEvent[any])(nil) + _ RemoteBackfill = (*SimpleRemoteEvent[any])(nil) +) + +func (sre *SimpleRemoteEvent[T]) AddLogContext(c zerolog.Context) zerolog.Context { + return sre.LogContext(c) +} + +func (sre *SimpleRemoteEvent[T]) GetPortalKey() networkid.PortalKey { + return sre.PortalKey +} + +func (sre *SimpleRemoteEvent[T]) GetTimestamp() time.Time { + if sre.Timestamp.IsZero() { + return time.Now() + } + return sre.Timestamp +} + +func (sre *SimpleRemoteEvent[T]) ConvertMessage(ctx context.Context, portal *Portal, intent MatrixAPI) (*ConvertedMessage, error) { + return sre.ConvertMessageFunc(ctx, portal, intent, sre.Data) +} + +func (sre *SimpleRemoteEvent[T]) ConvertEdit(ctx context.Context, portal *Portal, intent MatrixAPI, existing []*database.Message) (*ConvertedEdit, error) { + return sre.ConvertEditFunc(ctx, portal, intent, existing, sre.Data) +} + +func (sre *SimpleRemoteEvent[T]) GetID() networkid.MessageID { + return sre.ID +} + +func (sre *SimpleRemoteEvent[T]) GetSender() EventSender { + return sre.Sender +} + +func (sre *SimpleRemoteEvent[T]) GetTargetMessage() networkid.MessageID { + return sre.TargetMessage +} + +func (sre *SimpleRemoteEvent[T]) GetReactionEmoji() (string, networkid.EmojiID) { + return sre.Emoji, sre.EmojiID +} + +func (sre *SimpleRemoteEvent[T]) GetRemovedEmojiID() networkid.EmojiID { + return sre.EmojiID +} + +func (sre *SimpleRemoteEvent[T]) GetReactionDBMetadata() any { + return sre.ReactionDBMeta +} + +func (sre *SimpleRemoteEvent[T]) GetChatInfoChange(ctx context.Context) (*ChatInfoChange, error) { + return sre.ChatInfoChange, nil +} + +func (sre *SimpleRemoteEvent[T]) GetType() RemoteEventType { + return sre.Type +} + +func (sre *SimpleRemoteEvent[T]) ShouldCreatePortal() bool { + return sre.CreatePortal +} + +func (sre *SimpleRemoteEvent[T]) GetBackfillData(ctx context.Context, portal *Portal) (*FetchMessagesResponse, error) { + return sre.BackfillData, nil +} + +func (sre *SimpleRemoteEvent[T]) CheckNeedsBackfill(ctx context.Context, latestMessage *database.Message) (bool, error) { + return sre.ResyncBackfillNeeded, nil +} + +func (sre *SimpleRemoteEvent[T]) GetChatInfo(ctx context.Context, portal *Portal) (*ChatInfo, error) { + return sre.ResyncChatInfo, nil +} diff --git a/bridgev2/simplevent/chat.go b/bridgev2/simplevent/chat.go new file mode 100644 index 00000000..56e3a6b1 --- /dev/null +++ b/bridgev2/simplevent/chat.go @@ -0,0 +1,92 @@ +// Copyright (c) 2024 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package simplevent + +import ( + "context" + "time" + + "maunium.net/go/mautrix/bridgev2" + "maunium.net/go/mautrix/bridgev2/database" +) + +// ChatResync is a simple implementation of [bridgev2.RemoteChatResync]. +// +// If GetChatInfoFunc is set, it will be used to get the chat info. Otherwise, ChatInfo will be used. +// +// If CheckNeedsBackfillFunc is set, it will be used to determine if backfill is required. +// Otherwise, the latest database message timestamp is compared to LatestMessageTS. +// +// All four fields are optional. +type ChatResync struct { + EventMeta + + ChatInfo *bridgev2.ChatInfo + GetChatInfoFunc func(ctx context.Context, portal *bridgev2.Portal) (*bridgev2.ChatInfo, error) + + LatestMessageTS time.Time + CheckNeedsBackfillFunc func(ctx context.Context, latestMessage *database.Message) (bool, error) + BundledBackfillData any +} + +var ( + _ bridgev2.RemoteChatResync = (*ChatResync)(nil) + _ bridgev2.RemoteChatResyncWithInfo = (*ChatResync)(nil) + _ bridgev2.RemoteChatResyncBackfill = (*ChatResync)(nil) + _ bridgev2.RemoteChatResyncBackfillBundle = (*ChatResync)(nil) +) + +func (evt *ChatResync) CheckNeedsBackfill(ctx context.Context, latestMessage *database.Message) (bool, error) { + if evt.CheckNeedsBackfillFunc != nil { + return evt.CheckNeedsBackfillFunc(ctx, latestMessage) + } else if latestMessage == nil { + return !evt.LatestMessageTS.IsZero(), nil + } else { + return evt.LatestMessageTS.After(latestMessage.Timestamp), nil + } +} + +func (evt *ChatResync) GetBundledBackfillData() any { + return evt.BundledBackfillData +} + +func (evt *ChatResync) GetChatInfo(ctx context.Context, portal *bridgev2.Portal) (*bridgev2.ChatInfo, error) { + if evt.GetChatInfoFunc != nil { + return evt.GetChatInfoFunc(ctx, portal) + } + return evt.ChatInfo, nil +} + +// ChatDelete is a simple implementation of [bridgev2.RemoteChatDelete]. +type ChatDelete struct { + EventMeta + OnlyForMe bool + Children bool +} + +var _ bridgev2.RemoteChatDeleteWithChildren = (*ChatDelete)(nil) + +func (evt *ChatDelete) DeleteOnlyForMe() bool { + return evt.OnlyForMe +} + +func (evt *ChatDelete) DeleteChildren() bool { + return evt.Children +} + +// ChatInfoChange is a simple implementation of [bridgev2.RemoteChatInfoChange]. +type ChatInfoChange struct { + EventMeta + + ChatInfoChange *bridgev2.ChatInfoChange +} + +var _ bridgev2.RemoteChatInfoChange = (*ChatInfoChange)(nil) + +func (evt *ChatInfoChange) GetChatInfoChange(ctx context.Context) (*bridgev2.ChatInfoChange, error) { + return evt.ChatInfoChange, nil +} diff --git a/bridgev2/simplevent/message.go b/bridgev2/simplevent/message.go new file mode 100644 index 00000000..f8f8d7e1 --- /dev/null +++ b/bridgev2/simplevent/message.go @@ -0,0 +1,115 @@ +// Copyright (c) 2024 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package simplevent + +import ( + "context" + + "maunium.net/go/mautrix/bridgev2" + "maunium.net/go/mautrix/bridgev2/database" + "maunium.net/go/mautrix/bridgev2/networkid" +) + +// Message is a simple implementation of [bridgev2.RemoteMessage], [bridgev2.RemoteEdit] and [bridgev2.RemoteMessageUpsert]. +type Message[T any] struct { + EventMeta + Data T + + ID networkid.MessageID + TransactionID networkid.TransactionID + TargetMessage networkid.MessageID + + ConvertMessageFunc func(ctx context.Context, portal *bridgev2.Portal, intent bridgev2.MatrixAPI, data T) (*bridgev2.ConvertedMessage, error) + ConvertEditFunc func(ctx context.Context, portal *bridgev2.Portal, intent bridgev2.MatrixAPI, existing []*database.Message, data T) (*bridgev2.ConvertedEdit, error) + HandleExistingFunc func(ctx context.Context, portal *bridgev2.Portal, intent bridgev2.MatrixAPI, existing []*database.Message, data T) (bridgev2.UpsertResult, error) +} + +var ( + _ bridgev2.RemoteMessage = (*Message[any])(nil) + _ bridgev2.RemoteEdit = (*Message[any])(nil) + _ bridgev2.RemoteMessageUpsert = (*Message[any])(nil) + _ bridgev2.RemoteMessageWithTransactionID = (*Message[any])(nil) +) + +func (evt *Message[T]) ConvertMessage(ctx context.Context, portal *bridgev2.Portal, intent bridgev2.MatrixAPI) (*bridgev2.ConvertedMessage, error) { + return evt.ConvertMessageFunc(ctx, portal, intent, evt.Data) +} + +func (evt *Message[T]) ConvertEdit(ctx context.Context, portal *bridgev2.Portal, intent bridgev2.MatrixAPI, existing []*database.Message) (*bridgev2.ConvertedEdit, error) { + return evt.ConvertEditFunc(ctx, portal, intent, existing, evt.Data) +} + +func (evt *Message[T]) HandleExisting(ctx context.Context, portal *bridgev2.Portal, intent bridgev2.MatrixAPI, existing []*database.Message) (bridgev2.UpsertResult, error) { + return evt.HandleExistingFunc(ctx, portal, intent, existing, evt.Data) +} + +func (evt *Message[T]) GetID() networkid.MessageID { + return evt.ID +} + +func (evt *Message[T]) GetTargetMessage() networkid.MessageID { + return evt.TargetMessage +} + +func (evt *Message[T]) GetTransactionID() networkid.TransactionID { + return evt.TransactionID +} + +// PreConvertedMessage is a simple implementation of [bridgev2.RemoteMessage] with pre-converted data. +type PreConvertedMessage struct { + EventMeta + Data *bridgev2.ConvertedMessage + ID networkid.MessageID + TransactionID networkid.TransactionID + + HandleExistingFunc func(ctx context.Context, portal *bridgev2.Portal, intent bridgev2.MatrixAPI, existing []*database.Message) (bridgev2.UpsertResult, error) +} + +var ( + _ bridgev2.RemoteMessage = (*PreConvertedMessage)(nil) + _ bridgev2.RemoteMessageUpsert = (*PreConvertedMessage)(nil) + _ bridgev2.RemoteMessageWithTransactionID = (*PreConvertedMessage)(nil) +) + +func (evt *PreConvertedMessage) ConvertMessage(ctx context.Context, portal *bridgev2.Portal, intent bridgev2.MatrixAPI) (*bridgev2.ConvertedMessage, error) { + return evt.Data, nil +} + +func (evt *PreConvertedMessage) GetID() networkid.MessageID { + return evt.ID +} + +func (evt *PreConvertedMessage) GetTransactionID() networkid.TransactionID { + return evt.TransactionID +} + +func (evt *PreConvertedMessage) HandleExisting(ctx context.Context, portal *bridgev2.Portal, intent bridgev2.MatrixAPI, existing []*database.Message) (bridgev2.UpsertResult, error) { + if evt.HandleExistingFunc == nil { + return bridgev2.UpsertResult{}, nil + } + return evt.HandleExistingFunc(ctx, portal, intent, existing) +} + +type MessageRemove struct { + EventMeta + + TargetMessage networkid.MessageID + OnlyForMe bool +} + +var ( + _ bridgev2.RemoteMessageRemove = (*MessageRemove)(nil) + _ bridgev2.RemoteDeleteOnlyForMe = (*MessageRemove)(nil) +) + +func (evt *MessageRemove) GetTargetMessage() networkid.MessageID { + return evt.TargetMessage +} + +func (evt *MessageRemove) DeleteOnlyForMe() bool { + return evt.OnlyForMe +} diff --git a/bridgev2/simplevent/meta.go b/bridgev2/simplevent/meta.go new file mode 100644 index 00000000..449a8773 --- /dev/null +++ b/bridgev2/simplevent/meta.go @@ -0,0 +1,144 @@ +// Copyright (c) 2024 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package simplevent + +import ( + "context" + "time" + + "github.com/rs/zerolog" + + "maunium.net/go/mautrix/bridgev2" + "maunium.net/go/mautrix/bridgev2/networkid" +) + +// EventMeta is a struct containing metadata fields used by most event types. +type EventMeta struct { + Type bridgev2.RemoteEventType + LogContext func(c zerolog.Context) zerolog.Context + PortalKey networkid.PortalKey + UncertainReceiver bool + Sender bridgev2.EventSender + CreatePortal bool + Timestamp time.Time + StreamOrder int64 + + PreHandleFunc func(context.Context, *bridgev2.Portal) + PostHandleFunc func(context.Context, *bridgev2.Portal) +} + +var ( + _ bridgev2.RemoteEvent = (*EventMeta)(nil) + _ bridgev2.RemoteEventWithUncertainPortalReceiver = (*EventMeta)(nil) + _ bridgev2.RemoteEventThatMayCreatePortal = (*EventMeta)(nil) + _ bridgev2.RemoteEventWithTimestamp = (*EventMeta)(nil) + _ bridgev2.RemoteEventWithStreamOrder = (*EventMeta)(nil) + _ bridgev2.RemotePreHandler = (*EventMeta)(nil) + _ bridgev2.RemotePostHandler = (*EventMeta)(nil) +) + +func (evt *EventMeta) AddLogContext(c zerolog.Context) zerolog.Context { + if evt.LogContext == nil { + return c + } + return evt.LogContext(c) +} + +func (evt *EventMeta) GetPortalKey() networkid.PortalKey { + return evt.PortalKey +} + +func (evt *EventMeta) PortalReceiverIsUncertain() bool { + return evt.UncertainReceiver +} + +func (evt *EventMeta) GetTimestamp() time.Time { + if evt.Timestamp.IsZero() { + return time.Now() + } + return evt.Timestamp +} + +func (evt *EventMeta) GetStreamOrder() int64 { + return evt.StreamOrder +} + +func (evt *EventMeta) GetSender() bridgev2.EventSender { + return evt.Sender +} + +func (evt *EventMeta) GetType() bridgev2.RemoteEventType { + return evt.Type +} + +func (evt *EventMeta) ShouldCreatePortal() bool { + return evt.CreatePortal +} + +func (evt *EventMeta) PreHandle(ctx context.Context, portal *bridgev2.Portal) { + if evt.PreHandleFunc != nil { + evt.PreHandleFunc(ctx, portal) + } +} + +func (evt *EventMeta) PostHandle(ctx context.Context, portal *bridgev2.Portal) { + if evt.PostHandleFunc != nil { + evt.PostHandleFunc(ctx, portal) + } +} + +func (evt EventMeta) WithType(t bridgev2.RemoteEventType) EventMeta { + evt.Type = t + return evt +} + +func (evt EventMeta) WithLogContext(f func(c zerolog.Context) zerolog.Context) EventMeta { + evt.LogContext = f + return evt +} + +func (evt EventMeta) WithMoreLogContext(f func(c zerolog.Context) zerolog.Context) EventMeta { + origFunc := evt.LogContext + if origFunc == nil { + evt.LogContext = f + return evt + } + evt.LogContext = func(c zerolog.Context) zerolog.Context { + return f(origFunc(c)) + } + return evt +} + +func (evt EventMeta) WithPortalKey(p networkid.PortalKey) EventMeta { + evt.PortalKey = p + return evt +} + +func (evt EventMeta) WithUncertainReceiver(u bool) EventMeta { + evt.UncertainReceiver = u + return evt +} + +func (evt EventMeta) WithSender(s bridgev2.EventSender) EventMeta { + evt.Sender = s + return evt +} + +func (evt EventMeta) WithCreatePortal(c bool) EventMeta { + evt.CreatePortal = c + return evt +} + +func (evt EventMeta) WithTimestamp(t time.Time) EventMeta { + evt.Timestamp = t + return evt +} + +func (evt EventMeta) WithStreamOrder(s int64) EventMeta { + evt.StreamOrder = s + return evt +} diff --git a/bridgev2/simplevent/reaction.go b/bridgev2/simplevent/reaction.go new file mode 100644 index 00000000..34e0b025 --- /dev/null +++ b/bridgev2/simplevent/reaction.go @@ -0,0 +1,67 @@ +// Copyright (c) 2024 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package simplevent + +import ( + "maunium.net/go/mautrix/bridgev2" + "maunium.net/go/mautrix/bridgev2/networkid" +) + +// Reaction is a simple implementation of [bridgev2.RemoteReaction] and [bridgev2.RemoteReactionRemove]. +type Reaction struct { + EventMeta + TargetMessage networkid.MessageID + EmojiID networkid.EmojiID + Emoji string + ExtraContent map[string]any + ReactionDBMeta any +} + +var ( + _ bridgev2.RemoteReaction = (*Reaction)(nil) + _ bridgev2.RemoteReactionWithMeta = (*Reaction)(nil) + _ bridgev2.RemoteReactionWithExtraContent = (*Reaction)(nil) + _ bridgev2.RemoteReactionRemove = (*Reaction)(nil) +) + +func (evt *Reaction) GetTargetMessage() networkid.MessageID { + return evt.TargetMessage +} + +func (evt *Reaction) GetReactionEmoji() (string, networkid.EmojiID) { + return evt.Emoji, evt.EmojiID +} + +func (evt *Reaction) GetRemovedEmojiID() networkid.EmojiID { + return evt.EmojiID +} + +func (evt *Reaction) GetReactionDBMetadata() any { + return evt.ReactionDBMeta +} + +func (evt *Reaction) GetReactionExtraContent() map[string]any { + return evt.ExtraContent +} + +type ReactionSync struct { + EventMeta + TargetMessage networkid.MessageID + Reactions *bridgev2.ReactionSyncData +} + +var ( + _ bridgev2.RemoteReactionSync = (*ReactionSync)(nil) +) + +func (evt *ReactionSync) GetTargetMessage() networkid.MessageID { + return evt.TargetMessage +} + +func (evt *ReactionSync) GetReactions() *bridgev2.ReactionSyncData { + return evt.Reactions +} diff --git a/bridgev2/simplevent/receipt.go b/bridgev2/simplevent/receipt.go new file mode 100644 index 00000000..41614e40 --- /dev/null +++ b/bridgev2/simplevent/receipt.go @@ -0,0 +1,77 @@ +// Copyright (c) 2024 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package simplevent + +import ( + "time" + + "maunium.net/go/mautrix/bridgev2" + "maunium.net/go/mautrix/bridgev2/networkid" +) + +type Receipt struct { + EventMeta + + LastTarget networkid.MessageID + Targets []networkid.MessageID + ReadUpTo time.Time + + ReadUpToStreamOrder int64 +} + +var ( + _ bridgev2.RemoteReadReceipt = (*Receipt)(nil) + _ bridgev2.RemoteDeliveryReceipt = (*Receipt)(nil) +) + +func (evt *Receipt) GetLastReceiptTarget() networkid.MessageID { + return evt.LastTarget +} + +func (evt *Receipt) GetReceiptTargets() []networkid.MessageID { + return evt.Targets +} + +func (evt *Receipt) GetReadUpTo() time.Time { + return evt.ReadUpTo +} + +func (evt *Receipt) GetReadUpToStreamOrder() int64 { + return evt.ReadUpToStreamOrder +} + +type MarkUnread struct { + EventMeta + Unread bool +} + +var ( + _ bridgev2.RemoteMarkUnread = (*MarkUnread)(nil) +) + +func (evt *MarkUnread) GetUnread() bool { + return evt.Unread +} + +type Typing struct { + EventMeta + Timeout time.Duration + Type bridgev2.TypingType +} + +var ( + _ bridgev2.RemoteTyping = (*Typing)(nil) + _ bridgev2.RemoteTypingWithType = (*Typing)(nil) +) + +func (evt *Typing) GetTimeout() time.Duration { + return evt.Timeout +} + +func (evt *Typing) GetTypingType() bridgev2.TypingType { + return evt.Type +} diff --git a/bridgev2/space.go b/bridgev2/space.go new file mode 100644 index 00000000..2ca2bce3 --- /dev/null +++ b/bridgev2/space.go @@ -0,0 +1,193 @@ +// Copyright (c) 2024 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package bridgev2 + +import ( + "context" + "fmt" + "time" + + "github.com/rs/zerolog" + + "maunium.net/go/mautrix" + "maunium.net/go/mautrix/bridgev2/database" + "maunium.net/go/mautrix/event" + "maunium.net/go/mautrix/id" +) + +func (ul *UserLogin) MarkInPortal(ctx context.Context, portal *Portal) { + if ul.inPortalCache.Has(portal.PortalKey) { + return + } + userPortal, err := ul.Bridge.DB.UserPortal.GetOrCreate(ctx, ul.UserLogin, portal.PortalKey) + if err != nil { + zerolog.Ctx(ctx).Err(err).Msg("Failed to ensure user portal row exists") + return + } + ul.inPortalCache.Add(portal.PortalKey) + if portal.MXID != "" { + dp := ul.User.DoublePuppet(ctx) + if dp != nil { + err = dp.EnsureJoined(ctx, portal.MXID) + if err != nil { + zerolog.Ctx(ctx).Err(err).Msg("Failed to ensure double puppet is joined to portal") + } + } else { + err = ul.Bridge.Bot.EnsureInvited(ctx, portal.MXID, ul.UserMXID) + if err != nil { + zerolog.Ctx(ctx).Err(err).Msg("Failed to ensure user is invited to portal") + } + } + if ul.Bridge.Config.PersonalFilteringSpaces && (userPortal.InSpace == nil || !*userPortal.InSpace) { + go ul.tryAddPortalToSpace(context.WithoutCancel(ctx), portal, userPortal.CopyWithoutValues()) + } + } +} + +func (ul *UserLogin) tryAddPortalToSpace(ctx context.Context, portal *Portal, userPortal *database.UserPortal) { + err := ul.AddPortalToSpace(ctx, portal, userPortal) + if err != nil { + zerolog.Ctx(ctx).Err(err).Msg("Failed to add portal to space") + } +} + +func (ul *UserLogin) AddPortalToSpace(ctx context.Context, portal *Portal, userPortal *database.UserPortal) error { + if portal.MXID == "" || portal.Parent != nil { + return nil + } + spaceRoom, err := ul.GetSpaceRoom(ctx) + if err != nil { + return fmt.Errorf("failed to get space room: %w", err) + } else if spaceRoom == "" { + return nil + } + err = portal.toggleSpace(ctx, spaceRoom, false, false) + if err != nil { + return fmt.Errorf("failed to add portal to space: %w", err) + } + inSpace := true + userPortal.InSpace = &inSpace + err = ul.Bridge.DB.UserPortal.Put(ctx, userPortal) + if err != nil { + return fmt.Errorf("failed to save user portal row: %w", err) + } + zerolog.Ctx(ctx).Debug().Stringer("space_room_id", spaceRoom).Msg("Added portal to space") + return nil +} + +func (portal *Portal) createParentAndAddToSpace(ctx context.Context, source *UserLogin) { + err := portal.Parent.CreateMatrixRoom(ctx, source, nil) + if err != nil { + zerolog.Ctx(ctx).Err(err).Msg("Failed to create parent portal") + } else { + portal.addToParentSpaceAndSave(ctx, true) + } +} + +func (portal *Portal) addToParentSpaceAndSave(ctx context.Context, save bool) { + err := portal.toggleSpace(ctx, portal.Parent.MXID, true, false) + if err != nil { + zerolog.Ctx(ctx).Err(err).Stringer("space_mxid", portal.Parent.MXID).Msg("Failed to add portal to space") + } else { + portal.InSpace = true + if save { + err = portal.Save(ctx) + if err != nil { + zerolog.Ctx(ctx).Err(err).Msg("Failed to save portal to database after adding to space") + } + } + } +} + +func (portal *Portal) toggleSpace(ctx context.Context, spaceID id.RoomID, canonical, remove bool) error { + via := []string{portal.Bridge.Matrix.ServerName()} + if remove { + via = nil + } + _, err := portal.Bridge.Bot.SendState(ctx, spaceID, event.StateSpaceChild, portal.MXID.String(), &event.Content{ + Parsed: &event.SpaceChildEventContent{ + Via: via, + }, + }, time.Now()) + if err != nil { + return err + } + if canonical { + _, err = portal.Bridge.Bot.SendState(ctx, portal.MXID, event.StateSpaceParent, spaceID.String(), &event.Content{ + Parsed: &event.SpaceParentEventContent{ + Via: via, + Canonical: !remove, + }, + }, time.Now()) + if err != nil { + return err + } + } + return nil +} + +func (ul *UserLogin) GetSpaceRoom(ctx context.Context) (id.RoomID, error) { + if !ul.Bridge.Config.PersonalFilteringSpaces { + return ul.SpaceRoom, nil + } + ul.spaceCreateLock.Lock() + defer ul.spaceCreateLock.Unlock() + if ul.SpaceRoom != "" { + return ul.SpaceRoom, nil + } + netName := ul.Bridge.Network.GetName() + var err error + autoJoin := ul.Bridge.Matrix.GetCapabilities().AutoJoinInvites + doublePuppet := ul.User.DoublePuppet(ctx) + req := &mautrix.ReqCreateRoom{ + Visibility: "private", + Name: fmt.Sprintf("%s (%s)", netName.DisplayName, ul.RemoteName), + Topic: fmt.Sprintf("Your %s bridged chats - %s", netName.DisplayName, ul.RemoteName), + InitialState: []*event.Event{{ + Type: event.StateRoomAvatar, + Content: event.Content{ + Parsed: &event.RoomAvatarEventContent{ + URL: netName.NetworkIcon, + }, + }, + }}, + CreationContent: map[string]any{ + "type": event.RoomTypeSpace, + }, + PowerLevelOverride: &event.PowerLevelsEventContent{ + Users: map[id.UserID]int{ + ul.Bridge.Bot.GetMXID(): 9001, + ul.UserMXID: 50, + }, + }, + Invite: []id.UserID{ul.UserMXID}, + } + if autoJoin { + req.BeeperInitialMembers = []id.UserID{ul.UserMXID} + // TODO remove this after initial_members is supported in hungryserv + req.BeeperAutoJoinInvites = true + } + pfc, ok := ul.Client.(PersonalFilteringCustomizingNetworkAPI) + if ok { + pfc.CustomizePersonalFilteringSpace(req) + } + ul.SpaceRoom, err = ul.Bridge.Bot.CreateRoom(ctx, req) + if err != nil { + return "", fmt.Errorf("failed to create space room: %w", err) + } + if !autoJoin && doublePuppet != nil { + err = doublePuppet.EnsureJoined(ctx, ul.SpaceRoom) + if err != nil { + zerolog.Ctx(ctx).Err(err).Msg("Failed to auto-join created space room with double puppet") + } + } + err = ul.Save(ctx) + if err != nil { + return "", fmt.Errorf("failed to save space room ID: %w", err) + } + return ul.SpaceRoom, nil +} diff --git a/bridge/status/bridgestate.go b/bridgev2/status/bridgestate.go similarity index 63% rename from bridge/status/bridgestate.go rename to bridgev2/status/bridgestate.go index bb98e283..5925dd4f 100644 --- a/bridge/status/bridgestate.go +++ b/bridgev2/status/bridgestate.go @@ -12,15 +12,17 @@ import ( "encoding/json" "fmt" "io" + "maps" "net/http" "reflect" "time" "github.com/tidwall/sjson" "go.mau.fi/util/jsontime" - "golang.org/x/exp/maps" "maunium.net/go/mautrix" + "maunium.net/go/mautrix/bridgev2/networkid" + "maunium.net/go/mautrix/event" "maunium.net/go/mautrix/id" ) @@ -52,6 +54,67 @@ const ( StateLoggedOut BridgeStateEvent = "LOGGED_OUT" ) +func (e BridgeStateEvent) IsValid() bool { + switch e { + case + StateStarting, + StateUnconfigured, + StateRunning, + StateBridgeUnreachable, + StateConnecting, + StateBackfilling, + StateConnected, + StateTransientDisconnect, + StateBadCredentials, + StateUnknownError, + StateLoggedOut: + return true + default: + return false + } +} + +type BridgeStateUserAction string + +const ( + UserActionOpenNative BridgeStateUserAction = "OPEN_NATIVE" + UserActionRelogin BridgeStateUserAction = "RELOGIN" + UserActionRestart BridgeStateUserAction = "RESTART" +) + +type RemoteProfile struct { + Phone string `json:"phone,omitempty"` + Email string `json:"email,omitempty"` + Username string `json:"username,omitempty"` + Name string `json:"name,omitempty"` + Avatar id.ContentURIString `json:"avatar,omitempty"` + + AvatarFile *event.EncryptedFileInfo `json:"avatar_file,omitempty"` +} + +func coalesce[T ~string](a, b T) T { + if a != "" { + return a + } + return b +} + +func (rp *RemoteProfile) Merge(other RemoteProfile) RemoteProfile { + other.Phone = coalesce(rp.Phone, other.Phone) + other.Email = coalesce(rp.Email, other.Email) + other.Username = coalesce(rp.Username, other.Username) + other.Name = coalesce(rp.Name, other.Name) + other.Avatar = coalesce(rp.Avatar, other.Avatar) + if rp.AvatarFile != nil { + other.AvatarFile = rp.AvatarFile + } + return other +} + +func (rp *RemoteProfile) IsZero() bool { + return rp == nil || (rp.Phone == "" && rp.Email == "" && rp.Username == "" && rp.Name == "" && rp.Avatar == "" && rp.AvatarFile == nil) +} + type BridgeState struct { StateEvent BridgeStateEvent `json:"state_event"` Timestamp jsontime.Unix `json:"timestamp"` @@ -61,9 +124,12 @@ type BridgeState struct { Error BridgeStateErrorCode `json:"error,omitempty"` Message string `json:"message,omitempty"` - UserID id.UserID `json:"user_id,omitempty"` - RemoteID string `json:"remote_id,omitempty"` - RemoteName string `json:"remote_name,omitempty"` + UserAction BridgeStateUserAction `json:"user_action,omitempty"` + + UserID id.UserID `json:"user_id,omitempty"` + RemoteID networkid.UserLoginID `json:"remote_id,omitempty"` + RemoteName string `json:"remote_name,omitempty"` + RemoteProfile RemoteProfile `json:"remote_profile,omitzero"` Reason string `json:"reason,omitempty"` Info map[string]interface{} `json:"info,omitempty"` @@ -75,25 +141,15 @@ type GlobalBridgeState struct { } type BridgeStateFiller interface { - GetMXID() id.UserID - GetRemoteID() string - GetRemoteName() string -} - -type CustomBridgeStateFiller interface { - BridgeStateFiller FillBridgeState(BridgeState) BridgeState } +// Deprecated: use BridgeStateFiller instead +type StandaloneCustomBridgeStateFiller = BridgeStateFiller + func (pong BridgeState) Fill(user BridgeStateFiller) BridgeState { if user != nil { - pong.UserID = user.GetMXID() - pong.RemoteID = user.GetRemoteID() - pong.RemoteName = user.GetRemoteName() - - if custom, ok := user.(CustomBridgeStateFiller); ok { - pong = custom.FillBridgeState(pong) - } + pong = user.FillBridgeState(pong) } pong.Timestamp = jsontime.UnixNow() @@ -151,6 +207,9 @@ func (pong *BridgeState) SendHTTP(ctx context.Context, url, token string) error func (pong *BridgeState) ShouldDeduplicate(newPong *BridgeState) bool { return pong != nil && pong.StateEvent == newPong.StateEvent && + pong.RemoteName == newPong.RemoteName && + pong.UserAction == newPong.UserAction && + pong.RemoteProfile == newPong.RemoteProfile && pong.Error == newPong.Error && maps.EqualFunc(pong.Info, newPong.Info, reflect.DeepEqual) && pong.Timestamp.Add(time.Duration(pong.TTL)*time.Second).After(time.Now()) diff --git a/bridgev2/status/localbridgestate.go b/bridgev2/status/localbridgestate.go new file mode 100644 index 00000000..3ad66538 --- /dev/null +++ b/bridgev2/status/localbridgestate.go @@ -0,0 +1,23 @@ +package status + +type LocalBridgeAccountState string + +const ( + // LocalBridgeAccountStateSetup means the user wants this account to be setup and connected + LocalBridgeAccountStateSetup LocalBridgeAccountState = "SETUP" + // LocalBridgeAccountStateDeleted means the user wants this account to be deleted + LocalBridgeAccountStateDeleted LocalBridgeAccountState = "DELETED" +) + +type LocalBridgeDeviceState string + +const ( + // LocalBridgeDeviceStateSetup means this device is setup to be connected to this account + LocalBridgeDeviceStateSetup LocalBridgeDeviceState = "SETUP" + // LocalBridgeDeviceStateLoggedOut means the user has logged this particular device out while wanting their other devices to remain setup + LocalBridgeDeviceStateLoggedOut LocalBridgeDeviceState = "LOGGED_OUT" + // LocalBridgeDeviceStateError means this particular device has fallen into a persistent error state that may need user intervention to fix + LocalBridgeDeviceStateError LocalBridgeDeviceState = "ERROR" + // LocalBridgeDeviceStateDeleted means this particular device has cleaned up after the account as a whole was requested to be deleted + LocalBridgeDeviceStateDeleted LocalBridgeDeviceState = "DELETED" +) diff --git a/bridge/status/messagecheckpoint.go b/bridgev2/status/messagecheckpoint.go similarity index 96% rename from bridge/status/messagecheckpoint.go rename to bridgev2/status/messagecheckpoint.go index ea859b84..b3c05f4f 100644 --- a/bridge/status/messagecheckpoint.go +++ b/bridgev2/status/messagecheckpoint.go @@ -169,13 +169,13 @@ type CheckpointsJSON struct { Checkpoints []*MessageCheckpoint `json:"checkpoints"` } -func (cj *CheckpointsJSON) SendHTTP(endpoint string, token string) error { +func (cj *CheckpointsJSON) SendHTTP(ctx context.Context, cli *http.Client, endpoint string, token string) error { var body bytes.Buffer if err := json.NewEncoder(&body).Encode(cj); err != nil { return fmt.Errorf("failed to encode message checkpoint JSON: %w", err) } - ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + ctx, cancel := context.WithTimeout(ctx, 5*time.Second) defer cancel() req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, &body) if err != nil { @@ -186,7 +186,10 @@ func (cj *CheckpointsJSON) SendHTTP(endpoint string, token string) error { req.Header.Set("User-Agent", mautrix.DefaultUserAgent+" (checkpoint sender)") req.Header.Set("Content-Type", "application/json") - resp, err := http.DefaultClient.Do(req) + if cli == nil { + cli = http.DefaultClient + } + resp, err := cli.Do(req) if err != nil { return mautrix.HTTPError{ Request: req, diff --git a/bridgev2/unorganized-docs/FEATURES.md b/bridgev2/unorganized-docs/FEATURES.md new file mode 100644 index 00000000..73da364d --- /dev/null +++ b/bridgev2/unorganized-docs/FEATURES.md @@ -0,0 +1,49 @@ +# Megabridge features + +* [ ] Messages + * [x] Text (incl. formatting and mentions) + * [x] Attachments + * [ ] Polls + * [x] Replies + * [x] Threads + * [x] Edits + * [x] Reactions + * [x] Reaction mass-syncing + * [x] Deletions + * [x] Message status events and error notices + * [x] Backfilling history +* [x] Login +* [x] Logout +* [x] Re-login after credential expiry +* [x] Disappearing messages +* [x] Read receipts +* [ ] Presence +* [x] Typing notifications +* [x] Spaces +* [x] Relay mode +* [x] Chat metadata + * [x] Archive/low priority + * [x] Pin/favorite + * [x] Mark unread + * [x] Mute status + * [x] Temporary mutes ("snooze") +* [x] User metadata (name/avatar) +* [x] Group metadata + * [x] Initial meta and full resyncs + * [x] Name, avatar, topic + * [x] Members + * [x] Permissions + * [x] Change events + * [x] Name, avatar, topic + * [x] Members (join, leave, invite, kick, ban, knock) + * [x] Permissions (promote, demote) +* [ ] Misc actions + * [ ] Invites / accepting message requests + * [x] Create group + * [x] Create DM + * [x] Get contact list + * [x] Check if identifier is on remote network + * [x] Search users on remote network + * [ ] Delete chat + * [ ] Report spam +* [ ] Custom emojis diff --git a/bridgev2/unorganized-docs/README.md b/bridgev2/unorganized-docs/README.md new file mode 100644 index 00000000..62cc731a --- /dev/null +++ b/bridgev2/unorganized-docs/README.md @@ -0,0 +1,66 @@ +# Megabridge +Megabridge, also known as bridgev2 (final naming is subject to change), is a +new high-level framework for writing puppeting Matrix bridges with hopefully +minimal boilerplate code. + +## General architecture +Megabridge is split into three components: network connectors, the central +bridge module, and Matrix connectors. + +* Network connectors are responsible for connecting to the remote (non-Matrix) + network and handling all the protocol-specific details. +* The central bridge module has most of the generic bridge logic, such as + keeping track of portal mappings and handling messages. +* Matrix connectors are responsible for connecting to Matrix. Initially there + will be two Matrix connectors: one for the standard setup that connects to + a Matrix homeserver as an application service, and another for Beeper's local + bridge system. However, in the future there could be a third connector which + uses a single bot account and [MSC4144] instead of an appservice with ghost + users. + + [MSC4144]: https://github.com/matrix-org/matrix-spec-proposals/pull/4144 + +The central bridge module defines interfaces that it uses to interact with the +connectors on both sides. Additionally, the connectors are allowed to directly +call interface methods on other side. + +## Getting started with a new network connector +To create a new network connector, you need to implement the +`NetworkConnector`, `LoginProcess`, `NetworkAPI` and `RemoteEvent` interfaces. + +* `NetworkConnector` is the main entry point to the remote network. It is + responsible for general non-user-specific things, as well as creating + `NetworkAPI`s and starting login flows. +* `LoginProcess` is a state machine for logging into the remote network. +* `NetworkAPI` is the remote network client for a single login. It is + responsible for maintaining the connection to the remote network, receiving + incoming events, sending outgoing events, and fetching information like + chat/user metadata. +* `RemoteEvent` represents a single event from the remote network, such as a + message or a reaction. When the NetworkAPI receives an event, it should create + a `RemoteEvent` object and pass it to the bridge using `Bridge.QueueRemoteEvent`. + +### Login +Logins are implemented by combining three types of steps: + +* `user_input` asks the user to enter some information, such as a phone number, + username, email, password, or 2FA code. +* `cookies` either asks the user to extract cookies from their browser, or opens + a webview to do it automatically (depending on whether the login is being done + via bridge commands or a more advanced client). +* `display_and_wait` displays a QR code or other data to the user and waits until + the remote network accepts the login. + +The general flow is: + +1. Login handler (bridge command or client) calls `NetworkConnector.GetLoginFlows` + to get available login flows, and asks the user to pick one (or alternatively + automatically picks the first one if there's only one option). +2. Login handler calls `NetworkConnector.CreateLogin` with the chosen flow ID and + the network connector returns a `LoginProcess` object that remembers the user + and flow. +3. Login handler calls `LoginProcess.Start` to get the first step. +4. Login handler calls the appropriate functions (`Wait`, `SubmitUserInput` or + `SubmitCookies`) based on the step data as many times as needed. +5. When the login is done, the login process creates the `UserLogin` object and + returns a `complete` step. diff --git a/bridgev2/unorganized-docs/incoming-matrix-message.uml b/bridgev2/unorganized-docs/incoming-matrix-message.uml new file mode 100644 index 00000000..ae13ee74 --- /dev/null +++ b/bridgev2/unorganized-docs/incoming-matrix-message.uml @@ -0,0 +1,23 @@ +title Bridge v2 incoming Matrix message + +participant Network Library +participant Network Connector +participant Bridge +participant Portal +participant Database +participant Matrix + +Matrix->Bridge: QueueMatrixEvent(evt) +note over Bridge: GetPortalByID(evt.GetPortalID()) +Bridge->Portal: portal.events <- evt +loop event queue consumer + Portal->+Portal: \n evt := <-portal.events + note over Portal: Check for edit, reply/thread, etc + Portal->+Network Connector: HandleMatrixMessage(evt, replyTo) + Network Connector->Network Connector: msg := ConvertMatrixMessage(evt) + Network Connector->+Network Library: SendMessage(msg) + Network Library->-Network Connector: OK + Network Connector->-Portal: *database.Message{msg.ID} + Portal->-Database: Message.Insert() + Portal->Matrix: Success checkpoint +end diff --git a/bridgev2/unorganized-docs/incoming-remote-message.uml b/bridgev2/unorganized-docs/incoming-remote-message.uml new file mode 100644 index 00000000..f86d6e65 --- /dev/null +++ b/bridgev2/unorganized-docs/incoming-remote-message.uml @@ -0,0 +1,22 @@ +title Bridge v2 incoming remote message + +participant Network Library +participant Network Connector +participant Bridge +participant Portal +participant Database +participant Matrix + +Network Library->Network Connector: New event +Network Connector->Bridge: QueueRemoteEvent(evt) +note over Bridge: GetPortalByID(evt.GetPortalID()) +Bridge->Portal: portal.events <- evt +loop event queue consumer + Portal->+Portal: \n evt := <-portal.events + note over Portal: CreateMatrixRoom() if applicable + Portal->+Network Connector: ConvertRemoteMessage(evt) + Network Connector->-Portal: *ConvertedMessage + Portal->+Matrix: SendMessage(convertedMsg) + Matrix->-Portal: event ID + Portal->-Database: Message.Insert() +end diff --git a/bridgev2/unorganized-docs/login-step.schema.json b/bridgev2/unorganized-docs/login-step.schema.json new file mode 100644 index 00000000..b039354f --- /dev/null +++ b/bridgev2/unorganized-docs/login-step.schema.json @@ -0,0 +1,155 @@ +{ + "$schema": "https://json-schema.org/draft/2020-12/schema", + "$id": "https://go.mau.fi/mautrix/bridgev2/login-step.json", + "title": "Login step data", + "type": "object", + "properties": { + "type": { + "type": "string", + "enum": ["user_input", "cookies", "display_and_wait", "complete"] + }, + "step_id": { + "type": "string", + "description": "An unique ID identifying this step. This can be used to implement special behavior in clients." + }, + "instructions": { + "type": "string", + "description": "Human-readable instructions for completing this login step." + }, + "user_input": { + "type": "object", + "title": "User input params", + "description": "Parameters for the `user_input` login type", + "properties": { + "fields": { + "type": "array", + "description": "The list of fields that the user must fill", + "items": { + "title": "Field", + "description": "A field that the user must fill", + "type": "object", + "properties": { + "type": { + "type": "string", + "enum": ["username", "phone_number", "email", "password", "2fa_code", "token"] + }, + "id": { + "type": "string", + "description": "The ID of the field. This should be used when submitting the form.", + "examples": ["uid", "email", "2fa_password", "meow"] + }, + "name": { + "type": "string", + "description": "The name of the field shown to the user", + "examples": ["Username", "Password", "Phone number", "2FA code", "Meow"] + }, + "description": { + "type": "string", + "description": "The description of the field shown to the user", + "examples": ["Include the country code with a +"] + }, + "pattern": { + "type": "string", + "description": "A regular expression that the field value must match" + } + }, + "required": ["type", "id", "name"] + } + } + }, + "required": ["fields"] + }, + "cookies": { + "type": "object", + "title": "Cookie params", + "description": "Parameters for the `cookies` login type", + "properties": { + "url": { + "type": "string", + "description": "The URL to open when using a webview to extract cookies" + }, + "user_agent": { + "type": "string", + "description": "The user agent to use when opening the URL" + }, + "fields": { + "type": "array", + "description": "The list of cookies (or other stored data) that must be extracted", + "items": { + "title": "Cookie Field", + "description": "A cookie (or other stored data) that must be extracted", + "type": "object", + "properties": { + "type": { + "type": "string", + "description": "The type of data to extract", + "enum": ["cookie", "local_storage", "request_header", "request_body", "special"] + }, + "name": { + "type": "string", + "description": "The name of the cookie or key in the storage" + }, + "request_url_regex": { + "type": "string", + "description": "For the `request_header` and `request_body` types, a regex that matches the URLs from which the values can be extracted." + }, + "cookie_domain": { + "type": "string", + "description": "For the `cookie` type, the domain of the cookie" + } + }, + "required": ["type", "name"] + } + }, + "extract_js": { + "type": "string", + "description": "JavaScript code that can be evaluated inside the webview to extract the special keys" + } + }, + "required": ["url"] + }, + "display_and_wait": { + "type": "object", + "title": "Display and wait params", + "description": "Parameters for the `display_and_wait` login type", + "properties": { + "type": { + "type": "string", + "description": "The type of thing to display", + "enum": ["qr", "emoji", "code", "nothing"] + }, + "data": { + "type": "string", + "description": "The thing to display (raw data for QR, unicode emoji for emoji, plain string for code)" + }, + "image_url": { + "type": "string", + "description": "An image containing the thing to display. If present, this is recommended over using data directly. For emojis, the URL to the canonical image representation of the emoji" + } + }, + "required": ["type"] + }, + "complete": { + "type": "object", + "title": "Login complete information", + "description": "Information about a successful login", + "properties": { + "user_login_id": { + "type": "string", + "description": "The ID of the user login entry" + } + } + } + }, + "required": [ + "type", + "step_id", + "instructions" + ], + "oneOf": [ + {"title":"User input type","properties":{"type": {"type":"string","const": "user_input"}}, "required": ["user_input"]}, + {"title":"Cookies type","properties":{"type": {"type":"string","const": "cookies"}}, "required": ["cookies"]}, + {"title":"Display and wait type","properties":{"type": {"type":"string","const": "display_and_wait"}}, "required": ["display_and_wait"]}, + {"title":"Login complete","properties":{"type": {"type":"string","const": "complete"}}} + ] +} diff --git a/bridgev2/unorganized-docs/login-steps.uml b/bridgev2/unorganized-docs/login-steps.uml new file mode 100644 index 00000000..5af9c88e --- /dev/null +++ b/bridgev2/unorganized-docs/login-steps.uml @@ -0,0 +1,43 @@ +title Login flows + +participant User +participant Client +participant Bridge +participant User's device + +alt Username+Password/Phone number/2FA code + Client->+Bridge: /login + Bridge->-Client: step=user_input, fields=[...] + Client->User: input box(es) + User->Client: submit input + Client->+Bridge: /login/user_input + Bridge->-Client: success=true, step=next step +end + +alt Cookies + Client->+Bridge: /login + Bridge->-Client: step=cookies, url=..., cookies=[...] + Client->User: webview + User->Client: login in webview + Client->Bridge: /login/cookies + Bridge->-Client: success=true, step=next step +end + +alt QR/Emoji/Code + Client->+Bridge: /login + Bridge->-Client: step=display_and_wait, data=... + Client->+Bridge: /login/wait + Client->User: display QR/emoji/code + loop Refresh QR + Bridge->-Client: step=display_and_wait, data=new QR + Client->User: display new QR + Client->+Bridge: /login/wait + end +else Successful case + User->User's device: Scan QR/tap emoji/enter code + User's device->Bridge: Login successful + Bridge->-Client: success=true, step=next step +else Error + Bridge->Client: error=timeout + Client->User: error +end diff --git a/bridgev2/user.go b/bridgev2/user.go new file mode 100644 index 00000000..9a7896d6 --- /dev/null +++ b/bridgev2/user.go @@ -0,0 +1,275 @@ +// Copyright (c) 2024 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package bridgev2 + +import ( + "context" + "fmt" + "strings" + "sync" + "unsafe" + + "github.com/rs/zerolog" + "golang.org/x/exp/maps" + "golang.org/x/exp/slices" + + "maunium.net/go/mautrix" + "maunium.net/go/mautrix/bridgev2/bridgeconfig" + "maunium.net/go/mautrix/bridgev2/database" + "maunium.net/go/mautrix/bridgev2/networkid" + "maunium.net/go/mautrix/event" + "maunium.net/go/mautrix/id" +) + +type User struct { + *database.User + Bridge *Bridge + Log zerolog.Logger + + CommandState unsafe.Pointer + Permissions bridgeconfig.Permissions + + doublePuppetIntent MatrixAPI + doublePuppetInitialized bool + doublePuppetLock sync.Mutex + + managementCreateLock sync.Mutex + + logins map[networkid.UserLoginID]*UserLogin +} + +func (br *Bridge) loadUser(ctx context.Context, dbUser *database.User, queryErr error, userID *id.UserID) (*User, error) { + if queryErr != nil { + return nil, fmt.Errorf("failed to query db: %w", queryErr) + } + if dbUser == nil { + if userID == nil { + return nil, nil + } + dbUser = &database.User{ + BridgeID: br.ID, + MXID: *userID, + } + err := br.DB.User.Insert(ctx, dbUser) + if err != nil { + return nil, fmt.Errorf("failed to insert new user: %w", err) + } + } + user := &User{ + User: dbUser, + Bridge: br, + Log: br.Log.With().Stringer("user_mxid", dbUser.MXID).Logger(), + logins: make(map[networkid.UserLoginID]*UserLogin), + Permissions: br.Config.Permissions.Get(dbUser.MXID), + } + br.usersByMXID[user.MXID] = user + err := br.unlockedLoadUserLoginsByMXID(ctx, user) + if err != nil { + return nil, fmt.Errorf("failed to load user logins: %w", err) + } + return user, nil +} + +func (br *Bridge) unlockedGetUserByMXID(ctx context.Context, userID id.UserID, onlyIfExists bool) (*User, error) { + cached, ok := br.usersByMXID[userID] + if ok { + return cached, nil + } + idPtr := &userID + if onlyIfExists { + idPtr = nil + } + db, err := br.DB.User.GetByMXID(ctx, userID) + return br.loadUser(ctx, db, err, idPtr) +} + +func (br *Bridge) GetUserByMXID(ctx context.Context, userID id.UserID) (*User, error) { + br.cacheLock.Lock() + defer br.cacheLock.Unlock() + return br.unlockedGetUserByMXID(ctx, userID, false) +} + +func (br *Bridge) GetExistingUserByMXID(ctx context.Context, userID id.UserID) (*User, error) { + br.cacheLock.Lock() + defer br.cacheLock.Unlock() + return br.unlockedGetUserByMXID(ctx, userID, true) +} + +func (user *User) LogoutDoublePuppet(ctx context.Context) { + user.doublePuppetLock.Lock() + defer user.doublePuppetLock.Unlock() + user.AccessToken = "" + err := user.Save(ctx) + if err != nil { + zerolog.Ctx(ctx).Err(err).Msg("Failed to save removed access token") + } + user.doublePuppetIntent = nil + user.doublePuppetInitialized = false +} + +func (user *User) LoginDoublePuppet(ctx context.Context, token string) error { + if token == "" { + return fmt.Errorf("no token provided") + } + user.doublePuppetLock.Lock() + defer user.doublePuppetLock.Unlock() + intent, newToken, err := user.Bridge.Matrix.NewUserIntent(ctx, user.MXID, token) + if err != nil { + return err + } + user.AccessToken = newToken + user.doublePuppetIntent = intent + user.doublePuppetInitialized = true + err = user.Save(ctx) + if err != nil { + zerolog.Ctx(ctx).Err(err).Msg("Failed to save new access token") + } + if newToken != token { + return fmt.Errorf("logging in manually is not supported when automatic double puppeting is enabled") + } + return nil +} + +func (user *User) DoublePuppet(ctx context.Context) MatrixAPI { + user.doublePuppetLock.Lock() + defer user.doublePuppetLock.Unlock() + if user.doublePuppetInitialized { + return user.doublePuppetIntent + } + user.doublePuppetInitialized = true + log := user.Log.With().Str("action", "setup double puppet").Logger() + ctx = log.WithContext(ctx) + intent, newToken, err := user.Bridge.Matrix.NewUserIntent(ctx, user.MXID, user.AccessToken) + if err != nil { + log.Err(err).Msg("Failed to create new user intent") + return nil + } + user.doublePuppetIntent = intent + if newToken != user.AccessToken { + user.AccessToken = newToken + err = user.Save(ctx) + if err != nil { + log.Warn().Err(err).Msg("Failed to save new access token") + } + } + return intent +} + +func (user *User) GetUserLoginIDs() []networkid.UserLoginID { + user.Bridge.cacheLock.Lock() + defer user.Bridge.cacheLock.Unlock() + return maps.Keys(user.logins) +} + +// Deprecated: renamed to GetUserLogins +func (user *User) GetCachedUserLogins() []*UserLogin { + return user.GetUserLogins() +} + +func (user *User) GetUserLogins() []*UserLogin { + user.Bridge.cacheLock.Lock() + defer user.Bridge.cacheLock.Unlock() + return maps.Values(user.logins) +} + +func (user *User) HasTooManyLogins() bool { + return user.Permissions.MaxLogins > 0 && len(user.GetUserLoginIDs()) >= user.Permissions.MaxLogins +} + +func (user *User) GetFormattedUserLogins() string { + user.Bridge.cacheLock.Lock() + logins := make([]string, len(user.logins)) + for key, val := range user.logins { + logins = append(logins, fmt.Sprintf("* `%s` (%s) - `%s`", key, val.RemoteName, val.BridgeState.GetPrev().StateEvent)) + } + user.Bridge.cacheLock.Unlock() + return strings.Join(logins, "\n") +} + +func (user *User) GetDefaultLogin() *UserLogin { + user.Bridge.cacheLock.Lock() + defer user.Bridge.cacheLock.Unlock() + if len(user.logins) == 0 { + return nil + } + loginKeys := maps.Keys(user.logins) + slices.Sort(loginKeys) + return user.logins[loginKeys[0]] +} + +func (user *User) GetManagementRoom(ctx context.Context) (id.RoomID, error) { + user.managementCreateLock.Lock() + defer user.managementCreateLock.Unlock() + if user.ManagementRoom != "" { + return user.ManagementRoom, nil + } + netName := user.Bridge.Network.GetName() + var err error + autoJoin := user.Bridge.Matrix.GetCapabilities().AutoJoinInvites + doublePuppet := user.DoublePuppet(ctx) + req := &mautrix.ReqCreateRoom{ + Visibility: "private", + Name: netName.DisplayName, + Topic: fmt.Sprintf("%s bridge management room", netName.DisplayName), + InitialState: []*event.Event{{ + Type: event.StateRoomAvatar, + Content: event.Content{ + Parsed: &event.RoomAvatarEventContent{ + URL: netName.NetworkIcon, + }, + }, + }}, + PowerLevelOverride: &event.PowerLevelsEventContent{ + Users: map[id.UserID]int{ + user.Bridge.Bot.GetMXID(): 9001, + user.MXID: 50, + }, + }, + Invite: []id.UserID{user.MXID}, + IsDirect: true, + } + if autoJoin { + req.BeeperInitialMembers = []id.UserID{user.MXID} + // TODO remove this after initial_members is supported in hungryserv + req.BeeperAutoJoinInvites = true + } + user.ManagementRoom, err = user.Bridge.Bot.CreateRoom(ctx, req) + if err != nil { + return "", fmt.Errorf("failed to create management room: %w", err) + } + if !autoJoin && doublePuppet != nil { + err = doublePuppet.EnsureJoined(ctx, user.ManagementRoom) + if err != nil { + zerolog.Ctx(ctx).Err(err).Msg("Failed to auto-join created management room with double puppet") + } + } + err = user.Save(ctx) + if err != nil { + return "", fmt.Errorf("failed to save management room ID: %w", err) + } + return user.ManagementRoom, nil +} + +func (user *User) Save(ctx context.Context) error { + return user.Bridge.DB.User.Update(ctx, user.User) +} + +func (br *Bridge) TrackAnalytics(userID id.UserID, event string, props map[string]any) { + analyticSender, ok := br.Matrix.(MatrixConnectorWithAnalytics) + if ok { + analyticSender.TrackAnalytics(userID, event, props) + } +} + +func (user *User) TrackAnalytics(event string, props map[string]any) { + user.Bridge.TrackAnalytics(user.MXID, event, props) +} + +func (ul *UserLogin) TrackAnalytics(event string, props map[string]any) { + // TODO include user login ID? + ul.Bridge.TrackAnalytics(ul.UserMXID, event, props) +} diff --git a/bridgev2/userlogin.go b/bridgev2/userlogin.go new file mode 100644 index 00000000..d56dc4cc --- /dev/null +++ b/bridgev2/userlogin.go @@ -0,0 +1,571 @@ +// Copyright (c) 2024 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package bridgev2 + +import ( + "cmp" + "context" + "fmt" + "maps" + "slices" + "sync" + "time" + + "github.com/rs/zerolog" + "go.mau.fi/util/exsync" + + "maunium.net/go/mautrix/bridgev2/bridgeconfig" + "maunium.net/go/mautrix/bridgev2/database" + "maunium.net/go/mautrix/bridgev2/networkid" + "maunium.net/go/mautrix/bridgev2/status" + "maunium.net/go/mautrix/event" +) + +type UserLogin struct { + *database.UserLogin + Bridge *Bridge + User *User + Log zerolog.Logger + + Client NetworkAPI + BridgeState *BridgeStateQueue + + inPortalCache *exsync.Set[networkid.PortalKey] + + spaceCreateLock sync.Mutex + deleteLock sync.Mutex + disconnectOnce sync.Once +} + +func (br *Bridge) loadUserLogin(ctx context.Context, user *User, dbUserLogin *database.UserLogin) (*UserLogin, error) { + if dbUserLogin == nil { + return nil, nil + } + if user == nil { + var err error + user, err = br.unlockedGetUserByMXID(ctx, dbUserLogin.UserMXID, true) + if err != nil { + return nil, fmt.Errorf("failed to get user: %w", err) + } + // TODO if loading the user caused the provided userlogin to be loaded, cancel here? + // Currently this will double-load it + } + userLogin := &UserLogin{ + UserLogin: dbUserLogin, + Bridge: br, + User: user, + Log: user.Log.With().Str("login_id", string(dbUserLogin.ID)).Logger(), + + inPortalCache: exsync.NewSet[networkid.PortalKey](), + } + err := br.Network.LoadUserLogin(ctx, userLogin) + if err != nil { + userLogin.Log.Err(err).Msg("Failed to load user login") + return nil, nil + } else if userLogin.Client == nil { + userLogin.Log.Error().Msg("LoadUserLogin didn't fill Client") + return nil, nil + } + userLogin.BridgeState = br.NewBridgeStateQueue(userLogin) + user.logins[userLogin.ID] = userLogin + br.userLoginsByID[userLogin.ID] = userLogin + return userLogin, nil +} + +func (br *Bridge) loadManyUserLogins(ctx context.Context, user *User, logins []*database.UserLogin) ([]*UserLogin, error) { + output := make([]*UserLogin, 0, len(logins)) + for _, dbLogin := range logins { + if cached, ok := br.userLoginsByID[dbLogin.ID]; ok { + output = append(output, cached) + } else { + loaded, err := br.loadUserLogin(ctx, user, dbLogin) + if err != nil { + return nil, err + } else if loaded != nil { + output = append(output, loaded) + } + } + } + return output, nil +} + +func (br *Bridge) unlockedLoadUserLoginsByMXID(ctx context.Context, user *User) error { + logins, err := br.DB.UserLogin.GetAllForUser(ctx, user.MXID) + if err != nil { + return err + } + _, err = br.loadManyUserLogins(ctx, user, logins) + return err +} + +func (br *Bridge) GetUserLoginsInPortal(ctx context.Context, portal networkid.PortalKey) ([]*UserLogin, error) { + if portal.Receiver != "" { + ul := br.GetCachedUserLoginByID(portal.Receiver) + if ul == nil { + return nil, nil + } + return []*UserLogin{ul}, nil + } + logins, err := br.DB.UserLogin.GetAllInPortal(ctx, portal) + if err != nil { + return nil, err + } + br.cacheLock.Lock() + defer br.cacheLock.Unlock() + return br.loadManyUserLogins(ctx, nil, logins) +} + +func (br *Bridge) GetExistingUserLoginByID(ctx context.Context, id networkid.UserLoginID) (*UserLogin, error) { + br.cacheLock.Lock() + defer br.cacheLock.Unlock() + return br.unlockedGetExistingUserLoginByID(ctx, id) +} + +func (br *Bridge) unlockedGetExistingUserLoginByID(ctx context.Context, id networkid.UserLoginID) (*UserLogin, error) { + cached, ok := br.userLoginsByID[id] + if ok { + return cached, nil + } + login, err := br.DB.UserLogin.GetByID(ctx, id) + if err != nil { + return nil, err + } + return br.loadUserLogin(ctx, nil, login) +} + +func (br *Bridge) GetCachedUserLoginByID(id networkid.UserLoginID) *UserLogin { + br.cacheLock.Lock() + defer br.cacheLock.Unlock() + return br.userLoginsByID[id] +} + +func (br *Bridge) GetAllCachedUserLogins() (logins []*UserLogin) { + br.cacheLock.Lock() + defer br.cacheLock.Unlock() + return slices.Collect(maps.Values(br.userLoginsByID)) +} + +func (br *Bridge) GetCurrentBridgeStates() (states []status.BridgeState) { + br.cacheLock.Lock() + defer br.cacheLock.Unlock() + if len(br.userLoginsByID) == 0 { + return []status.BridgeState{{ + StateEvent: status.StateUnconfigured, + }} + } + states = make([]status.BridgeState, len(br.userLoginsByID)) + i := 0 + for _, login := range br.userLoginsByID { + states[i] = login.BridgeState.GetPrev() + i++ + } + return +} + +type NewLoginParams struct { + LoadUserLogin func(context.Context, *UserLogin) error + DeleteOnConflict bool + DontReuseExisting bool +} + +// NewLogin creates a UserLogin object for this user with the given parameters. +// +// If a login already exists with the same ID, it is reused after updating the remote name +// and metadata from the provided data, unless DontReuseExisting is set in params. +// +// If the existing login belongs to another user, this returns an error, +// unless DeleteOnConflict is set in the params, in which case the existing login is deleted. +// +// This will automatically call LoadUserLogin after creating the UserLogin object. +// The load method defaults to the network connector's LoadUserLogin method, but it can be overridden in params. +func (user *User) NewLogin(ctx context.Context, data *database.UserLogin, params *NewLoginParams) (*UserLogin, error) { + user.Bridge.cacheLock.Lock() + defer user.Bridge.cacheLock.Unlock() + data.BridgeID = user.BridgeID + data.UserMXID = user.MXID + if data.Metadata == nil { + metaTypes := user.Bridge.Network.GetDBMetaTypes() + if metaTypes.UserLogin != nil { + data.Metadata = metaTypes.UserLogin() + } + } + if params == nil { + params = &NewLoginParams{} + } + if params.LoadUserLogin == nil { + params.LoadUserLogin = user.Bridge.Network.LoadUserLogin + } + ul, err := user.Bridge.unlockedGetExistingUserLoginByID(ctx, data.ID) + if err != nil { + return nil, fmt.Errorf("failed to check if login already exists: %w", err) + } + var doInsert bool + if ul != nil && ul.UserMXID != user.MXID { + if params.DeleteOnConflict { + ul.Delete(ctx, status.BridgeState{StateEvent: status.StateLoggedOut, Reason: "LOGIN_OVERRIDDEN_ANOTHER_USER"}, DeleteOpts{ + LogoutRemote: false, + unlocked: true, + }) + ul = nil + } else { + return nil, fmt.Errorf("%s is already logged in with that account", ul.UserMXID) + } + } + if ul != nil { + if params.DontReuseExisting { + return nil, fmt.Errorf("login already exists") + } + doInsert = false + ul.RemoteName = data.RemoteName + ul.RemoteProfile = ul.RemoteProfile.Merge(data.RemoteProfile) + if merger, ok := ul.Metadata.(database.MetaMerger); ok { + merger.CopyFrom(data.Metadata) + } else { + ul.Metadata = data.Metadata + } + } else { + doInsert = true + ul = &UserLogin{ + UserLogin: data, + Bridge: user.Bridge, + User: user, + Log: user.Log.With().Str("login_id", string(data.ID)).Logger(), + } + ul.BridgeState = user.Bridge.NewBridgeStateQueue(ul) + } + noCancelCtx := ul.Log.WithContext(user.Bridge.BackgroundCtx) + err = params.LoadUserLogin(noCancelCtx, ul) + if err != nil { + return nil, err + } else if ul.Client == nil { + ul.Log.Error().Msg("LoadUserLogin didn't fill Client in NewLogin") + return nil, fmt.Errorf("client not filled by LoadUserLogin") + } + if doInsert { + err = user.Bridge.DB.UserLogin.Insert(noCancelCtx, ul.UserLogin) + if err != nil { + return nil, err + } + user.Bridge.userLoginsByID[ul.ID] = ul + user.logins[ul.ID] = ul + } else { + err = ul.Save(noCancelCtx) + if err != nil { + return nil, err + } + } + return ul, nil +} + +func (ul *UserLogin) Save(ctx context.Context) error { + return ul.Bridge.DB.UserLogin.Update(ctx, ul.UserLogin) +} + +func (ul *UserLogin) Logout(ctx context.Context) { + ul.Delete(ctx, status.BridgeState{StateEvent: status.StateLoggedOut}, DeleteOpts{LogoutRemote: true}) +} + +type DeleteOpts struct { + LogoutRemote bool + DontCleanupRooms bool + BlockingCleanup bool + unlocked bool +} + +func (ul *UserLogin) Delete(ctx context.Context, state status.BridgeState, opts DeleteOpts) { + cleanupRooms := !opts.DontCleanupRooms && ul.Bridge.Config.CleanupOnLogout.Enabled + zerolog.Ctx(ctx).Info().Str("user_login_id", string(ul.ID)). + Bool("logout_remote", opts.LogoutRemote). + Bool("cleanup_rooms", cleanupRooms). + Msg("Deleting user login") + ul.deleteLock.Lock() + defer ul.deleteLock.Unlock() + if ul.BridgeState == nil { + return + } + if opts.LogoutRemote { + ul.Client.LogoutRemote(ctx) + } else { + // we probably shouldn't delete the login if disconnect isn't finished + ul.Disconnect() + } + var portals []*database.UserPortal + var err error + if cleanupRooms { + portals, err = ul.Bridge.DB.UserPortal.GetAllForLogin(ctx, ul.UserLogin) + if err != nil { + ul.Log.Err(err).Msg("Failed to get user portals") + } + } + err = ul.Bridge.DB.UserLogin.Delete(ctx, ul.ID) + if err != nil { + ul.Log.Err(err).Msg("Failed to delete user login") + } + if !opts.unlocked { + ul.Bridge.cacheLock.Lock() + } + delete(ul.User.logins, ul.ID) + delete(ul.Bridge.userLoginsByID, ul.ID) + if !opts.unlocked { + ul.Bridge.cacheLock.Unlock() + } + backgroundCtx := zerolog.Ctx(ctx).WithContext(ul.Bridge.BackgroundCtx) + if !opts.BlockingCleanup { + go ul.deleteSpace(backgroundCtx) + } else { + ul.deleteSpace(backgroundCtx) + } + if portals != nil { + if !opts.BlockingCleanup { + go ul.kickUserFromPortals(backgroundCtx, portals, state.StateEvent == status.StateBadCredentials, false) + } else { + ul.kickUserFromPortals(backgroundCtx, portals, state.StateEvent == status.StateBadCredentials, false) + } + } + if state.StateEvent != "" { + ul.BridgeState.Send(state) + } + ul.BridgeState.Destroy() + ul.BridgeState = nil +} + +func (ul *UserLogin) deleteSpace(ctx context.Context) { + if ul.SpaceRoom == "" { + return + } + err := ul.Bridge.Bot.DeleteRoom(ctx, ul.SpaceRoom, false) + if err != nil { + ul.Log.Err(err).Msg("Failed to delete space room") + } +} + +// KickUserFromPortalsForBadCredentials can be called to kick the user from portals without deleting the entire UserLogin object. +func (ul *UserLogin) KickUserFromPortalsForBadCredentials(ctx context.Context) { + log := zerolog.Ctx(ctx) + portals, err := ul.Bridge.DB.UserPortal.GetAllForLogin(ctx, ul.UserLogin) + if err != nil { + log.Err(err).Msg("Failed to get user portals") + } + ul.kickUserFromPortals(ctx, portals, true, true) +} + +func DeleteManyPortals(ctx context.Context, portals []*Portal, errorCallback func(portal *Portal, delete bool, err error)) { + // TODO is there a more sensible place/name for this function? + if len(portals) == 0 { + return + } + getDepth := func(portal *Portal) int { + depth := 0 + for portal.Parent != nil { + depth++ + portal = portal.Parent + } + return depth + } + // Sort portals so parents are last (to avoid errors caused by deleting parent portals before children) + slices.SortFunc(portals, func(a, b *Portal) int { + return cmp.Compare(getDepth(b), getDepth(a)) + }) + for _, portal := range portals { + err := portal.Delete(ctx) + if err != nil { + zerolog.Ctx(ctx).Err(err). + Stringer("portal_mxid", portal.MXID). + Object("portal_key", portal.PortalKey). + Msg("Failed to delete portal row from database") + if errorCallback != nil { + errorCallback(portal, false, err) + } + continue + } + if portal.MXID != "" { + err = portal.Bridge.Bot.DeleteRoom(ctx, portal.MXID, false) + if err != nil { + zerolog.Ctx(ctx).Err(err). + Stringer("portal_mxid", portal.MXID). + Msg("Failed to clean up portal room") + if errorCallback != nil { + errorCallback(portal, true, err) + } + } + } + } +} + +func (ul *UserLogin) kickUserFromPortals(ctx context.Context, portals []*database.UserPortal, badCredentials, deleteRow bool) { + var portalsToDelete []*Portal + for _, up := range portals { + portalToDelete, err := ul.kickUserFromPortal(ctx, up, badCredentials, deleteRow) + if err != nil { + zerolog.Ctx(ctx).Err(err). + Object("portal_key", up.Portal). + Stringer("user_mxid", up.UserMXID). + Msg("Failed to apply logout action") + } else if portalToDelete != nil { + portalsToDelete = append(portalsToDelete, portalToDelete) + } + } + DeleteManyPortals(ctx, portalsToDelete, nil) +} + +func (ul *UserLogin) kickUserFromPortal(ctx context.Context, up *database.UserPortal, badCredentials, deleteRow bool) (*Portal, error) { + portal, action, reason, err := ul.getLogoutAction(ctx, up, badCredentials) + if err != nil { + return nil, err + } else if portal == nil { + return nil, nil + } + zerolog.Ctx(ctx).Debug(). + Str("login_id", string(ul.ID)). + Stringer("user_mxid", ul.UserMXID). + Str("logout_action", string(action)). + Str("action_reason", reason). + Object("portal_key", portal.PortalKey). + Stringer("portal_mxid", portal.MXID). + Msg("Calculated portal action for logout processing") + switch action { + case bridgeconfig.CleanupActionNull, bridgeconfig.CleanupActionNothing: + // do nothing + case bridgeconfig.CleanupActionKick: + _, err = ul.Bridge.Bot.SendState(ctx, portal.MXID, event.StateMember, ul.UserMXID.String(), &event.Content{ + Parsed: &event.MemberEventContent{ + Membership: event.MembershipLeave, + Reason: "Logged out of bridge", + }, + }, time.Time{}) + if err != nil { + return nil, fmt.Errorf("failed to kick user from portal: %w", err) + } + zerolog.Ctx(ctx).Debug(). + Str("login_id", string(ul.ID)). + Stringer("user_mxid", ul.UserMXID). + Stringer("portal_mxid", portal.MXID). + Msg("Kicked user from portal") + if deleteRow { + err = ul.Bridge.DB.UserPortal.Delete(ctx, up) + if err != nil { + zerolog.Ctx(ctx).Warn(). + Str("login_id", string(ul.ID)). + Stringer("user_mxid", ul.UserMXID). + Stringer("portal_mxid", portal.MXID). + Msg("Failed to delete user portal row") + } + } + case bridgeconfig.CleanupActionDelete, bridgeconfig.CleanupActionUnbridge: + // return portal instead of deleting here to allow sorting by depth + return portal, nil + } + return nil, nil +} + +func (ul *UserLogin) getLogoutAction(ctx context.Context, up *database.UserPortal, badCredentials bool) (*Portal, bridgeconfig.CleanupAction, string, error) { + portal, err := ul.Bridge.GetExistingPortalByKey(ctx, up.Portal) + if err != nil { + return nil, bridgeconfig.CleanupActionNull, "", fmt.Errorf("failed to get full portal: %w", err) + } else if portal == nil || portal.MXID == "" { + return nil, bridgeconfig.CleanupActionNull, "portal not found", nil + } + actionsSet := ul.Bridge.Config.CleanupOnLogout.Manual + if badCredentials { + actionsSet = ul.Bridge.Config.CleanupOnLogout.BadCredentials + } + if portal.Receiver != "" { + return portal, actionsSet.Private, "portal has receiver", nil + } + otherUPs, err := ul.Bridge.DB.UserPortal.GetAllInPortal(ctx, portal.PortalKey) + if err != nil { + return portal, bridgeconfig.CleanupActionNull, "", fmt.Errorf("failed to get other logins in portal: %w", err) + } + hasOtherUsers := false + for _, otherUP := range otherUPs { + if otherUP.LoginID == ul.ID { + continue + } + if otherUP.UserMXID == ul.UserMXID { + otherUL := ul.Bridge.GetCachedUserLoginByID(otherUP.LoginID) + if otherUL != nil && otherUL.Client.IsLoggedIn() { + return portal, bridgeconfig.CleanupActionNull, "user has another login in portal", nil + } + } else { + hasOtherUsers = true + } + } + if portal.RelayLoginID != "" { + return portal, actionsSet.Relayed, "portal has relay login", nil + } else if hasOtherUsers { + return portal, actionsSet.SharedHasUsers, "portal has logins of other users", nil + } + return portal, actionsSet.SharedNoUsers, "portal doesn't have logins of other users", nil +} + +func (ul *UserLogin) MarkAsPreferredIn(ctx context.Context, portal *Portal) error { + return ul.Bridge.DB.UserPortal.MarkAsPreferred(ctx, ul.UserLogin, portal.PortalKey) +} + +var _ status.BridgeStateFiller = (*UserLogin)(nil) + +func (ul *UserLogin) FillBridgeState(state status.BridgeState) status.BridgeState { + state.UserID = ul.UserMXID + state.RemoteID = ul.ID + state.RemoteName = ul.RemoteName + state.RemoteProfile = ul.RemoteProfile + filler, ok := ul.Client.(status.BridgeStateFiller) + if ok { + return filler.FillBridgeState(state) + } + return state +} + +func (ul *UserLogin) Disconnect() { + ul.DisconnectWithTimeout(0) +} + +func (ul *UserLogin) DisconnectWithTimeout(timeout time.Duration) { + ul.disconnectOnce.Do(func() { + ul.disconnectInternal(timeout) + }) +} + +func (ul *UserLogin) disconnectInternal(timeout time.Duration) { + ul.BridgeState.StopUnknownErrorReconnect() + disconnected := make(chan struct{}) + go func() { + ul.Client.Disconnect() + close(disconnected) + }() + + var timeoutC <-chan time.Time + if timeout > 0 { + timeoutC = time.After(timeout) + } + for { + select { + case <-disconnected: + return + case <-time.After(2 * time.Second): + ul.Log.Warn().Msg("Client disconnection taking long") + case <-timeoutC: + ul.Log.Error().Msg("Client disconnection timed out") + return + } + } +} + +func (ul *UserLogin) recreateClient(ctx context.Context) error { + oldClient := ul.Client + err := ul.Bridge.Network.LoadUserLogin(ctx, ul) + if err != nil { + return err + } + if ul.Client == oldClient { + zerolog.Ctx(ctx).Warn().Msg("LoadUserLogin didn't update client") + } else { + zerolog.Ctx(ctx).Debug().Msg("Recreated user login client") + } + ul.disconnectOnce = sync.Once{} + return nil +} diff --git a/client.go b/client.go index ad68e793..7062d9b9 100644 --- a/client.go +++ b/client.go @@ -13,12 +13,18 @@ import ( "net/http" "net/url" "os" + "runtime" + "slices" "strconv" + "strings" "sync/atomic" "time" "github.com/rs/zerolog" + "go.mau.fi/util/ptr" + "go.mau.fi/util/random" "go.mau.fi/util/retryafter" + "golang.org/x/exp/maps" "maunium.net/go/mautrix/crypto/backup" "maunium.net/go/mautrix/event" @@ -35,53 +41,86 @@ type CryptoHelper interface { } type VerificationHelper interface { + // Init initializes the helper. This should be called before any other + // methods. Init(context.Context) error + + // StartVerification starts an interactive verification flow with the given + // user via a to-device event. StartVerification(ctx context.Context, to id.UserID) (id.VerificationTransactionID, error) + // StartInRoomVerification starts an interactive verification flow with the + // given user in the given room. StartInRoomVerification(ctx context.Context, roomID id.RoomID, to id.UserID) (id.VerificationTransactionID, error) + + // AcceptVerification accepts a verification request. AcceptVerification(ctx context.Context, txnID id.VerificationTransactionID) error + // DismissVerification dismisses a verification request. This will not send + // a cancellation to the other device. This method should only be called + // *before* the request has been accepted and will error otherwise. + DismissVerification(ctx context.Context, txnID id.VerificationTransactionID) error + // CancelVerification cancels a verification request. This method should + // only be called *after* the request has been accepted, although it will + // not error if called beforehand. CancelVerification(ctx context.Context, txnID id.VerificationTransactionID, code event.VerificationCancelCode, reason string) error + // HandleScannedQRData handles the data from a QR code scan. HandleScannedQRData(ctx context.Context, data []byte) error + // ConfirmQRCodeScanned confirms that our QR code has been scanned. ConfirmQRCodeScanned(ctx context.Context, txnID id.VerificationTransactionID) error + // StartSAS starts a SAS verification flow. StartSAS(ctx context.Context, txnID id.VerificationTransactionID) error + // ConfirmSAS indicates that the user has confirmed that the SAS matches + // SAS shown on the other user's device. ConfirmSAS(ctx context.Context, txnID id.VerificationTransactionID) error } // Client represents a Matrix client. type Client struct { - HomeserverURL *url.URL // The base homeserver URL - UserID id.UserID // The user ID of the client. Used for forming HTTP paths which use the client's user ID. - DeviceID id.DeviceID // The device ID of the client. - AccessToken string // The access_token for the client. - UserAgent string // The value for the User-Agent header - Client *http.Client // The underlying HTTP client which will be used to make HTTP requests. - Syncer Syncer // The thing which can process /sync responses - Store SyncStore // The thing which can store tokens/ids - StateStore StateStore - Crypto CryptoHelper - Verification VerificationHelper + HomeserverURL *url.URL // The base homeserver URL + UserID id.UserID // The user ID of the client. Used for forming HTTP paths which use the client's user ID. + DeviceID id.DeviceID // The device ID of the client. + AccessToken string // The access_token for the client. + UserAgent string // The value for the User-Agent header + Client *http.Client // The underlying HTTP client which will be used to make HTTP requests. + Syncer Syncer // The thing which can process /sync responses + Store SyncStore // The thing which can store tokens/ids + StateStore StateStore + Crypto CryptoHelper + Verification VerificationHelper + SpecVersions *RespVersions + ExternalClient *http.Client // The HTTP client used for external (not matrix) media HTTP requests. Log zerolog.Logger RequestHook func(req *http.Request) - ResponseHook func(req *http.Request, resp *http.Response, duration time.Duration) + ResponseHook func(req *http.Request, resp *http.Response, err error, duration time.Duration) + + UpdateRequestOnRetry func(req *http.Request, cause error) *http.Request SyncPresence event.Presence + SyncTraceLog bool StreamSyncMinAge time.Duration // Number of times that mautrix will retry any HTTP request // if the request fails entirely or returns a HTTP gateway error (502-504) DefaultHTTPRetries int + // Amount of time to wait between HTTP retries, defaults to 4 seconds + DefaultHTTPBackoff time.Duration // Set to true to disable automatically sleeping on 429 errors. IgnoreRateLimit bool + ResponseSizeLimit int64 + txnID int32 // Should the ?user_id= query parameter be set in requests? // See https://spec.matrix.org/v1.6/application-service-api/#identity-assertion SetAppServiceUserID bool + // Should the org.matrix.msc3202.device_id query parameter be set in requests? + // See https://github.com/matrix-org/matrix-spec-proposals/pull/3202 + SetAppServiceDeviceID bool syncingID uint32 // Identifies the current Sync. Only one Sync can be active at any given time. } @@ -103,6 +142,12 @@ type IdentityServerInfo struct { // Use ParseUserID to extract the server name from a user ID. // https://spec.matrix.org/v1.2/client-server-api/#server-discovery func DiscoverClientAPI(ctx context.Context, serverName string) (*ClientWellKnown, error) { + return DiscoverClientAPIWithClient(ctx, &http.Client{Timeout: 30 * time.Second}, serverName) +} + +const WellKnownMaxSize = 64 * 1024 + +func DiscoverClientAPIWithClient(ctx context.Context, client *http.Client, serverName string) (*ClientWellKnown, error) { wellKnownURL := url.URL{ Scheme: "https", Host: serverName, @@ -114,10 +159,11 @@ func DiscoverClientAPI(ctx context.Context, serverName string) (*ClientWellKnown return nil, err } - req.Header.Set("Accept", "application/json") - req.Header.Set("User-Agent", DefaultUserAgent+" (.well-known fetcher)") + if runtime.GOOS != "js" { + req.Header.Set("Accept", "application/json") + req.Header.Set("User-Agent", DefaultUserAgent+" (.well-known fetcher)") + } - client := &http.Client{Timeout: 30 * time.Second} resp, err := client.Do(req) if err != nil { return nil, err @@ -126,11 +172,15 @@ func DiscoverClientAPI(ctx context.Context, serverName string) (*ClientWellKnown if resp.StatusCode == http.StatusNotFound { return nil, nil + } else if resp.ContentLength > WellKnownMaxSize { + return nil, errors.New(".well-known response too large") } - data, err := io.ReadAll(resp.Body) + data, err := io.ReadAll(io.LimitReader(resp.Body, WellKnownMaxSize)) if err != nil { return nil, err + } else if len(data) >= WellKnownMaxSize { + return nil, errors.New(".well-known response too large") } var wellKnown ClientWellKnown @@ -197,14 +247,20 @@ func (cli *Client) SyncWithContext(ctx context.Context) error { } } lastSuccessfulSync := time.Now().Add(-cli.StreamSyncMinAge - 1*time.Hour) + // Always do first sync with 0 timeout + isFailing := true for { streamResp := false if cli.StreamSyncMinAge > 0 && time.Since(lastSuccessfulSync) > cli.StreamSyncMinAge { cli.Log.Debug().Msg("Last sync is old, will stream next response") streamResp = true } + timeout := 30000 + if isFailing || nextBatch == "" { + timeout = 0 + } resSync, err := cli.FullSyncRequest(ctx, ReqSync{ - Timeout: 30000, + Timeout: timeout, Since: nextBatch, FilterID: filterID, FullState: false, @@ -212,6 +268,7 @@ func (cli *Client) SyncWithContext(ctx context.Context) error { StreamResponse: streamResp, }) if err != nil { + isFailing = true if ctx.Err() != nil { return ctx.Err() } @@ -219,6 +276,9 @@ func (cli *Client) SyncWithContext(ctx context.Context) error { if err2 != nil { return err2 } + if duration <= 0 { + continue + } select { case <-ctx.Done(): return ctx.Err() @@ -226,6 +286,7 @@ func (cli *Client) SyncWithContext(ctx context.Context) error { continue } } + isFailing = false lastSuccessfulSync = time.Now() // Check that the syncing state hasn't changed @@ -269,21 +330,38 @@ type contextKey int const ( LogBodyContextKey contextKey = iota LogRequestIDContextKey + MaxAttemptsContextKey + SyncTokenContextKey ) func (cli *Client) RequestStart(req *http.Request) { - if cli.RequestHook != nil { + if cli != nil && cli.RequestHook != nil { cli.RequestHook(req) } } +// WithMaxRetries updates the context to set the maximum number of retries for any HTTP requests made with the context. +// +// 0 means the request will only be attempted once and will not be retried. +// Negative values will remove the override and fallback to the defaults. +func WithMaxRetries(ctx context.Context, maxRetries int) context.Context { + return context.WithValue(ctx, MaxAttemptsContextKey, maxRetries+1) +} + func (cli *Client) LogRequestDone(req *http.Request, resp *http.Response, err error, handlerErr error, contentLength int, duration time.Duration) { + if cli == nil { + return + } var evt *zerolog.Event - if err != nil { + if errors.Is(err, context.Canceled) { + evt = zerolog.Ctx(req.Context()).Warn() + } else if err != nil { evt = zerolog.Ctx(req.Context()).Err(err) } else if handlerErr != nil { evt = zerolog.Ctx(req.Context()).Warn(). AnErr("body_parse_err", handlerErr) + } else if cli.SyncTraceLog && strings.HasSuffix(req.URL.Path, "/_matrix/client/v3/sync") { + evt = zerolog.Ctx(req.Context()).Trace() } else { evt = zerolog.Ctx(req.Context()).Debug() } @@ -291,10 +369,10 @@ func (cli *Client) LogRequestDone(req *http.Request, resp *http.Response, err er Str("method", req.Method). Str("url", req.URL.String()). Dur("duration", duration) + if cli.ResponseHook != nil { + cli.ResponseHook(req, resp, err, duration) + } if resp != nil { - if cli.ResponseHook != nil { - cli.ResponseHook(req, resp, duration) - } mime := resp.Header.Get("Content-Type") length := resp.ContentLength if length == -1 && contentLength > 0 { @@ -308,9 +386,18 @@ func (cli *Client) LogRequestDone(req *http.Request, resp *http.Response, err er } } if body := req.Context().Value(LogBodyContextKey); body != nil { - evt.Interface("req_body", body) + switch typedLogBody := body.(type) { + case json.RawMessage: + evt.RawJSON("req_body", typedLogBody) + case string: + evt.Str("req_body", typedLogBody) + default: + panic(fmt.Errorf("invalid type for LogBodyContextKey: %T", body)) + } } - if err != nil { + if errors.Is(err, context.Canceled) { + evt.Msg("Request canceled") + } else if err != nil { evt.Msg("Request failed") } else if handlerErr != nil { evt.Msg("Request parsing failed") @@ -319,33 +406,47 @@ func (cli *Client) LogRequestDone(req *http.Request, resp *http.Response, err er } } -func (cli *Client) MakeRequest(ctx context.Context, method string, httpURL string, reqBody interface{}, resBody interface{}) ([]byte, error) { +func (cli *Client) MakeRequest(ctx context.Context, method string, httpURL string, reqBody any, resBody any) ([]byte, error) { return cli.MakeFullRequest(ctx, FullRequest{Method: method, URL: httpURL, RequestJSON: reqBody, ResponseJSON: resBody}) } -type ClientResponseHandler = func(req *http.Request, res *http.Response, responseJSON interface{}) ([]byte, error) +type ClientResponseHandler = func(req *http.Request, res *http.Response, responseJSON any, sizeLimit int64) ([]byte, error) type FullRequest struct { - Method string - URL string - Headers http.Header - RequestJSON interface{} - RequestBytes []byte - RequestBody io.Reader - RequestLength int64 - ResponseJSON interface{} - MaxAttempts int - SensitiveContent bool - Handler ClientResponseHandler - Logger *zerolog.Logger + Method string + URL string + Headers http.Header + RequestJSON interface{} + RequestBytes []byte + RequestBody io.Reader + RequestLength int64 + ResponseJSON interface{} + MaxAttempts int + BackoffDuration time.Duration + SensitiveContent bool + Handler ClientResponseHandler + DontReadResponse bool + ResponseSizeLimit int64 + Logger *zerolog.Logger + Client *http.Client } var requestID int32 var logSensitiveContent = os.Getenv("MAUTRIX_LOG_SENSITIVE_CONTENT") == "yes" func (params *FullRequest) compileRequest(ctx context.Context) (*http.Request, error) { + reqID := atomic.AddInt32(&requestID, 1) + logger := zerolog.Ctx(ctx) + if logger.GetLevel() == zerolog.Disabled || logger == zerolog.DefaultContextLogger { + logger = params.Logger + } + ctx = logger.With(). + Int32("req_id", reqID). + Logger().WithContext(ctx) + var logBody any - reqBody := params.RequestBody + var reqBody io.Reader + var reqLen int64 if params.RequestJSON != nil { jsonStr, err := json.Marshal(params.RequestJSON) if err != nil { @@ -356,29 +457,38 @@ func (params *FullRequest) compileRequest(ctx context.Context) (*http.Request, e } if params.SensitiveContent && !logSensitiveContent { logBody = "" + } else if len(jsonStr) > 32768 { + logBody = fmt.Sprintf("", len(jsonStr)) } else { - logBody = params.RequestJSON + logBody = json.RawMessage(jsonStr) } reqBody = bytes.NewReader(jsonStr) + reqLen = int64(len(jsonStr)) } else if params.RequestBytes != nil { logBody = fmt.Sprintf("<%d bytes>", len(params.RequestBytes)) reqBody = bytes.NewReader(params.RequestBytes) - params.RequestLength = int64(len(params.RequestBytes)) - } else if params.RequestLength > 0 && params.RequestBody != nil { - logBody = fmt.Sprintf("<%d bytes>", params.RequestLength) + reqLen = int64(len(params.RequestBytes)) + } else if params.RequestBody != nil { + logBody = "" + reqLen = -1 + if params.RequestLength > 0 { + logBody = fmt.Sprintf("<%d bytes>", params.RequestLength) + reqLen = params.RequestLength + } else if params.RequestLength == 0 { + zerolog.Ctx(ctx).Warn(). + Msg("RequestBody passed without specifying request length") + } + reqBody = params.RequestBody + if rsc, ok := params.RequestBody.(io.ReadSeekCloser); ok { + // Prevent HTTP from closing the request body, it might be needed for retries + reqBody = nopCloseSeeker{rsc} + } } else if params.Method != http.MethodGet && params.Method != http.MethodHead { params.RequestJSON = struct{}{} - logBody = params.RequestJSON + logBody = json.RawMessage("{}") reqBody = bytes.NewReader([]byte("{}")) + reqLen = 2 } - reqID := atomic.AddInt32(&requestID, 1) - logger := zerolog.Ctx(ctx) - if logger.GetLevel() == zerolog.Disabled || logger == zerolog.DefaultContextLogger { - logger = params.Logger - } - ctx = logger.With(). - Int32("req_id", reqID). - Logger().WithContext(ctx) ctx = context.WithValue(ctx, LogBodyContextKey, logBody) ctx = context.WithValue(ctx, LogRequestIDContextKey, int(reqID)) req, err := http.NewRequestWithContext(ctx, params.Method, params.URL, reqBody) @@ -394,37 +504,76 @@ func (params *FullRequest) compileRequest(ctx context.Context) (*http.Request, e if params.RequestJSON != nil { req.Header.Set("Content-Type", "application/json") } - if params.RequestLength > 0 && params.RequestBody != nil { - req.ContentLength = params.RequestLength - } + req.ContentLength = reqLen return req, nil } -// MakeFullRequest makes a JSON HTTP request to the given URL. -// If "resBody" is not nil, the response body will be json.Unmarshalled into it. -// -// Returns the HTTP body as bytes on 2xx with a nil error. Returns an error if the response is not 2xx along -// with the HTTP body bytes if it got that far. This error is an HTTPError which includes the returned -// HTTP status code and possibly a RespError as the WrappedError, if the HTTP body could be decoded as a RespError. func (cli *Client) MakeFullRequest(ctx context.Context, params FullRequest) ([]byte, error) { + data, _, err := cli.MakeFullRequestWithResp(ctx, params) + return data, err +} + +func (cli *Client) MakeFullRequestWithResp(ctx context.Context, params FullRequest) ([]byte, *http.Response, error) { + if cli == nil { + return nil, nil, ErrClientIsNil + } + if cli.HomeserverURL == nil || cli.HomeserverURL.Scheme == "" { + return nil, nil, ErrClientHasNoHomeserver + } if params.MaxAttempts == 0 { - params.MaxAttempts = 1 + cli.DefaultHTTPRetries + maxAttempts, ok := ctx.Value(MaxAttemptsContextKey).(int) + if ok && maxAttempts > 0 { + params.MaxAttempts = maxAttempts + } else { + params.MaxAttempts = 1 + cli.DefaultHTTPRetries + } + } + if params.BackoffDuration == 0 { + if cli.DefaultHTTPBackoff == 0 { + params.BackoffDuration = 4 * time.Second + } else { + params.BackoffDuration = cli.DefaultHTTPBackoff + } } if params.Logger == nil { params.Logger = &cli.Log } req, err := params.compileRequest(ctx) if err != nil { - return nil, err + return nil, nil, err } if params.Handler == nil { - params.Handler = handleNormalResponse + if params.DontReadResponse { + params.Handler = noopHandleResponse + } else { + params.Handler = handleNormalResponse + } + } + if cli.UserAgent != "" { + req.Header.Set("User-Agent", cli.UserAgent) } - req.Header.Set("User-Agent", cli.UserAgent) if len(cli.AccessToken) > 0 { req.Header.Set("Authorization", "Bearer "+cli.AccessToken) } - return cli.executeCompiledRequest(req, params.MaxAttempts-1, 4*time.Second, params.ResponseJSON, params.Handler) + if params.ResponseSizeLimit == 0 { + params.ResponseSizeLimit = cli.ResponseSizeLimit + } + if params.ResponseSizeLimit == 0 { + params.ResponseSizeLimit = DefaultResponseSizeLimit + } + if params.Client == nil { + params.Client = cli.Client + } + return cli.executeCompiledRequest( + req, + params.MaxAttempts-1, + params.BackoffDuration, + params.ResponseJSON, + params.Handler, + params.DontReadResponse, + params.ResponseSizeLimit, + params.Client, + ) } func (cli *Client) cliOrContextLog(ctx context.Context) *zerolog.Logger { @@ -435,29 +584,69 @@ func (cli *Client) cliOrContextLog(ctx context.Context) *zerolog.Logger { return log } -func (cli *Client) doRetry(req *http.Request, cause error, retries int, backoff time.Duration, responseJSON interface{}, handler ClientResponseHandler) ([]byte, error) { +func (cli *Client) doRetry( + req *http.Request, + cause error, + retries int, + backoff time.Duration, + responseJSON any, + handler ClientResponseHandler, + dontReadResponse bool, + sizeLimit int64, + client *http.Client, +) ([]byte, *http.Response, error) { log := zerolog.Ctx(req.Context()) if req.Body != nil { - if req.GetBody == nil { - log.Warn().Msg("Failed to get new body to retry request: GetBody is nil") - return nil, cause - } var err error - req.Body, err = req.GetBody() - if err != nil { - log.Warn().Err(err).Msg("Failed to get new body to retry request") - return nil, cause + if req.GetBody != nil { + req.Body, err = req.GetBody() + if err != nil { + log.Warn().Err(err).Msg("Failed to get new body to retry request") + return nil, nil, cause + } + } else if bodySeeker, ok := req.Body.(io.ReadSeeker); ok { + _, err = bodySeeker.Seek(0, io.SeekStart) + if err != nil { + log.Warn().Err(err).Msg("Failed to seek to beginning of request body") + return nil, nil, cause + } + } else { + log.Warn().Msg("Failed to get new body to retry request: GetBody is nil and Body is not an io.ReadSeeker") + return nil, nil, cause } } log.Warn().Err(cause). + Str("method", req.Method). + Str("url", req.URL.String()). Int("retry_in_seconds", int(backoff.Seconds())). Msg("Request failed, retrying") - time.Sleep(backoff) - return cli.executeCompiledRequest(req, retries-1, backoff*2, responseJSON, handler) + select { + case <-time.After(backoff): + case <-req.Context().Done(): + if !errors.Is(context.Cause(req.Context()), ErrContextCancelRetry) { + return nil, nil, req.Context().Err() + } + } + if cli.UpdateRequestOnRetry != nil { + req = cli.UpdateRequestOnRetry(req, cause) + } + return cli.executeCompiledRequest(req, retries-1, backoff*2, responseJSON, handler, dontReadResponse, sizeLimit, client) } -func readRequestBody(req *http.Request, res *http.Response) ([]byte, error) { - contents, err := io.ReadAll(res.Body) +func readResponseBody(req *http.Request, res *http.Response, limit int64) ([]byte, error) { + if res.ContentLength > limit { + return nil, HTTPError{ + Request: req, + Response: res, + + Message: "not reading response", + WrappedError: fmt.Errorf("%w (%.2f MiB)", ErrResponseTooLong, float64(res.ContentLength)/1024/1024), + } + } + contents, err := io.ReadAll(io.LimitReader(res.Body, limit+1)) + if err == nil && len(contents) > int(limit) { + err = ErrBodyReadReachedLimit + } if err != nil { return nil, HTTPError{ Request: req, @@ -478,17 +667,20 @@ func closeTemp(log *zerolog.Logger, file *os.File) { } } -func streamResponse(req *http.Request, res *http.Response, responseJSON interface{}) ([]byte, error) { +func streamResponse(req *http.Request, res *http.Response, responseJSON any, limit int64) ([]byte, error) { log := zerolog.Ctx(req.Context()) file, err := os.CreateTemp("", "mautrix-response-") if err != nil { log.Warn().Err(err).Msg("Failed to create temporary file for streaming response") - _, err = handleNormalResponse(req, res, responseJSON) + _, err = handleNormalResponse(req, res, responseJSON, limit) return nil, err } defer closeTemp(log, file) - if _, err = io.Copy(file, res.Body); err != nil { + var n int64 + if n, err = io.Copy(file, io.LimitReader(res.Body, limit+1)); err != nil { return nil, fmt.Errorf("failed to copy response to file: %w", err) + } else if n > limit { + return nil, ErrBodyReadReachedLimit } else if _, err = file.Seek(0, 0); err != nil { return nil, fmt.Errorf("failed to seek to beginning of response file: %w", err) } else if err = json.NewDecoder(file).Decode(responseJSON); err != nil { @@ -498,8 +690,12 @@ func streamResponse(req *http.Request, res *http.Response, responseJSON interfac } } -func handleNormalResponse(req *http.Request, res *http.Response, responseJSON interface{}) ([]byte, error) { - if contents, err := readRequestBody(req, res); err != nil { +func noopHandleResponse(req *http.Request, res *http.Response, responseJSON any, limit int64) ([]byte, error) { + return nil, nil +} + +func handleNormalResponse(req *http.Request, res *http.Response, responseJSON any, limit int64) ([]byte, error) { + if contents, err := readResponseBody(req, res, limit); err != nil { return nil, err } else if responseJSON == nil { return contents, nil @@ -517,13 +713,20 @@ func handleNormalResponse(req *http.Request, res *http.Response, responseJSON in } } +const ErrorResponseSizeLimit = 512 * 1024 + +var DefaultResponseSizeLimit int64 = 512 * 1024 * 1024 + func ParseErrorResponse(req *http.Request, res *http.Response) ([]byte, error) { - contents, err := readRequestBody(req, res) + defer res.Body.Close() + contents, err := readResponseBody(req, res, ErrorResponseSizeLimit) if err != nil { return contents, err } - respErr := &RespError{} + respErr := &RespError{ + StatusCode: res.StatusCode, + } if _ = json.Unmarshal(contents, respErr); respErr.ErrCode == "" { respErr = nil } @@ -535,17 +738,31 @@ func ParseErrorResponse(req *http.Request, res *http.Response) ([]byte, error) { } } -func (cli *Client) executeCompiledRequest(req *http.Request, retries int, backoff time.Duration, responseJSON interface{}, handler ClientResponseHandler) ([]byte, error) { +func (cli *Client) executeCompiledRequest( + req *http.Request, + retries int, + backoff time.Duration, + responseJSON any, + handler ClientResponseHandler, + dontReadResponse bool, + sizeLimit int64, + client *http.Client, +) ([]byte, *http.Response, error) { cli.RequestStart(req) startTime := time.Now() - res, err := cli.Client.Do(req) - duration := time.Now().Sub(startTime) - if res != nil { + res, err := client.Do(req) + duration := time.Since(startTime) + if res != nil && !dontReadResponse { defer res.Body.Close() } if err != nil { - if retries > 0 { - return cli.doRetry(req, err, retries, backoff, responseJSON, handler) + // Either error is *not* canceled or the underlying cause of cancelation explicitly asks to retry + canRetry := !errors.Is(err, context.Canceled) || + errors.Is(context.Cause(req.Context()), ErrContextCancelRetry) + if retries > 0 && canRetry { + return cli.doRetry( + req, err, retries, backoff, responseJSON, handler, dontReadResponse, sizeLimit, client, + ) } err = HTTPError{ Request: req, @@ -555,12 +772,14 @@ func (cli *Client) executeCompiledRequest(req *http.Request, retries int, backof WrappedError: err, } cli.LogRequestDone(req, res, err, nil, 0, duration) - return nil, err + return nil, res, err } if retries > 0 && retryafter.Should(res.StatusCode, !cli.IgnoreRateLimit) { backoff = retryafter.Parse(res.Header.Get("Retry-After"), backoff) - return cli.doRetry(req, fmt.Errorf("HTTP %d", res.StatusCode), retries, backoff, responseJSON, handler) + return cli.doRetry( + req, fmt.Errorf("HTTP %d", res.StatusCode), retries, backoff, responseJSON, handler, dontReadResponse, sizeLimit, client, + ) } var body []byte @@ -568,15 +787,14 @@ func (cli *Client) executeCompiledRequest(req *http.Request, retries int, backof body, err = ParseErrorResponse(req, res) cli.LogRequestDone(req, res, nil, nil, len(body), duration) } else { - body, err = handler(req, res, responseJSON) + body, err = handler(req, res, responseJSON, sizeLimit) cli.LogRequestDone(req, res, nil, err, len(body), duration) } - return body, err + return body, res, err } // Whoami gets the user ID of the current user. See https://spec.matrix.org/v1.2/client-server-api/#get_matrixclientv3accountwhoami func (cli *Client) Whoami(ctx context.Context) (resp *RespWhoami, err error) { - urlPath := cli.BuildClientURL("v3", "account", "whoami") _, err = cli.MakeRequest(ctx, http.MethodGet, urlPath, nil, &resp) return @@ -601,12 +819,15 @@ func (cli *Client) SyncRequest(ctx context.Context, timeout int, since, filterID } type ReqSync struct { - Timeout int - Since string - FilterID string - FullState bool - SetPresence event.Presence - StreamResponse bool + Timeout int + Since string + FilterID string + FullState bool + SetPresence event.Presence + StreamResponse bool + UseStateAfter bool + BeeperStreaming bool + Client *http.Client } func (req *ReqSync) BuildQuery() map[string]string { @@ -625,6 +846,12 @@ func (req *ReqSync) BuildQuery() map[string]string { if req.FullState { query["full_state"] = "true" } + if req.UseStateAfter { + query["use_state_after"] = "true" + } + if req.BeeperStreaming { + query["com.beeper.streaming"] = "true" + } return query } @@ -635,6 +862,7 @@ func (cli *Client) FullSyncRequest(ctx context.Context, req ReqSync) (resp *Resp Method: http.MethodGet, URL: urlPath, ResponseJSON: &resp, + Client: req.Client, // We don't want automatic retries for SyncRequest, the Sync() wrapper handles those. MaxAttempts: 1, } @@ -643,7 +871,7 @@ func (cli *Client) FullSyncRequest(ctx context.Context, req ReqSync) (resp *Resp } start := time.Now() _, err = cli.MakeFullRequest(ctx, fullReq) - duration := time.Now().Sub(start) + duration := time.Since(start) timeout := time.Duration(req.Timeout) * time.Millisecond buffer := 10 * time.Second if req.Since == "" { @@ -690,7 +918,7 @@ func (cli *Client) RegisterAvailable(ctx context.Context, username string) (resp return } -func (cli *Client) register(ctx context.Context, url string, req *ReqRegister) (resp *RespRegister, uiaResp *RespUserInteractive, err error) { +func (cli *Client) register(ctx context.Context, url string, req *ReqRegister[any]) (resp *RespRegister, uiaResp *RespUserInteractive, err error) { var bodyBytes []byte bodyBytes, err = cli.MakeFullRequest(ctx, FullRequest{ Method: http.MethodPost, @@ -714,7 +942,7 @@ func (cli *Client) register(ctx context.Context, url string, req *ReqRegister) ( // Register makes an HTTP request according to https://spec.matrix.org/v1.2/client-server-api/#post_matrixclientv3register // // Registers with kind=user. For kind=guest, see RegisterGuest. -func (cli *Client) Register(ctx context.Context, req *ReqRegister) (*RespRegister, *RespUserInteractive, error) { +func (cli *Client) Register(ctx context.Context, req *ReqRegister[any]) (*RespRegister, *RespUserInteractive, error) { u := cli.BuildClientURL("v3", "register") return cli.register(ctx, u, req) } @@ -723,7 +951,7 @@ func (cli *Client) Register(ctx context.Context, req *ReqRegister) (*RespRegiste // with kind=guest. // // For kind=user, see Register. -func (cli *Client) RegisterGuest(ctx context.Context, req *ReqRegister) (*RespRegister, *RespUserInteractive, error) { +func (cli *Client) RegisterGuest(ctx context.Context, req *ReqRegister[any]) (*RespRegister, *RespUserInteractive, error) { query := map[string]string{ "kind": "guest", } @@ -746,8 +974,8 @@ func (cli *Client) RegisterGuest(ctx context.Context, req *ReqRegister) (*RespRe // panic(err) // } // token := res.AccessToken -func (cli *Client) RegisterDummy(ctx context.Context, req *ReqRegister) (*RespRegister, error) { - res, uia, err := cli.Register(ctx, req) +func (cli *Client) RegisterDummy(ctx context.Context, req *ReqRegister[any]) (*RespRegister, error) { + _, uia, err := cli.Register(ctx, req) if err != nil && uia == nil { return nil, err } else if uia == nil { @@ -756,7 +984,7 @@ func (cli *Client) RegisterDummy(ctx context.Context, req *ReqRegister) (*RespRe return nil, errors.New("server does not support m.login.dummy") } req.Auth = BaseAuthData{Type: AuthTypeDummy, Session: uia.Session} - res, _, err = cli.Register(ctx, req) + res, _, err := cli.Register(ctx, req) if err != nil { return nil, err } @@ -806,6 +1034,22 @@ func (cli *Client) Login(ctx context.Context, req *ReqLogin) (resp *RespLogin, e return } +// Create a device for an appservice user using MSC4190. +func (cli *Client) CreateDeviceMSC4190(ctx context.Context, deviceID id.DeviceID, initialDisplayName string) error { + if len(deviceID) == 0 { + deviceID = id.DeviceID(strings.ToUpper(random.String(10))) + } + _, err := cli.MakeRequest(ctx, http.MethodPut, cli.BuildClientURL("v3", "devices", deviceID), &ReqPutDevice{ + DisplayName: initialDisplayName, + }, nil) + if err != nil { + return err + } + cli.DeviceID = deviceID + cli.SetAppServiceDeviceID = true + return nil +} + // Logout the current user. See https://spec.matrix.org/v1.2/client-server-api/#post_matrixclientv3logout // This does not clear the credentials from the client instance. See ClearCredentials() instead. func (cli *Client) Logout(ctx context.Context) (resp *RespLogout, err error) { @@ -826,6 +1070,9 @@ func (cli *Client) LogoutAll(ctx context.Context) (resp *RespLogout, err error) func (cli *Client) Versions(ctx context.Context) (resp *RespVersions, err error) { urlPath := cli.BuildClientURL("versions") _, err = cli.MakeRequest(ctx, http.MethodGet, urlPath, nil, &resp) + if resp != nil { + cli.SpecVersions = resp + } return } @@ -836,20 +1083,19 @@ func (cli *Client) Capabilities(ctx context.Context) (resp *RespCapabilities, er return } -// JoinRoom joins the client to a room ID or alias. See https://spec.matrix.org/v1.2/client-server-api/#post_matrixclientv3joinroomidoralias +// JoinRoom joins the client to a room ID or alias. See https://spec.matrix.org/v1.13/client-server-api/#post_matrixclientv3joinroomidoralias // -// If serverName is specified, this will be added as a query param to instruct the homeserver to join via that server. If content is specified, it will -// be JSON encoded and used as the request body. -func (cli *Client) JoinRoom(ctx context.Context, roomIDorAlias, serverName string, content interface{}) (resp *RespJoinRoom, err error) { - var urlPath string - if serverName != "" { - urlPath = cli.BuildURLWithQuery(ClientURLPath{"v3", "join", roomIDorAlias}, map[string]string{ - "server_name": serverName, - }) - } else { - urlPath = cli.BuildClientURL("v3", "join", roomIDorAlias) +// The last parameter contains optional extra fields and can be left nil. +func (cli *Client) JoinRoom(ctx context.Context, roomIDorAlias string, req *ReqJoinRoom) (resp *RespJoinRoom, err error) { + if req == nil { + req = &ReqJoinRoom{} } - _, err = cli.MakeRequest(ctx, http.MethodPost, urlPath, content, &resp) + urlPath := cli.BuildURLWithFullQuery(ClientURLPath{"v3", "join", roomIDorAlias}, func(q url.Values) { + if len(req.Via) > 0 { + q["via"] = req.Via + } + }) + _, err = cli.MakeRequest(ctx, http.MethodPost, urlPath, req, &resp) if err == nil && cli.StateStore != nil { err = cli.StateStore.SetMembership(ctx, resp.RoomID, cli.UserID, event.MembershipJoin) if err != nil { @@ -859,6 +1105,28 @@ func (cli *Client) JoinRoom(ctx context.Context, roomIDorAlias, serverName strin return } +// KnockRoom requests to join a room ID or alias. See https://spec.matrix.org/v1.13/client-server-api/#post_matrixclientv3knockroomidoralias +// +// The last parameter contains optional extra fields and can be left nil. +func (cli *Client) KnockRoom(ctx context.Context, roomIDorAlias string, req *ReqKnockRoom) (resp *RespKnockRoom, err error) { + if req == nil { + req = &ReqKnockRoom{} + } + urlPath := cli.BuildURLWithFullQuery(ClientURLPath{"v3", "knock", roomIDorAlias}, func(q url.Values) { + if len(req.Via) > 0 { + q["via"] = req.Via + } + }) + _, err = cli.MakeRequest(ctx, http.MethodPost, urlPath, req, &resp) + if err == nil && cli.StateStore != nil { + err = cli.StateStore.SetMembership(ctx, resp.RoomID, cli.UserID, event.MembershipKnock) + if err != nil { + err = fmt.Errorf("failed to update state store: %w", err) + } + } + return +} + // JoinRoomByID joins the client to a room ID. See https://spec.matrix.org/v1.2/client-server-api/#post_matrixclientv3roomsroomidjoin // // Unlike JoinRoom, this method can only be used to join rooms that the server already knows about. @@ -880,10 +1148,54 @@ func (cli *Client) GetProfile(ctx context.Context, mxid id.UserID) (resp *RespUs return } +func (cli *Client) SearchUserDirectory(ctx context.Context, query string, limit int) (resp *RespSearchUserDirectory, err error) { + urlPath := cli.BuildClientURL("v3", "user_directory", "search") + _, err = cli.MakeRequest(ctx, http.MethodPost, urlPath, &ReqSearchUserDirectory{ + SearchTerm: query, + Limit: limit, + }, &resp) + return +} + +func (cli *Client) GetMutualRooms(ctx context.Context, otherUserID id.UserID, extras ...ReqMutualRooms) (resp *RespMutualRooms, err error) { + supportsStable := cli.SpecVersions.Supports(FeatureStableMutualRooms) + supportsUnstable := cli.SpecVersions.Supports(FeatureUnstableMutualRooms) + if cli.SpecVersions != nil && !supportsUnstable && !supportsStable { + err = fmt.Errorf("server does not support fetching mutual rooms") + return + } + query := map[string]string{ + "user_id": otherUserID.String(), + } + if len(extras) > 0 { + query["from"] = extras[0].From + } + urlPath := cli.BuildURLWithQuery(ClientURLPath{"v1", "user", "mutual_rooms"}, query) + if !supportsStable && supportsUnstable { + urlPath = cli.BuildURLWithQuery(ClientURLPath{"unstable", "uk.half-shot.msc2666", "user", "mutual_rooms"}, query) + } + _, err = cli.MakeRequest(ctx, http.MethodGet, urlPath, nil, &resp) + return +} + +func (cli *Client) GetRoomSummary(ctx context.Context, roomIDOrAlias string, via ...string) (resp *RespRoomSummary, err error) { + urlPath := ClientURLPath{"unstable", "im.nheko.summary", "summary", roomIDOrAlias} + if cli.SpecVersions.ContainsGreaterOrEqual(SpecV115) { + urlPath = ClientURLPath{"v1", "room_summary", roomIDOrAlias} + } + // TODO add version check after one is added to MSC3266 + fullURL := cli.BuildURLWithFullQuery(urlPath, func(q url.Values) { + if len(via) > 0 { + q["via"] = via + } + }) + _, err = cli.MakeRequest(ctx, http.MethodGet, fullURL, nil, &resp) + return +} + // GetDisplayName returns the display name of the user with the specified MXID. See https://spec.matrix.org/v1.2/client-server-api/#get_matrixclientv3profileuseriddisplayname func (cli *Client) GetDisplayName(ctx context.Context, mxid id.UserID) (resp *RespUserDisplayName, err error) { - urlPath := cli.BuildClientURL("v3", "profile", mxid, "displayname") - _, err = cli.MakeRequest(ctx, http.MethodGet, urlPath, nil, &resp) + err = cli.GetProfileField(ctx, mxid, "displayname", &resp) return } @@ -894,25 +1206,47 @@ func (cli *Client) GetOwnDisplayName(ctx context.Context) (resp *RespUserDisplay // SetDisplayName sets the user's profile display name. See https://spec.matrix.org/v1.2/client-server-api/#put_matrixclientv3profileuseriddisplayname func (cli *Client) SetDisplayName(ctx context.Context, displayName string) (err error) { - urlPath := cli.BuildClientURL("v3", "profile", cli.UserID, "displayname") - s := struct { - DisplayName string `json:"displayname"` - }{displayName} - _, err = cli.MakeRequest(ctx, http.MethodPut, urlPath, &s, nil) + return cli.SetProfileField(ctx, "displayname", displayName) +} + +// SetProfileField sets an arbitrary profile field. See https://spec.matrix.org/v1.16/client-server-api/#put_matrixclientv3profileuseridkeyname +func (cli *Client) SetProfileField(ctx context.Context, key string, value any) (err error) { + urlPath := cli.BuildClientURL("v3", "profile", cli.UserID, key) + if key != "displayname" && key != "avatar_url" && !cli.SpecVersions.Supports(FeatureArbitraryProfileFields) && cli.SpecVersions.Supports(FeatureUnstableProfileFields) { + urlPath = cli.BuildClientURL("unstable", "uk.tcpip.msc4133", "profile", cli.UserID, key) + } + _, err = cli.MakeRequest(ctx, http.MethodPut, urlPath, map[string]any{ + key: value, + }, nil) + return +} + +// DeleteProfileField deletes an arbitrary profile field. See https://spec.matrix.org/v1.16/client-server-api/#put_matrixclientv3profileuseridkeyname +func (cli *Client) DeleteProfileField(ctx context.Context, key string) (err error) { + urlPath := cli.BuildClientURL("v3", "profile", cli.UserID, key) + if key != "displayname" && key != "avatar_url" && !cli.SpecVersions.Supports(FeatureArbitraryProfileFields) && cli.SpecVersions.Supports(FeatureUnstableProfileFields) { + urlPath = cli.BuildClientURL("unstable", "uk.tcpip.msc4133", "profile", cli.UserID, key) + } + _, err = cli.MakeRequest(ctx, http.MethodDelete, urlPath, nil, nil) + return +} + +// GetProfileField gets an arbitrary profile field and parses the response into the given struct. See https://spec.matrix.org/unstable/client-server-api/#get_matrixclientv3profileuseridkeyname +func (cli *Client) GetProfileField(ctx context.Context, userID id.UserID, key string, into any) (err error) { + urlPath := cli.BuildClientURL("v3", "profile", userID, key) + if key != "displayname" && key != "avatar_url" && !cli.SpecVersions.Supports(FeatureArbitraryProfileFields) && cli.SpecVersions.Supports(FeatureUnstableProfileFields) { + urlPath = cli.BuildClientURL("unstable", "uk.tcpip.msc4133", "profile", cli.UserID, key) + } + _, err = cli.MakeRequest(ctx, http.MethodGet, urlPath, nil, into) return } // GetAvatarURL gets the avatar URL of the user with the specified MXID. See https://spec.matrix.org/v1.2/client-server-api/#get_matrixclientv3profileuseridavatar_url func (cli *Client) GetAvatarURL(ctx context.Context, mxid id.UserID) (url id.ContentURI, err error) { - urlPath := cli.BuildClientURL("v3", "profile", mxid, "avatar_url") s := struct { AvatarURL id.ContentURI `json:"avatar_url"` }{} - - _, err = cli.MakeRequest(ctx, http.MethodGet, urlPath, nil, &s) - if err != nil { - return - } + err = cli.GetProfileField(ctx, mxid, "avatar_url", &s) url = s.AvatarURL return } @@ -937,9 +1271,9 @@ func (cli *Client) SetAvatarURL(ctx context.Context, url id.ContentURI) (err err } // BeeperUpdateProfile sets custom fields in the user's profile. -func (cli *Client) BeeperUpdateProfile(ctx context.Context, data map[string]any) (err error) { +func (cli *Client) BeeperUpdateProfile(ctx context.Context, data any) (err error) { urlPath := cli.BuildClientURL("v3", "profile", cli.UserID) - _, err = cli.MakeRequest(ctx, http.MethodPatch, urlPath, &data, nil) + _, err = cli.MakeRequest(ctx, http.MethodPatch, urlPath, data, nil) return } @@ -979,15 +1313,6 @@ func (cli *Client) SetRoomAccountData(ctx context.Context, roomID id.RoomID, nam return nil } -type ReqSendEvent struct { - Timestamp int64 - TransactionID string - - DontEncrypt bool - - MeowEventID id.EventID -} - // SendMessageEvent sends a message event into a room. See https://spec.matrix.org/v1.2/client-server-api/#put_matrixclientv3roomsroomidsendeventtypetxnid // contentJSON should be a pointer to something that can be encoded as JSON using json.Marshal. func (cli *Client) SendMessageEvent(ctx context.Context, roomID id.RoomID, eventType event.Type, contentJSON interface{}, extra ...ReqSendEvent) (resp *RespSendEvent, err error) { @@ -1010,8 +1335,14 @@ func (cli *Client) SendMessageEvent(ctx context.Context, roomID id.RoomID, event if req.MeowEventID != "" { queryParams["fi.mau.event_id"] = req.MeowEventID.String() } + if req.UnstableDelay > 0 { + queryParams["org.matrix.msc4140.delay"] = strconv.FormatInt(req.UnstableDelay.Milliseconds(), 10) + } + if req.UnstableStickyDuration > 0 { + queryParams["org.matrix.msc4354.sticky_duration_ms"] = strconv.FormatInt(req.UnstableStickyDuration.Milliseconds(), 10) + } - if !req.DontEncrypt && cli.Crypto != nil && eventType != event.EventReaction && eventType != event.EventEncrypted { + if !req.DontEncrypt && cli != nil && cli.Crypto != nil && eventType != event.EventReaction && eventType != event.EventEncrypted { var isEncrypted bool isEncrypted, err = cli.StateStore.IsEncrypted(ctx, roomID) if err != nil { @@ -1033,12 +1364,77 @@ func (cli *Client) SendMessageEvent(ctx context.Context, roomID id.RoomID, event return } -// SendStateEvent sends a state event into a room. See https://spec.matrix.org/v1.2/client-server-api/#put_matrixclientv3roomsroomidstateeventtypestatekey -// contentJSON should be a pointer to something that can be encoded as JSON using json.Marshal. -func (cli *Client) SendStateEvent(ctx context.Context, roomID id.RoomID, eventType event.Type, stateKey string, contentJSON interface{}) (resp *RespSendEvent, err error) { - urlPath := cli.BuildClientURL("v3", "rooms", roomID, "state", eventType.String(), stateKey) +// BeeperSendEphemeralEvent sends an ephemeral event into a room using Beeper's unstable endpoint. +// contentJSON should be a value that can be encoded as JSON using json.Marshal. +func (cli *Client) BeeperSendEphemeralEvent(ctx context.Context, roomID id.RoomID, eventType event.Type, contentJSON any, extra ...ReqSendEvent) (resp *RespSendEvent, err error) { + var req ReqSendEvent + if len(extra) > 0 { + req = extra[0] + } + + var txnID string + if len(req.TransactionID) > 0 { + txnID = req.TransactionID + } else { + txnID = cli.TxnID() + } + + queryParams := map[string]string{} + if req.Timestamp > 0 { + queryParams["ts"] = strconv.FormatInt(req.Timestamp, 10) + } + + if !req.DontEncrypt && cli != nil && cli.Crypto != nil && eventType != event.EventEncrypted { + var isEncrypted bool + isEncrypted, err = cli.StateStore.IsEncrypted(ctx, roomID) + if err != nil { + err = fmt.Errorf("failed to check if room is encrypted: %w", err) + return + } + if isEncrypted { + if contentJSON, err = cli.Crypto.Encrypt(ctx, roomID, eventType, contentJSON); err != nil { + err = fmt.Errorf("failed to encrypt event: %w", err) + return + } + eventType = event.EventEncrypted + } + } + + urlData := ClientURLPath{"unstable", "com.beeper.ephemeral", "rooms", roomID, "ephemeral", eventType.String(), txnID} + urlPath := cli.BuildURLWithQuery(urlData, queryParams) _, err = cli.MakeRequest(ctx, http.MethodPut, urlPath, contentJSON, &resp) - if err == nil && cli.StateStore != nil { + return +} + +// SendStateEvent sends a state event into a room. See https://spec.matrix.org/v1.16/client-server-api/#put_matrixclientv3roomsroomidstateeventtypestatekey +// contentJSON should be a pointer to something that can be encoded as JSON using json.Marshal. +func (cli *Client) SendStateEvent(ctx context.Context, roomID id.RoomID, eventType event.Type, stateKey string, contentJSON any, extra ...ReqSendEvent) (resp *RespSendEvent, err error) { + var req ReqSendEvent + if len(extra) > 0 { + req = extra[0] + } + + queryParams := map[string]string{} + if req.MeowEventID != "" { + queryParams["fi.mau.event_id"] = req.MeowEventID.String() + } + if req.TransactionID != "" { + queryParams["fi.mau.transaction_id"] = req.TransactionID + } + if req.UnstableDelay > 0 { + queryParams["org.matrix.msc4140.delay"] = strconv.FormatInt(req.UnstableDelay.Milliseconds(), 10) + } + if req.UnstableStickyDuration > 0 { + queryParams["org.matrix.msc4354.sticky_duration_ms"] = strconv.FormatInt(req.UnstableStickyDuration.Milliseconds(), 10) + } + if req.Timestamp > 0 { + queryParams["ts"] = strconv.FormatInt(req.Timestamp, 10) + } + + urlData := ClientURLPath{"v3", "rooms", roomID, "state", eventType.String(), stateKey} + urlPath := cli.BuildURLWithQuery(urlData, queryParams) + _, err = cli.MakeRequest(ctx, http.MethodPut, urlPath, contentJSON, &resp) + if err == nil && cli.StateStore != nil && req.UnstableDelay == 0 { cli.updateStoreWithOutgoingEvent(ctx, roomID, eventType, stateKey, contentJSON) } return @@ -1046,14 +1442,44 @@ func (cli *Client) SendStateEvent(ctx context.Context, roomID id.RoomID, eventTy // SendMassagedStateEvent sends a state event into a room with a custom timestamp. See https://spec.matrix.org/v1.2/client-server-api/#put_matrixclientv3roomsroomidstateeventtypestatekey // contentJSON should be a pointer to something that can be encoded as JSON using json.Marshal. +// +// Deprecated: SendStateEvent accepts a timestamp via ReqSendEvent and should be used instead. func (cli *Client) SendMassagedStateEvent(ctx context.Context, roomID id.RoomID, eventType event.Type, stateKey string, contentJSON interface{}, ts int64) (resp *RespSendEvent, err error) { - urlPath := cli.BuildURLWithQuery(ClientURLPath{"v3", "rooms", roomID, "state", eventType.String(), stateKey}, map[string]string{ - "ts": strconv.FormatInt(ts, 10), + resp, err = cli.SendStateEvent(ctx, roomID, eventType, stateKey, contentJSON, ReqSendEvent{ + Timestamp: ts, }) - _, err = cli.MakeRequest(ctx, http.MethodPut, urlPath, contentJSON, &resp) - if err == nil && cli.StateStore != nil { - cli.updateStoreWithOutgoingEvent(ctx, roomID, eventType, stateKey, contentJSON) + return +} + +func (cli *Client) DelayedEvents(ctx context.Context, req *ReqDelayedEvents) (resp *RespDelayedEvents, err error) { + query := map[string]string{} + if req.DelayID != "" { + query["delay_id"] = string(req.DelayID) } + if req.Status != "" { + query["status"] = string(req.Status) + } + if req.NextBatch != "" { + query["next_batch"] = req.NextBatch + } + + urlPath := cli.BuildURLWithQuery(ClientURLPath{"unstable", "org.matrix.msc4140", "delayed_events"}, query) + _, err = cli.MakeRequest(ctx, http.MethodGet, urlPath, req, &resp) + + // Migration: merge old keys with new ones + if resp != nil { + resp.Scheduled = append(resp.Scheduled, resp.DelayedEvents...) + resp.DelayedEvents = nil + resp.Finalised = append(resp.Finalised, resp.FinalisedEvents...) + resp.FinalisedEvents = nil + } + + return +} + +func (cli *Client) UpdateDelayedEvent(ctx context.Context, req *ReqUpdateDelayedEvent) (resp *RespUpdateDelayedEvent, err error) { + urlPath := cli.BuildClientURL("unstable", "org.matrix.msc4140", "delayed_events", req.DelayID) + _, err = cli.MakeRequest(ctx, http.MethodPost, urlPath, req, &resp) return } @@ -1108,6 +1534,19 @@ func (cli *Client) RedactEvent(ctx context.Context, roomID id.RoomID, eventID id return } +func (cli *Client) UnstableRedactUserEvents(ctx context.Context, roomID id.RoomID, userID id.UserID, req *ReqRedactUser) (resp *RespRedactUserEvents, err error) { + if req == nil { + req = &ReqRedactUser{} + } + query := map[string]string{} + if req.Limit > 0 { + query["limit"] = strconv.Itoa(req.Limit) + } + urlPath := cli.BuildURLWithQuery(ClientURLPath{"unstable", "org.matrix.msc4194", "rooms", roomID, "redact", "user", userID}, query) + _, err = cli.MakeRequest(ctx, http.MethodPost, urlPath, req, &resp) + return +} + // CreateRoom creates a new Matrix room. See https://spec.matrix.org/v1.2/client-server-api/#post_matrixclientv3createroom // // resp, err := cli.CreateRoom(&mautrix.ReqCreateRoom{ @@ -1125,6 +1564,10 @@ func (cli *Client) CreateRoom(ctx context.Context, req *ReqCreateRoom) (resp *Re Msg("Failed to update creator membership in state store after creating room") } for _, evt := range req.InitialState { + evt.RoomID = resp.RoomID + if evt.StateKey == nil { + evt.StateKey = ptr.Ptr("") + } UpdateStateStore(ctx, cli.StateStore, evt) } inviteMembership := event.MembershipInvite @@ -1139,9 +1582,6 @@ func (cli *Client) CreateRoom(ctx context.Context, req *ReqCreateRoom) (resp *Re Msg("Failed to update membership in state store after creating room") } } - for _, evt := range req.InitialState { - cli.updateStoreWithOutgoingEvent(ctx, resp.RoomID, evt.Type, evt.GetStateKey(), &evt.Content) - } } return } @@ -1252,15 +1692,14 @@ func (cli *Client) GetOwnPresence(ctx context.Context) (resp *RespPresence, err return cli.GetPresence(ctx, cli.UserID) } -func (cli *Client) SetPresence(ctx context.Context, status event.Presence) (err error) { - req := ReqPresence{Presence: status} +func (cli *Client) SetPresence(ctx context.Context, presence ReqPresence) (err error) { u := cli.BuildClientURL("v3", "presence", cli.UserID, "status") - _, err = cli.MakeRequest(ctx, http.MethodPut, u, req, nil) + _, err = cli.MakeRequest(ctx, http.MethodPut, u, presence, nil) return } func (cli *Client) updateStoreWithOutgoingEvent(ctx context.Context, roomID id.RoomID, eventType event.Type, stateKey string, contentJSON interface{}) { - if cli.StateStore == nil { + if cli == nil || cli.StateStore == nil { return } fakeEvt := &event.Event{ @@ -1292,8 +1731,8 @@ func (cli *Client) updateStoreWithOutgoingEvent(ctx context.Context, roomID id.R UpdateStateStore(ctx, cli.StateStore, fakeEvt) } -// StateEvent gets a single state event in a room. It will attempt to JSON unmarshal into the given "outContent" struct with -// the HTTP response body, or return an error. +// StateEvent gets the content of a single state event in a room. +// It will attempt to JSON unmarshal into the given "outContent" struct with the HTTP response body, or return an error. // See https://spec.matrix.org/v1.2/client-server-api/#get_matrixclientv3roomsroomidstateeventtypestatekey func (cli *Client) StateEvent(ctx context.Context, roomID id.RoomID, eventType event.Type, stateKey string, outContent interface{}) (err error) { u := cli.BuildClientURL("v3", "rooms", roomID, "state", eventType.String(), stateKey) @@ -1304,12 +1743,43 @@ func (cli *Client) StateEvent(ctx context.Context, roomID id.RoomID, eventType e return } +// FullStateEvent gets a single state event in a room. Unlike [StateEvent], this gets the entire event +// (including details like the sender and timestamp). +// This requires the server to support the ?format=event query parameter, which is currently missing from the spec. +// See https://github.com/matrix-org/matrix-spec/issues/1047 for more info +func (cli *Client) FullStateEvent(ctx context.Context, roomID id.RoomID, eventType event.Type, stateKey string) (evt *event.Event, err error) { + u := cli.BuildURLWithQuery(ClientURLPath{"v3", "rooms", roomID, "state", eventType.String(), stateKey}, map[string]string{ + "format": "event", + }) + _, err = cli.MakeRequest(ctx, http.MethodGet, u, nil, &evt) + if evt != nil { + evt.Type.Class = event.StateEventType + _ = evt.Content.ParseRaw(evt.Type) + if evt.RoomID == "" { + evt.RoomID = roomID + } + } + if err == nil && cli.StateStore != nil { + UpdateStateStore(ctx, cli.StateStore, evt) + } + return +} + // parseRoomStateArray parses a JSON array as a stream and stores the events inside it in a room state map. -func parseRoomStateArray(_ *http.Request, res *http.Response, responseJSON interface{}) ([]byte, error) { +func parseRoomStateArray(req *http.Request, res *http.Response, responseJSON any, limit int64) ([]byte, error) { + if res.ContentLength > limit { + return nil, HTTPError{ + Request: req, + Response: res, + + Message: "not reading response", + WrappedError: fmt.Errorf("%w (%.2f MiB)", ErrResponseTooLong, float64(res.ContentLength)/1024/1024), + } + } response := make(RoomStateMap) responsePtr := responseJSON.(*map[event.Type]map[string]*event.Event) *responsePtr = response - dec := json.NewDecoder(res.Body) + dec := json.NewDecoder(io.LimitReader(res.Body, limit)) arrayStart, err := dec.Token() if err != nil { @@ -1343,6 +1813,8 @@ func parseRoomStateArray(_ *http.Request, res *http.Response, responseJSON inter return nil, nil } +type RoomStateMap = map[event.Type]map[string]*event.Event + // State gets all state in a room. // See https://spec.matrix.org/v1.2/client-server-api/#get_matrixclientv3roomsroomidstate func (cli *Client) State(ctx context.Context, roomID id.RoomID) (stateMap RoomStateMap, err error) { @@ -1352,29 +1824,63 @@ func (cli *Client) State(ctx context.Context, roomID id.RoomID) (stateMap RoomSt ResponseJSON: &stateMap, Handler: parseRoomStateArray, }) + if stateMap != nil { + pls, ok := stateMap[event.StatePowerLevels][""] + if ok { + pls.Content.AsPowerLevels().CreateEvent = stateMap[event.StateCreate][""] + } + } if err == nil && cli.StateStore != nil { - clearErr := cli.StateStore.ClearCachedMembers(ctx, roomID) - cli.cliOrContextLog(ctx).Warn().Err(clearErr). - Stringer("room_id", roomID). - Msg("Failed to clear cached member list after fetching state") - for _, evts := range stateMap { + for evtType, evts := range stateMap { + if evtType == event.StateMember { + continue + } for _, evt := range evts { + if evt.RoomID == "" { + evt.RoomID = roomID + } UpdateStateStore(ctx, cli.StateStore, evt) } } + updateErr := cli.StateStore.ReplaceCachedMembers(ctx, roomID, maps.Values(stateMap[event.StateMember])) + if updateErr != nil { + cli.cliOrContextLog(ctx).Warn().Err(updateErr). + Stringer("room_id", roomID). + Msg("Failed to update members in state store after fetching members") + } + } + return +} + +// StateAsArray gets all the state in a room as an array. It does not update the state store. +// Use State to get the events as a map and also update the state store. +func (cli *Client) StateAsArray(ctx context.Context, roomID id.RoomID) (state []*event.Event, err error) { + _, err = cli.MakeRequest(ctx, http.MethodGet, cli.BuildClientURL("v3", "rooms", roomID, "state"), nil, &state) + if err == nil { + for _, evt := range state { + evt.Type.Class = event.StateEventType + } } return } // GetMediaConfig fetches the configuration of the content repository, such as upload limitations. func (cli *Client) GetMediaConfig(ctx context.Context) (resp *RespMediaConfig, err error) { - u := cli.BuildURL(MediaURLPath{"v3", "config"}) - _, err = cli.MakeRequest(ctx, http.MethodGet, u, nil, &resp) + _, err = cli.MakeRequest(ctx, http.MethodGet, cli.BuildClientURL("v1", "media", "config"), nil, &resp) + return +} + +func (cli *Client) RequestOpenIDToken(ctx context.Context) (resp *RespOpenIDToken, err error) { + _, err = cli.MakeRequest(ctx, http.MethodPost, cli.BuildClientURL("v3", "user", cli.UserID, "openid", "request_token"), nil, &resp) return } // UploadLink uploads an HTTP URL and then returns an MXC URI. func (cli *Client) UploadLink(ctx context.Context, link string) (*RespMediaUpload, error) { + if cli == nil { + return nil, ErrClientIsNil + } + req, err := http.NewRequestWithContext(ctx, http.MethodGet, link, nil) if err != nil { return nil, err @@ -1390,89 +1896,55 @@ func (cli *Client) UploadLink(ctx context.Context, link string) (*RespMediaUploa return cli.Upload(ctx, res.Body, res.Header.Get("Content-Type"), res.ContentLength) } -func (cli *Client) GetDownloadURL(mxcURL id.ContentURI) string { - return cli.BuildURLWithQuery(MediaURLPath{"v3", "download", mxcURL.Homeserver, mxcURL.FileID}, map[string]string{"allow_redirect": "true"}) +func (cli *Client) Download(ctx context.Context, mxcURL id.ContentURI) (*http.Response, error) { + if mxcURL.IsEmpty() { + return nil, fmt.Errorf("empty mxc uri provided to Download") + } + _, resp, err := cli.MakeFullRequestWithResp(ctx, FullRequest{ + Method: http.MethodGet, + URL: cli.BuildClientURL("v1", "media", "download", mxcURL.Homeserver, mxcURL.FileID), + DontReadResponse: true, + }) + return resp, err } -func (cli *Client) Download(ctx context.Context, mxcURL id.ContentURI) (io.ReadCloser, error) { - resp, err := cli.download(ctx, mxcURL) - if err != nil { - return nil, err - } - return resp.Body, nil +type DownloadThumbnailExtra struct { + Method string + Animated bool } -func (cli *Client) doMediaRetry(req *http.Request, cause error, retries int, backoff time.Duration) (*http.Response, error) { - log := zerolog.Ctx(req.Context()) - if req.Body != nil { - if req.GetBody == nil { - log.Warn().Msg("Failed to get new body to retry request: GetBody is nil") - return nil, cause - } - var err error - req.Body, err = req.GetBody() - if err != nil { - log.Warn().Err(err).Msg("Failed to get new body to retry request") - return nil, cause - } +func (cli *Client) DownloadThumbnail(ctx context.Context, mxcURL id.ContentURI, height, width int, extras ...DownloadThumbnailExtra) (*http.Response, error) { + if mxcURL.IsEmpty() { + return nil, fmt.Errorf("empty mxc uri provided to DownloadThumbnail") } - log.Warn().Err(cause). - Int("retry_in_seconds", int(backoff.Seconds())). - Msg("Request failed, retrying") - time.Sleep(backoff) - return cli.doMediaRequest(req, retries-1, backoff*2) -} - -func (cli *Client) doMediaRequest(req *http.Request, retries int, backoff time.Duration) (*http.Response, error) { - cli.RequestStart(req) - startTime := time.Now() - res, err := cli.Client.Do(req) - duration := time.Now().Sub(startTime) - if err != nil { - if retries > 0 { - return cli.doMediaRetry(req, err, retries, backoff) - } - err = HTTPError{ - Request: req, - Response: res, - - Message: "request error", - WrappedError: err, - } - cli.LogRequestDone(req, res, err, nil, 0, duration) - return nil, err + if len(extras) > 1 { + panic(fmt.Errorf("invalid number of arguments to DownloadThumbnail: %d", len(extras))) } - - if retries > 0 && retryafter.Should(res.StatusCode, !cli.IgnoreRateLimit) { - backoff = retryafter.Parse(res.Header.Get("Retry-After"), backoff) - return cli.doMediaRetry(req, fmt.Errorf("HTTP %d", res.StatusCode), retries, backoff) + var extra DownloadThumbnailExtra + if len(extras) == 1 { + extra = extras[0] } - - if res.StatusCode < 200 || res.StatusCode >= 300 { - var body []byte - body, err = ParseErrorResponse(req, res) - cli.LogRequestDone(req, res, err, nil, len(body), duration) - } else { - cli.LogRequestDone(req, res, nil, nil, -1, duration) + path := ClientURLPath{"v1", "media", "thumbnail", mxcURL.Homeserver, mxcURL.FileID} + query := map[string]string{ + "height": strconv.Itoa(height), + "width": strconv.Itoa(width), } - return res, err -} - -func (cli *Client) download(ctx context.Context, mxcURL id.ContentURI) (*http.Response, error) { - ctxLog := zerolog.Ctx(ctx) - if ctxLog.GetLevel() == zerolog.Disabled || ctxLog == zerolog.DefaultContextLogger { - ctx = cli.Log.WithContext(ctx) + if extra.Method != "" { + query["method"] = extra.Method } - req, err := http.NewRequestWithContext(ctx, http.MethodGet, cli.GetDownloadURL(mxcURL), nil) - if err != nil { - return nil, err + if extra.Animated { + query["animated"] = "true" } - req.Header.Set("User-Agent", cli.UserAgent+" (media downloader)") - return cli.doMediaRequest(req, cli.DefaultHTTPRetries, 4*time.Second) + _, resp, err := cli.MakeFullRequestWithResp(ctx, FullRequest{ + Method: http.MethodGet, + URL: cli.BuildURLWithQuery(path, query), + DontReadResponse: true, + }) + return resp, err } func (cli *Client) DownloadBytes(ctx context.Context, mxcURL id.ContentURI) ([]byte, error) { - resp, err := cli.download(ctx, mxcURL) + resp, err := cli.Download(ctx, mxcURL) if err != nil { return nil, err } @@ -1480,17 +1952,27 @@ func (cli *Client) DownloadBytes(ctx context.Context, mxcURL id.ContentURI) ([]b return io.ReadAll(resp.Body) } +type ReqCreateMXC struct { + BeeperUniqueID string + BeeperRoomID id.RoomID +} + // CreateMXC creates a blank Matrix content URI to allow uploading the content asynchronously later. // // See https://spec.matrix.org/v1.7/client-server-api/#post_matrixmediav1create -func (cli *Client) CreateMXC(ctx context.Context) (*RespCreateMXC, error) { - u, _ := url.Parse(cli.BuildURL(MediaURLPath{"v1", "create"})) +func (cli *Client) CreateMXC(ctx context.Context, extra ...ReqCreateMXC) (*RespCreateMXC, error) { var m RespCreateMXC - _, err := cli.MakeFullRequest(ctx, FullRequest{ - Method: http.MethodPost, - URL: u.String(), - ResponseJSON: &m, - }) + query := map[string]string{} + if len(extra) > 0 { + if extra[0].BeeperUniqueID != "" { + query["com.beeper.unique_id"] = extra[0].BeeperUniqueID + } + if extra[0].BeeperRoomID != "" { + query["com.beeper.room_id"] = string(extra[0].BeeperRoomID) + } + } + createURL := cli.BuildURLWithQuery(MediaURLPath{"v1", "create"}, query) + _, err := cli.MakeRequest(ctx, http.MethodPost, createURL, nil, &m) return &m, err } @@ -1502,14 +1984,20 @@ func (cli *Client) CreateMXC(ctx context.Context) (*RespCreateMXC, error) { func (cli *Client) UploadAsync(ctx context.Context, req ReqUploadMedia) (*RespCreateMXC, error) { resp, err := cli.CreateMXC(ctx) if err != nil { + req.DoneCallback() return nil, err } req.MXC = resp.ContentURI req.UnstableUploadURL = resp.UnstableUploadURL + if req.AsyncContext == nil { + req.AsyncContext = cli.cliOrContextLog(ctx).WithContext(context.Background()) + } go func() { - _, err = cli.UploadMedia(ctx, req) + _, err = cli.UploadMedia(req.AsyncContext, req) if err != nil { - cli.Log.Error().Str("mxc", req.MXC.String()).Err(err).Msg("Async upload of media failed") + zerolog.Ctx(req.AsyncContext).Err(err). + Stringer("mxc", req.MXC). + Msg("Async upload of media failed") } }() return resp, nil @@ -1545,6 +2033,9 @@ type ReqUploadMedia struct { ContentType string FileName string + AsyncContext context.Context + DoneCallback func() + // MXC specifies an existing MXC URI which doesn't have content yet to upload into. // See https://spec.matrix.org/unstable/client-server-api/#put_matrixmediav3uploadservernamemediaid MXC id.ContentURI @@ -1554,47 +2045,74 @@ type ReqUploadMedia struct { UnstableUploadURL string } -func (cli *Client) tryUploadMediaToURL(ctx context.Context, url, contentType string, content io.Reader) (*http.Response, error) { - cli.Log.Debug().Str("url", url).Msg("Uploading media to external URL") +func (cli *Client) tryUploadMediaToURL(ctx context.Context, url, contentType string, content io.Reader, contentLength int64) (*http.Response, error) { + cli.Log.Debug(). + Str("url", url). + Int64("content_length", contentLength). + Msg("Uploading media to external URL") req, err := http.NewRequestWithContext(ctx, http.MethodPut, url, content) if err != nil { return nil, err } + req.ContentLength = contentLength req.Header.Set("Content-Type", contentType) - req.Header.Set("User-Agent", cli.UserAgent+" (external media uploader)") + if cli.UserAgent != "" { + req.Header.Set("User-Agent", cli.UserAgent+" (external media uploader)") + } - return http.DefaultClient.Do(req) + if cli.ExternalClient != nil { + return cli.ExternalClient.Do(req) + } else { + return http.DefaultClient.Do(req) + } } func (cli *Client) uploadMediaToURL(ctx context.Context, data ReqUploadMedia) (*RespMediaUpload, error) { retries := cli.DefaultHTTPRetries - if data.ContentBytes == nil { - // Can't retry with a reader + reader := data.Content + if data.ContentBytes != nil { + data.ContentLength = int64(len(data.ContentBytes)) + reader = bytes.NewReader(data.ContentBytes) + } else if rsc, ok := reader.(io.ReadSeekCloser); ok { + // Prevent HTTP from closing the request body, it might be needed for retries + reader = nopCloseSeeker{rsc} + } + readerSeeker, canSeek := reader.(io.ReadSeeker) + if !canSeek { retries = 0 } for { - reader := data.Content - if reader == nil { - reader = bytes.NewReader(data.ContentBytes) - } else { - data.Content = nil - } - resp, err := cli.tryUploadMediaToURL(ctx, data.UnstableUploadURL, data.ContentType, reader) + resp, err := cli.tryUploadMediaToURL(ctx, data.UnstableUploadURL, data.ContentType, reader, data.ContentLength) if err == nil { if resp.StatusCode >= 200 && resp.StatusCode < 300 { // Everything is fine break } err = fmt.Errorf("HTTP %d", resp.StatusCode) + } else if errors.Is(err, context.Canceled) { + cli.Log.Warn().Str("url", data.UnstableUploadURL).Msg("External media upload canceled") + return nil, err } if retries <= 0 { cli.Log.Warn().Str("url", data.UnstableUploadURL).Err(err). Msg("Error uploading media to external URL, not retrying") return nil, err } - cli.Log.Warn().Str("url", data.UnstableUploadURL).Err(err). + backoff := time.Second * time.Duration(cli.DefaultHTTPRetries-retries) + cli.Log.Warn().Err(err). + Str("url", data.UnstableUploadURL). + Int("retry_in_seconds", int(backoff.Seconds())). Msg("Error uploading media to external URL, retrying") + select { + case <-time.After(backoff): + case <-ctx.Done(): + return nil, ctx.Err() + } retries-- + _, err = readerSeeker.Seek(0, io.SeekStart) + if err != nil { + return nil, fmt.Errorf("failed to seek back to start of reader: %w", err) + } } query := map[string]string{} @@ -1605,11 +2123,7 @@ func (cli *Client) uploadMediaToURL(ctx context.Context, data ReqUploadMedia) (* notifyURL := cli.BuildURLWithQuery(MediaURLPath{"unstable", "com.beeper.msc3870", "upload", data.MXC.Homeserver, data.MXC.FileID, "complete"}, query) var m *RespMediaUpload - _, err := cli.MakeFullRequest(ctx, FullRequest{ - Method: http.MethodPost, - URL: notifyURL, - ResponseJSON: m, - }) + _, err := cli.MakeRequest(ctx, http.MethodPost, notifyURL, nil, &m) if err != nil { return nil, err } @@ -1617,9 +2131,23 @@ func (cli *Client) uploadMediaToURL(ctx context.Context, data ReqUploadMedia) (* return m, nil } +type nopCloseSeeker struct { + io.ReadSeeker +} + +func (nopCloseSeeker) Close() error { + return nil +} + // UploadMedia uploads the given data to the content repository and returns an MXC URI. // See https://spec.matrix.org/v1.7/client-server-api/#post_matrixmediav3upload func (cli *Client) UploadMedia(ctx context.Context, data ReqUploadMedia) (*RespMediaUpload, error) { + if data.DoneCallback != nil { + defer data.DoneCallback() + } + if cli == nil { + return nil, ErrClientIsNil + } if data.UnstableUploadURL != "" { if data.MXC.IsEmpty() { return nil, errors.New("MXC must also be set when uploading to external URL") @@ -1660,7 +2188,7 @@ func (cli *Client) UploadMedia(ctx context.Context, data ReqUploadMedia) (*RespM // // See https://spec.matrix.org/v1.2/client-server-api/#get_matrixmediav3preview_url func (cli *Client) GetURLPreview(ctx context.Context, url string) (*RespPreviewURL, error) { - reqURL := cli.BuildURLWithQuery(MediaURLPath{"v3", "preview_url"}, map[string]string{ + reqURL := cli.BuildURLWithQuery(ClientURLPath{"v1", "media", "preview_url"}, map[string]string{ "url": url, }) var output RespPreviewURL @@ -1676,24 +2204,26 @@ func (cli *Client) JoinedMembers(ctx context.Context, roomID id.RoomID) (resp *R u := cli.BuildClientURL("v3", "rooms", roomID, "joined_members") _, err = cli.MakeRequest(ctx, http.MethodGet, u, nil, &resp) if err == nil && cli.StateStore != nil { - clearErr := cli.StateStore.ClearCachedMembers(ctx, roomID, event.MembershipJoin) - if clearErr != nil { - cli.cliOrContextLog(ctx).Warn().Err(clearErr). - Stringer("room_id", roomID). - Msg("Failed to clear cached member list after fetching joined members") - } + fakeEvents := make([]*event.Event, len(resp.Joined)) + i := 0 for userID, member := range resp.Joined { - updateErr := cli.StateStore.SetMember(ctx, roomID, userID, &event.MemberEventContent{ - Membership: event.MembershipJoin, - AvatarURL: id.ContentURIString(member.AvatarURL), - Displayname: member.DisplayName, - }) - if updateErr != nil { - cli.cliOrContextLog(ctx).Warn().Err(updateErr). - Stringer("room_id", roomID). - Stringer("user_id", userID). - Msg("Failed to update membership in state store after fetching joined members") + fakeEvents[i] = &event.Event{ + StateKey: ptr.Ptr(userID.String()), + Type: event.StateMember, + RoomID: roomID, + Content: event.Content{Parsed: &event.MemberEventContent{ + Membership: event.MembershipJoin, + AvatarURL: id.ContentURIString(member.AvatarURL), + Displayname: member.DisplayName, + }}, } + i++ + } + updateErr := cli.StateStore.ReplaceCachedMembers(ctx, roomID, fakeEvents, event.MembershipJoin) + if updateErr != nil { + cli.cliOrContextLog(ctx).Warn().Err(updateErr). + Stringer("room_id", roomID). + Msg("Failed to update members in state store after fetching joined members") } } return @@ -1716,21 +2246,26 @@ func (cli *Client) Members(ctx context.Context, roomID id.RoomID, req ...ReqMemb } u := cli.BuildURLWithQuery(ClientURLPath{"v3", "rooms", roomID, "members"}, query) _, err = cli.MakeRequest(ctx, http.MethodGet, u, nil, &resp) - if err == nil && cli.StateStore != nil { - var clearMemberships []event.Membership - if extra.Membership != "" { - clearMemberships = append(clearMemberships, extra.Membership) - } - if extra.NotMembership == "" { - clearErr := cli.StateStore.ClearCachedMembers(ctx, roomID, clearMemberships...) - if clearErr != nil { - cli.cliOrContextLog(ctx).Warn().Err(clearErr). - Stringer("room_id", roomID). - Msg("Failed to clear cached member list after fetching joined members") - } - } + if err == nil { for _, evt := range resp.Chunk { - UpdateStateStore(ctx, cli.StateStore, evt) + _ = evt.Content.ParseRaw(evt.Type) + } + } + if err == nil && cli.StateStore != nil { + var onlyMemberships []event.Membership + if extra.Membership != "" { + onlyMemberships = []event.Membership{extra.Membership} + } else if extra.NotMembership != "" { + onlyMemberships = []event.Membership{event.MembershipJoin, event.MembershipLeave, event.MembershipInvite, event.MembershipBan, event.MembershipKnock} + onlyMemberships = slices.DeleteFunc(onlyMemberships, func(m event.Membership) bool { + return m == extra.NotMembership + }) + } + updateErr := cli.StateStore.ReplaceCachedMembers(ctx, roomID, resp.Chunk, onlyMemberships...) + if updateErr != nil { + cli.cliOrContextLog(ctx).Warn().Err(updateErr). + Stringer("room_id", roomID). + Msg("Failed to update members in state store after fetching members") } } return @@ -1746,6 +2281,12 @@ func (cli *Client) JoinedRooms(ctx context.Context) (resp *RespJoinedRooms, err return } +func (cli *Client) PublicRooms(ctx context.Context, req *ReqPublicRooms) (resp *RespPublicRooms, err error) { + urlPath := cli.BuildURLWithQuery(ClientURLPath{"v3", "publicRooms"}, req.Query()) + _, err = cli.MakeRequest(ctx, http.MethodGet, urlPath, nil, &resp) + return +} + // Hierarchy returns a list of rooms that are in the room's hierarchy. See https://spec.matrix.org/v1.4/client-server-api/#get_matrixclientv1roomsroomidhierarchy // // The hierarchy API is provided to walk the space tree and discover the rooms with their aesthetic details. works in a depth-first manner: @@ -1760,11 +2301,10 @@ func (cli *Client) Hierarchy(ctx context.Context, roomID id.RoomID, req *ReqHier // Messages returns a list of message and state events for a room. It uses // pagination query parameters to paginate history in the room. -// See https://spec.matrix.org/v1.2/client-server-api/#get_matrixclientv3roomsroomidmessages +// See https://spec.matrix.org/v1.12/client-server-api/#get_matrixclientv3roomsroomidmessages func (cli *Client) Messages(ctx context.Context, roomID id.RoomID, from, to string, dir Direction, filter *FilterPart, limit int) (resp *RespMessages, err error) { query := map[string]string{ - "from": from, - "dir": string(dir), + "dir": string(dir), } if filter != nil { filterJSON, err := json.Marshal(filter) @@ -1773,6 +2313,9 @@ func (cli *Client) Messages(ctx context.Context, roomID id.RoomID, from, to stri } query["filter"] = string(filterJSON) } + if from != "" { + query["from"] = from + } if to != "" { query["to"] = to } @@ -1826,6 +2369,20 @@ func (cli *Client) GetEvent(ctx context.Context, roomID id.RoomID, eventID id.Ev return } +func (cli *Client) GetUnredactedEventContent(ctx context.Context, roomID id.RoomID, eventID id.EventID) (resp *event.Event, err error) { + urlPath := cli.BuildURLWithQuery(ClientURLPath{"v3", "rooms", roomID, "event", eventID}, map[string]string{ + "fi.mau.msc2815.include_unredacted_content": "true", + }) + _, err = cli.MakeRequest(ctx, http.MethodGet, urlPath, nil, &resp) + return +} + +func (cli *Client) GetRelations(ctx context.Context, roomID id.RoomID, eventID id.EventID, req *ReqGetRelations) (resp *RespGetRelations, err error) { + urlPath := cli.BuildURLWithQuery(append(ClientURLPath{"v1", "rooms", roomID, "relations", eventID}, req.PathSuffix()...), req.Query()) + _, err = cli.MakeRequest(ctx, http.MethodGet, urlPath, nil, &resp) + return +} + func (cli *Client) MarkRead(ctx context.Context, roomID id.RoomID, eventID id.EventID) (err error) { return cli.SendReceipt(ctx, roomID, eventID, event.ReceiptTypeRead, nil) } @@ -1853,15 +2410,19 @@ func (cli *Client) SetReadMarkers(ctx context.Context, roomID id.RoomID, content return } -func (cli *Client) AddTag(ctx context.Context, roomID id.RoomID, tag string, order float64) error { - var tagData event.Tag - if order == order { - tagData.Order = json.Number(strconv.FormatFloat(order, 'e', -1, 64)) - } - return cli.AddTagWithCustomData(ctx, roomID, tag, tagData) +func (cli *Client) SetBeeperInboxState(ctx context.Context, roomID id.RoomID, content *ReqSetBeeperInboxState) (err error) { + urlPath := cli.BuildClientURL("unstable", "com.beeper.inbox", "user", cli.UserID, "rooms", roomID, "inbox_state") + _, err = cli.MakeRequest(ctx, http.MethodPut, urlPath, content, nil) + return } -func (cli *Client) AddTagWithCustomData(ctx context.Context, roomID id.RoomID, tag string, data interface{}) (err error) { +func (cli *Client) AddTag(ctx context.Context, roomID id.RoomID, tag event.RoomTag, order float64) error { + return cli.AddTagWithCustomData(ctx, roomID, tag, &event.TagMetadata{ + Order: json.Number(strconv.FormatFloat(order, 'e', -1, 64)), + }) +} + +func (cli *Client) AddTagWithCustomData(ctx context.Context, roomID id.RoomID, tag event.RoomTag, data any) (err error) { urlPath := cli.BuildClientURL("v3", "user", cli.UserID, "rooms", roomID, "tags", tag) _, err = cli.MakeRequest(ctx, http.MethodPut, urlPath, data, nil) return @@ -1872,13 +2433,13 @@ func (cli *Client) GetTags(ctx context.Context, roomID id.RoomID) (tags event.Ta return } -func (cli *Client) GetTagsWithCustomData(ctx context.Context, roomID id.RoomID, resp interface{}) (err error) { +func (cli *Client) GetTagsWithCustomData(ctx context.Context, roomID id.RoomID, resp any) (err error) { urlPath := cli.BuildClientURL("v3", "user", cli.UserID, "rooms", roomID, "tags") _, err = cli.MakeRequest(ctx, http.MethodGet, urlPath, nil, &resp) return } -func (cli *Client) RemoveTag(ctx context.Context, roomID id.RoomID, tag string) (err error) { +func (cli *Client) RemoveTag(ctx context.Context, roomID id.RoomID, tag event.RoomTag) (err error) { urlPath := cli.BuildClientURL("v3", "user", cli.UserID, "rooms", roomID, "tags", tag) _, err = cli.MakeRequest(ctx, http.MethodDelete, urlPath, nil, nil) return @@ -2126,15 +2687,15 @@ func (cli *Client) SetDeviceInfo(ctx context.Context, deviceID id.DeviceID, req return err } -func (cli *Client) DeleteDevice(ctx context.Context, deviceID id.DeviceID, req *ReqDeleteDevice) error { +func (cli *Client) DeleteDevice(ctx context.Context, deviceID id.DeviceID, req *ReqDeleteDevice[any]) error { urlPath := cli.BuildClientURL("v3", "devices", deviceID) _, err := cli.MakeRequest(ctx, http.MethodDelete, urlPath, req, nil) return err } -func (cli *Client) DeleteDevices(ctx context.Context, req *ReqDeleteDevices) error { +func (cli *Client) DeleteDevices(ctx context.Context, req *ReqDeleteDevices[any]) error { urlPath := cli.BuildClientURL("v3", "delete_devices") - _, err := cli.MakeRequest(ctx, http.MethodDelete, urlPath, req, nil) + _, err := cli.MakeRequest(ctx, http.MethodPost, urlPath, req, nil) return err } @@ -2143,7 +2704,7 @@ type UIACallback = func(*RespUserInteractive) interface{} // UploadCrossSigningKeys uploads the given cross-signing keys to the server. // Because the endpoint requires user-interactive authentication a callback must be provided that, // given the UI auth parameters, produces the required result (or nil to end the flow). -func (cli *Client) UploadCrossSigningKeys(ctx context.Context, keys *UploadCrossSigningKeysReq, uiaCallback UIACallback) error { +func (cli *Client) UploadCrossSigningKeys(ctx context.Context, keys *UploadCrossSigningKeysReq[any], uiaCallback UIACallback) error { content, err := cli.MakeFullRequest(ctx, FullRequest{ Method: http.MethodPost, URL: cli.BuildClientURL("v3", "keys", "device_signing", "upload"), @@ -2213,24 +2774,73 @@ func (cli *Client) PutPushRule(ctx context.Context, scope string, kind pushrules return err } -// BatchSend sends a batch of historical events into a room. This is only available for appservices. +func (cli *Client) ReportEvent(ctx context.Context, roomID id.RoomID, eventID id.EventID, reason string) error { + urlPath := cli.BuildClientURL("v3", "rooms", roomID, "report", eventID) + _, err := cli.MakeRequest(ctx, http.MethodPost, urlPath, &ReqReport{Reason: reason, Score: -100}, nil) + return err +} + +func (cli *Client) ReportRoom(ctx context.Context, roomID id.RoomID, reason string) error { + urlPath := cli.BuildClientURL("v3", "rooms", roomID, "report") + _, err := cli.MakeRequest(ctx, http.MethodPost, urlPath, &ReqReport{Reason: reason, Score: -100}, nil) + return err +} + +// AdminWhoIs fetches session information belonging to a specific user. Typically requires being a server admin. // -// Deprecated: MSC2716 has been abandoned, so this is now Beeper-specific. BeeperBatchSend should be used instead. -func (cli *Client) BatchSend(ctx context.Context, roomID id.RoomID, req *ReqBatchSend) (resp *RespBatchSend, err error) { - path := ClientURLPath{"unstable", "org.matrix.msc2716", "rooms", roomID, "batch_send"} - query := map[string]string{ - "prev_event_id": req.PrevEventID.String(), +// https://spec.matrix.org/v1.15/client-server-api/#get_matrixclientv3adminwhoisuserid +func (cli *Client) AdminWhoIs(ctx context.Context, userID id.UserID) (resp RespWhoIs, err error) { + urlPath := cli.BuildClientURL("v3", "admin", "whois", userID) + _, err = cli.MakeRequest(ctx, http.MethodGet, urlPath, nil, &resp) + return +} + +func (cli *Client) makeMSC4323URL(action string, target id.UserID) string { + if cli.SpecVersions.Supports(FeatureUnstableAccountModeration) { + return cli.BuildClientURL("unstable", "uk.timedout.msc4323", "admin", action, target) + } else if cli.SpecVersions.Supports(FeatureStableAccountModeration) { + return cli.BuildClientURL("v1", "admin", action, target) } - if req.BeeperNewMessages { - query["com.beeper.new_messages"] = "true" + return "" +} + +// GetSuspendedStatus uses MSC4323 to check if a user is suspended. +func (cli *Client) GetSuspendedStatus(ctx context.Context, userID id.UserID) (res *RespSuspended, err error) { + urlPath := cli.makeMSC4323URL("suspend", userID) + if urlPath == "" { + return nil, MUnrecognized.WithMessage("Homeserver does not advertise MSC4323 support") } - if req.BeeperMarkReadBy != "" { - query["com.beeper.mark_read_by"] = req.BeeperMarkReadBy.String() + _, err = cli.MakeRequest(ctx, http.MethodGet, urlPath, nil, res) + return +} + +// GetLockStatus uses MSC4323 to check if a user is locked. +func (cli *Client) GetLockStatus(ctx context.Context, userID id.UserID) (res *RespLocked, err error) { + urlPath := cli.makeMSC4323URL("lock", userID) + if urlPath == "" { + return nil, MUnrecognized.WithMessage("Homeserver does not advertise MSC4323 support") } - if len(req.BatchID) > 0 { - query["batch_id"] = req.BatchID.String() + _, err = cli.MakeRequest(ctx, http.MethodGet, urlPath, nil, res) + return +} + +// SetSuspendedStatus uses MSC4323 to set whether a user account is suspended. +func (cli *Client) SetSuspendedStatus(ctx context.Context, userID id.UserID, suspended bool) (res *RespSuspended, err error) { + urlPath := cli.makeMSC4323URL("suspend", userID) + if urlPath == "" { + return nil, MUnrecognized.WithMessage("Homeserver does not advertise MSC4323 support") } - _, err = cli.MakeRequest(ctx, http.MethodPost, cli.BuildURLWithQuery(path, query), req, &resp) + _, err = cli.MakeRequest(ctx, http.MethodPut, urlPath, &ReqSuspend{Suspended: suspended}, res) + return +} + +// SetLockStatus uses MSC4323 to set whether a user account is locked. +func (cli *Client) SetLockStatus(ctx context.Context, userID id.UserID, locked bool) (res *RespLocked, err error) { + urlPath := cli.makeMSC4323URL("lock", userID) + if urlPath == "" { + return nil, MUnrecognized.WithMessage("Homeserver does not advertise MSC4323 support") + } + _, err = cli.MakeRequest(ctx, http.MethodPut, urlPath, &ReqLocked{Locked: locked}, res) return } @@ -2271,6 +2881,9 @@ func (cli *Client) BeeperDeleteRoom(ctx context.Context, roomID id.RoomID) (err // TxnID returns the next transaction ID. func (cli *Client) TxnID() string { + if cli == nil { + return "client is nil" + } txnID := atomic.AddInt32(&cli.txnID, 1) return fmt.Sprintf("mautrix-go_%d_%d", time.Now().UnixNano(), txnID) } diff --git a/client_ephemeral_test.go b/client_ephemeral_test.go new file mode 100644 index 00000000..c2846427 --- /dev/null +++ b/client_ephemeral_test.go @@ -0,0 +1,158 @@ +// Copyright (c) 2026 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package mautrix_test + +import ( + "context" + "encoding/json" + "errors" + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "maunium.net/go/mautrix" + "maunium.net/go/mautrix/event" + "maunium.net/go/mautrix/id" +) + +func TestClient_SendEphemeralEvent_UsesUnstablePathTxnAndTS(t *testing.T) { + roomID := id.RoomID("!room:example.com") + evtType := event.Type{Type: "com.example.ephemeral", Class: event.EphemeralEventType} + txnID := "txn-123" + + var gotPath string + var gotQueryTS string + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + gotPath = r.URL.Path + gotQueryTS = r.URL.Query().Get("ts") + assert.Equal(t, http.MethodPut, r.Method) + _, _ = w.Write([]byte(`{"event_id":"$evt"}`)) + })) + defer ts.Close() + + cli, err := mautrix.NewClient(ts.URL, "", "") + require.NoError(t, err) + + _, err = cli.BeeperSendEphemeralEvent( + context.Background(), + roomID, + evtType, + map[string]any{"foo": "bar"}, + mautrix.ReqSendEvent{TransactionID: txnID, Timestamp: 1234}, + ) + require.NoError(t, err) + + assert.True(t, strings.Contains(gotPath, "/_matrix/client/unstable/com.beeper.ephemeral/rooms/")) + assert.True(t, strings.HasSuffix(gotPath, "/ephemeral/com.example.ephemeral/"+txnID)) + assert.Equal(t, "1234", gotQueryTS) +} + +func TestClient_SendEphemeralEvent_UnsupportedReturnsMUnrecognized(t *testing.T) { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusNotFound) + _, _ = w.Write([]byte(`{"errcode":"M_UNRECOGNIZED","error":"Unrecognized endpoint"}`)) + })) + defer ts.Close() + + cli, err := mautrix.NewClient(ts.URL, "", "") + require.NoError(t, err) + + _, err = cli.BeeperSendEphemeralEvent( + context.Background(), + id.RoomID("!room:example.com"), + event.Type{Type: "com.example.ephemeral", Class: event.EphemeralEventType}, + map[string]any{"foo": "bar"}, + ) + require.Error(t, err) + assert.True(t, errors.Is(err, mautrix.MUnrecognized)) +} + +func TestClient_SendEphemeralEvent_EncryptsInEncryptedRooms(t *testing.T) { + roomID := id.RoomID("!room:example.com") + evtType := event.Type{Type: "com.example.ephemeral", Class: event.EphemeralEventType} + txnID := "txn-encrypted" + + stateStore := mautrix.NewMemoryStateStore() + err := stateStore.SetEncryptionEvent(context.Background(), roomID, &event.EncryptionEventContent{ + Algorithm: id.AlgorithmMegolmV1, + }) + require.NoError(t, err) + + fakeCrypto := &fakeCryptoHelper{ + encryptedContent: &event.EncryptedEventContent{ + Algorithm: id.AlgorithmMegolmV1, + MegolmCiphertext: []byte("ciphertext"), + }, + } + + var gotPath string + var gotBody map[string]any + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + gotPath = r.URL.Path + assert.Equal(t, http.MethodPut, r.Method) + err := json.NewDecoder(r.Body).Decode(&gotBody) + require.NoError(t, err) + _, _ = w.Write([]byte(`{"event_id":"$evt"}`)) + })) + defer ts.Close() + + cli, err := mautrix.NewClient(ts.URL, "", "") + require.NoError(t, err) + cli.StateStore = stateStore + cli.Crypto = fakeCrypto + + _, err = cli.BeeperSendEphemeralEvent( + context.Background(), + roomID, + evtType, + map[string]any{"foo": "bar"}, + mautrix.ReqSendEvent{TransactionID: txnID}, + ) + require.NoError(t, err) + + assert.True(t, strings.HasSuffix(gotPath, "/ephemeral/m.room.encrypted/"+txnID)) + assert.Equal(t, string(id.AlgorithmMegolmV1), gotBody["algorithm"]) + assert.Equal(t, 1, fakeCrypto.encryptCalls) + assert.Equal(t, roomID, fakeCrypto.lastRoomID) + assert.Equal(t, evtType, fakeCrypto.lastEventType) +} + +type fakeCryptoHelper struct { + encryptCalls int + lastRoomID id.RoomID + lastEventType event.Type + lastEncryptInput any + encryptedContent *event.EncryptedEventContent +} + +func (f *fakeCryptoHelper) Encrypt(_ context.Context, roomID id.RoomID, eventType event.Type, content any) (*event.EncryptedEventContent, error) { + f.encryptCalls++ + f.lastRoomID = roomID + f.lastEventType = eventType + f.lastEncryptInput = content + return f.encryptedContent, nil +} + +func (f *fakeCryptoHelper) Decrypt(context.Context, *event.Event) (*event.Event, error) { + return nil, nil +} + +func (f *fakeCryptoHelper) WaitForSession(context.Context, id.RoomID, id.SenderKey, id.SessionID, time.Duration) bool { + return false +} + +func (f *fakeCryptoHelper) RequestSession(context.Context, id.RoomID, id.SenderKey, id.SessionID, id.UserID, id.DeviceID) { +} + +func (f *fakeCryptoHelper) Init(context.Context) error { + return nil +} diff --git a/commands/container.go b/commands/container.go new file mode 100644 index 00000000..9b909b75 --- /dev/null +++ b/commands/container.go @@ -0,0 +1,133 @@ +// Copyright (c) 2026 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package commands + +import ( + "fmt" + "slices" + "strings" + "sync" + + "go.mau.fi/util/exmaps" + + "maunium.net/go/mautrix/event/cmdschema" +) + +type CommandContainer[MetaType any] struct { + commands map[string]*Handler[MetaType] + aliases map[string]string + lock sync.RWMutex + parent *Handler[MetaType] +} + +func NewCommandContainer[MetaType any]() *CommandContainer[MetaType] { + return &CommandContainer[MetaType]{ + commands: make(map[string]*Handler[MetaType]), + aliases: make(map[string]string), + } +} + +func (cont *CommandContainer[MetaType]) AllSpecs() []*cmdschema.EventContent { + data := make(exmaps.Set[*Handler[MetaType]]) + cont.collectHandlers(data) + specs := make([]*cmdschema.EventContent, 0, data.Size()) + for handler := range data.Iter() { + if handler.Parameters != nil { + specs = append(specs, handler.Spec()) + } + } + return specs +} + +func (cont *CommandContainer[MetaType]) collectHandlers(into exmaps.Set[*Handler[MetaType]]) { + cont.lock.RLock() + defer cont.lock.RUnlock() + for _, handler := range cont.commands { + into.Add(handler) + if handler.subcommandContainer != nil { + handler.subcommandContainer.collectHandlers(into) + } + } +} + +// Register registers the given command handlers. +func (cont *CommandContainer[MetaType]) Register(handlers ...*Handler[MetaType]) { + if cont == nil { + return + } + cont.lock.Lock() + defer cont.lock.Unlock() + for i, handler := range handlers { + if handler == nil { + panic(fmt.Errorf("handler #%d is nil", i+1)) + } + cont.registerOne(handler) + } +} + +func (cont *CommandContainer[MetaType]) registerOne(handler *Handler[MetaType]) { + if strings.ToLower(handler.Name) != handler.Name { + panic(fmt.Errorf("command %q is not lowercase", handler.Name)) + } else if val, alreadyExists := cont.commands[handler.Name]; alreadyExists && val != handler { + panic(fmt.Errorf("tried to register command %q, but it's already registered", handler.Name)) + } else if aliasTarget, alreadyExists := cont.aliases[handler.Name]; alreadyExists { + panic(fmt.Errorf("tried to register command %q, but it's already registered as an alias for %q", handler.Name, aliasTarget)) + } + if !slices.Contains(handler.parents, cont.parent) { + handler.parents = append(handler.parents, cont.parent) + handler.nestedNameCache = nil + } + cont.commands[handler.Name] = handler + for _, alias := range handler.Aliases { + if strings.ToLower(alias) != alias { + panic(fmt.Errorf("alias %q is not lowercase", alias)) + } else if val, alreadyExists := cont.aliases[alias]; alreadyExists && val != handler.Name { + panic(fmt.Errorf("tried to register alias %q for %q, but it's already registered for %q", alias, handler.Name, cont.aliases[alias])) + } else if _, alreadyExists = cont.commands[alias]; alreadyExists { + panic(fmt.Errorf("tried to register alias %q for %q, but it's already registered as a command", alias, handler.Name)) + } + cont.aliases[alias] = handler.Name + } + handler.initSubcommandContainer() +} + +func (cont *CommandContainer[MetaType]) Unregister(handlers ...*Handler[MetaType]) { + if cont == nil { + return + } + cont.lock.Lock() + defer cont.lock.Unlock() + for _, handler := range handlers { + cont.unregisterOne(handler) + } +} + +func (cont *CommandContainer[MetaType]) unregisterOne(handler *Handler[MetaType]) { + delete(cont.commands, handler.Name) + for _, alias := range handler.Aliases { + if cont.aliases[alias] == handler.Name { + delete(cont.aliases, alias) + } + } +} + +func (cont *CommandContainer[MetaType]) GetHandler(name string) *Handler[MetaType] { + if cont == nil { + return nil + } + cont.lock.RLock() + defer cont.lock.RUnlock() + alias, ok := cont.aliases[name] + if ok { + name = alias + } + handler, ok := cont.commands[name] + if !ok { + handler = cont.commands[UnknownCommandName] + } + return handler +} diff --git a/commands/event.go b/commands/event.go new file mode 100644 index 00000000..76d6c9f0 --- /dev/null +++ b/commands/event.go @@ -0,0 +1,237 @@ +// Copyright (c) 2026 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package commands + +import ( + "context" + "encoding/json" + "fmt" + "strings" + + "github.com/rs/zerolog" + + "maunium.net/go/mautrix/event" + "maunium.net/go/mautrix/format" + "maunium.net/go/mautrix/id" +) + +// Event contains the data of a single command event. +// It also provides some helper methods for responding to the command. +type Event[MetaType any] struct { + *event.Event + // RawInput is the entire message before splitting into command and arguments. + RawInput string + // ParentCommands is the chain of commands leading up to this command. + // This is only set if the command is a subcommand. + ParentCommands []string + ParentHandlers []*Handler[MetaType] + // Command is the lowercased first word of the message. + Command string + // Args are the rest of the message split by whitespace ([strings.Fields]). + Args []string + // RawArgs is the same as args, but without the splitting by whitespace. + RawArgs string + + StructuredArgs json.RawMessage + + Ctx context.Context + Log *zerolog.Logger + Proc *Processor[MetaType] + Handler *Handler[MetaType] + Meta MetaType + + redactedBy id.EventID +} + +var IDHTMLParser = &format.HTMLParser{ + PillConverter: func(displayname, mxid, eventID string, ctx format.Context) string { + if len(mxid) == 0 { + return displayname + } + if eventID != "" { + return fmt.Sprintf("https://matrix.to/#/%s/%s", mxid, eventID) + } + return mxid + }, + ItalicConverter: func(s string, c format.Context) string { + return fmt.Sprintf("*%s*", s) + }, + Newline: "\n", +} + +// ParseEvent parses a message into a command event struct. +func (proc *Processor[MetaType]) ParseEvent(ctx context.Context, evt *event.Event) *Event[MetaType] { + content, ok := evt.Content.Parsed.(*event.MessageEventContent) + if !ok || content.MsgType == event.MsgNotice || content.RelatesTo.GetReplaceID() != "" { + return nil + } + text := content.Body + if content.Format == event.FormatHTML { + text = IDHTMLParser.Parse(content.FormattedBody, format.NewContext(ctx)) + } + if content.MSC4391BotCommand != nil { + if !content.Mentions.Has(proc.Client.UserID) || len(content.Mentions.UserIDs) != 1 { + return nil + } + wrapped := StructuredCommandToEvent[MetaType](ctx, evt, content.MSC4391BotCommand) + wrapped.RawInput = text + return wrapped + } + if len(text) == 0 { + return nil + } + return RawTextToEvent[MetaType](ctx, evt, text) +} + +func StructuredCommandToEvent[MetaType any](ctx context.Context, evt *event.Event, content *event.MSC4391BotCommandInput) *Event[MetaType] { + commandParts := strings.Split(content.Command, " ") + return &Event[MetaType]{ + Event: evt, + // Fake a command and args to let the subcommand finder in Process work. + Command: commandParts[0], + Args: commandParts[1:], + Ctx: ctx, + Log: zerolog.Ctx(ctx), + + StructuredArgs: content.Arguments, + } +} + +func RawTextToEvent[MetaType any](ctx context.Context, evt *event.Event, text string) *Event[MetaType] { + parts := strings.Fields(text) + if len(parts) == 0 { + parts = []string{""} + } + return &Event[MetaType]{ + Event: evt, + RawInput: text, + Command: strings.ToLower(parts[0]), + Args: parts[1:], + RawArgs: strings.TrimLeft(strings.TrimPrefix(text, parts[0]), " "), + Log: zerolog.Ctx(ctx), + Ctx: ctx, + } +} + +type ReplyOpts struct { + AllowHTML bool + AllowMarkdown bool + Reply bool + Thread bool + SendAsText bool + Edit id.EventID + OverrideMentions *event.Mentions + Extra map[string]any +} + +func (evt *Event[MetaType]) Reply(msg string, args ...any) id.EventID { + if len(args) > 0 { + msg = fmt.Sprintf(msg, args...) + } + return evt.Respond(msg, ReplyOpts{AllowMarkdown: true, Reply: true}) +} + +func (evt *Event[MetaType]) Respond(msg string, opts ReplyOpts) id.EventID { + content := format.RenderMarkdown(msg, opts.AllowMarkdown, opts.AllowHTML) + if opts.Thread { + content.SetThread(evt.Event) + } + if opts.Reply { + content.SetReply(evt.Event) + } + if !opts.SendAsText { + content.MsgType = event.MsgNotice + } + if opts.Edit != "" { + content.SetEdit(opts.Edit) + } + if opts.OverrideMentions != nil { + content.Mentions = opts.OverrideMentions + } + var wrapped any = &content + if opts.Extra != nil { + wrapped = &event.Content{ + Parsed: &content, + Raw: opts.Extra, + } + } + resp, err := evt.Proc.Client.SendMessageEvent(evt.Ctx, evt.RoomID, event.EventMessage, wrapped) + if err != nil { + zerolog.Ctx(evt.Ctx).Err(err).Msg("Failed to send reply") + return "" + } + return resp.EventID +} + +func (evt *Event[MetaType]) React(emoji string) id.EventID { + resp, err := evt.Proc.Client.SendReaction(evt.Ctx, evt.RoomID, evt.ID, emoji) + if err != nil { + zerolog.Ctx(evt.Ctx).Err(err).Msg("Failed to send reaction") + return "" + } + return resp.EventID +} + +func (evt *Event[MetaType]) Redact() id.EventID { + if evt.redactedBy != "" { + return evt.redactedBy + } + resp, err := evt.Proc.Client.RedactEvent(evt.Ctx, evt.RoomID, evt.ID) + if err != nil { + zerolog.Ctx(evt.Ctx).Err(err).Msg("Failed to redact command") + return "" + } + evt.redactedBy = resp.EventID + return resp.EventID +} + +func (evt *Event[MetaType]) MarkRead() { + err := evt.Proc.Client.MarkRead(evt.Ctx, evt.RoomID, evt.ID) + if err != nil { + zerolog.Ctx(evt.Ctx).Err(err).Msg("Failed to send read receipt") + } +} + +// ShiftArg removes the first argument from the Args list and RawArgs data and returns it. +// RawInput will not be modified. +func (evt *Event[MetaType]) ShiftArg() string { + if len(evt.Args) == 0 { + return "" + } + firstArg := evt.Args[0] + evt.RawArgs = strings.TrimLeft(strings.TrimPrefix(evt.RawArgs, evt.Args[0]), " ") + evt.Args = evt.Args[1:] + return firstArg +} + +// UnshiftArg reverses ShiftArg by adding the given value to the beginning of the Args list and RawArgs data. +func (evt *Event[MetaType]) UnshiftArg(arg string) { + evt.RawArgs = arg + " " + evt.RawArgs + evt.Args = append([]string{arg}, evt.Args...) +} + +func (evt *Event[MetaType]) ParseArgs(into any) error { + return json.Unmarshal(evt.StructuredArgs, into) +} + +func ParseArgs[T, MetaType any](evt *Event[MetaType]) (into T, err error) { + err = evt.ParseArgs(&into) + return +} + +func WithParsedArgs[T, MetaType any](fn func(*Event[MetaType], T)) func(*Event[MetaType]) { + return func(evt *Event[MetaType]) { + parsed, err := ParseArgs[T, MetaType](evt) + if err != nil { + evt.Log.Debug().Err(err).Msg("Failed to parse structured args into struct") + // TODO better error, usage info? deduplicate with Process + evt.Reply("Failed to parse arguments: %v", err) + return + } + fn(evt, parsed) + } +} diff --git a/commands/handler.go b/commands/handler.go new file mode 100644 index 00000000..56f27f06 --- /dev/null +++ b/commands/handler.go @@ -0,0 +1,105 @@ +// Copyright (c) 2026 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package commands + +import ( + "strings" + + "maunium.net/go/mautrix/event" + "maunium.net/go/mautrix/event/cmdschema" +) + +type Handler[MetaType any] struct { + // Func is the function that is called when the command is executed. + Func func(ce *Event[MetaType]) + + // Name is the primary name of the command. It must be lowercase. + Name string + // Aliases are alternative names for the command. They must be lowercase. + Aliases []string + // Subcommands are subcommands of this command. + Subcommands []*Handler[MetaType] + // PreFunc is a function that is called before checking subcommands. + // It can be used to have parameters between subcommands (e.g. `!rooms `). + // Event.ShiftArg will likely be useful for implementing such parameters. + PreFunc func(ce *Event[MetaType]) + + // Description is a short description of the command. + Description *event.ExtensibleTextContainer + // Parameters is a description of structured command parameters. + // If set, the StructuredArgs field of Event will be populated. + Parameters []*cmdschema.Parameter + TailParam string + + parents []*Handler[MetaType] + nestedNameCache []string + subcommandContainer *CommandContainer[MetaType] +} + +func (h *Handler[MetaType]) NestedNames() []string { + if h.nestedNameCache != nil { + return h.nestedNameCache + } + nestedNames := make([]string, 0, (1+len(h.Aliases))*len(h.parents)) + for _, parent := range h.parents { + if parent == nil { + nestedNames = append(nestedNames, h.Name) + nestedNames = append(nestedNames, h.Aliases...) + } else { + for _, parentName := range parent.NestedNames() { + nestedNames = append(nestedNames, parentName+" "+h.Name) + for _, alias := range h.Aliases { + nestedNames = append(nestedNames, parentName+" "+alias) + } + } + } + } + h.nestedNameCache = nestedNames + return nestedNames +} + +func (h *Handler[MetaType]) Spec() *cmdschema.EventContent { + names := h.NestedNames() + return &cmdschema.EventContent{ + Command: names[0], + Aliases: names[1:], + Parameters: h.Parameters, + Description: h.Description, + TailParam: h.TailParam, + } +} + +func (h *Handler[MetaType]) CopyFrom(other *Handler[MetaType]) { + if h.Parameters == nil { + h.Parameters = other.Parameters + h.TailParam = other.TailParam + } + h.Func = other.Func +} + +func (h *Handler[MetaType]) initSubcommandContainer() { + if len(h.Subcommands) > 0 { + h.subcommandContainer = NewCommandContainer[MetaType]() + h.subcommandContainer.parent = h + h.subcommandContainer.Register(h.Subcommands...) + } else { + h.subcommandContainer = nil + } +} + +func MakeUnknownCommandHandler[MetaType any](prefix string) *Handler[MetaType] { + return &Handler[MetaType]{ + Name: UnknownCommandName, + Func: func(ce *Event[MetaType]) { + if len(ce.ParentCommands) == 0 { + ce.Reply("Unknown command `%s%s`", prefix, ce.Command) + } else { + ce.Reply("Unknown subcommand `%s%s %s`", prefix, strings.Join(ce.ParentCommands, " "), ce.Command) + } + }, + } +} diff --git a/commands/prevalidate.go b/commands/prevalidate.go new file mode 100644 index 00000000..facca4da --- /dev/null +++ b/commands/prevalidate.go @@ -0,0 +1,84 @@ +// Copyright (c) 2025 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package commands + +import ( + "strings" +) + +// A PreValidator contains a function that takes an Event and returns true if the event should be processed further. +// +// The [PreValidator] field in [Processor] is called before the handler of the command is checked. +// It can be used to modify the command or arguments, or to skip the command entirely. +// +// The primary use case is removing a static command prefix, such as requiring all commands start with `!`. +type PreValidator[MetaType any] interface { + Validate(*Event[MetaType]) bool +} + +// FuncPreValidator is a simple function that implements the PreValidator interface. +type FuncPreValidator[MetaType any] func(*Event[MetaType]) bool + +func (f FuncPreValidator[MetaType]) Validate(ce *Event[MetaType]) bool { + return f(ce) +} + +// AllPreValidator can be used to combine multiple PreValidators, such that +// all of them must return true for the command to be processed further. +type AllPreValidator[MetaType any] []PreValidator[MetaType] + +func (f AllPreValidator[MetaType]) Validate(ce *Event[MetaType]) bool { + for _, validator := range f { + if !validator.Validate(ce) { + return false + } + } + return true +} + +// AnyPreValidator can be used to combine multiple PreValidators, such that +// at least one of them must return true for the command to be processed further. +type AnyPreValidator[MetaType any] []PreValidator[MetaType] + +func (f AnyPreValidator[MetaType]) Validate(ce *Event[MetaType]) bool { + for _, validator := range f { + if validator.Validate(ce) { + return true + } + } + return false +} + +// ValidatePrefixCommand checks that the first word in the input is exactly the given string, +// and if so, removes it from the command and sets the command to the next word. +// +// For example, `ValidateCommandPrefix("!mybot")` would only allow commands in the form `!mybot foo`, +// where `foo` would be used to look up the command handler. +func ValidatePrefixCommand[MetaType any](prefix string) PreValidator[MetaType] { + return FuncPreValidator[MetaType](func(ce *Event[MetaType]) bool { + if ce.Command == prefix && len(ce.Args) > 0 { + ce.Command = strings.ToLower(ce.ShiftArg()) + return true + } + return false + }) +} + +// ValidatePrefixSubstring checks that the command starts with the given prefix, +// and if so, removes it from the command. +// +// For example, `ValidatePrefixSubstring("!")` would only allow commands in the form `!foo`, +// where `foo` would be used to look up the command handler. +func ValidatePrefixSubstring[MetaType any](prefix string) PreValidator[MetaType] { + return FuncPreValidator[MetaType](func(ce *Event[MetaType]) bool { + if strings.HasPrefix(ce.Command, prefix) { + ce.Command = ce.Command[len(prefix):] + return true + } + return false + }) +} diff --git a/commands/processor.go b/commands/processor.go new file mode 100644 index 00000000..80f6745d --- /dev/null +++ b/commands/processor.go @@ -0,0 +1,152 @@ +// Copyright (c) 2026 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package commands + +import ( + "context" + "runtime/debug" + "strings" + + "github.com/rs/zerolog" + + "maunium.net/go/mautrix" + "maunium.net/go/mautrix/event" +) + +// Processor implements boilerplate code for splitting messages into a command and arguments, +// and finding the appropriate handler for the command. +type Processor[MetaType any] struct { + *CommandContainer[MetaType] + + Client *mautrix.Client + LogArgs bool + PreValidator PreValidator[MetaType] + Meta MetaType + + ReactionCommandPrefix string +} + +// UnknownCommandName is the name of the fallback handler which is used if no other handler is found. +// If even the unknown command handler is not found, the command is ignored. +const UnknownCommandName = "__unknown-command__" + +func NewProcessor[MetaType any](cli *mautrix.Client) *Processor[MetaType] { + proc := &Processor[MetaType]{ + CommandContainer: NewCommandContainer[MetaType](), + Client: cli, + PreValidator: ValidatePrefixSubstring[MetaType]("!"), + } + proc.Register(MakeUnknownCommandHandler[MetaType]("!")) + return proc +} + +func (proc *Processor[MetaType]) Process(ctx context.Context, evt *event.Event) { + log := zerolog.Ctx(ctx).With(). + Stringer("sender", evt.Sender). + Stringer("room_id", evt.RoomID). + Stringer("event_id", evt.ID). + Logger() + defer func() { + panicErr := recover() + if panicErr != nil { + logEvt := log.Error(). + Bytes(zerolog.ErrorStackFieldName, debug.Stack()) + if realErr, ok := panicErr.(error); ok { + logEvt = logEvt.Err(realErr) + } else { + logEvt = logEvt.Any(zerolog.ErrorFieldName, panicErr) + } + logEvt.Msg("Panic in command handler") + _, err := proc.Client.SendReaction(ctx, evt.RoomID, evt.ID, "💥") + if err != nil { + log.Err(err).Msg("Failed to send reaction after panic") + } + } + }() + var parsed *Event[MetaType] + switch evt.Type { + case event.EventReaction: + parsed = proc.ParseReaction(ctx, evt) + case event.EventMessage: + parsed = proc.ParseEvent(ctx, evt) + } + if parsed == nil || (!proc.PreValidator.Validate(parsed) && parsed.StructuredArgs == nil) { + return + } + parsed.Proc = proc + parsed.Meta = proc.Meta + parsed.Ctx = ctx + + handler := proc.GetHandler(parsed.Command) + if handler == nil { + return + } + parsed.Handler = handler + if handler.PreFunc != nil { + handler.PreFunc(parsed) + } + handlerChain := zerolog.Arr() + handlerChain.Str(handler.Name) + for handler.subcommandContainer != nil && len(parsed.Args) > 0 { + subHandler := handler.subcommandContainer.GetHandler(strings.ToLower(parsed.Args[0])) + if subHandler != nil { + parsed.ParentCommands = append(parsed.ParentCommands, parsed.Command) + parsed.ParentHandlers = append(parsed.ParentHandlers, handler) + handler = subHandler + handlerChain.Str(subHandler.Name) + parsed.Command = strings.ToLower(parsed.ShiftArg()) + parsed.Handler = subHandler + if subHandler.PreFunc != nil { + subHandler.PreFunc(parsed) + } + } else { + break + } + } + if parsed.StructuredArgs != nil && len(parsed.Args) > 0 { + // TODO allow unknown command handlers to be called? + // The client sent MSC4391 data, but the target command wasn't found + log.Debug().Msg("Didn't find handler for MSC4391 command") + return + } + + logWith := log.With(). + Str("command", parsed.Command). + Array("handler", handlerChain) + if len(parsed.ParentCommands) > 0 { + logWith = logWith.Strs("parent_commands", parsed.ParentCommands) + } + if proc.LogArgs { + logWith = logWith.Strs("args", parsed.Args) + if parsed.StructuredArgs != nil { + logWith = logWith.RawJSON("structured_args", parsed.StructuredArgs) + } + } + log = logWith.Logger() + parsed.Ctx = log.WithContext(ctx) + parsed.Log = &log + + if handler.Parameters != nil && parsed.StructuredArgs == nil { + // The handler wants structured parameters, but the client didn't send MSC4391 data + var err error + parsed.StructuredArgs, err = handler.Spec().ParseArguments(parsed.RawArgs) + if err != nil { + log.Debug().Err(err).Msg("Failed to parse structured arguments") + // TODO better error, usage info? deduplicate with WithParsedArgs + parsed.Reply("Failed to parse arguments: %v", err) + return + } + if proc.LogArgs { + log.UpdateContext(func(c zerolog.Context) zerolog.Context { + return c.RawJSON("structured_args", parsed.StructuredArgs) + }) + } + } + + log.Debug().Msg("Processing command") + handler.Func(parsed) +} diff --git a/commands/reactions.go b/commands/reactions.go new file mode 100644 index 00000000..0d316219 --- /dev/null +++ b/commands/reactions.go @@ -0,0 +1,143 @@ +// Copyright (c) 2026 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package commands + +import ( + "context" + "encoding/json" + "strings" + + "github.com/rs/zerolog" + + "maunium.net/go/mautrix" + "maunium.net/go/mautrix/event" +) + +const ReactionCommandsKey = "fi.mau.reaction_commands" +const ReactionMultiUseKey = "fi.mau.reaction_multi_use" + +type ReactionCommandData struct { + Command string `json:"command"` + Args any `json:"args,omitempty"` +} + +func (proc *Processor[MetaType]) ParseReaction(ctx context.Context, evt *event.Event) *Event[MetaType] { + content, ok := evt.Content.Parsed.(*event.ReactionEventContent) + if !ok { + return nil + } + evtID := content.RelatesTo.EventID + if evtID == "" || !strings.HasPrefix(content.RelatesTo.Key, proc.ReactionCommandPrefix) { + return nil + } + targetEvt, err := proc.Client.GetEvent(ctx, evt.RoomID, evtID) + if err != nil { + zerolog.Ctx(ctx).Err(err).Stringer("target_event_id", evtID).Msg("Failed to get target event for reaction") + return nil + } else if targetEvt.Sender != proc.Client.UserID || targetEvt.Unsigned.RedactedBecause != nil { + return nil + } + if targetEvt.Type == event.EventEncrypted { + if proc.Client.Crypto == nil { + zerolog.Ctx(ctx).Warn(). + Stringer("target_event_id", evtID). + Msg("Received reaction to encrypted event, but don't have crypto helper in client") + return nil + } + _ = targetEvt.Content.ParseRaw(targetEvt.Type) + targetEvt, err = proc.Client.Crypto.Decrypt(ctx, targetEvt) + if err != nil { + zerolog.Ctx(ctx).Err(err). + Stringer("target_event_id", evtID). + Msg("Failed to decrypt target event for reaction") + return nil + } + } + reactionCommands, ok := targetEvt.Content.Raw[ReactionCommandsKey].(map[string]any) + if !ok { + zerolog.Ctx(ctx).Trace(). + Stringer("target_event_id", evtID). + Msg("Reaction target event doesn't have commands key") + return nil + } + isMultiUse, _ := targetEvt.Content.Raw[ReactionMultiUseKey].(bool) + rawCmd, ok := reactionCommands[content.RelatesTo.Key] + if !ok { + zerolog.Ctx(ctx).Debug(). + Stringer("target_event_id", evtID). + Str("reaction_key", content.RelatesTo.Key). + Msg("Reaction command not found in target event") + return nil + } + var wrappedEvt *Event[MetaType] + switch typedCmd := rawCmd.(type) { + case string: + wrappedEvt = RawTextToEvent[MetaType](ctx, evt, typedCmd) + case map[string]any: + var input event.MSC4391BotCommandInput + if marshaled, err := json.Marshal(typedCmd); err != nil { + + } else if err = json.Unmarshal(marshaled, &input); err != nil { + + } else { + wrappedEvt = StructuredCommandToEvent[MetaType](ctx, evt, &input) + } + } + if wrappedEvt == nil { + zerolog.Ctx(ctx).Debug(). + Stringer("target_event_id", evtID). + Str("reaction_key", content.RelatesTo.Key). + Msg("Reaction command data is invalid") + return nil + } + wrappedEvt.Proc = proc + wrappedEvt.Redact() + if !isMultiUse { + DeleteAllReactions(ctx, proc.Client, evt) + } + if wrappedEvt.Command == "" { + return nil + } + return wrappedEvt +} + +func DeleteAllReactionsCommandFunc[MetaType any](ce *Event[MetaType]) { + DeleteAllReactions(ce.Ctx, ce.Proc.Client, ce.Event) +} + +func DeleteAllReactions(ctx context.Context, client *mautrix.Client, evt *event.Event) { + rel, ok := evt.Content.Parsed.(event.Relatable) + if !ok { + return + } + relation := rel.OptionalGetRelatesTo() + if relation == nil { + return + } + targetEvt := relation.GetReplyTo() + if targetEvt == "" { + targetEvt = relation.GetAnnotationID() + } + if targetEvt == "" { + return + } + relations, err := client.GetRelations(ctx, evt.RoomID, targetEvt, &mautrix.ReqGetRelations{ + RelationType: event.RelAnnotation, + EventType: event.EventReaction, + Limit: 20, + }) + if err != nil { + zerolog.Ctx(ctx).Err(err).Msg("Failed to get reactions to delete") + return + } + for _, relEvt := range relations.Chunk { + _, err = client.RedactEvent(ctx, relEvt.RoomID, relEvt.ID) + if err != nil { + zerolog.Ctx(ctx).Err(err).Stringer("event_id", relEvt.ID).Msg("Failed to redact reaction event") + } + } +} diff --git a/crypto/account.go b/crypto/account.go index d242df6f..0bd09ecf 100644 --- a/crypto/account.go +++ b/crypto/account.go @@ -7,7 +7,12 @@ package crypto import ( + "encoding/json" + + "github.com/tidwall/sjson" + "maunium.net/go/mautrix" + "maunium.net/go/mautrix/crypto/canonicaljson" "maunium.net/go/mautrix/crypto/olm" "maunium.net/go/mautrix/crypto/signatures" "maunium.net/go/mautrix/id" @@ -22,32 +27,61 @@ type OlmAccount struct { } func NewOlmAccount() *OlmAccount { + account, err := olm.NewAccount() + if err != nil { + panic(err) + } return &OlmAccount{ - Internal: *olm.NewAccount(), + Internal: account, } } func (account *OlmAccount) Keys() (id.SigningKey, id.IdentityKey) { if len(account.signingKey) == 0 || len(account.identityKey) == 0 { - account.signingKey, account.identityKey = account.Internal.IdentityKeys() + var err error + account.signingKey, account.identityKey, err = account.Internal.IdentityKeys() + if err != nil { + panic(err) + } } return account.signingKey, account.identityKey } func (account *OlmAccount) SigningKey() id.SigningKey { if len(account.signingKey) == 0 { - account.signingKey, account.identityKey = account.Internal.IdentityKeys() + var err error + account.signingKey, account.identityKey, err = account.Internal.IdentityKeys() + if err != nil { + panic(err) + } } return account.signingKey } func (account *OlmAccount) IdentityKey() id.IdentityKey { if len(account.identityKey) == 0 { - account.signingKey, account.identityKey = account.Internal.IdentityKeys() + var err error + account.signingKey, account.identityKey, err = account.Internal.IdentityKeys() + if err != nil { + panic(err) + } } return account.identityKey } +// SignJSON signs the given JSON object following the Matrix specification: +// https://matrix.org/docs/spec/appendices#signing-json +func (account *OlmAccount) SignJSON(obj any) (string, error) { + objJSON, err := json.Marshal(obj) + if err != nil { + return "", err + } + objJSON, _ = sjson.DeleteBytes(objJSON, "unsigned") + objJSON, _ = sjson.DeleteBytes(objJSON, "signatures") + signed, err := account.Internal.Sign(canonicaljson.CanonicalJSONAssumeValid(objJSON)) + return string(signed), err +} + func (account *OlmAccount) getInitialKeys(userID id.UserID, deviceID id.DeviceID) *mautrix.DeviceKeys { deviceKeys := &mautrix.DeviceKeys{ UserID: userID, @@ -59,7 +93,7 @@ func (account *OlmAccount) getInitialKeys(userID id.UserID, deviceID id.DeviceID }, } - signature, err := account.Internal.SignJSON(deviceKeys) + signature, err := account.SignJSON(deviceKeys) if err != nil { panic(err) } @@ -74,13 +108,16 @@ func (account *OlmAccount) getOneTimeKeys(userID id.UserID, deviceID id.DeviceID account.Internal.GenOneTimeKeys(uint(newCount)) } oneTimeKeys := make(map[id.KeyID]mautrix.OneTimeKey) - for keyID, key := range account.Internal.OneTimeKeys() { + internalKeys, err := account.Internal.OneTimeKeys() + if err != nil { + panic(err) + } + for keyID, key := range internalKeys { key := mautrix.OneTimeKey{Key: key} - signature, _ := account.Internal.SignJSON(key) + signature, _ := account.SignJSON(key) key.Signatures = signatures.NewSingleSignature(userID, id.KeyAlgorithmEd25519, deviceID.String(), signature) key.IsSigned = true oneTimeKeys[id.NewKeyID(id.KeyAlgorithmSignedCurve25519, keyID)] = key } - account.Internal.MarkKeysAsPublished() return oneTimeKeys } diff --git a/crypto/aescbc/aes_cbc_test.go b/crypto/aescbc/aes_cbc_test.go index bb03f706..d6611dc9 100644 --- a/crypto/aescbc/aes_cbc_test.go +++ b/crypto/aescbc/aes_cbc_test.go @@ -7,11 +7,13 @@ package aescbc_test import ( - "bytes" "crypto/aes" "crypto/rand" "testing" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "maunium.net/go/mautrix/crypto/aescbc" ) @@ -22,32 +24,23 @@ func TestAESCBC(t *testing.T) { // The key length can be 32, 24, 16 bytes (OR in bits: 128, 192 or 256) key := make([]byte, 32) _, err = rand.Read(key) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) iv := make([]byte, aes.BlockSize) _, err = rand.Read(iv) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) plaintext = []byte("secret message for testing") //increase to next block size for len(plaintext)%8 != 0 { plaintext = append(plaintext, []byte("-")...) } - if ciphertext, err = aescbc.Encrypt(key, iv, plaintext); err != nil { - t.Fatal(err) - } + ciphertext, err = aescbc.Encrypt(key, iv, plaintext) + require.NoError(t, err) resultPlainText, err := aescbc.Decrypt(key, iv, ciphertext) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) - if string(resultPlainText) != string(plaintext) { - t.Fatalf("message '%s' (length %d) != '%s'", resultPlainText, len(resultPlainText), plaintext) - } + assert.Equal(t, string(resultPlainText), string(plaintext)) } func TestAESCBCCase1(t *testing.T) { @@ -61,18 +54,10 @@ func TestAESCBCCase1(t *testing.T) { key := make([]byte, 32) iv := make([]byte, aes.BlockSize) encrypted, err := aescbc.Encrypt(key, iv, input) - if err != nil { - t.Fatal(err) - } - if !bytes.Equal(expected, encrypted) { - t.Fatalf("encrypted did not match expected:\n%v\n%v\n", encrypted, expected) - } + require.NoError(t, err) + assert.Equal(t, expected, encrypted, "encrypted output does not match expected") decrypted, err := aescbc.Decrypt(key, iv, encrypted) - if err != nil { - t.Fatal(err) - } - if !bytes.Equal(input, decrypted) { - t.Fatalf("decrypted did not match expected:\n%v\n%v\n", decrypted, input) - } + require.NoError(t, err) + assert.Equal(t, input, decrypted, "decrypted output does not match input") } diff --git a/crypto/attachment/attachments.go b/crypto/attachment/attachments.go index e516fded..727aacbf 100644 --- a/crypto/attachment/attachments.go +++ b/crypto/attachment/attachments.go @@ -9,9 +9,11 @@ package attachment import ( "crypto/aes" "crypto/cipher" + "crypto/hmac" "crypto/sha256" "encoding/base64" "errors" + "fmt" "hash" "io" @@ -19,13 +21,24 @@ import ( ) var ( - HashMismatch = errors.New("mismatching SHA-256 digest") - UnsupportedVersion = errors.New("unsupported Matrix file encryption version") - UnsupportedAlgorithm = errors.New("unsupported JWK encryption algorithm") - InvalidKey = errors.New("failed to decode key") - InvalidInitVector = errors.New("failed to decode initialization vector") - InvalidHash = errors.New("failed to decode SHA-256 hash") - ReaderClosed = errors.New("encrypting reader was already closed") + ErrHashMismatch = errors.New("mismatching SHA-256 digest") + ErrUnsupportedVersion = errors.New("unsupported Matrix file encryption version") + ErrUnsupportedAlgorithm = errors.New("unsupported JWK encryption algorithm") + ErrInvalidKey = errors.New("failed to decode key") + ErrInvalidInitVector = errors.New("failed to decode initialization vector") + ErrInvalidHash = errors.New("failed to decode SHA-256 hash") + ErrReaderClosed = errors.New("encrypting reader was already closed") +) + +// Deprecated: use variables prefixed with Err +var ( + HashMismatch = ErrHashMismatch + UnsupportedVersion = ErrUnsupportedVersion + UnsupportedAlgorithm = ErrUnsupportedAlgorithm + InvalidKey = ErrInvalidKey + InvalidInitVector = ErrInvalidInitVector + InvalidHash = ErrInvalidHash + ReaderClosed = ErrReaderClosed ) var ( @@ -83,25 +96,25 @@ func (ef *EncryptedFile) decodeKeys(includeHash bool) error { if ef.decoded != nil { return nil } else if len(ef.Key.Key) != keyBase64Length { - return InvalidKey + return ErrInvalidKey } else if len(ef.InitVector) != ivBase64Length { - return InvalidInitVector + return ErrInvalidInitVector } else if includeHash && len(ef.Hashes.SHA256) != hashBase64Length { - return InvalidHash + return ErrInvalidHash } ef.decoded = &decodedKeys{} _, err := base64.RawURLEncoding.Decode(ef.decoded.key[:], []byte(ef.Key.Key)) if err != nil { - return InvalidKey + return ErrInvalidKey } _, err = base64.RawStdEncoding.Decode(ef.decoded.iv[:], []byte(ef.InitVector)) if err != nil { - return InvalidInitVector + return ErrInvalidInitVector } if includeHash { _, err = base64.RawStdEncoding.Decode(ef.decoded.sha256[:], []byte(ef.Hashes.SHA256)) if err != nil { - return InvalidHash + return ErrInvalidHash } } return nil @@ -126,6 +139,43 @@ func (ef *EncryptedFile) EncryptInPlace(data []byte) { ef.Hashes.SHA256 = base64.RawStdEncoding.EncodeToString(checksum[:]) } +type ReadWriterAt interface { + io.WriterAt + io.Reader +} + +// EncryptFile encrypts the given file in-place and updates the SHA256 hash in the EncryptedFile struct. +func (ef *EncryptedFile) EncryptFile(file ReadWriterAt) error { + err := ef.decodeKeys(false) + if err != nil { + return err + } + block, _ := aes.NewCipher(ef.decoded.key[:]) + stream := cipher.NewCTR(block, ef.decoded.iv[:]) + hasher := sha256.New() + buf := make([]byte, 32*1024) + var writePtr int64 + var n int + for { + n, err = file.Read(buf) + if err != nil && !errors.Is(err, io.EOF) { + return err + } + if n == 0 { + break + } + stream.XORKeyStream(buf[:n], buf[:n]) + _, err = file.WriteAt(buf[:n], writePtr) + if err != nil { + return err + } + writePtr += int64(n) + hasher.Write(buf[:n]) + } + ef.Hashes.SHA256 = base64.RawStdEncoding.EncodeToString(hasher.Sum(nil)) + return nil +} + type encryptingReader struct { stream cipher.Stream hash hash.Hash @@ -136,17 +186,45 @@ type encryptingReader struct { isDecrypting bool } +var _ io.ReadSeekCloser = (*encryptingReader)(nil) + +func (r *encryptingReader) Seek(offset int64, whence int) (int64, error) { + if r.closed { + return 0, ErrReaderClosed + } + if offset != 0 || whence != io.SeekStart { + return 0, fmt.Errorf("attachments.EncryptStream: only seeking to the beginning is supported") + } + seeker, ok := r.source.(io.ReadSeeker) + if !ok { + return 0, fmt.Errorf("attachments.EncryptStream: source reader (%T) is not an io.ReadSeeker", r.source) + } + n, err := seeker.Seek(offset, whence) + if err != nil { + return 0, err + } + block, _ := aes.NewCipher(r.file.decoded.key[:]) + r.stream = cipher.NewCTR(block, r.file.decoded.iv[:]) + r.hash.Reset() + return n, nil +} + func (r *encryptingReader) Read(dst []byte) (n int, err error) { if r.closed { - return 0, ReaderClosed + return 0, ErrReaderClosed } else if r.isDecrypting && r.file.decoded == nil { if err = r.file.PrepareForDecryption(); err != nil { return } } n, err = r.source.Read(dst) + if r.isDecrypting { + r.hash.Write(dst[:n]) + } r.stream.XORKeyStream(dst[:n], dst[:n]) - r.hash.Write(dst[:n]) + if !r.isDecrypting { + r.hash.Write(dst[:n]) + } return } @@ -156,10 +234,8 @@ func (r *encryptingReader) Close() (err error) { err = closer.Close() } if r.isDecrypting { - var downloadedChecksum [utils.SHAHashLength]byte - r.hash.Sum(downloadedChecksum[:]) - if downloadedChecksum != r.file.decoded.sha256 { - return HashMismatch + if !hmac.Equal(r.hash.Sum(nil), r.file.decoded.sha256[:]) { + return ErrHashMismatch } } else { r.file.Hashes.SHA256 = base64.RawStdEncoding.EncodeToString(r.hash.Sum(nil)) @@ -173,7 +249,7 @@ func (r *encryptingReader) Close() (err error) { // The Close() method of the returned io.ReadCloser must be called for the SHA256 hash // in the EncryptedFile struct to be updated. The metadata is not valid before the hash // is filled. -func (ef *EncryptedFile) EncryptStream(reader io.Reader) io.ReadCloser { +func (ef *EncryptedFile) EncryptStream(reader io.Reader) io.ReadSeekCloser { ef.decodeKeys(false) block, _ := aes.NewCipher(ef.decoded.key[:]) return &encryptingReader{ @@ -200,9 +276,9 @@ func (ef *EncryptedFile) Decrypt(ciphertext []byte) ([]byte, error) { // DecryptInPlace will always call this automatically, so calling this manually is not necessary when using that function. func (ef *EncryptedFile) PrepareForDecryption() error { if ef.Version != "v2" { - return UnsupportedVersion + return ErrUnsupportedVersion } else if ef.Key.Algorithm != "A256CTR" { - return UnsupportedAlgorithm + return ErrUnsupportedAlgorithm } else if err := ef.decodeKeys(true); err != nil { return err } @@ -213,12 +289,13 @@ func (ef *EncryptedFile) PrepareForDecryption() error { func (ef *EncryptedFile) DecryptInPlace(data []byte) error { if err := ef.PrepareForDecryption(); err != nil { return err - } else if ef.decoded.sha256 != sha256.Sum256(data) { - return HashMismatch - } else { - utils.XorA256CTR(data, ef.decoded.key, ef.decoded.iv) - return nil } + dataHash := sha256.Sum256(data) + if !hmac.Equal(ef.decoded.sha256[:], dataHash[:]) { + return ErrHashMismatch + } + utils.XorA256CTR(data, ef.decoded.key, ef.decoded.iv) + return nil } // DecryptStream wraps the given io.Reader in order to decrypt the data. @@ -228,12 +305,13 @@ func (ef *EncryptedFile) DecryptInPlace(data []byte) error { // // The Close call will validate the hash and return an error if it doesn't match. // In this case, the written data should be considered compromised and should not be used further. -func (ef *EncryptedFile) DecryptStream(reader io.Reader) io.ReadCloser { +func (ef *EncryptedFile) DecryptStream(reader io.Reader) io.ReadSeekCloser { block, _ := aes.NewCipher(ef.decoded.key[:]) return &encryptingReader{ - stream: cipher.NewCTR(block, ef.decoded.iv[:]), - hash: sha256.New(), - source: reader, - file: ef, + isDecrypting: true, + stream: cipher.NewCTR(block, ef.decoded.iv[:]), + hash: sha256.New(), + source: reader, + file: ef, } } diff --git a/crypto/attachment/attachments_test.go b/crypto/attachment/attachments_test.go index d7f1394a..9fe929ab 100644 --- a/crypto/attachment/attachments_test.go +++ b/crypto/attachment/attachments_test.go @@ -53,33 +53,33 @@ func TestUnsupportedVersion(t *testing.T) { file := parseHelloWorld() file.Version = "foo" err := file.DecryptInPlace([]byte(helloWorldCiphertext)) - assert.ErrorIs(t, err, UnsupportedVersion) + assert.ErrorIs(t, err, ErrUnsupportedVersion) } func TestUnsupportedAlgorithm(t *testing.T) { file := parseHelloWorld() file.Key.Algorithm = "bar" err := file.DecryptInPlace([]byte(helloWorldCiphertext)) - assert.ErrorIs(t, err, UnsupportedAlgorithm) + assert.ErrorIs(t, err, ErrUnsupportedAlgorithm) } func TestHashMismatch(t *testing.T) { file := parseHelloWorld() file.Hashes.SHA256 = base64.RawStdEncoding.EncodeToString([]byte(random32Bytes)) err := file.DecryptInPlace([]byte(helloWorldCiphertext)) - assert.ErrorIs(t, err, HashMismatch) + assert.ErrorIs(t, err, ErrHashMismatch) } func TestTooLongHash(t *testing.T) { file := parseHelloWorld() file.Hashes.SHA256 = "TG9yZW0gaXBzdW0gZG9sb3Igc2l0IGFtZXQsIGNvbnNlY3RldHVlciBhZGlwaXNjaW5nIGVsaXQuIFNlZCBwb3N1ZXJlIGludGVyZHVtIHNlbS4gUXVpc3F1ZSBsaWd1bGEgZXJvcyB1bGxhbWNvcnBlciBxdWlzLCBsYWNpbmlhIHF1aXMgZmFjaWxpc2lzIHNlZCBzYXBpZW4uCg" err := file.DecryptInPlace([]byte(helloWorldCiphertext)) - assert.ErrorIs(t, err, InvalidHash) + assert.ErrorIs(t, err, ErrInvalidHash) } func TestTooShortHash(t *testing.T) { file := parseHelloWorld() file.Hashes.SHA256 = "5/Gy1JftyyQ" err := file.DecryptInPlace([]byte(helloWorldCiphertext)) - assert.ErrorIs(t, err, InvalidHash) + assert.ErrorIs(t, err, ErrInvalidHash) } diff --git a/crypto/backup/encryptedsessiondata.go b/crypto/backup/encryptedsessiondata.go index 37b0a6c8..25250178 100644 --- a/crypto/backup/encryptedsessiondata.go +++ b/crypto/backup/encryptedsessiondata.go @@ -47,33 +47,31 @@ func calculateEncryptionParameters(sharedSecret []byte) (key, macKey, iv []byte, return encryptionParams[:32], encryptionParams[32:64], encryptionParams[64:], nil } -// calculateCompatMAC calculates the MAC for compatibility with Olm and -// Vodozemac which do not actually write the ciphertext when computing the MAC. +// calculateCompatMAC calculates the MAC as described in step 5 of according to +// [Section 11.12.3.2.2] of the Spec which was updated in spec version 1.10 to +// reflect what is actually implemented in libolm and Vodozemac. // -// Deprecated: Use [calculateMAC] instead. +// Libolm implemented the MAC functionality incorrectly. The MAC is computed +// over an empty string rather than the ciphertext. Vodozemac implemented this +// functionality the same way as libolm for compatibility. In version 1.10 of +// the spec, the description of step 5 was updated to reflect the de-facto +// standard of libolm and Vodozemac. +// +// [Section 11.12.3.2.2]: https://spec.matrix.org/v1.11/client-server-api/#backup-algorithm-mmegolm_backupv1curve25519-aes-sha2 func calculateCompatMAC(macKey []byte) []byte { hash := hmac.New(sha256.New, macKey) return hash.Sum(nil)[:8] } -// calculateMAC calculates the MAC as described in step 5 of according to -// [Section 11.12.3.2.2] of the Spec. -// -// [Section 11.12.3.2.2]: https://spec.matrix.org/v1.9/client-server-api/#backup-algorithm-mmegolm_backupv1curve25519-aes-sha2 -func calculateMAC(macKey, ciphertext []byte) []byte { - hash := hmac.New(sha256.New, macKey) - _, err := hash.Write(ciphertext) - if err != nil { - panic(err) - } - return hash.Sum(nil)[:8] -} - // EncryptSessionData encrypts the given session data with the given recovery // key as defined in [Section 11.12.3.2.2 of the Spec]. // // [Section 11.12.3.2.2 of the Spec]: https://spec.matrix.org/v1.9/client-server-api/#backup-algorithm-mmegolm_backupv1curve25519-aes-sha2 func EncryptSessionData[T any](backupKey *MegolmBackupKey, sessionData T) (*EncryptedSessionData[T], error) { + return EncryptSessionDataWithPubkey(backupKey.PublicKey(), sessionData) +} + +func EncryptSessionDataWithPubkey[T any](pubkey *ecdh.PublicKey, sessionData T) (*EncryptedSessionData[T], error) { sessionJSON, err := json.Marshal(sessionData) if err != nil { return nil, err @@ -84,7 +82,7 @@ func EncryptSessionData[T any](backupKey *MegolmBackupKey, sessionData T) (*Encr return nil, err } - sharedSecret, err := ephemeralKey.ECDH(backupKey.PublicKey()) + sharedSecret, err := ephemeralKey.ECDH(pubkey) if err != nil { return nil, err } diff --git a/crypto/canonicaljson/json_test.go b/crypto/canonicaljson/json_test.go index d1a7f0a5..36476aa4 100644 --- a/crypto/canonicaljson/json_test.go +++ b/crypto/canonicaljson/json_test.go @@ -17,31 +17,43 @@ package canonicaljson import ( "testing" + + "github.com/stretchr/testify/assert" ) -func testSortJSON(t *testing.T, input, want string) { - got := SortJSON([]byte(input), nil) - - // Squash out the whitespace before comparing the JSON in case SortJSON had inserted whitespace. - if string(CompactJSON(got, nil)) != want { - t.Errorf("SortJSON(%q): want %q got %q", input, want, got) - } -} - func TestSortJSON(t *testing.T) { - testSortJSON(t, `[{"b":"two","a":1}]`, `[{"a":1,"b":"two"}]`) - testSortJSON(t, `{"B":{"4":4,"3":3},"A":{"1":1,"2":2}}`, - `{"A":{"1":1,"2":2},"B":{"3":3,"4":4}}`) - testSortJSON(t, `[true,false,null]`, `[true,false,null]`) - testSortJSON(t, `[9007199254740991]`, `[9007199254740991]`) - testSortJSON(t, "\t\n[9007199254740991]", `[9007199254740991]`) + var tests = []struct { + input string + want string + }{ + {"{}", "{}"}, + {`[{"b":"two","a":1}]`, `[{"a":1,"b":"two"}]`}, + {`{"B":{"4":4,"3":3},"A":{"1":1,"2":2}}`, `{"A":{"1":1,"2":2},"B":{"3":3,"4":4}}`}, + {`[true,false,null]`, `[true,false,null]`}, + {`[9007199254740991]`, `[9007199254740991]`}, + {"\t\n[9007199254740991]", `[9007199254740991]`}, + {`[true,false,null]`, `[true,false,null]`}, + {`[{"b":"two","a":1}]`, `[{"a":1,"b":"two"}]`}, + {`{"B":{"4":4,"3":3},"A":{"1":1,"2":2}}`, `{"A":{"1":1,"2":2},"B":{"3":3,"4":4}}`}, + {`[true,false,null]`, `[true,false,null]`}, + {`[9007199254740991]`, `[9007199254740991]`}, + {"\t\n[9007199254740991]", `[9007199254740991]`}, + {`[true,false,null]`, `[true,false,null]`}, + } + for _, test := range tests { + t.Run(test.input, func(t *testing.T) { + got := SortJSON([]byte(test.input), nil) + + // Squash out the whitespace before comparing the JSON in case SortJSON had inserted whitespace. + assert.EqualValues(t, test.want, string(CompactJSON(got, nil))) + }) + } } func testCompactJSON(t *testing.T, input, want string) { + t.Helper() got := string(CompactJSON([]byte(input), nil)) - if got != want { - t.Errorf("CompactJSON(%q): want %q got %q", input, want, got) - } + assert.EqualValues(t, want, got) } func TestCompactJSON(t *testing.T) { @@ -74,18 +86,23 @@ func TestCompactJSON(t *testing.T) { testCompactJSON(t, `["\"\\\/"]`, `["\"\\/"]`) } -func testReadHex(t *testing.T, input string, want uint32) { - got := readHexDigits([]byte(input)) - if want != got { - t.Errorf("readHexDigits(%q): want 0x%x got 0x%x", input, want, got) +func TestReadHex(t *testing.T) { + tests := []struct { + input string + want uint32 + }{ + + {"0123", 0x0123}, + {"4567", 0x4567}, + {"89AB", 0x89AB}, + {"CDEF", 0xCDEF}, + {"89ab", 0x89AB}, + {"cdef", 0xCDEF}, + } + for _, test := range tests { + t.Run(test.input, func(t *testing.T) { + got := readHexDigits([]byte(test.input)) + assert.Equal(t, test.want, got) + }) } } - -func TestReadHex(t *testing.T) { - testReadHex(t, "0123", 0x0123) - testReadHex(t, "4567", 0x4567) - testReadHex(t, "89AB", 0x89AB) - testReadHex(t, "CDEF", 0xCDEF) - testReadHex(t, "89ab", 0x89AB) - testReadHex(t, "cdef", 0xCDEF) -} diff --git a/crypto/cross_sign_key.go b/crypto/cross_sign_key.go index f7dc08cb..5d9bf5b3 100644 --- a/crypto/cross_sign_key.go +++ b/crypto/cross_sign_key.go @@ -11,6 +11,8 @@ import ( "context" "fmt" + "go.mau.fi/util/jsonbytes" + "maunium.net/go/mautrix" "maunium.net/go/mautrix/crypto/olm" "maunium.net/go/mautrix/crypto/signatures" @@ -33,9 +35,9 @@ func (cskc *CrossSigningKeysCache) PublicKeys() *CrossSigningPublicKeysCache { } type CrossSigningSeeds struct { - MasterKey []byte - SelfSigningKey []byte - UserSigningKey []byte + MasterKey jsonbytes.UnpaddedURLBytes `json:"m.cross_signing.master"` + SelfSigningKey jsonbytes.UnpaddedURLBytes `json:"m.cross_signing.self_signing"` + UserSigningKey jsonbytes.UnpaddedURLBytes `json:"m.cross_signing.user_signing"` } func (mach *OlmMachine) ExportCrossSigningKeys() CrossSigningSeeds { @@ -101,7 +103,7 @@ func (mach *OlmMachine) PublishCrossSigningKeys(ctx context.Context, keys *Cross masterKeyID: keys.MasterKey.PublicKey(), }, } - masterSig, err := mach.account.Internal.SignJSON(masterKey) + masterSig, err := mach.account.SignJSON(masterKey) if err != nil { return fmt.Errorf("failed to sign master key: %w", err) } @@ -133,7 +135,7 @@ func (mach *OlmMachine) PublishCrossSigningKeys(ctx context.Context, keys *Cross } userKey.Signatures = signatures.NewSingleSignature(userID, id.KeyAlgorithmEd25519, keys.MasterKey.PublicKey().String(), userSig) - err = mach.Client.UploadCrossSigningKeys(ctx, &mautrix.UploadCrossSigningKeysReq{ + err = mach.Client.UploadCrossSigningKeys(ctx, &mautrix.UploadCrossSigningKeysReq[any]{ Master: masterKey, SelfSigning: selfKey, UserSigning: userKey, @@ -142,6 +144,16 @@ func (mach *OlmMachine) PublishCrossSigningKeys(ctx context.Context, keys *Cross return err } + if err := mach.CryptoStore.PutSignature(ctx, userID, keys.MasterKey.PublicKey(), userID, mach.account.SigningKey(), masterSig); err != nil { + return fmt.Errorf("error storing signature of master key by device signing key in crypto store: %w", err) + } + if err := mach.CryptoStore.PutSignature(ctx, userID, keys.SelfSigningKey.PublicKey(), userID, keys.MasterKey.PublicKey(), selfSig); err != nil { + return fmt.Errorf("error storing signature of self-signing key by master key in crypto store: %w", err) + } + if err := mach.CryptoStore.PutSignature(ctx, userID, keys.UserSigningKey.PublicKey(), userID, keys.MasterKey.PublicKey(), userSig); err != nil { + return fmt.Errorf("error storing signature of user-signing key by master key in crypto store: %w", err) + } + mach.CrossSigningKeys = keys mach.crossSigningPubkeys = keys.PublicKeys() diff --git a/crypto/cross_sign_pubkey.go b/crypto/cross_sign_pubkey.go index 77efab5b..223fc7b5 100644 --- a/crypto/cross_sign_pubkey.go +++ b/crypto/cross_sign_pubkey.go @@ -20,6 +20,20 @@ type CrossSigningPublicKeysCache struct { UserSigningKey id.Ed25519 } +func (mach *OlmMachine) GetOwnVerificationStatus(ctx context.Context) (hasKeys, isVerified bool, err error) { + pubkeys := mach.GetOwnCrossSigningPublicKeys(ctx) + if pubkeys != nil { + hasKeys = true + isVerified, err = mach.CryptoStore.IsKeySignedBy( + ctx, mach.Client.UserID, mach.GetAccount().SigningKey(), mach.Client.UserID, pubkeys.SelfSigningKey, + ) + if err != nil { + err = fmt.Errorf("failed to check if current device is signed by own self-signing key: %w", err) + } + } + return +} + func (mach *OlmMachine) GetOwnCrossSigningPublicKeys(ctx context.Context) *CrossSigningPublicKeysCache { if mach.crossSigningPubkeys != nil { return mach.crossSigningPubkeys @@ -49,8 +63,8 @@ func (mach *OlmMachine) GetCrossSigningPublicKeys(ctx context.Context, userID id if len(dbKeys) > 0 { masterKey, ok := dbKeys[id.XSUsageMaster] if ok { - selfSigning, _ := dbKeys[id.XSUsageSelfSigning] - userSigning, _ := dbKeys[id.XSUsageUserSigning] + selfSigning := dbKeys[id.XSUsageSelfSigning] + userSigning := dbKeys[id.XSUsageUserSigning] return &CrossSigningPublicKeysCache{ MasterKey: masterKey.Key, SelfSigningKey: selfSigning.Key, diff --git a/crypto/cross_sign_signing.go b/crypto/cross_sign_signing.go index 86920728..ae3d1eb1 100644 --- a/crypto/cross_sign_signing.go +++ b/crypto/cross_sign_signing.go @@ -19,15 +19,15 @@ import ( ) var ( - ErrCrossSigningKeysNotCached = errors.New("cross-signing private keys not in cache") - ErrUserSigningKeyNotCached = errors.New("user-signing private key not in cache") - ErrSelfSigningKeyNotCached = errors.New("self-signing private key not in cache") - ErrSignatureUploadFail = errors.New("server-side failure uploading signatures") - ErrCantSignOwnMasterKey = errors.New("signing your own master key is not allowed") - ErrCantSignOtherDevice = errors.New("signing other users' devices is not allowed") - ErrUserNotInQueryResponse = errors.New("could not find user in query keys response") - ErrDeviceNotInQueryResponse = errors.New("could not find device in query keys response") - ErrOlmAccountNotLoaded = errors.New("olm account has not been loaded") + ErrCrossSigningPubkeysNotCached = errors.New("cross-signing public keys not in cache") + ErrUserSigningKeyNotCached = errors.New("user-signing private key not in cache") + ErrSelfSigningKeyNotCached = errors.New("self-signing private key not in cache") + ErrSignatureUploadFail = errors.New("server-side failure uploading signatures") + ErrCantSignOwnMasterKey = errors.New("signing your own master key is not allowed") + ErrCantSignOtherDevice = errors.New("signing other users' devices is not allowed") + ErrUserNotInQueryResponse = errors.New("could not find user in query keys response") + ErrDeviceNotInQueryResponse = errors.New("could not find device in query keys response") + ErrOlmAccountNotLoaded = errors.New("olm account has not been loaded") ErrCrossSigningMasterKeyNotFound = errors.New("cross-signing master key not found") ErrMasterKeyMACNotFound = errors.New("found cross-signing master key, but didn't find corresponding MAC in verification request") @@ -69,15 +69,16 @@ func (mach *OlmMachine) SignUser(ctx context.Context, userID id.UserID, masterKe // SignOwnMasterKey uses the current account for signing the current user's master key and uploads the signature. func (mach *OlmMachine) SignOwnMasterKey(ctx context.Context) error { - if mach.CrossSigningKeys == nil { - return ErrCrossSigningKeysNotCached + crossSigningPubkeys := mach.GetOwnCrossSigningPublicKeys(ctx) + if crossSigningPubkeys == nil { + return ErrCrossSigningPubkeysNotCached } else if mach.account == nil { return ErrOlmAccountNotLoaded } userID := mach.Client.UserID deviceID := mach.Client.DeviceID - masterKey := mach.CrossSigningKeys.MasterKey.PublicKey() + masterKey := crossSigningPubkeys.MasterKey masterKeyObj := mautrix.ReqKeysSignatures{ UserID: userID, @@ -86,7 +87,7 @@ func (mach *OlmMachine) SignOwnMasterKey(ctx context.Context) error { id.NewKeyID(id.KeyAlgorithmEd25519, masterKey.String()): masterKey.String(), }, } - signature, err := mach.account.Internal.SignJSON(masterKeyObj) + signature, err := mach.account.SignJSON(masterKeyObj) if err != nil { return fmt.Errorf("failed to sign JSON: %w", err) } diff --git a/crypto/cross_sign_ssss.go b/crypto/cross_sign_ssss.go index 389a9fd2..fd42880d 100644 --- a/crypto/cross_sign_ssss.go +++ b/crypto/cross_sign_ssss.go @@ -8,6 +8,7 @@ package crypto import ( "context" + "errors" "fmt" "maunium.net/go/mautrix" @@ -71,6 +72,46 @@ func (mach *OlmMachine) GenerateAndUploadCrossSigningKeysWithPassword(ctx contex }, passphrase) } +func (mach *OlmMachine) VerifyWithRecoveryKey(ctx context.Context, recoveryKey string) error { + keyID, keyData, err := mach.SSSS.GetDefaultKeyData(ctx) + if err != nil { + return fmt.Errorf("failed to get default SSSS key data: %w", err) + } + key, err := keyData.VerifyRecoveryKey(keyID, recoveryKey) + if errors.Is(err, ssss.ErrUnverifiableKey) { + mach.machOrContextLog(ctx).Warn(). + Str("key_id", keyID). + Msg("SSSS key is unverifiable, trying to use without verifying") + } else if err != nil { + return err + } + err = mach.FetchCrossSigningKeysFromSSSS(ctx, key) + if err != nil { + return fmt.Errorf("failed to fetch cross-signing keys from SSSS: %w", err) + } + err = mach.SignOwnDevice(ctx, mach.OwnIdentity()) + if err != nil { + return fmt.Errorf("failed to sign own device: %w", err) + } + err = mach.SignOwnMasterKey(ctx) + if err != nil { + return fmt.Errorf("failed to sign own master key: %w", err) + } + return nil +} + +func (mach *OlmMachine) GenerateAndVerifyWithRecoveryKey(ctx context.Context) (recoveryKey string, err error) { + recoveryKey, _, err = mach.GenerateAndUploadCrossSigningKeys(ctx, nil, "") + if err != nil { + err = fmt.Errorf("failed to generate and upload cross-signing keys: %w", err) + } else if err = mach.SignOwnDevice(ctx, mach.OwnIdentity()); err != nil { + err = fmt.Errorf("failed to sign own device: %w", err) + } else if err = mach.SignOwnMasterKey(ctx); err != nil { + err = fmt.Errorf("failed to sign own master key: %w", err) + } + return +} + // GenerateAndUploadCrossSigningKeys generates a new key with all corresponding cross-signing keys. // // A passphrase can be provided to generate the SSSS key. If the passphrase is empty, a random key @@ -97,12 +138,12 @@ func (mach *OlmMachine) GenerateAndUploadCrossSigningKeys(ctx context.Context, u // Publish cross-signing keys err = mach.PublishCrossSigningKeys(ctx, keysCache, uiaCallback) if err != nil { - return "", nil, fmt.Errorf("failed to publish cross-signing keys: %w", err) + return key.RecoveryKey(), keysCache, fmt.Errorf("failed to publish cross-signing keys: %w", err) } err = mach.SSSS.SetDefaultKeyID(ctx, key.ID) if err != nil { - return "", nil, fmt.Errorf("failed to mark %s as the default key: %w", key.ID, err) + return key.RecoveryKey(), keysCache, fmt.Errorf("failed to mark %s as the default key: %w", key.ID, err) } return key.RecoveryKey(), keysCache, nil diff --git a/crypto/cross_sign_store.go b/crypto/cross_sign_store.go index 28d0bad0..57406b11 100644 --- a/crypto/cross_sign_store.go +++ b/crypto/cross_sign_store.go @@ -10,6 +10,8 @@ package crypto import ( "context" + "go.mau.fi/util/exzerolog" + "maunium.net/go/mautrix" "maunium.net/go/mautrix/crypto/signatures" "maunium.net/go/mautrix/id" @@ -18,38 +20,36 @@ import ( func (mach *OlmMachine) storeCrossSigningKeys(ctx context.Context, crossSigningKeys map[id.UserID]mautrix.CrossSigningKeys, deviceKeys map[id.UserID]map[id.DeviceID]mautrix.DeviceKeys) { log := mach.machOrContextLog(ctx) for userID, userKeys := range crossSigningKeys { - log := log.With().Str("user_id", userID.String()).Logger() + log := log.With().Stringer("user_id", userID).Logger() currentKeys, err := mach.CryptoStore.GetCrossSigningKeys(ctx, userID) if err != nil { log.Error().Err(err). Msg("Error fetching current cross-signing keys of user") } - if currentKeys != nil { - for curKeyUsage, curKey := range currentKeys { - log := log.With().Str("old_key", curKey.Key.String()).Str("old_key_usage", string(curKeyUsage)).Logger() - // got a new key with the same usage as an existing key - for _, newKeyUsage := range userKeys.Usage { - if newKeyUsage == curKeyUsage { - if _, ok := userKeys.Keys[id.NewKeyID(id.KeyAlgorithmEd25519, curKey.Key.String())]; !ok { - // old key is not in the new key map, so we drop signatures made by it - if count, err := mach.CryptoStore.DropSignaturesByKey(ctx, userID, curKey.Key); err != nil { - log.Error().Err(err).Msg("Error deleting old signatures made by user") - } else { - log.Debug(). - Int64("signature_count", count). - Msg("Dropped signatures made by old key as it has been replaced") - } + for curKeyUsage, curKey := range currentKeys { + log := log.With().Stringer("old_key", curKey.Key).Str("old_key_usage", string(curKeyUsage)).Logger() + // got a new key with the same usage as an existing key + for _, newKeyUsage := range userKeys.Usage { + if newKeyUsage == curKeyUsage { + if _, ok := userKeys.Keys[id.NewKeyID(id.KeyAlgorithmEd25519, curKey.Key.String())]; !ok { + // old key is not in the new key map, so we drop signatures made by it + if count, err := mach.CryptoStore.DropSignaturesByKey(ctx, userID, curKey.Key); err != nil { + log.Error().Err(err).Msg("Error deleting old signatures made by user") + } else { + log.Debug(). + Int64("signature_count", count). + Msg("Dropped signatures made by old key as it has been replaced") } - break } + break } } } for _, key := range userKeys.Keys { - log := log.With().Str("key", key.String()).Strs("usages", strishArray(userKeys.Usage)).Logger() + log := log.With().Stringer("key", key).Array("usages", exzerolog.ArrayOfStrs(userKeys.Usage)).Logger() for _, usage := range userKeys.Usage { - log.Debug().Str("usage", string(usage)).Msg("Storing cross-signing key") + log.Trace().Str("usage", string(usage)).Msg("Storing cross-signing key") if err = mach.CryptoStore.PutCrossSigningKey(ctx, userID, usage, key); err != nil { log.Error().Err(err).Msg("Error storing cross-signing key") } @@ -75,16 +75,16 @@ func (mach *OlmMachine) storeCrossSigningKeys(ctx context.Context, crossSigningK } } if len(signingKey) != 43 { - log.Debug().Msg("Cross-signing key has a signature from an unknown key") + log.Trace().Msg("Cross-signing key has a signature from an unknown key") continue } - log.Debug().Msg("Verifying cross-signing key signature") + log.Trace().Msg("Verifying cross-signing key signature") if verified, err := signatures.VerifySignatureJSON(userKeys, signUserID, signKeyName, signingKey); err != nil { log.Warn().Err(err).Msg("Error verifying cross-signing key signature") } else { if verified { - log.Debug().Err(err).Msg("Cross-signing key signature verified") + log.Trace().Err(err).Msg("Cross-signing key signature verified") err = mach.CryptoStore.PutSignature(ctx, userID, key, signUserID, signingKey, signature) if err != nil { log.Error().Err(err).Msg("Error storing cross-signing key signature") @@ -96,5 +96,12 @@ func (mach *OlmMachine) storeCrossSigningKeys(ctx context.Context, crossSigningK } } } + + // Clear internal cache so that it refreshes from crypto store + if userID == mach.Client.UserID && mach.crossSigningPubkeys != nil { + log.Debug().Msg("Resetting internal cross-signing key cache") + mach.crossSigningPubkeys = nil + mach.crossSigningPubkeysFetched = false + } } } diff --git a/crypto/cross_sign_test.go b/crypto/cross_sign_test.go index e11fb018..b70370a2 100644 --- a/crypto/cross_sign_test.go +++ b/crypto/cross_sign_test.go @@ -13,6 +13,8 @@ import ( "testing" "github.com/rs/zerolog" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "go.mau.fi/util/dbutil" "maunium.net/go/mautrix" @@ -24,17 +26,12 @@ var noopLogger = zerolog.Nop() func getOlmMachine(t *testing.T) *OlmMachine { rawDB, err := sql.Open("sqlite3", ":memory:?_busy_timeout=5000") - if err != nil { - t.Fatalf("Error opening db: %v", err) - } + require.NoError(t, err, "Error opening raw database") db, err := dbutil.NewWithDB(rawDB, "sqlite3") - if err != nil { - t.Fatalf("Error opening db: %v", err) - } + require.NoError(t, err, "Error creating database wrapper") sqlStore := NewSQLCryptoStore(db, nil, "accid", id.DeviceID("dev"), []byte("test")) - if err = sqlStore.DB.Upgrade(context.TODO()); err != nil { - t.Fatalf("Error creating tables: %v", err) - } + err = sqlStore.DB.Upgrade(context.TODO()) + require.NoError(t, err, "Error upgrading database") userID := id.UserID("@mautrix") mk, _ := olm.NewPKSigning() @@ -66,29 +63,25 @@ func TestTrustOwnDevice(t *testing.T) { DeviceID: "device", SigningKey: id.Ed25519("deviceKey"), } - if m.IsDeviceTrusted(ownDevice) { - t.Error("Own device trusted while it shouldn't be") - } + assert.False(t, m.IsDeviceTrusted(context.TODO(), ownDevice), "Own device trusted while it shouldn't be") m.CryptoStore.PutSignature(context.TODO(), ownDevice.UserID, m.CrossSigningKeys.SelfSigningKey.PublicKey(), ownDevice.UserID, m.CrossSigningKeys.MasterKey.PublicKey(), "sig1") m.CryptoStore.PutSignature(context.TODO(), ownDevice.UserID, ownDevice.SigningKey, ownDevice.UserID, m.CrossSigningKeys.SelfSigningKey.PublicKey(), "sig2") - if trusted, _ := m.IsUserTrusted(context.TODO(), ownDevice.UserID); !trusted { - t.Error("Own user not trusted while they should be") - } - if !m.IsDeviceTrusted(ownDevice) { - t.Error("Own device not trusted while it should be") - } + trusted, err := m.IsUserTrusted(context.TODO(), ownDevice.UserID) + require.NoError(t, err, "Error checking if own user is trusted") + assert.True(t, trusted, "Own user not trusted while they should be") + assert.True(t, m.IsDeviceTrusted(context.TODO(), ownDevice), "Own device not trusted while it should be") } func TestTrustOtherUser(t *testing.T) { m := getOlmMachine(t) otherUser := id.UserID("@user") - if trusted, _ := m.IsUserTrusted(context.TODO(), otherUser); trusted { - t.Error("Other user trusted while they shouldn't be") - } + trusted, err := m.IsUserTrusted(context.TODO(), otherUser) + require.NoError(t, err, "Error checking if other user is trusted") + assert.False(t, trusted, "Other user trusted while they shouldn't be") theirMasterKey, _ := olm.NewPKSigning() m.CryptoStore.PutCrossSigningKey(context.TODO(), otherUser, id.XSUsageMaster, theirMasterKey.PublicKey()) @@ -100,16 +93,16 @@ func TestTrustOtherUser(t *testing.T) { m.CryptoStore.PutSignature(context.TODO(), otherUser, theirMasterKey.PublicKey(), m.Client.UserID, m.CrossSigningKeys.SelfSigningKey.PublicKey(), "invalid_sig") - if trusted, _ := m.IsUserTrusted(context.TODO(), otherUser); trusted { - t.Error("Other user trusted before their master key has been signed with our user-signing key") - } + trusted, err = m.IsUserTrusted(context.TODO(), otherUser) + require.NoError(t, err, "Error checking if other user is trusted") + assert.False(t, trusted, "Other user trusted before their master key has been signed with our user-signing key") m.CryptoStore.PutSignature(context.TODO(), otherUser, theirMasterKey.PublicKey(), m.Client.UserID, m.CrossSigningKeys.UserSigningKey.PublicKey(), "sig2") - if trusted, _ := m.IsUserTrusted(context.TODO(), otherUser); !trusted { - t.Error("Other user not trusted while they should be") - } + trusted, err = m.IsUserTrusted(context.TODO(), otherUser) + require.NoError(t, err, "Error checking if other user is trusted") + assert.True(t, trusted, "Other user not trusted while they should be") } func TestTrustOtherDevice(t *testing.T) { @@ -120,12 +113,11 @@ func TestTrustOtherDevice(t *testing.T) { DeviceID: "theirDevice", SigningKey: id.Ed25519("theirDeviceKey"), } - if trusted, _ := m.IsUserTrusted(context.TODO(), otherUser); trusted { - t.Error("Other user trusted while they shouldn't be") - } - if m.IsDeviceTrusted(theirDevice) { - t.Error("Other device trusted while it shouldn't be") - } + + trusted, err := m.IsUserTrusted(context.TODO(), otherUser) + require.NoError(t, err, "Error checking if other user is trusted") + assert.False(t, trusted, "Other user trusted while they shouldn't be") + assert.False(t, m.IsDeviceTrusted(context.TODO(), theirDevice), "Other device trusted while it shouldn't be") theirMasterKey, _ := olm.NewPKSigning() m.CryptoStore.PutCrossSigningKey(context.TODO(), otherUser, id.XSUsageMaster, theirMasterKey.PublicKey()) @@ -137,21 +129,17 @@ func TestTrustOtherDevice(t *testing.T) { m.CryptoStore.PutSignature(context.TODO(), otherUser, theirMasterKey.PublicKey(), m.Client.UserID, m.CrossSigningKeys.UserSigningKey.PublicKey(), "sig2") - if trusted, _ := m.IsUserTrusted(context.TODO(), otherUser); !trusted { - t.Error("Other user not trusted while they should be") - } + trusted, err = m.IsUserTrusted(context.TODO(), otherUser) + require.NoError(t, err, "Error checking if other user is trusted") + assert.True(t, trusted, "Other user not trusted while they should be") m.CryptoStore.PutSignature(context.TODO(), otherUser, theirSSK.PublicKey(), otherUser, theirMasterKey.PublicKey(), "sig3") - if m.IsDeviceTrusted(theirDevice) { - t.Error("Other device trusted before it has been signed with user's SSK") - } + assert.False(t, m.IsDeviceTrusted(context.TODO(), theirDevice), "Other device trusted before it has been signed with user's SSK") m.CryptoStore.PutSignature(context.TODO(), otherUser, theirDevice.SigningKey, otherUser, theirSSK.PublicKey(), "sig4") - if !m.IsDeviceTrusted(theirDevice) { - t.Error("Other device not trusted while it should be") - } + assert.True(t, m.IsDeviceTrusted(context.TODO(), theirDevice), "Other device not trusted after it has been signed with user's SSK") } diff --git a/crypto/cross_sign_validation.go b/crypto/cross_sign_validation.go index ff2452ec..4cdf0dd5 100644 --- a/crypto/cross_sign_validation.go +++ b/crypto/cross_sign_validation.go @@ -13,6 +13,9 @@ import ( "maunium.net/go/mautrix/id" ) +// ResolveTrust resolves the trust state of the device from cross-signing. +// +// Deprecated: This method doesn't take a context. Use [OlmMachine.ResolveTrustContext] instead. func (mach *OlmMachine) ResolveTrust(device *id.Device) id.TrustState { state, _ := mach.ResolveTrustContext(context.Background(), device) return state @@ -32,7 +35,7 @@ func (mach *OlmMachine) ResolveTrustContext(ctx context.Context, device *id.Devi } theirMSK, ok := theirKeys[id.XSUsageMaster] if !ok { - mach.machOrContextLog(ctx).Error(). + mach.machOrContextLog(ctx).Debug(). Str("user_id", device.UserID.String()). Msg("Master key of user not found") return id.TrustStateUnset, nil @@ -77,8 +80,12 @@ func (mach *OlmMachine) ResolveTrustContext(ctx context.Context, device *id.Devi } // IsDeviceTrusted returns whether a device has been determined to be trusted either through verification or cross-signing. -func (mach *OlmMachine) IsDeviceTrusted(device *id.Device) bool { - switch mach.ResolveTrust(device) { +// +// Note: this will return false if resolving the trust state fails due to database errors. +// Use [OlmMachine.ResolveTrustContext] if special error handling is required. +func (mach *OlmMachine) IsDeviceTrusted(ctx context.Context, device *id.Device) bool { + trust, _ := mach.ResolveTrustContext(ctx, device) + switch trust { case id.TrustStateVerified, id.TrustStateCrossSignedTOFU, id.TrustStateCrossSignedVerified: return true default: diff --git a/crypto/cryptohelper/cryptohelper.go b/crypto/cryptohelper/cryptohelper.go index a0065012..b62dc128 100644 --- a/crypto/cryptohelper/cryptohelper.go +++ b/crypto/cryptohelper/cryptohelper.go @@ -15,6 +15,7 @@ import ( "github.com/rs/zerolog" "go.mau.fi/util/dbutil" + _ "go.mau.fi/util/dbutil/litestream" "maunium.net/go/mautrix" "maunium.net/go/mautrix/crypto" @@ -36,8 +37,12 @@ type CryptoHelper struct { DecryptErrorCallback func(*event.Event, error) + MSC4190 bool LoginAs *mautrix.ReqLogin + ASEventProcessor crypto.ASEventProcessor + CustomPostDecrypt func(context.Context, *event.Event) + DBAccountID string } @@ -58,7 +63,7 @@ func NewCryptoHelper(cli *mautrix.Client, pickleKey []byte, store any) (*CryptoH return nil, fmt.Errorf("pickle key must be provided") } _, isExtensible := cli.Syncer.(mautrix.ExtensibleSyncer) - if !isExtensible { + if !cli.SetAppServiceDeviceID && !isExtensible { return nil, fmt.Errorf("the client syncer must implement ExtensibleSyncer") } @@ -74,7 +79,7 @@ func NewCryptoHelper(cli *mautrix.Client, pickleKey []byte, store any) (*CryptoH } unmanagedCryptoStore = typedStore case string: - db, err := dbutil.NewWithDialect(typedStore, "sqlite3") + db, err := dbutil.NewWithDialect(fmt.Sprintf("file:%s?_txlock=immediate", typedStore), "sqlite3-fk-wal") if err != nil { return nil, err } @@ -111,7 +116,9 @@ func (helper *CryptoHelper) Init(ctx context.Context) error { } syncer, ok := helper.client.Syncer.(mautrix.ExtensibleSyncer) if !ok { - return fmt.Errorf("the client syncer must implement ExtensibleSyncer") + if !helper.client.SetAppServiceDeviceID { + return fmt.Errorf("the client syncer must implement ExtensibleSyncer") + } } var stateStore crypto.StateStore @@ -136,11 +143,42 @@ func (helper *CryptoHelper) Init(ctx context.Context) error { if err != nil { return fmt.Errorf("failed to upgrade crypto state store: %w", err) } - storedDeviceID, err := managedCryptoStore.FindDeviceID(ctx) + cryptoStore = managedCryptoStore + } else { + cryptoStore = helper.unmanagedCryptoStore + } + shouldFindDeviceID := helper.LoginAs != nil || helper.unmanagedCryptoStore == nil + if rawCryptoStore, ok := cryptoStore.(*crypto.SQLCryptoStore); ok && shouldFindDeviceID { + storedDeviceID, err := rawCryptoStore.FindDeviceID(ctx) if err != nil { return fmt.Errorf("failed to find existing device ID: %w", err) } - if helper.LoginAs != nil { + if helper.MSC4190 { + helper.log.Debug().Msg("Creating bot device with MSC4190") + err = helper.client.CreateDeviceMSC4190(ctx, storedDeviceID, helper.LoginAs.InitialDeviceDisplayName) + if err != nil { + return fmt.Errorf("failed to create device for bot: %w", err) + } + rawCryptoStore.DeviceID = helper.client.DeviceID + } else if helper.LoginAs != nil && helper.LoginAs.Type == mautrix.AuthTypeAppservice && helper.client.SetAppServiceDeviceID { + if storedDeviceID == "" { + helper.log.Debug(). + Str("username", helper.LoginAs.Identifier.User). + Msg("Logging in with appservice") + var resp *mautrix.RespLogin + resp, err = helper.client.Login(ctx, helper.LoginAs) + if err != nil { + return err + } + helper.client.DeviceID = resp.DeviceID + } else { + helper.log.Debug(). + Str("username", helper.LoginAs.Identifier.User). + Stringer("device_id", storedDeviceID). + Msg("Using existing device") + helper.client.DeviceID = storedDeviceID + } + } else if helper.LoginAs != nil { if storedDeviceID != "" { helper.LoginAs.DeviceID = storedDeviceID } @@ -153,18 +191,12 @@ func (helper *CryptoHelper) Init(ctx context.Context) error { if err != nil { return err } - if storedDeviceID == "" { - managedCryptoStore.DeviceID = helper.client.DeviceID - } } else if storedDeviceID != "" && storedDeviceID != helper.client.DeviceID { return fmt.Errorf("mismatching device ID in client and crypto store (%q != %q)", storedDeviceID, helper.client.DeviceID) } - cryptoStore = managedCryptoStore - } else { - if helper.LoginAs != nil { - return fmt.Errorf("LoginAs can only be used with a managed crypto store") - } - cryptoStore = helper.unmanagedCryptoStore + rawCryptoStore.DeviceID = helper.client.DeviceID + } else if helper.LoginAs != nil { + return fmt.Errorf("LoginAs can only be used with a managed crypto store") } if helper.client.DeviceID == "" || helper.client.UserID == "" { return fmt.Errorf("the client must be logged in") @@ -177,16 +209,22 @@ func (helper *CryptoHelper) Init(ctx context.Context) error { return err } - syncer.OnSync(helper.mach.ProcessSyncResponse) - syncer.OnEventType(event.StateMember, helper.mach.HandleMemberEvent) - if _, ok = helper.client.Syncer.(mautrix.DispatchableSyncer); ok { - syncer.OnEventType(event.EventEncrypted, helper.HandleEncrypted) - } else { - helper.log.Warn().Msg("Client syncer does not implement DispatchableSyncer. Events will not be decrypted automatically.") - } - if helper.managedStateStore != nil { - syncer.OnEvent(helper.client.StateStoreSyncHandler) + if syncer != nil { + syncer.OnSync(helper.mach.ProcessSyncResponse) + syncer.OnEventType(event.StateMember, helper.mach.HandleMemberEvent) + if _, ok = helper.client.Syncer.(mautrix.DispatchableSyncer); ok { + syncer.OnEventType(event.EventEncrypted, helper.HandleEncrypted) + } else { + helper.log.Warn().Msg("Client syncer does not implement DispatchableSyncer. Events will not be decrypted automatically.") + } + if helper.managedStateStore != nil { + syncer.OnEvent(helper.client.StateStoreSyncHandler) + } + } else if helper.ASEventProcessor != nil { + helper.mach.AddAppserviceListener(helper.ASEventProcessor) + helper.ASEventProcessor.On(event.EventEncrypted, helper.HandleEncrypted) } + return nil } @@ -223,24 +261,24 @@ func (helper *CryptoHelper) verifyDeviceKeysOnServer(ctx context.Context) error if !ok || len(device.Keys) == 0 { if isShared { return fmt.Errorf("olm account is marked as shared, keys seem to have disappeared from the server") - } else { - helper.log.Debug().Msg("Olm account not shared and keys not on server, so device is probably fine") - return nil } + helper.log.Debug().Msg("Olm account not shared and keys not on server, sharing initial keys") + err = helper.mach.ShareKeys(ctx, -1) + if err != nil { + return fmt.Errorf("failed to share keys: %w", err) + } + return nil } else if !isShared { return fmt.Errorf("olm account is not marked as shared, but there are keys on the server") } else if ed := device.Keys.GetEd25519(helper.client.DeviceID); ownID.SigningKey != ed { return fmt.Errorf("mismatching identity key on server (%q != %q)", ownID.SigningKey, ed) - } - if !isShared { - helper.log.Debug().Msg("Olm account not marked as shared, but keys on server match?") } else { helper.log.Debug().Msg("Olm account marked as shared and keys on server match, device is fine") + return nil } - return nil } -var NoSessionFound = crypto.NoSessionFound +var NoSessionFound = crypto.ErrNoSessionFound const initialSessionWaitTimeout = 3 * time.Second const extendedSessionWaitTimeout = 22 * time.Second @@ -259,29 +297,25 @@ func (helper *CryptoHelper) HandleEncrypted(ctx context.Context, evt *event.Even ctx = log.WithContext(ctx) decrypted, err := helper.Decrypt(ctx, evt) - if errors.Is(err, NoSessionFound) { - log.Debug(). - Int("wait_seconds", int(initialSessionWaitTimeout.Seconds())). - Msg("Couldn't find session, waiting for keys to arrive...") - if helper.mach.WaitForSession(ctx, evt.RoomID, content.SenderKey, content.SessionID, initialSessionWaitTimeout) { - log.Debug().Msg("Got keys after waiting, trying to decrypt event again") - decrypted, err = helper.Decrypt(ctx, evt) - } else { - go helper.waitLongerForSession(ctx, log, evt) - return - } - } - if err != nil { + if errors.Is(err, NoSessionFound) && ctx.Value(mautrix.SyncTokenContextKey) != "" { + go helper.waitForSession(ctx, evt) + } else if err != nil { log.Warn().Err(err).Msg("Failed to decrypt event") helper.DecryptErrorCallback(evt, err) - return + } else { + helper.postDecrypt(ctx, decrypted) } - helper.postDecrypt(ctx, decrypted) } func (helper *CryptoHelper) postDecrypt(ctx context.Context, decrypted *event.Event) { decrypted.Mautrix.EventSource |= event.SourceDecrypted - helper.client.Syncer.(mautrix.DispatchableSyncer).Dispatch(ctx, decrypted) + if helper.CustomPostDecrypt != nil { + helper.CustomPostDecrypt(ctx, decrypted) + } else if helper.ASEventProcessor != nil { + helper.ASEventProcessor.Dispatch(ctx, decrypted) + } else { + helper.client.Syncer.(mautrix.DispatchableSyncer).Dispatch(ctx, decrypted) + } } func (helper *CryptoHelper) RequestSession(ctx context.Context, roomID id.RoomID, senderKey id.SenderKey, sessionID id.SessionID, userID id.UserID, deviceID id.DeviceID) { @@ -311,10 +345,33 @@ func (helper *CryptoHelper) RequestSession(ctx context.Context, roomID id.RoomID } } -func (helper *CryptoHelper) waitLongerForSession(ctx context.Context, log zerolog.Logger, evt *event.Event) { +func (helper *CryptoHelper) waitForSession(ctx context.Context, evt *event.Event) { + log := zerolog.Ctx(ctx) + content := evt.Content.AsEncrypted() + + log.Debug(). + Int("wait_seconds", int(initialSessionWaitTimeout.Seconds())). + Msg("Couldn't find session, waiting for keys to arrive...") + if helper.mach.WaitForSession(ctx, evt.RoomID, content.SenderKey, content.SessionID, initialSessionWaitTimeout) { + log.Debug().Msg("Got keys after waiting, trying to decrypt event again") + decrypted, err := helper.Decrypt(ctx, evt) + if err != nil { + log.Warn().Err(err).Msg("Failed to decrypt event") + helper.DecryptErrorCallback(evt, err) + } else { + helper.postDecrypt(ctx, decrypted) + } + } else { + go helper.waitLongerForSession(ctx, evt) + } +} + +func (helper *CryptoHelper) waitLongerForSession(ctx context.Context, evt *event.Event) { + log := zerolog.Ctx(ctx) content := evt.Content.AsEncrypted() log.Debug().Int("wait_seconds", int(extendedSessionWaitTimeout.Seconds())).Msg("Couldn't find session, requesting keys and waiting longer...") + //lint:ignore SA1019 RequestSession will gracefully request from all devices if DeviceID is blank go helper.RequestSession(context.TODO(), evt.RoomID, content.SenderKey, content.SessionID, evt.Sender, content.DeviceID) if !helper.mach.WaitForSession(ctx, evt.RoomID, content.SenderKey, content.SessionID, extendedSessionWaitTimeout) { @@ -351,14 +408,18 @@ func (helper *CryptoHelper) Decrypt(ctx context.Context, evt *event.Event) (*eve } func (helper *CryptoHelper) Encrypt(ctx context.Context, roomID id.RoomID, evtType event.Type, content any) (encrypted *event.EncryptedEventContent, err error) { + return helper.EncryptWithStateKey(ctx, roomID, evtType, nil, content) +} + +func (helper *CryptoHelper) EncryptWithStateKey(ctx context.Context, roomID id.RoomID, evtType event.Type, stateKey *string, content any) (encrypted *event.EncryptedEventContent, err error) { if helper == nil { return nil, fmt.Errorf("crypto helper is nil") } helper.lock.RLock() defer helper.lock.RUnlock() - encrypted, err = helper.mach.EncryptMegolmEvent(ctx, roomID, evtType, content) + encrypted, err = helper.mach.EncryptMegolmEventWithStateKey(ctx, roomID, evtType, stateKey, content) if err != nil { - if !errors.Is(err, crypto.SessionExpired) && err != crypto.NoGroupSession && !errors.Is(err, crypto.SessionNotShared) { + if !errors.Is(err, crypto.ErrSessionExpired) && err != crypto.ErrNoGroupSession && !errors.Is(err, crypto.ErrSessionNotShared) { return } helper.log.Debug(). @@ -371,7 +432,7 @@ func (helper *CryptoHelper) Encrypt(ctx context.Context, roomID id.RoomID, evtTy err = fmt.Errorf("failed to get room member list: %w", err) } else if err = helper.mach.ShareGroupSession(ctx, roomID, users); err != nil { err = fmt.Errorf("failed to share group session: %w", err) - } else if encrypted, err = helper.mach.EncryptMegolmEvent(ctx, roomID, evtType, content); err != nil { + } else if encrypted, err = helper.mach.EncryptMegolmEventWithStateKey(ctx, roomID, evtType, stateKey, content); err != nil { err = fmt.Errorf("failed to encrypt event after re-sharing group session: %w", err) } } diff --git a/crypto/decryptmegolm.go b/crypto/decryptmegolm.go index abe01871..457d5a0c 100644 --- a/crypto/decryptmegolm.go +++ b/crypto/decryptmegolm.go @@ -14,6 +14,9 @@ import ( "strings" "github.com/rs/zerolog" + "github.com/tidwall/gjson" + "github.com/tidwall/sjson" + "go.mau.fi/util/exgjson" "maunium.net/go/mautrix/crypto/olm" "maunium.net/go/mautrix/event" @@ -21,28 +24,61 @@ import ( ) var ( - IncorrectEncryptedContentType = errors.New("event content is not instance of *event.EncryptedEventContent") - NoSessionFound = errors.New("failed to decrypt megolm event: no session with given ID found") - DuplicateMessageIndex = errors.New("duplicate megolm message index") - WrongRoom = errors.New("encrypted megolm event is not intended for this room") - DeviceKeyMismatch = errors.New("device keys in event and verified device info do not match") - SenderKeyMismatch = errors.New("sender keys in content and megolm session do not match") - RatchetError = errors.New("failed to ratchet session after use") + ErrIncorrectEncryptedContentType = errors.New("event content is not instance of *event.EncryptedEventContent") + ErrNoSessionFound = errors.New("failed to decrypt megolm event: no session with given ID found") + ErrDuplicateMessageIndex = errors.New("duplicate megolm message index") + ErrWrongRoom = errors.New("encrypted megolm event is not intended for this room") + ErrDeviceKeyMismatch = errors.New("device keys in event and verified device info do not match") + ErrRatchetError = errors.New("failed to ratchet session after use") + ErrCorruptedMegolmPayload = errors.New("corrupted megolm payload") +) + +// Deprecated: use variables prefixed with Err +var ( + IncorrectEncryptedContentType = ErrIncorrectEncryptedContentType + NoSessionFound = ErrNoSessionFound + DuplicateMessageIndex = ErrDuplicateMessageIndex + WrongRoom = ErrWrongRoom + DeviceKeyMismatch = ErrDeviceKeyMismatch + RatchetError = ErrRatchetError ) type megolmEvent struct { - RoomID id.RoomID `json:"room_id"` - Type event.Type `json:"type"` - Content event.Content `json:"content"` + RoomID id.RoomID `json:"room_id"` + Type event.Type `json:"type"` + StateKey *string `json:"state_key"` + Content event.Content `json:"content"` +} + +var ( + relatesToContentPath = exgjson.Path("m.relates_to") + relatesToTopLevelPath = exgjson.Path("content", "m.relates_to") +) + +const sessionIDLength = 43 + +func validateCiphertextCharacters(ciphertext []byte) bool { + for _, b := range ciphertext { + if (b < 'a' || b > 'z') && (b < 'A' || b > 'Z') && (b < '0' || b > '9') && b != '+' && b != '/' { + return false + } + } + return true } // DecryptMegolmEvent decrypts an m.room.encrypted event where the algorithm is m.megolm.v1.aes-sha2 func (mach *OlmMachine) DecryptMegolmEvent(ctx context.Context, evt *event.Event) (*event.Event, error) { content, ok := evt.Content.Parsed.(*event.EncryptedEventContent) if !ok { - return nil, IncorrectEncryptedContentType + return nil, ErrIncorrectEncryptedContentType } else if content.Algorithm != id.AlgorithmMegolmV1 { - return nil, UnsupportedAlgorithm + return nil, ErrUnsupportedAlgorithm + } else if len(content.MegolmCiphertext) < 74 { + return nil, fmt.Errorf("%w: ciphertext too short (%d bytes)", ErrCorruptedMegolmPayload, len(content.MegolmCiphertext)) + } else if len(content.SessionID) != sessionIDLength { + return nil, fmt.Errorf("%w: invalid session ID length %d", ErrCorruptedMegolmPayload, len(content.SessionID)) + } else if !validateCiphertextCharacters(content.MegolmCiphertext) { + return nil, fmt.Errorf("%w: invalid characters in ciphertext", ErrCorruptedMegolmPayload) } log := mach.machOrContextLog(ctx).With(). Str("action", "decrypt megolm event"). @@ -88,16 +124,28 @@ func (mach *OlmMachine) DecryptMegolmEvent(ctx context.Context, evt *event.Event Msg("Couldn't resolve trust level of session: sent by unknown device") trustLevel = id.TrustStateUnknownDevice } else if device.SigningKey != sess.SigningKey || device.IdentityKey != sess.SenderKey { - return nil, DeviceKeyMismatch + log.Debug(). + Stringer("session_sender_key", sess.SenderKey). + Stringer("device_sender_key", device.IdentityKey). + Stringer("session_signing_key", sess.SigningKey). + Stringer("device_signing_key", device.SigningKey). + Msg("Device keys don't match keys in session, marking as untrusted") + trustLevel = id.TrustStateDeviceKeyMismatch } else { - trustLevel = mach.ResolveTrust(device) + trustLevel, err = mach.ResolveTrustContext(ctx, device) + if err != nil { + return nil, err + } } } else { forwardedKeys = true lastChainItem := sess.ForwardingChains[len(sess.ForwardingChains)-1] device, _ = mach.CryptoStore.FindDeviceByKey(ctx, evt.Sender, id.IdentityKey(lastChainItem)) if device != nil { - trustLevel = mach.ResolveTrust(device) + trustLevel, err = mach.ResolveTrustContext(ctx, device) + if err != nil { + return nil, err + } } else { log.Debug(). Str("forward_last_sender_key", lastChainItem). @@ -107,41 +155,54 @@ func (mach *OlmMachine) DecryptMegolmEvent(ctx context.Context, evt *event.Event } } + if content.RelatesTo != nil { + relation := gjson.GetBytes(evt.Content.VeryRaw, relatesToContentPath) + if relation.Exists() && !gjson.GetBytes(plaintext, relatesToTopLevelPath).IsObject() { + var raw []byte + if relation.Index > 0 { + raw = evt.Content.VeryRaw[relation.Index : relation.Index+len(relation.Raw)] + } else { + raw = []byte(relation.Raw) + } + updatedPlaintext, err := sjson.SetRawBytes(plaintext, relatesToTopLevelPath, raw) + if err != nil { + log.Warn().Msg("Failed to copy m.relates_to to decrypted payload") + } else if updatedPlaintext != nil { + plaintext = updatedPlaintext + } + } else if !relation.Exists() { + log.Warn().Msg("Failed to find m.relates_to in raw encrypted event even though it was present in parsed content") + } + } + megolmEvt := &megolmEvent{} err = json.Unmarshal(plaintext, &megolmEvt) if err != nil { return nil, fmt.Errorf("failed to parse megolm payload: %w", err) } else if megolmEvt.RoomID != encryptionRoomID { - return nil, WrongRoom + return nil, ErrWrongRoom + } + if evt.StateKey != nil && megolmEvt.StateKey != nil && mach.AllowEncryptedState { + megolmEvt.Type.Class = event.StateEventType + } else { + megolmEvt.Type.Class = evt.Type.Class + megolmEvt.StateKey = nil } - megolmEvt.Type.Class = evt.Type.Class log = log.With().Str("decrypted_event_type", megolmEvt.Type.Repr()).Logger() err = megolmEvt.Content.ParseRaw(megolmEvt.Type) if err != nil { if errors.Is(err, event.ErrUnsupportedContentType) { log.Warn().Msg("Unsupported event type in encrypted event") - } else { + } else if !mach.IgnorePostDecryptionParseErrors { return nil, fmt.Errorf("failed to parse content of megolm payload event: %w", err) } } - if content.RelatesTo != nil { - relatable, ok := megolmEvt.Content.Parsed.(event.Relatable) - if ok { - if relatable.OptionalGetRelatesTo() == nil { - relatable.SetRelatesTo(content.RelatesTo) - } else { - log.Trace().Msg("Not overriding relation data as encrypted payload already has it") - } - } - if _, hasRelation := megolmEvt.Content.Raw["m.relates_to"]; !hasRelation { - megolmEvt.Content.Raw["m.relates_to"] = evt.Content.Raw["m.relates_to"] - } - } log.Debug().Msg("Event decrypted successfully") megolmEvt.Type.Class = evt.Type.Class return &event.Event{ Sender: evt.Sender, Type: megolmEvt.Type, + StateKey: megolmEvt.StateKey, Timestamp: evt.Timestamp, ID: evt.ID, RoomID: evt.RoomID, @@ -152,6 +213,7 @@ func (mach *OlmMachine) DecryptMegolmEvent(ctx context.Context, evt *event.Event TrustSource: device, ForwardedKeys: forwardedKeys, WasEncrypted: true, + EventSource: evt.Mautrix.EventSource | event.SourceDecrypted, ReceivedAt: evt.Mautrix.ReceivedAt, }, }, nil @@ -170,39 +232,37 @@ const missedIndexCutoff = 10 func (mach *OlmMachine) checkUndecryptableMessageIndexDuplication(ctx context.Context, sess *InboundGroupSession, evt *event.Event, content *event.EncryptedEventContent) (uint, error) { log := *zerolog.Ctx(ctx) - messageIndex, decodeErr := parseMessageIndex(content.MegolmCiphertext) + messageIndex, decodeErr := ParseMegolmMessageIndex(content.MegolmCiphertext) if decodeErr != nil { log.Warn().Err(decodeErr).Msg("Failed to parse message index to check if it's a duplicate for message that failed to decrypt") - return 0, fmt.Errorf("%w (also failed to parse message index)", olm.UnknownMessageIndex) + return 0, fmt.Errorf("%w (also failed to parse message index)", olm.ErrUnknownMessageIndex) } firstKnown := sess.Internal.FirstKnownIndex() log = log.With().Uint("message_index", messageIndex).Uint32("first_known_index", firstKnown).Logger() if ok, err := mach.CryptoStore.ValidateMessageIndex(ctx, sess.SenderKey, content.SessionID, evt.ID, messageIndex, evt.Timestamp); err != nil { log.Debug().Err(err).Msg("Failed to check if message index is duplicate") - return messageIndex, fmt.Errorf("%w (failed to check if index is duplicate; received: %d, earliest known: %d)", olm.UnknownMessageIndex, messageIndex, firstKnown) + return messageIndex, fmt.Errorf("%w (failed to check if index is duplicate; received: %d, earliest known: %d)", olm.ErrUnknownMessageIndex, messageIndex, firstKnown) } else if !ok { log.Debug().Msg("Failed to decrypt message due to unknown index and found duplicate") - return messageIndex, fmt.Errorf("%w %d (also failed to decrypt because earliest known index is %d)", DuplicateMessageIndex, messageIndex, firstKnown) + return messageIndex, fmt.Errorf("%w %d (also failed to decrypt because earliest known index is %d)", ErrDuplicateMessageIndex, messageIndex, firstKnown) } log.Debug().Msg("Failed to decrypt message due to unknown index, but index is not duplicate") - return messageIndex, fmt.Errorf("%w (not duplicate index; received: %d, earliest known: %d)", olm.UnknownMessageIndex, messageIndex, firstKnown) + return messageIndex, fmt.Errorf("%w (not duplicate index; received: %d, earliest known: %d)", olm.ErrUnknownMessageIndex, messageIndex, firstKnown) } func (mach *OlmMachine) actuallyDecryptMegolmEvent(ctx context.Context, evt *event.Event, encryptionRoomID id.RoomID, content *event.EncryptedEventContent) (*InboundGroupSession, []byte, uint, error) { mach.megolmDecryptLock.Lock() defer mach.megolmDecryptLock.Unlock() - sess, err := mach.CryptoStore.GetGroupSession(ctx, encryptionRoomID, content.SenderKey, content.SessionID) + sess, err := mach.CryptoStore.GetGroupSession(ctx, encryptionRoomID, content.SessionID) if err != nil { return nil, nil, 0, fmt.Errorf("failed to get group session: %w", err) } else if sess == nil { - return nil, nil, 0, fmt.Errorf("%w (ID %s)", NoSessionFound, content.SessionID) - } else if content.SenderKey != "" && content.SenderKey != sess.SenderKey { - return sess, nil, 0, SenderKeyMismatch + return nil, nil, 0, fmt.Errorf("%w (ID %s)", ErrNoSessionFound, content.SessionID) } plaintext, messageIndex, err := sess.Internal.Decrypt(content.MegolmCiphertext) if err != nil { - if errors.Is(err, olm.UnknownMessageIndex) && mach.RatchetKeysOnDecrypt { + if errors.Is(err, olm.ErrUnknownMessageIndex) && mach.RatchetKeysOnDecrypt { messageIndex, err = mach.checkUndecryptableMessageIndexDuplication(ctx, sess, evt, content) return sess, nil, messageIndex, fmt.Errorf("failed to decrypt megolm event: %w", err) } @@ -210,7 +270,12 @@ func (mach *OlmMachine) actuallyDecryptMegolmEvent(ctx context.Context, evt *eve } else if ok, err := mach.CryptoStore.ValidateMessageIndex(ctx, sess.SenderKey, content.SessionID, evt.ID, messageIndex, evt.Timestamp); err != nil { return sess, nil, messageIndex, fmt.Errorf("failed to check if message index is duplicate: %w", err) } else if !ok { - return sess, nil, messageIndex, fmt.Errorf("%w %d", DuplicateMessageIndex, messageIndex) + return sess, nil, messageIndex, fmt.Errorf("%w %d", ErrDuplicateMessageIndex, messageIndex) + } + + // Normal clients don't care about tracking the ratchet state, so let them bypass the rest of the function + if mach.DisableRatchetTracking { + return sess, plaintext, messageIndex, nil } expectedMessageIndex := sess.RatchetSafety.NextIndex @@ -254,27 +319,27 @@ func (mach *OlmMachine) actuallyDecryptMegolmEvent(ctx context.Context, evt *eve Int("max_messages", sess.MaxMessages). Logger() if sess.MaxMessages > 0 && int(ratchetTargetIndex) >= sess.MaxMessages && len(sess.RatchetSafety.MissedIndices) == 0 && mach.DeleteFullyUsedKeysOnDecrypt { - err = mach.CryptoStore.RedactGroupSession(ctx, sess.RoomID, sess.SenderKey, sess.ID(), "maximum messages reached") + err = mach.CryptoStore.RedactGroupSession(ctx, sess.RoomID, sess.ID(), "maximum messages reached") if err != nil { log.Err(err).Msg("Failed to delete fully used session") - return sess, plaintext, messageIndex, RatchetError + return sess, plaintext, messageIndex, ErrRatchetError } else { log.Info().Msg("Deleted fully used session") } } else if ratchetCurrentIndex < ratchetTargetIndex && mach.RatchetKeysOnDecrypt { if err = sess.RatchetTo(ratchetTargetIndex); err != nil { log.Err(err).Msg("Failed to ratchet session") - return sess, plaintext, messageIndex, RatchetError - } else if err = mach.CryptoStore.PutGroupSession(ctx, sess.RoomID, sess.SenderKey, sess.ID(), sess); err != nil { + return sess, plaintext, messageIndex, ErrRatchetError + } else if err = mach.CryptoStore.PutGroupSession(ctx, sess); err != nil { log.Err(err).Msg("Failed to store ratcheted session") - return sess, plaintext, messageIndex, RatchetError + return sess, plaintext, messageIndex, ErrRatchetError } else { log.Info().Msg("Ratcheted session forward") } } else if didModify { - if err = mach.CryptoStore.PutGroupSession(ctx, sess.RoomID, sess.SenderKey, sess.ID(), sess); err != nil { + if err = mach.CryptoStore.PutGroupSession(ctx, sess); err != nil { log.Err(err).Msg("Failed to store updated ratchet safety data") - return sess, plaintext, messageIndex, RatchetError + return sess, plaintext, messageIndex, ErrRatchetError } else { log.Debug().Msg("Ratchet safety data changed (ratchet state didn't change)") } diff --git a/crypto/decryptolm.go b/crypto/decryptolm.go index 55614b76..aea5e6dc 100644 --- a/crypto/decryptolm.go +++ b/crypto/decryptolm.go @@ -8,26 +8,45 @@ package crypto import ( "context" + "crypto/sha256" + "encoding/base64" "encoding/json" "errors" "fmt" + "slices" "time" "github.com/rs/zerolog" + "go.mau.fi/util/exerrors" + "go.mau.fi/util/ptr" + "maunium.net/go/mautrix/crypto/goolm/account" "maunium.net/go/mautrix/event" "maunium.net/go/mautrix/id" ) var ( - UnsupportedAlgorithm = errors.New("unsupported event encryption algorithm") - NotEncryptedForMe = errors.New("olm event doesn't contain ciphertext for this device") - UnsupportedOlmMessageType = errors.New("unsupported olm message type") - DecryptionFailedWithMatchingSession = errors.New("decryption failed with matching session") - DecryptionFailedForNormalMessage = errors.New("decryption failed for normal message") - SenderMismatch = errors.New("mismatched sender in olm payload") - RecipientMismatch = errors.New("mismatched recipient in olm payload") - RecipientKeyMismatch = errors.New("mismatched recipient key in olm payload") + ErrUnsupportedAlgorithm = errors.New("unsupported event encryption algorithm") + ErrNotEncryptedForMe = errors.New("olm event doesn't contain ciphertext for this device") + ErrUnsupportedOlmMessageType = errors.New("unsupported olm message type") + ErrDecryptionFailedWithMatchingSession = errors.New("decryption failed with matching session") + ErrDecryptionFailedForNormalMessage = errors.New("decryption failed for normal message") + ErrSenderMismatch = errors.New("mismatched sender in olm payload") + ErrRecipientMismatch = errors.New("mismatched recipient in olm payload") + ErrRecipientKeyMismatch = errors.New("mismatched recipient key in olm payload") + ErrDuplicateMessage = errors.New("duplicate olm message") +) + +// Deprecated: use variables prefixed with Err +var ( + UnsupportedAlgorithm = ErrUnsupportedAlgorithm + NotEncryptedForMe = ErrNotEncryptedForMe + UnsupportedOlmMessageType = ErrUnsupportedOlmMessageType + DecryptionFailedWithMatchingSession = ErrDecryptionFailedWithMatchingSession + DecryptionFailedForNormalMessage = ErrDecryptionFailedForNormalMessage + SenderMismatch = ErrSenderMismatch + RecipientMismatch = ErrRecipientMismatch + RecipientKeyMismatch = ErrRecipientKeyMismatch ) // DecryptedOlmEvent represents an event that was decrypted from an event encrypted with the m.olm.v1.curve25519-aes-sha2 algorithm. @@ -49,13 +68,13 @@ type DecryptedOlmEvent struct { func (mach *OlmMachine) decryptOlmEvent(ctx context.Context, evt *event.Event) (*DecryptedOlmEvent, error) { content, ok := evt.Content.Parsed.(*event.EncryptedEventContent) if !ok { - return nil, IncorrectEncryptedContentType + return nil, ErrIncorrectEncryptedContentType } else if content.Algorithm != id.AlgorithmOlmV1 { - return nil, UnsupportedAlgorithm + return nil, ErrUnsupportedAlgorithm } ownContent, ok := content.OlmCiphertext[mach.account.IdentityKey()] if !ok { - return nil, NotEncryptedForMe + return nil, ErrNotEncryptedForMe } decrypted, err := mach.decryptAndParseOlmCiphertext(ctx, evt, content.SenderKey, ownContent.Type, ownContent.Body) if err != nil { @@ -71,9 +90,14 @@ type OlmEventKeys struct { func (mach *OlmMachine) decryptAndParseOlmCiphertext(ctx context.Context, evt *event.Event, senderKey id.SenderKey, olmType id.OlmMsgType, ciphertext string) (*DecryptedOlmEvent, error) { if olmType != id.OlmMsgTypePreKey && olmType != id.OlmMsgTypeMsg { - return nil, UnsupportedOlmMessageType + return nil, ErrUnsupportedOlmMessageType } + log := mach.machOrContextLog(ctx).With(). + Stringer("sender_key", senderKey). + Int("olm_msg_type", int(olmType)). + Logger() + ctx = log.WithContext(ctx) endTimeTrace := mach.timeTrace(ctx, "decrypting olm ciphertext", 5*time.Second) plaintext, err := mach.tryDecryptOlmCiphertext(ctx, evt.Sender, senderKey, olmType, ciphertext) endTimeTrace() @@ -90,16 +114,18 @@ func (mach *OlmMachine) decryptAndParseOlmCiphertext(ctx context.Context, evt *e } olmEvt.Type.Class = evt.Type.Class if evt.Sender != olmEvt.Sender { - return nil, SenderMismatch + return nil, ErrSenderMismatch } else if mach.Client.UserID != olmEvt.Recipient { - return nil, RecipientMismatch + return nil, ErrRecipientMismatch } else if mach.account.SigningKey() != olmEvt.RecipientKeys.Ed25519 { - return nil, RecipientKeyMismatch + return nil, ErrRecipientKeyMismatch } - err = olmEvt.Content.ParseRaw(olmEvt.Type) - if err != nil && !errors.Is(err, event.ErrUnsupportedContentType) { - return nil, fmt.Errorf("failed to parse content of olm payload event: %w", err) + if len(olmEvt.Content.VeryRaw) > 0 { + err = olmEvt.Content.ParseRaw(olmEvt.Type) + if err != nil && !errors.Is(err, event.ErrUnsupportedContentType) { + return nil, fmt.Errorf("failed to parse content of olm payload event: %w", err) + } } olmEvt.SenderKey = senderKey @@ -107,16 +133,40 @@ func (mach *OlmMachine) decryptAndParseOlmCiphertext(ctx context.Context, evt *e return &olmEvt, nil } +func olmMessageHash(ciphertext string) ([32]byte, error) { + if ciphertext == "" { + return [32]byte{}, fmt.Errorf("empty ciphertext") + } + ciphertextBytes, err := base64.RawStdEncoding.DecodeString(ciphertext) + return sha256.Sum256(ciphertextBytes), err +} + func (mach *OlmMachine) tryDecryptOlmCiphertext(ctx context.Context, sender id.UserID, senderKey id.SenderKey, olmType id.OlmMsgType, ciphertext string) ([]byte, error) { + ciphertextHash, err := olmMessageHash(ciphertext) + if err != nil { + return nil, fmt.Errorf("failed to hash olm ciphertext: %w", err) + } + log := *zerolog.Ctx(ctx) endTimeTrace := mach.timeTrace(ctx, "waiting for olm lock", 5*time.Second) mach.olmLock.Lock() endTimeTrace() defer mach.olmLock.Unlock() - plaintext, err := mach.tryDecryptOlmCiphertextWithExistingSession(ctx, senderKey, olmType, ciphertext) + duplicateTS, err := mach.CryptoStore.GetOlmHash(ctx, ciphertextHash) if err != nil { - if err == DecryptionFailedWithMatchingSession { + log.Warn().Err(err).Msg("Failed to check for duplicate olm message") + } else if !duplicateTS.IsZero() { + log.Warn(). + Hex("ciphertext_hash", ciphertextHash[:]). + Time("duplicate_ts", duplicateTS). + Msg("Ignoring duplicate olm message") + return nil, ErrDuplicateMessage + } + + plaintext, err := mach.tryDecryptOlmCiphertextWithExistingSession(ctx, senderKey, olmType, ciphertext, ciphertextHash) + if err != nil { + if err == ErrDecryptionFailedWithMatchingSession { log.Warn().Msg("Found matching session, but decryption failed") go mach.unwedgeDevice(log, sender, senderKey) } @@ -134,9 +184,10 @@ func (mach *OlmMachine) tryDecryptOlmCiphertext(ctx context.Context, sender id.U // if it isn't one at this point in time anymore, so return early. if olmType != id.OlmMsgTypePreKey { go mach.unwedgeDevice(log, sender, senderKey) - return nil, DecryptionFailedForNormalMessage + return nil, ErrDecryptionFailedForNormalMessage } + accountBackup, _ := mach.account.Internal.Pickle([]byte("tmp")) log.Trace().Msg("Trying to create inbound session") endTimeTrace = mach.timeTrace(ctx, "creating inbound olm session", time.Second) session, err := mach.createInboundSession(ctx, senderKey, ciphertext) @@ -147,6 +198,8 @@ func (mach *OlmMachine) tryDecryptOlmCiphertext(ctx context.Context, sender id.U } log = log.With().Str("new_olm_session_id", session.ID().String()).Logger() log.Debug(). + Hex("ciphertext_hash", ciphertextHash[:]). + Hex("ciphertext_hash_repeat", ptr.Ptr(exerrors.Must(olmMessageHash(ciphertext)))[:]). Str("olm_session_description", session.Describe()). Msg("Created inbound olm session") ctx = log.WithContext(ctx) @@ -155,11 +208,28 @@ func (mach *OlmMachine) tryDecryptOlmCiphertext(ctx context.Context, sender id.U plaintext, err = session.Decrypt(ciphertext, olmType) endTimeTrace() if err != nil { + log.Debug(). + Hex("ciphertext_hash", ciphertextHash[:]). + Hex("ciphertext_hash_repeat", ptr.Ptr(exerrors.Must(olmMessageHash(ciphertext)))[:]). + Str("ciphertext", ciphertext). + Str("olm_session_description", session.Describe()). + Msg("DEBUG: Failed to decrypt prekey olm message with newly created session") + err2 := mach.goolmRetryHack(ctx, senderKey, ciphertext, accountBackup) + if err2 != nil { + log.Debug().Err(err2).Msg("Goolm confirmed decryption failure") + } else { + log.Warn().Msg("Goolm decryption was successful after libolm failure?") + } + go mach.unwedgeDevice(log, sender, senderKey) return nil, fmt.Errorf("failed to decrypt olm event with session created from prekey message: %w", err) } endTimeTrace = mach.timeTrace(ctx, "updating new session in database", time.Second) + err = mach.CryptoStore.PutOlmHash(ctx, ciphertextHash, time.Now()) + if err != nil { + log.Warn().Err(err).Msg("Failed to store olm message hash after decrypting") + } err = mach.CryptoStore.UpdateSession(ctx, senderKey, session) endTimeTrace() if err != nil { @@ -168,7 +238,28 @@ func (mach *OlmMachine) tryDecryptOlmCiphertext(ctx context.Context, sender id.U return plaintext, nil } -func (mach *OlmMachine) tryDecryptOlmCiphertextWithExistingSession(ctx context.Context, senderKey id.SenderKey, olmType id.OlmMsgType, ciphertext string) ([]byte, error) { +func (mach *OlmMachine) goolmRetryHack(ctx context.Context, senderKey id.SenderKey, ciphertext string, accountBackup []byte) error { + acc, err := account.AccountFromPickled(accountBackup, []byte("tmp")) + if err != nil { + return fmt.Errorf("failed to unpickle olm account: %w", err) + } + sess, err := acc.NewInboundSessionFrom(&senderKey, ciphertext) + if err != nil { + return fmt.Errorf("failed to create inbound session: %w", err) + } + _, err = sess.Decrypt(ciphertext, id.OlmMsgTypePreKey) + if err != nil { + // This is the expected result if libolm failed + return fmt.Errorf("failed to decrypt with new session: %w", err) + } + return nil +} + +const MaxOlmSessionsPerDevice = 5 + +func (mach *OlmMachine) tryDecryptOlmCiphertextWithExistingSession( + ctx context.Context, senderKey id.SenderKey, olmType id.OlmMsgType, ciphertext string, ciphertextHash [32]byte, +) ([]byte, error) { log := *zerolog.Ctx(ctx) endTimeTrace := mach.timeTrace(ctx, "getting sessions with sender key", time.Second) sessions, err := mach.CryptoStore.GetSessions(ctx, senderKey) @@ -176,6 +267,32 @@ func (mach *OlmMachine) tryDecryptOlmCiphertextWithExistingSession(ctx context.C if err != nil { return nil, fmt.Errorf("failed to get session for %s: %w", senderKey, err) } + if len(sessions) > MaxOlmSessionsPerDevice*2 { + // SQL store sorts sessions, but other implementations may not, so re-sort just in case + slices.SortFunc(sessions, func(a, b *OlmSession) int { + return b.LastDecryptedTime.Compare(a.LastDecryptedTime) + }) + log.Warn(). + Int("session_count", len(sessions)). + Time("newest_last_decrypted_at", sessions[0].LastDecryptedTime). + Time("oldest_last_decrypted_at", sessions[len(sessions)-1].LastDecryptedTime). + Msg("Too many sessions, deleting old ones") + for i := MaxOlmSessionsPerDevice; i < len(sessions); i++ { + err = mach.CryptoStore.DeleteSession(ctx, senderKey, sessions[i]) + if err != nil { + log.Warn().Err(err). + Stringer("olm_session_id", sessions[i].ID()). + Time("last_decrypt", sessions[i].LastDecryptedTime). + Msg("Failed to delete olm session") + } else { + log.Debug(). + Stringer("olm_session_id", sessions[i].ID()). + Time("last_decrypt", sessions[i].LastDecryptedTime). + Msg("Deleted olm session") + } + } + sessions = sessions[:MaxOlmSessionsPerDevice] + } for _, session := range sessions { log := log.With().Str("olm_session_id", session.ID().String()).Logger() @@ -190,22 +307,33 @@ func (mach *OlmMachine) tryDecryptOlmCiphertextWithExistingSession(ctx context.C continue } } - log.Debug().Str("session_description", session.Describe()).Msg("Trying to decrypt olm message") endTimeTrace = mach.timeTrace(ctx, "decrypting olm message", time.Second) plaintext, err := session.Decrypt(ciphertext, olmType) endTimeTrace() if err != nil { + log.Warn().Err(err). + Hex("ciphertext_hash", ciphertextHash[:]). + Hex("ciphertext_hash_repeat", ptr.Ptr(exerrors.Must(olmMessageHash(ciphertext)))[:]). + Str("session_description", session.Describe()). + Msg("Failed to decrypt olm message") if olmType == id.OlmMsgTypePreKey { - return nil, DecryptionFailedWithMatchingSession + return nil, ErrDecryptionFailedWithMatchingSession } } else { endTimeTrace = mach.timeTrace(ctx, "updating session in database", time.Second) + err = mach.CryptoStore.PutOlmHash(ctx, ciphertextHash, time.Now()) + if err != nil { + log.Warn().Err(err).Msg("Failed to store olm message hash after decrypting") + } err = mach.CryptoStore.UpdateSession(ctx, senderKey, session) endTimeTrace() if err != nil { log.Warn().Err(err).Msg("Failed to update olm session in crypto store after decrypting") } - log.Debug().Msg("Decrypted olm message") + log.Debug(). + Hex("ciphertext_hash", ciphertextHash[:]). + Str("session_description", session.Describe()). + Msg("Decrypted olm message") return plaintext, nil } } @@ -229,10 +357,10 @@ const MinUnwedgeInterval = 1 * time.Hour func (mach *OlmMachine) unwedgeDevice(log zerolog.Logger, sender id.UserID, senderKey id.SenderKey) { log = log.With().Str("action", "unwedge olm session").Logger() - ctx := log.WithContext(context.TODO()) + ctx := log.WithContext(mach.backgroundCtx) mach.recentlyUnwedgedLock.Lock() prevUnwedge, ok := mach.recentlyUnwedged[senderKey] - delta := time.Now().Sub(prevUnwedge) + delta := time.Since(prevUnwedge) if ok && delta < MinUnwedgeInterval { log.Debug(). Str("previous_recreation", delta.String()). @@ -243,6 +371,17 @@ func (mach *OlmMachine) unwedgeDevice(log zerolog.Logger, sender id.UserID, send mach.recentlyUnwedged[senderKey] = time.Now() mach.recentlyUnwedgedLock.Unlock() + lastCreatedAt, err := mach.CryptoStore.GetNewestSessionCreationTS(ctx, senderKey) + if err != nil { + log.Warn().Err(err).Msg("Failed to get newest session creation timestamp") + return + } else if time.Since(lastCreatedAt) < MinUnwedgeInterval { + log.Debug(). + Time("last_created_at", lastCreatedAt). + Msg("Not creating new Olm session as it was already recreated recently") + return + } + deviceIdentity, err := mach.GetOrFetchDeviceByKey(ctx, sender, senderKey) if err != nil { log.Error().Err(err).Msg("Failed to find device info by identity key") @@ -252,7 +391,10 @@ func (mach *OlmMachine) unwedgeDevice(log zerolog.Logger, sender id.UserID, send return } - log.Debug().Str("device_id", deviceIdentity.DeviceID.String()).Msg("Creating new Olm session") + log.Debug(). + Time("last_created", lastCreatedAt). + Stringer("device_id", deviceIdentity.DeviceID). + Msg("Creating new Olm session") mach.devicesToUnwedgeLock.Lock() mach.devicesToUnwedge[senderKey] = true mach.devicesToUnwedgeLock.Unlock() diff --git a/crypto/devicelist.go b/crypto/devicelist.go index 16c4164e..f0d2b129 100644 --- a/crypto/devicelist.go +++ b/crypto/devicelist.go @@ -10,8 +10,11 @@ import ( "context" "errors" "fmt" + "slices" + "strings" "github.com/rs/zerolog" + "go.mau.fi/util/exzerolog" "maunium.net/go/mautrix" "maunium.net/go/mautrix/crypto/signatures" @@ -19,12 +22,23 @@ import ( ) var ( - MismatchingDeviceID = errors.New("mismatching device ID in parameter and keys object") - MismatchingUserID = errors.New("mismatching user ID in parameter and keys object") - MismatchingSigningKey = errors.New("received update for device with different signing key") - NoSigningKeyFound = errors.New("didn't find ed25519 signing key") - NoIdentityKeyFound = errors.New("didn't find curve25519 identity key") - InvalidKeySignature = errors.New("invalid signature on device keys") + ErrMismatchingDeviceID = errors.New("mismatching device ID in parameter and keys object") + ErrMismatchingUserID = errors.New("mismatching user ID in parameter and keys object") + ErrMismatchingSigningKey = errors.New("received update for device with different signing key") + ErrNoSigningKeyFound = errors.New("didn't find ed25519 signing key") + ErrNoIdentityKeyFound = errors.New("didn't find curve25519 identity key") + ErrInvalidKeySignature = errors.New("invalid signature on device keys") + ErrUserNotTracked = errors.New("user is not tracked") +) + +// Deprecated: use variables prefixed with Err +var ( + MismatchingDeviceID = ErrMismatchingDeviceID + MismatchingUserID = ErrMismatchingUserID + MismatchingSigningKey = ErrMismatchingSigningKey + NoSigningKeyFound = ErrNoSigningKeyFound + NoIdentityKeyFound = ErrNoIdentityKeyFound + InvalidKeySignature = ErrInvalidKeySignature ) func (mach *OlmMachine) LoadDevices(ctx context.Context, user id.UserID) (keys map[id.DeviceID]*id.Device) { @@ -39,6 +53,81 @@ func (mach *OlmMachine) LoadDevices(ctx context.Context, user id.UserID) (keys m return nil } +type CachedDevices struct { + Devices []*id.Device + MasterKey *id.CrossSigningKey + HasValidSelfSigningKey bool + MasterKeySignedByUs bool +} + +func (mach *OlmMachine) GetCachedDevices(ctx context.Context, userID id.UserID) (*CachedDevices, error) { + userIDs, err := mach.CryptoStore.FilterTrackedUsers(ctx, []id.UserID{userID}) + if err != nil { + return nil, fmt.Errorf("failed to check if user's devices are tracked: %w", err) + } else if len(userIDs) == 0 { + return nil, ErrUserNotTracked + } + ownKeys := mach.GetOwnCrossSigningPublicKeys(ctx) + var ownUserSigningKey id.Ed25519 + if ownKeys != nil { + ownUserSigningKey = ownKeys.UserSigningKey + } + var resp CachedDevices + csKeys, err := mach.CryptoStore.GetCrossSigningKeys(ctx, userID) + theirMasterKey := csKeys[id.XSUsageMaster] + theirSelfSignKey := csKeys[id.XSUsageSelfSigning] + if err != nil { + return nil, fmt.Errorf("failed to get cross-signing keys: %w", err) + } else if csKeys != nil && theirMasterKey.Key != "" { + resp.MasterKey = &theirMasterKey + if theirSelfSignKey.Key != "" { + resp.HasValidSelfSigningKey, err = mach.CryptoStore.IsKeySignedBy(ctx, userID, theirSelfSignKey.Key, userID, theirMasterKey.Key) + if err != nil { + return nil, fmt.Errorf("failed to check if self-signing key is signed by master key: %w", err) + } + } + } + devices, err := mach.CryptoStore.GetDevices(ctx, userID) + if err != nil { + return nil, fmt.Errorf("failed to get devices: %w", err) + } + if userID == mach.Client.UserID { + if ownKeys != nil && ownKeys.MasterKey == theirMasterKey.Key { + resp.MasterKeySignedByUs, err = mach.CryptoStore.IsKeySignedBy(ctx, userID, theirMasterKey.Key, userID, mach.OwnIdentity().SigningKey) + } + } else if ownUserSigningKey != "" && theirMasterKey.Key != "" { + // TODO should own master key and user-signing key signatures be checked here too? + resp.MasterKeySignedByUs, err = mach.CryptoStore.IsKeySignedBy(ctx, userID, theirMasterKey.Key, mach.Client.UserID, ownUserSigningKey) + } + if err != nil { + return nil, fmt.Errorf("failed to check if user is trusted: %w", err) + } + resp.Devices = make([]*id.Device, len(devices)) + i := 0 + for _, device := range devices { + resp.Devices[i] = device + if resp.HasValidSelfSigningKey && device.Trust == id.TrustStateUnset { + signed, err := mach.CryptoStore.IsKeySignedBy(ctx, device.UserID, device.SigningKey, device.UserID, theirSelfSignKey.Key) + if err != nil { + return nil, fmt.Errorf("failed to check if device %s is signed by self-signing key: %w", device.DeviceID, err) + } else if signed { + if resp.MasterKeySignedByUs { + device.Trust = id.TrustStateCrossSignedVerified + } else if theirMasterKey.Key == theirMasterKey.First { + device.Trust = id.TrustStateCrossSignedTOFU + } else { + device.Trust = id.TrustStateCrossSignedUntrusted + } + } + } + i++ + } + slices.SortFunc(resp.Devices, func(a, b *id.Device) int { + return strings.Compare(a.DeviceID.String(), b.DeviceID.String()) + }) + return &resp, nil +} + func (mach *OlmMachine) storeDeviceSelfSignatures(ctx context.Context, userID id.UserID, deviceID id.DeviceID, resp *mautrix.RespQueryKeys) { log := zerolog.Ctx(ctx) deviceKeys := resp.DeviceKeys[userID][deviceID] @@ -93,6 +182,10 @@ func (mach *OlmMachine) storeDeviceSelfSignatures(ctx context.Context, userID id } } +// FetchKeys fetches the devices of a list of other users. If includeUntracked +// is set to false, then the users are filtered to to only include user IDs +// whose device lists have been stored with the PutDevices function on the +// [Store]. See the FilterTrackedUsers function on [Store] for details. func (mach *OlmMachine) FetchKeys(ctx context.Context, users []id.UserID, includeUntracked bool) (data map[id.UserID]map[id.DeviceID]*id.Device, err error) { req := &mautrix.ReqQueryKeys{ DeviceKeys: mautrix.DeviceKeysRequest{}, @@ -111,7 +204,7 @@ func (mach *OlmMachine) FetchKeys(ctx context.Context, users []id.UserID, includ for _, userID := range users { req.DeviceKeys[userID] = mautrix.DeviceIDList{} } - log.Debug().Strs("users", strishArray(users)).Msg("Querying keys for users") + log.Debug().Array("users", exzerolog.ArrayOfStrs(users)).Msg("Querying keys for users") resp, err := mach.Client.QueryKeys(ctx, req) if err != nil { return nil, fmt.Errorf("failed to query keys: %w", err) @@ -122,7 +215,7 @@ func (mach *OlmMachine) FetchKeys(ctx context.Context, users []id.UserID, includ log.Trace().Int("user_count", len(resp.DeviceKeys)).Msg("Query key result received") data = make(map[id.UserID]map[id.DeviceID]*id.Device) for userID, devices := range resp.DeviceKeys { - log := log.With().Str("user_id", userID.String()).Logger() + log := log.With().Stringer("user_id", userID).Logger() delete(req.DeviceKeys, userID) newDevices := make(map[id.DeviceID]*id.Device) @@ -138,7 +231,7 @@ func (mach *OlmMachine) FetchKeys(ctx context.Context, users []id.UserID, includ Msg("Updating devices in store") changed := false for deviceID, deviceKeys := range devices { - log := log.With().Str("device_id", deviceID.String()).Logger() + log := log.With().Stringer("device_id", deviceID).Logger() existing, ok := existingDevices[deviceID] if !ok { // New device @@ -177,7 +270,7 @@ func (mach *OlmMachine) FetchKeys(ctx context.Context, users []id.UserID, includ log.Err(err).Msg("Failed to redact megolm sessions from deleted device") } else { log.Info(). - Strs("session_ids", stringifyArray(sessionIDs)). + Array("session_ids", exzerolog.ArrayOfStrs(sessionIDs)). Msg("Redacted megolm sessions from deleted device") } } @@ -186,7 +279,7 @@ func (mach *OlmMachine) FetchKeys(ctx context.Context, users []id.UserID, includ } } for userID := range req.DeviceKeys { - log.Warn().Str("user_id", userID.String()).Msg("Didn't get any keys for user") + log.Warn().Stringer("user_id", userID).Msg("Didn't get any keys for user") } mach.storeCrossSigningKeys(ctx, resp.MasterKeys, resp.DeviceKeys) @@ -228,28 +321,28 @@ func (mach *OlmMachine) OnDevicesChanged(ctx context.Context, userID id.UserID) func (mach *OlmMachine) validateDevice(userID id.UserID, deviceID id.DeviceID, deviceKeys mautrix.DeviceKeys, existing *id.Device) (*id.Device, error) { if deviceID != deviceKeys.DeviceID { - return nil, fmt.Errorf("%w (expected %s, got %s)", MismatchingDeviceID, deviceID, deviceKeys.DeviceID) + return nil, fmt.Errorf("%w (expected %s, got %s)", ErrMismatchingDeviceID, deviceID, deviceKeys.DeviceID) } else if userID != deviceKeys.UserID { - return nil, fmt.Errorf("%w (expected %s, got %s)", MismatchingUserID, userID, deviceKeys.UserID) + return nil, fmt.Errorf("%w (expected %s, got %s)", ErrMismatchingUserID, userID, deviceKeys.UserID) } signingKey := deviceKeys.Keys.GetEd25519(deviceID) identityKey := deviceKeys.Keys.GetCurve25519(deviceID) if signingKey == "" { - return nil, NoSigningKeyFound + return nil, ErrNoSigningKeyFound } else if identityKey == "" { - return nil, NoIdentityKeyFound + return nil, ErrNoIdentityKeyFound } if existing != nil && existing.SigningKey != signingKey { - return existing, fmt.Errorf("%w (expected %s, got %s)", MismatchingSigningKey, existing.SigningKey, signingKey) + return existing, fmt.Errorf("%w (expected %s, got %s)", ErrMismatchingSigningKey, existing.SigningKey, signingKey) } ok, err := signatures.VerifySignatureJSON(deviceKeys, userID, deviceID.String(), signingKey) if err != nil { return existing, fmt.Errorf("failed to verify signature: %w", err) } else if !ok { - return existing, InvalidKeySignature + return existing, ErrInvalidKeySignature } name, ok := deviceKeys.Unsigned["device_display_name"].(string) diff --git a/crypto/ed25519/ed25519.go b/crypto/ed25519/ed25519.go new file mode 100644 index 00000000..327cbb3c --- /dev/null +++ b/crypto/ed25519/ed25519.go @@ -0,0 +1,302 @@ +// Copyright 2024 Sumner Evans. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package ed25519 implements the Ed25519 signature algorithm. See +// https://ed25519.cr.yp.to/. +// +// This package stores the private key in the NaCl format, which is a different +// format than that used by the [crypto/ed25519] package in the standard +// library. +// +// This picture will help with the rest of the explanation: +// https://blog.mozilla.org/warner/files/2011/11/key-formats.png +// +// The private key in the [crypto/ed25519] package is a 64-byte value where the +// first 32-bytes are the seed and the last 32-bytes are the public key. +// +// The private key in this package is stored in the NaCl format. That is, the +// left 32-bytes are the private scalar A and the right 32-bytes are the right +// half of the SHA512 result. +// +// The contents of this package are mostly copied from the standard library, +// and as such the source code is licensed under the BSD license of the +// standard library implementation. +// +// Other notable changes from the standard library include: +// +// - The Seed function of the standard library is not implemented in this +// package because there is no way to recover the seed after hashing it. +package ed25519 + +import ( + "crypto" + "crypto/ed25519" + cryptorand "crypto/rand" + "crypto/sha512" + "crypto/subtle" + "errors" + "io" + "strconv" + + "filippo.io/edwards25519" +) + +const ( + // PublicKeySize is the size, in bytes, of public keys as used in this package. + PublicKeySize = 32 + // PrivateKeySize is the size, in bytes, of private keys as used in this package. + PrivateKeySize = 64 + // SignatureSize is the size, in bytes, of signatures generated and verified by this package. + SignatureSize = 64 + // SeedSize is the size, in bytes, of private key seeds. These are the private key representations used by RFC 8032. + SeedSize = 32 +) + +// PublicKey is the type of Ed25519 public keys. +type PublicKey []byte + +// Any methods implemented on PublicKey might need to also be implemented on +// PrivateKey, as the latter embeds the former and will expose its methods. + +// Equal reports whether pub and x have the same value. +func (pub PublicKey) Equal(x crypto.PublicKey) bool { + switch x := x.(type) { + case PublicKey: + return subtle.ConstantTimeCompare(pub, x) == 1 + case ed25519.PublicKey: + return subtle.ConstantTimeCompare(pub, x) == 1 + default: + return false + } +} + +// PrivateKey is the type of Ed25519 private keys. It implements [crypto.Signer]. +type PrivateKey []byte + +// Public returns the [PublicKey] corresponding to priv. +// +// This method differs from the standard library because it calculates the +// public key instead of returning the right half of the private key (which +// contains the public key in the standard library). +func (priv PrivateKey) Public() crypto.PublicKey { + s, err := edwards25519.NewScalar().SetBytesWithClamping(priv[:32]) + if err != nil { + panic("ed25519: internal error: setting scalar failed") + } + return (&edwards25519.Point{}).ScalarBaseMult(s).Bytes() +} + +// Equal reports whether priv and x have the same value. +func (priv PrivateKey) Equal(x crypto.PrivateKey) bool { + // TODO do we have any need to check equality with standard library ed25519 + // private keys? + xx, ok := x.(PrivateKey) + if !ok { + return false + } + return subtle.ConstantTimeCompare(priv, xx) == 1 +} + +// Sign signs the given message with priv. rand is ignored and can be nil. +// +// If opts.HashFunc() is [crypto.SHA512], the pre-hashed variant Ed25519ph is used +// and message is expected to be a SHA-512 hash, otherwise opts.HashFunc() must +// be [crypto.Hash](0) and the message must not be hashed, as Ed25519 performs two +// passes over messages to be signed. +// +// A value of type [Options] can be used as opts, or crypto.Hash(0) or +// crypto.SHA512 directly to select plain Ed25519 or Ed25519ph, respectively. +func (priv PrivateKey) Sign(rand io.Reader, message []byte, opts crypto.SignerOpts) (signature []byte, err error) { + hash := opts.HashFunc() + context := "" + if opts, ok := opts.(*Options); ok { + context = opts.Context + } + switch { + case hash == crypto.SHA512: // Ed25519ph + if l := len(message); l != sha512.Size { + return nil, errors.New("ed25519: bad Ed25519ph message hash length: " + strconv.Itoa(l)) + } + if l := len(context); l > 255 { + return nil, errors.New("ed25519: bad Ed25519ph context length: " + strconv.Itoa(l)) + } + signature := make([]byte, SignatureSize) + sign(signature, priv, message, domPrefixPh, context) + return signature, nil + case hash == crypto.Hash(0) && context != "": // Ed25519ctx + if l := len(context); l > 255 { + return nil, errors.New("ed25519: bad Ed25519ctx context length: " + strconv.Itoa(l)) + } + signature := make([]byte, SignatureSize) + sign(signature, priv, message, domPrefixCtx, context) + return signature, nil + case hash == crypto.Hash(0): // Ed25519 + return Sign(priv, message), nil + default: + return nil, errors.New("ed25519: expected opts.HashFunc() zero (unhashed message, for standard Ed25519) or SHA-512 (for Ed25519ph)") + } +} + +// Options can be used with [PrivateKey.Sign] or [VerifyWithOptions] +// to select Ed25519 variants. +type Options struct { + // Hash can be zero for regular Ed25519, or crypto.SHA512 for Ed25519ph. + Hash crypto.Hash + + // Context, if not empty, selects Ed25519ctx or provides the context string + // for Ed25519ph. It can be at most 255 bytes in length. + Context string +} + +// HashFunc returns o.Hash. +func (o *Options) HashFunc() crypto.Hash { return o.Hash } + +// GenerateKey generates a public/private key pair using entropy from rand. +// If rand is nil, [crypto/rand.Reader] will be used. +// +// The output of this function is deterministic, and equivalent to reading +// [SeedSize] bytes from rand, and passing them to [NewKeyFromSeed]. +func GenerateKey(rand io.Reader) (PublicKey, PrivateKey, error) { + if rand == nil { + rand = cryptorand.Reader + } + + seed := make([]byte, SeedSize) + if _, err := io.ReadFull(rand, seed); err != nil { + return nil, nil, err + } + + privateKey := NewKeyFromSeed(seed) + return PublicKey(privateKey.Public().([]byte)), privateKey, nil +} + +// NewKeyFromSeed calculates a private key from a seed. It will panic if +// len(seed) is not [SeedSize]. This function is provided for interoperability +// with RFC 8032. RFC 8032's private keys correspond to seeds in this +// package. +func NewKeyFromSeed(seed []byte) PrivateKey { + // Outline the function body so that the returned key can be stack-allocated. + privateKey := make([]byte, PrivateKeySize) + newKeyFromSeed(privateKey, seed) + return privateKey +} + +func newKeyFromSeed(privateKey, seed []byte) { + if l := len(seed); l != SeedSize { + panic("ed25519: bad seed length: " + strconv.Itoa(l)) + } + + h := sha512.Sum512(seed) + + // Apply clamping to get A in the left half, and leave the right half + // as-is. This gets the private key into the NaCl format. + h[0] &= 248 + h[31] &= 63 + h[31] |= 64 + copy(privateKey, h[:]) +} + +// Sign signs the message with privateKey and returns a signature. It will +// panic if len(privateKey) is not [PrivateKeySize]. +func Sign(privateKey PrivateKey, message []byte) []byte { + // Outline the function body so that the returned signature can be + // stack-allocated. + signature := make([]byte, SignatureSize) + sign(signature, privateKey, message, domPrefixPure, "") + return signature +} + +// Domain separation prefixes used to disambiguate Ed25519/Ed25519ph/Ed25519ctx. +// See RFC 8032, Section 2 and Section 5.1. +const ( + // domPrefixPure is empty for pure Ed25519. + domPrefixPure = "" + // domPrefixPh is dom2(phflag=1) for Ed25519ph. It must be followed by the + // uint8-length prefixed context. + domPrefixPh = "SigEd25519 no Ed25519 collisions\x01" + // domPrefixCtx is dom2(phflag=0) for Ed25519ctx. It must be followed by the + // uint8-length prefixed context. + domPrefixCtx = "SigEd25519 no Ed25519 collisions\x00" +) + +func sign(signature []byte, privateKey PrivateKey, message []byte, domPrefix, context string) { + if l := len(privateKey); l != PrivateKeySize { + panic("ed25519: bad private key length: " + strconv.Itoa(l)) + } + // We have to extract the public key from the private key. + publicKey := privateKey.Public().([]byte) + // The private key is already the hashed value of the seed. + h := privateKey + + s, err := edwards25519.NewScalar().SetBytesWithClamping(h[:32]) + if err != nil { + panic("ed25519: internal error: setting scalar failed") + } + prefix := h[32:] + + mh := sha512.New() + if domPrefix != domPrefixPure { + mh.Write([]byte(domPrefix)) + mh.Write([]byte{byte(len(context))}) + mh.Write([]byte(context)) + } + mh.Write(prefix) + mh.Write(message) + messageDigest := make([]byte, 0, sha512.Size) + messageDigest = mh.Sum(messageDigest) + r, err := edwards25519.NewScalar().SetUniformBytes(messageDigest) + if err != nil { + panic("ed25519: internal error: setting scalar failed") + } + + R := (&edwards25519.Point{}).ScalarBaseMult(r) + + kh := sha512.New() + if domPrefix != domPrefixPure { + kh.Write([]byte(domPrefix)) + kh.Write([]byte{byte(len(context))}) + kh.Write([]byte(context)) + } + kh.Write(R.Bytes()) + kh.Write(publicKey) + kh.Write(message) + hramDigest := make([]byte, 0, sha512.Size) + hramDigest = kh.Sum(hramDigest) + k, err := edwards25519.NewScalar().SetUniformBytes(hramDigest) + if err != nil { + panic("ed25519: internal error: setting scalar failed") + } + + S := edwards25519.NewScalar().MultiplyAdd(k, s, r) + + copy(signature[:32], R.Bytes()) + copy(signature[32:], S.Bytes()) +} + +// Verify reports whether sig is a valid signature of message by publicKey. It +// will panic if len(publicKey) is not [PublicKeySize]. +// +// This is just a wrapper around [ed25519.Verify] from the standard library. +func Verify(publicKey PublicKey, message, sig []byte) bool { + return ed25519.Verify(ed25519.PublicKey(publicKey), message, sig) +} + +// VerifyWithOptions reports whether sig is a valid signature of message by +// publicKey. A valid signature is indicated by returning a nil error. It will +// panic if len(publicKey) is not [PublicKeySize]. +// +// If opts.Hash is [crypto.SHA512], the pre-hashed variant Ed25519ph is used and +// message is expected to be a SHA-512 hash, otherwise opts.Hash must be +// [crypto.Hash](0) and the message must not be hashed, as Ed25519 performs two +// passes over messages to be signed. +// +// This is just a wrapper around [ed25519.VerifyWithOptions] from the standard +// library. +func VerifyWithOptions(publicKey PublicKey, message, sig []byte, opts *Options) error { + return ed25519.VerifyWithOptions(ed25519.PublicKey(publicKey), message, sig, &ed25519.Options{ + Hash: opts.Hash, + Context: opts.Context, + }) +} diff --git a/crypto/ed25519/ed25519_test.go b/crypto/ed25519/ed25519_test.go new file mode 100644 index 00000000..931c06f6 --- /dev/null +++ b/crypto/ed25519/ed25519_test.go @@ -0,0 +1,20 @@ +package ed25519_test + +import ( + stdlibed25519 "crypto/ed25519" + "testing" + + "github.com/stretchr/testify/assert" + "go.mau.fi/util/random" + + "maunium.net/go/mautrix/crypto/ed25519" +) + +func TestPubkeyEqual(t *testing.T) { + pubkeyBytes := random.Bytes(32) + pubkey := ed25519.PublicKey(pubkeyBytes) + pubkey2 := ed25519.PublicKey(pubkeyBytes) + stdlibPubkey := stdlibed25519.PublicKey(pubkeyBytes) + assert.True(t, pubkey.Equal(pubkey2)) + assert.True(t, pubkey.Equal(stdlibPubkey)) +} diff --git a/crypto/encryptmegolm.go b/crypto/encryptmegolm.go index d592bd1c..88f9c8d4 100644 --- a/crypto/encryptmegolm.go +++ b/crypto/encryptmegolm.go @@ -15,6 +15,9 @@ import ( "fmt" "github.com/rs/zerolog" + "github.com/tidwall/gjson" + "go.mau.fi/util/exgjson" + "go.mau.fi/util/exzerolog" "maunium.net/go/mautrix" "maunium.net/go/mautrix/event" @@ -22,11 +25,32 @@ import ( ) var ( - AlreadyShared = errors.New("group session already shared") - NoGroupSession = errors.New("no group session created") + ErrNoGroupSession = errors.New("no group session created") ) -func getRelatesTo(content interface{}) *event.RelatesTo { +// Deprecated: use variables prefixed with Err +var ( + NoGroupSession = ErrNoGroupSession +) + +func getRawJSON[T any](content json.RawMessage, path ...string) *T { + value := gjson.GetBytes(content, exgjson.Path(path...)) + if !value.IsObject() { + return nil + } + var result T + err := json.Unmarshal([]byte(value.Raw), &result) + if err != nil { + return nil + } + return &result +} + +func getRelatesTo(content any, plaintext json.RawMessage) *event.RelatesTo { + contentJSON, ok := content.(json.RawMessage) + if ok { + return getRawJSON[event.RelatesTo](contentJSON, "m.relates_to") + } contentStruct, ok := content.(*event.Content) if ok { content = contentStruct.Parsed @@ -35,10 +59,14 @@ func getRelatesTo(content interface{}) *event.RelatesTo { if ok { return relatable.OptionalGetRelatesTo() } - return nil + return getRawJSON[event.RelatesTo](plaintext, "content", "m.relates_to") } -func getMentions(content interface{}) *event.Mentions { +func getMentions(content any) *event.Mentions { + contentJSON, ok := content.(json.RawMessage) + if ok { + return getRawJSON[event.Mentions](contentJSON, "m.mentions") + } contentStruct, ok := content.(*event.Content) if ok { content = contentStruct.Parsed @@ -51,22 +79,28 @@ func getMentions(content interface{}) *event.Mentions { } type rawMegolmEvent struct { - RoomID id.RoomID `json:"room_id"` - Type event.Type `json:"type"` - Content interface{} `json:"content"` + RoomID id.RoomID `json:"room_id"` + Type event.Type `json:"type"` + StateKey *string `json:"state_key,omitempty"` + Content interface{} `json:"content"` } // IsShareError returns true if the error is caused by the lack of an outgoing megolm session and can be solved with OlmMachine.ShareGroupSession func IsShareError(err error) bool { - return err == SessionExpired || err == SessionNotShared || err == NoGroupSession + return err == ErrSessionExpired || err == ErrSessionNotShared || err == ErrNoGroupSession } -func parseMessageIndex(ciphertext []byte) (uint, error) { +func ParseMegolmMessageIndex(ciphertext []byte) (uint, error) { + if len(ciphertext) == 0 { + return 0, fmt.Errorf("empty ciphertext") + } decoded := make([]byte, base64.RawStdEncoding.DecodedLen(len(ciphertext))) var err error _, err = base64.RawStdEncoding.Decode(decoded, ciphertext) if err != nil { return 0, err + } else if len(decoded) < 2+binary.MaxVarintLen64 { + return 0, fmt.Errorf("decoded ciphertext too short: %d bytes", len(decoded)) } else if decoded[0] != 3 || decoded[1] != 8 { return 0, fmt.Errorf("unexpected initial bytes %d and %d", decoded[0], decoded[1]) } @@ -82,24 +116,34 @@ func parseMessageIndex(ciphertext []byte) (uint, error) { // If you use the event.Content struct, make sure you pass a pointer to the struct, // as JSON serialization will not work correctly otherwise. func (mach *OlmMachine) EncryptMegolmEvent(ctx context.Context, roomID id.RoomID, evtType event.Type, content interface{}) (*event.EncryptedEventContent, error) { + return mach.EncryptMegolmEventWithStateKey(ctx, roomID, evtType, nil, content) +} + +// EncryptMegolmEventWithStateKey encrypts data with the m.megolm.v1.aes-sha2 algorithm. +// +// If you use the event.Content struct, make sure you pass a pointer to the struct, +// as JSON serialization will not work correctly otherwise. +func (mach *OlmMachine) EncryptMegolmEventWithStateKey(ctx context.Context, roomID id.RoomID, evtType event.Type, stateKey *string, content interface{}) (*event.EncryptedEventContent, error) { mach.megolmEncryptLock.Lock() defer mach.megolmEncryptLock.Unlock() session, err := mach.CryptoStore.GetOutboundGroupSession(ctx, roomID) if err != nil { return nil, fmt.Errorf("failed to get outbound group session: %w", err) } else if session == nil { - return nil, NoGroupSession + return nil, ErrNoGroupSession } plaintext, err := json.Marshal(&rawMegolmEvent{ - RoomID: roomID, - Type: evtType, - Content: content, + RoomID: roomID, + Type: evtType, + StateKey: stateKey, + Content: content, }) if err != nil { return nil, err } log := mach.machOrContextLog(ctx).With(). Str("event_type", evtType.Type). + Any("state_key", stateKey). Str("room_id", roomID.String()). Str("session_id", session.ID().String()). Uint("expected_index", session.Internal.MessageIndex()). @@ -109,7 +153,7 @@ func (mach *OlmMachine) EncryptMegolmEvent(ctx context.Context, roomID id.RoomID if err != nil { return nil, err } - idx, err := parseMessageIndex(ciphertext) + idx, err := ParseMegolmMessageIndex(ciphertext) if err != nil { log.Warn().Err(err).Msg("Failed to get megolm message index of encrypted event") } else { @@ -124,31 +168,47 @@ func (mach *OlmMachine) EncryptMegolmEvent(ctx context.Context, roomID id.RoomID Algorithm: id.AlgorithmMegolmV1, SessionID: session.ID(), MegolmCiphertext: ciphertext, - RelatesTo: getRelatesTo(content), + RelatesTo: getRelatesTo(content, plaintext), // These are deprecated SenderKey: mach.account.IdentityKey(), DeviceID: mach.Client.DeviceID, } + if mach.MSC4392Relations && encrypted.RelatesTo != nil { + // When MSC4392 mode is enabled, reply and reaction metadata is stripped from the unencrypted content. + // Other relations like threads are still left unencrypted. + encrypted.RelatesTo.InReplyTo = nil + encrypted.RelatesTo.IsFallingBack = false + if evtType == event.EventReaction || encrypted.RelatesTo.Type == "" { + encrypted.RelatesTo = nil + } + } if mach.PlaintextMentions { encrypted.Mentions = getMentions(content) } return encrypted, nil } -func (mach *OlmMachine) newOutboundGroupSession(ctx context.Context, roomID id.RoomID) *OutboundGroupSession { +func (mach *OlmMachine) newOutboundGroupSession(ctx context.Context, roomID id.RoomID) (*OutboundGroupSession, error) { encryptionEvent, err := mach.StateStore.GetEncryptionEvent(ctx, roomID) if err != nil { mach.machOrContextLog(ctx).Err(err). Stringer("room_id", roomID). Msg("Failed to get encryption event in room") + return nil, fmt.Errorf("failed to get encryption event in room %s: %w", roomID, err) + } + session, err := NewOutboundGroupSession(roomID, encryptionEvent) + if err != nil { + return nil, err } - session := NewOutboundGroupSession(roomID, encryptionEvent) if !mach.DontStoreOutboundKeys { signingKey, idKey := mach.account.Keys() - mach.createGroupSession(ctx, idKey, signingKey, roomID, session.ID(), session.Internal.Key(), session.MaxAge, session.MaxMessages, false) + err := mach.createGroupSession(ctx, idKey, signingKey, roomID, session.ID(), session.Internal.Key(), session.MaxAge, session.MaxMessages, false) + if err != nil { + return nil, err + } } - return session + return session, err } type deviceSessionWrapper struct { @@ -156,14 +216,6 @@ type deviceSessionWrapper struct { identity *id.Device } -func strishArray[T ~string](arr []T) []string { - out := make([]string, len(arr)) - for i, item := range arr { - out[i] = string(item) - } - return out -} - // ShareGroupSession shares a group session for a specific room with all the devices of the given user list. // // For devices with TrustStateBlacklisted, a m.room_key.withheld event with code=m.blacklisted is sent. @@ -175,7 +227,8 @@ func (mach *OlmMachine) ShareGroupSession(ctx context.Context, roomID id.RoomID, if err != nil { return fmt.Errorf("failed to get previous outbound group session: %w", err) } else if session != nil && session.Shared && !session.Expired() { - return AlreadyShared + mach.machOrContextLog(ctx).Debug().Stringer("room_id", roomID).Msg("Not re-sharing group session, already shared") + return nil } log := mach.machOrContextLog(ctx).With(). Str("room_id", roomID.String()). @@ -183,27 +236,30 @@ func (mach *OlmMachine) ShareGroupSession(ctx context.Context, roomID id.RoomID, Logger() ctx = log.WithContext(ctx) if session == nil || session.Expired() { - session = mach.newOutboundGroupSession(ctx, roomID) + if session, err = mach.newOutboundGroupSession(ctx, roomID); err != nil { + return err + } } log = log.With().Str("session_id", session.ID().String()).Logger() ctx = log.WithContext(ctx) - log.Debug().Strs("users", strishArray(users)).Msg("Sharing group session for room") + log.Debug().Array("users", exzerolog.ArrayOfStrs(users)).Msg("Sharing group session for room") withheldCount := 0 toDeviceWithheld := &mautrix.ReqSendToDevice{Messages: make(map[id.UserID]map[id.DeviceID]*event.Content)} olmSessions := make(map[id.UserID]map[id.DeviceID]deviceSessionWrapper) missingSessions := make(map[id.UserID]map[id.DeviceID]*id.Device) missingUserSessions := make(map[id.DeviceID]*id.Device) - var fetchKeys []id.UserID + var fetchKeysForUsers []id.UserID for _, userID := range users { - log := log.With().Str("target_user_id", userID.String()).Logger() + log := log.With().Stringer("target_user_id", userID).Logger() devices, err := mach.CryptoStore.GetDevices(ctx, userID) if err != nil { - log.Error().Err(err).Msg("Failed to get devices of user") + log.Err(err).Msg("Failed to get devices of user") + return fmt.Errorf("failed to get devices of user %s: %w", userID, err) } else if devices == nil { log.Debug().Msg("GetDevices returned nil, will fetch keys and retry") - fetchKeys = append(fetchKeys, userID) + fetchKeysForUsers = append(fetchKeysForUsers, userID) } else if len(devices) == 0 { log.Trace().Msg("User has no devices, skipping") } else { @@ -227,18 +283,19 @@ func (mach *OlmMachine) ShareGroupSession(ctx context.Context, roomID id.RoomID, } } - if len(fetchKeys) > 0 { - log.Debug().Strs("users", strishArray(fetchKeys)).Msg("Fetching missing keys") - if keys, err := mach.FetchKeys(ctx, fetchKeys, true); err != nil { - log.Err(err).Strs("users", strishArray(fetchKeys)).Msg("Failed to fetch missing keys") - } else if keys != nil { - for userID, devices := range keys { - log.Debug(). - Int("device_count", len(devices)). - Str("target_user_id", userID.String()). - Msg("Got device keys for user") - missingSessions[userID] = devices - } + if len(fetchKeysForUsers) > 0 { + log.Debug().Array("users", exzerolog.ArrayOfStrs(fetchKeysForUsers)).Msg("Fetching missing keys") + keys, err := mach.FetchKeys(ctx, fetchKeysForUsers, true) + if err != nil { + log.Err(err).Array("users", exzerolog.ArrayOfStrs(fetchKeysForUsers)).Msg("Failed to fetch missing keys") + return fmt.Errorf("failed to fetch missing keys: %w", err) + } + for userID, devices := range keys { + log.Debug(). + Int("device_count", len(devices)). + Str("target_user_id", userID.String()). + Msg("Got device keys for user") + missingSessions[userID] = devices } } @@ -246,7 +303,8 @@ func (mach *OlmMachine) ShareGroupSession(ctx context.Context, roomID id.RoomID, log.Debug().Msg("Creating missing olm sessions") err = mach.createOutboundSessions(ctx, missingSessions) if err != nil { - log.Error().Err(err).Msg("Failed to create missing olm sessions") + log.Err(err).Msg("Failed to create missing olm sessions") + return fmt.Errorf("failed to create missing olm sessions: %w", err) } } @@ -266,7 +324,7 @@ func (mach *OlmMachine) ShareGroupSession(ctx context.Context, roomID id.RoomID, toDeviceWithheld.Messages[userID] = withheld } - log := log.With().Str("target_user_id", userID.String()).Logger() + log := log.With().Stringer("target_user_id", userID).Logger() log.Trace().Msg("Trying to find olm session to encrypt megolm session for user (post-fetch retry)") mach.findOlmSessionsForUser(ctx, session, userID, devices, output, withheld, nil) log.Debug(). @@ -312,41 +370,39 @@ func (mach *OlmMachine) encryptAndSendGroupSession(ctx context.Context, session log.Trace().Msg("Encrypting group session for all found devices") deviceCount := 0 toDevice := &mautrix.ReqSendToDevice{Messages: make(map[id.UserID]map[id.DeviceID]*event.Content)} + logUsers := zerolog.Dict() for userID, sessions := range olmSessions { if len(sessions) == 0 { continue } + logDevices := zerolog.Dict() output := make(map[id.DeviceID]*event.Content) toDevice.Messages[userID] = output for deviceID, device := range sessions { - log.Trace(). - Str("target_user_id", userID.String()). - Str("target_device_id", deviceID.String()). - Msg("Encrypting group session for device") content := mach.encryptOlmEvent(ctx, device.session, device.identity, event.ToDeviceRoomKey, session.ShareContent()) output[deviceID] = &event.Content{Parsed: content} + logDevices.Str(string(deviceID), string(device.identity.IdentityKey)) deviceCount++ - log.Debug(). - Str("target_user_id", userID.String()). - Str("target_device_id", deviceID.String()). - Msg("Encrypted group session for device") if !mach.DisableSharedGroupSessionTracking { err := mach.CryptoStore.MarkOutboundGroupSessionShared(ctx, userID, device.identity.IdentityKey, session.id) if err != nil { log.Warn(). Err(err). - Str("target_user_id", userID.String()). - Str("target_device_id", deviceID.String()). + Stringer("target_user_id", userID). + Stringer("target_device_id", deviceID). + Stringer("target_identity_key", device.identity.IdentityKey). Stringer("target_session_id", session.id). Msg("Failed to mark outbound group session shared") } } } + logUsers.Dict(string(userID), logDevices) } log.Debug(). Int("device_count", deviceCount). Int("user_count", len(toDevice.Messages)). + Dict("destination_map", logUsers). Msg("Sending to-device messages to share group session") _, err := mach.Client.SendToDevice(ctx, event.ToDeviceEncrypted, toDevice) return err @@ -355,8 +411,9 @@ func (mach *OlmMachine) encryptAndSendGroupSession(ctx context.Context, session func (mach *OlmMachine) findOlmSessionsForUser(ctx context.Context, session *OutboundGroupSession, userID id.UserID, devices map[id.DeviceID]*id.Device, output map[id.DeviceID]deviceSessionWrapper, withheld map[id.DeviceID]*event.Content, missingOutput map[id.DeviceID]*id.Device) { for deviceID, device := range devices { log := zerolog.Ctx(ctx).With(). - Str("target_user_id", userID.String()). - Str("target_device_id", deviceID.String()). + Stringer("target_user_id", userID). + Stringer("target_device_id", deviceID). + Stringer("target_identity_key", device.IdentityKey). Logger() userKey := UserDevice{UserID: userID, DeviceID: deviceID} if state := session.Users[userKey]; state != OGSNotShared { @@ -374,7 +431,7 @@ func (mach *OlmMachine) findOlmSessionsForUser(ctx context.Context, session *Out Reason: "Device is blacklisted", }} session.Users[userKey] = OGSIgnored - } else if trustState := mach.ResolveTrust(device); trustState < mach.SendKeysMinTrust { + } else if trustState, _ := mach.ResolveTrustContext(ctx, device); trustState < mach.SendKeysMinTrust { log.Debug(). Str("min_trust", mach.SendKeysMinTrust.String()). Str("device_trust", trustState.String()). diff --git a/crypto/encryptolm.go b/crypto/encryptolm.go index 15e9df29..765307af 100644 --- a/crypto/encryptolm.go +++ b/crypto/encryptolm.go @@ -17,6 +17,70 @@ import ( "maunium.net/go/mautrix/id" ) +func (mach *OlmMachine) EncryptToDevices(ctx context.Context, eventType event.Type, req *mautrix.ReqSendToDevice) (*mautrix.ReqSendToDevice, error) { + devicesToCreateSessions := make(map[id.UserID]map[id.DeviceID]*id.Device) + for userID, devices := range req.Messages { + for deviceID := range devices { + device, err := mach.GetOrFetchDevice(ctx, userID, deviceID) + if err != nil { + return nil, fmt.Errorf("failed to get device %s of user %s: %w", deviceID, userID, err) + } + + if _, ok := devicesToCreateSessions[userID]; !ok { + devicesToCreateSessions[userID] = make(map[id.DeviceID]*id.Device) + } + devicesToCreateSessions[userID][deviceID] = device + } + } + if err := mach.createOutboundSessions(ctx, devicesToCreateSessions); err != nil { + return nil, fmt.Errorf("failed to create outbound sessions: %w", err) + } + + mach.olmLock.Lock() + defer mach.olmLock.Unlock() + + encryptedReq := &mautrix.ReqSendToDevice{ + Messages: make(map[id.UserID]map[id.DeviceID]*event.Content), + } + + log := mach.machOrContextLog(ctx) + + for userID, devices := range req.Messages { + encryptedReq.Messages[userID] = make(map[id.DeviceID]*event.Content) + + for deviceID, content := range devices { + device := devicesToCreateSessions[userID][deviceID] + + olmSess, err := mach.CryptoStore.GetLatestSession(ctx, device.IdentityKey) + if err != nil { + return nil, fmt.Errorf("failed to get latest session for device %s of %s: %w", deviceID, userID, err) + } else if olmSess == nil { + log.Warn(). + Str("target_user_id", userID.String()). + Str("target_device_id", deviceID.String()). + Str("identity_key", device.IdentityKey.String()). + Msg("No outbound session found for device") + continue + } + + encrypted := mach.encryptOlmEvent(ctx, olmSess, device, eventType, *content) + encryptedContent := &event.Content{Parsed: &encrypted} + + log.Debug(). + Str("decrypted_type", eventType.Type). + Str("target_user_id", userID.String()). + Str("target_device_id", deviceID.String()). + Str("target_identity_key", device.IdentityKey.String()). + Str("olm_session_id", olmSess.ID().String()). + Msg("Encrypted to-device event") + + encryptedReq.Messages[userID][deviceID] = encryptedContent + } + } + + return encryptedReq, nil +} + func (mach *OlmMachine) encryptOlmEvent(ctx context.Context, session *OlmSession, recipient *id.Device, evtType event.Type, content event.Content) *event.EncryptedEventContent { evt := &DecryptedOlmEvent{ Sender: mach.Client.UserID, @@ -32,12 +96,19 @@ func (mach *OlmMachine) encryptOlmEvent(ctx context.Context, session *OlmSession panic(err) } log := mach.machOrContextLog(ctx) + msgType, ciphertext, err := session.Encrypt(plaintext) + if err != nil { + panic(err) + } + ciphertextStr := string(ciphertext) + ciphertextHash, _ := olmMessageHash(ciphertextStr) log.Debug(). + Stringer("event_type", evtType). Str("recipient_identity_key", recipient.IdentityKey.String()). Str("olm_session_id", session.ID().String()). Str("olm_session_description", session.Describe()). - Msg("Encrypting olm message") - msgType, ciphertext := session.Encrypt(plaintext) + Hex("ciphertext_hash", ciphertextHash[:]). + Msg("Encrypted olm message") err = mach.CryptoStore.UpdateSession(ctx, recipient.IdentityKey, session) if err != nil { log.Error().Err(err).Msg("Failed to update olm session in crypto store after encrypting") @@ -48,7 +119,7 @@ func (mach *OlmMachine) encryptOlmEvent(ctx context.Context, session *OlmSession OlmCiphertext: event.OlmCiphertexts{ recipient.IdentityKey: { Type: msgType, - Body: string(ciphertext), + Body: ciphertextStr, }, }, } diff --git a/crypto/goolm/account/account.go b/crypto/goolm/account/account.go index 4057543a..b48843a4 100644 --- a/crypto/goolm/account/account.go +++ b/crypto/goolm/account/account.go @@ -4,18 +4,14 @@ package account import ( "encoding/base64" "encoding/json" - "errors" "fmt" - "io" "maunium.net/go/mautrix/id" - "maunium.net/go/mautrix/crypto/goolm" - "maunium.net/go/mautrix/crypto/goolm/cipher" "maunium.net/go/mautrix/crypto/goolm/crypto" "maunium.net/go/mautrix/crypto/goolm/libolmpickle" "maunium.net/go/mautrix/crypto/goolm/session" - "maunium.net/go/mautrix/crypto/goolm/utilities" + "maunium.net/go/mautrix/crypto/olm" ) const ( @@ -39,41 +35,36 @@ type Account struct { NumFallbackKeys uint8 `json:"number_fallback_keys"` } +// Ensure that Account adheres to the olm.Account interface. +var _ olm.Account = (*Account)(nil) + // AccountFromJSONPickled loads the Account details from a pickled base64 string. The input is decrypted with the supplied key. func AccountFromJSONPickled(pickled, key []byte) (*Account, error) { if len(pickled) == 0 { - return nil, fmt.Errorf("accountFromPickled: %w", goolm.ErrEmptyInput) + return nil, fmt.Errorf("accountFromPickled: %w", olm.ErrEmptyInput) } a := &Account{} - err := a.UnpickleAsJSON(pickled, key) - if err != nil { - return nil, err - } - return a, nil + return a, a.UnpickleAsJSON(pickled, key) } // AccountFromPickled loads the Account details from a pickled base64 string. The input is decrypted with the supplied key. func AccountFromPickled(pickled, key []byte) (*Account, error) { if len(pickled) == 0 { - return nil, fmt.Errorf("accountFromPickled: %w", goolm.ErrEmptyInput) + return nil, fmt.Errorf("accountFromPickled: %w", olm.ErrEmptyInput) } a := &Account{} - err := a.Unpickle(pickled, key) - if err != nil { - return nil, err - } - return a, nil + return a, a.Unpickle(pickled, key) } -// NewAccount creates a new Account. If reader is nil, crypto/rand is used for the key creation. -func NewAccount(reader io.Reader) (*Account, error) { +// NewAccount creates a new Account. +func NewAccount() (*Account, error) { a := &Account{} - kPEd25519, err := crypto.Ed25519GenerateKey(reader) + kPEd25519, err := crypto.Ed25519GenerateKey() if err != nil { return nil, err } a.IdKeys.Ed25519 = kPEd25519 - kPCurve25519, err := crypto.Curve25519GenerateKey(reader) + kPCurve25519, err := crypto.Curve25519GenerateKey() if err != nil { return nil, err } @@ -82,72 +73,60 @@ func NewAccount(reader io.Reader) (*Account, error) { } // PickleAsJSON returns an Account as a base64 string encrypted using the supplied key. The unencrypted representation of the Account is in JSON format. -func (a Account) PickleAsJSON(key []byte) ([]byte, error) { - return utilities.PickleAsJSON(a, accountPickleVersionJSON, key) +func (a *Account) PickleAsJSON(key []byte) ([]byte, error) { + return libolmpickle.PickleAsJSON(a, accountPickleVersionJSON, key) } // UnpickleAsJSON updates an Account by a base64 encrypted string using the supplied key. The unencrypted representation has to be in JSON format. func (a *Account) UnpickleAsJSON(pickled, key []byte) error { - return utilities.UnpickleAsJSON(a, pickled, key, accountPickleVersionJSON) + return libolmpickle.UnpickleAsJSON(a, pickled, key, accountPickleVersionJSON) } // IdentityKeysJSON returns the public parts of the identity keys for the Account in a JSON string. -func (a Account) IdentityKeysJSON() ([]byte, error) { +func (a *Account) IdentityKeysJSON() ([]byte, error) { res := struct { Ed25519 string `json:"ed25519"` Curve25519 string `json:"curve25519"` }{} - ed25519, curve25519 := a.IdentityKeys() + ed25519, curve25519, err := a.IdentityKeys() + if err != nil { + return nil, err + } res.Ed25519 = string(ed25519) res.Curve25519 = string(curve25519) return json.Marshal(res) } // IdentityKeys returns the public parts of the Ed25519 and Curve25519 identity keys for the Account. -func (a Account) IdentityKeys() (id.Ed25519, id.Curve25519) { +func (a *Account) IdentityKeys() (id.Ed25519, id.Curve25519, error) { ed25519 := id.Ed25519(base64.RawStdEncoding.EncodeToString(a.IdKeys.Ed25519.PublicKey)) curve25519 := id.Curve25519(base64.RawStdEncoding.EncodeToString(a.IdKeys.Curve25519.PublicKey)) - return ed25519, curve25519 + return ed25519, curve25519, nil } // Sign returns the base64-encoded signature of a message using the Ed25519 key // for this Account. -func (a Account) Sign(message []byte) ([]byte, error) { +func (a *Account) Sign(message []byte) ([]byte, error) { if len(message) == 0 { - return nil, fmt.Errorf("sign: %w", goolm.ErrEmptyInput) + return nil, fmt.Errorf("sign: %w", olm.ErrEmptyInput) + } else if signature, err := a.IdKeys.Ed25519.Sign(message); err != nil { + return nil, err + } else { + return []byte(base64.RawStdEncoding.EncodeToString(signature)), nil } - return []byte(base64.RawStdEncoding.EncodeToString(a.IdKeys.Ed25519.Sign(message))), nil } // OneTimeKeys returns the public parts of the unpublished one time keys of the Account. // // The returned data is a map with the mapping of key id to base64-encoded Curve25519 key. -func (a Account) OneTimeKeys() map[string]id.Curve25519 { +func (a *Account) OneTimeKeys() (map[string]id.Curve25519, error) { oneTimeKeys := make(map[string]id.Curve25519) for _, curKey := range a.OTKeys { if !curKey.Published { - oneTimeKeys[curKey.KeyIDEncoded()] = id.Curve25519(curKey.PublicKeyEncoded()) + oneTimeKeys[curKey.KeyIDEncoded()] = curKey.Key.PublicKey.B64Encoded() } } - return oneTimeKeys -} - -//OneTimeKeysJSON returns the public parts of the unpublished one time keys of the Account as a JSON string. -// -//The returned JSON is of format: -/* - { - Curve25519: { - "AAAAAA": "wo76WcYtb0Vk/pBOdmduiGJ0wIEjW4IBMbbQn7aSnTo", - "AAAAAB": "LRvjo46L1X2vx69sS9QNFD29HWulxrmW11Up5AfAjgU" - } - } -*/ -func (a Account) OneTimeKeysJSON() ([]byte, error) { - res := make(map[string]map[string]id.Curve25519) - otKeys := a.OneTimeKeys() - res["Curve25519"] = otKeys - return json.Marshal(res) + return oneTimeKeys, nil } // MarkKeysAsPublished marks the current set of one time keys and the fallback key as being @@ -163,14 +142,14 @@ func (a *Account) MarkKeysAsPublished() { // GenOneTimeKeys generates a number of new one time keys. If the total number // of keys stored by this Account exceeds MaxOneTimeKeys then the older -// keys are discarded. If reader is nil, crypto/rand is used for the key creation. -func (a *Account) GenOneTimeKeys(reader io.Reader, num uint) error { +// keys are discarded. +func (a *Account) GenOneTimeKeys(num uint) error { for i := uint(0); i < num; i++ { key := crypto.OneTimeKey{ Published: false, ID: a.NextOneTimeKeyID, } - newKP, err := crypto.Curve25519GenerateKey(reader) + newKP, err := crypto.Curve25519GenerateKey() if err != nil { return err } @@ -186,9 +165,9 @@ func (a *Account) GenOneTimeKeys(reader io.Reader, num uint) error { // NewOutboundSession creates a new outbound session to a // given curve25519 identity Key and one time key. -func (a Account) NewOutboundSession(theirIdentityKey, theirOneTimeKey id.Curve25519) (*session.OlmSession, error) { +func (a *Account) NewOutboundSession(theirIdentityKey, theirOneTimeKey id.Curve25519) (olm.Session, error) { if len(theirIdentityKey) == 0 || len(theirOneTimeKey) == 0 { - return nil, fmt.Errorf("outbound session: %w", goolm.ErrEmptyInput) + return nil, fmt.Errorf("outbound session: %w", olm.ErrEmptyInput) } theirIdentityKeyDecoded, err := base64.RawStdEncoding.DecodeString(string(theirIdentityKey)) if err != nil { @@ -198,20 +177,21 @@ func (a Account) NewOutboundSession(theirIdentityKey, theirOneTimeKey id.Curve25 if err != nil { return nil, err } - s, err := session.NewOutboundOlmSession(a.IdKeys.Curve25519, theirIdentityKeyDecoded, theirOneTimeKeyDecoded) - if err != nil { - return nil, err - } - return s, nil + return session.NewOutboundOlmSession(a.IdKeys.Curve25519, theirIdentityKeyDecoded, theirOneTimeKeyDecoded) } -// NewInboundSession creates a new inbound session from an incoming PRE_KEY message. -func (a Account) NewInboundSession(theirIdentityKey *id.Curve25519, oneTimeKeyMsg []byte) (*session.OlmSession, error) { +// NewInboundSession creates a new in-bound session for sending/receiving +// messages from an incoming PRE_KEY message. Returns error on failure. +func (a *Account) NewInboundSession(oneTimeKeyMsg string) (olm.Session, error) { + return a.NewInboundSessionFrom(nil, oneTimeKeyMsg) +} + +// NewInboundSessionFrom creates a new inbound session from an incoming PRE_KEY message. +func (a *Account) NewInboundSessionFrom(theirIdentityKey *id.Curve25519, oneTimeKeyMsg string) (olm.Session, error) { if len(oneTimeKeyMsg) == 0 { - return nil, fmt.Errorf("inbound session: %w", goolm.ErrEmptyInput) + return nil, fmt.Errorf("inbound session: %w", olm.ErrEmptyInput) } var theirIdentityKeyDecoded *crypto.Curve25519PublicKey - var err error if theirIdentityKey != nil { theirIdentityKeyDecodedByte, err := base64.RawStdEncoding.DecodeString(string(*theirIdentityKey)) if err != nil { @@ -221,14 +201,10 @@ func (a Account) NewInboundSession(theirIdentityKey *id.Curve25519, oneTimeKeyMs theirIdentityKeyDecoded = &theirIdentityKeyCurve } - s, err := session.NewInboundOlmSession(theirIdentityKeyDecoded, oneTimeKeyMsg, a.searchOTKForOur, a.IdKeys.Curve25519) - if err != nil { - return nil, err - } - return s, nil + return session.NewInboundOlmSession(theirIdentityKeyDecoded, []byte(oneTimeKeyMsg), a.searchOTKForOur, a.IdKeys.Curve25519) } -func (a Account) searchOTKForOur(toFind crypto.Curve25519PublicKey) *crypto.OneTimeKey { +func (a *Account) searchOTKForOur(toFind crypto.Curve25519PublicKey) *crypto.OneTimeKey { for curIndex := range a.OTKeys { if a.OTKeys[curIndex].Key.PublicKey.Equal(toFind) { return &a.OTKeys[curIndex] @@ -244,27 +220,29 @@ func (a Account) searchOTKForOur(toFind crypto.Curve25519PublicKey) *crypto.OneT } // RemoveOneTimeKeys removes the one time key in this Account which matches the one time key in the session s. -func (a *Account) RemoveOneTimeKeys(s *session.OlmSession) { - toFind := s.BobOneTimeKey +func (a *Account) RemoveOneTimeKeys(s olm.Session) error { + toFind := s.(*session.OlmSession).BobOneTimeKey for curIndex := range a.OTKeys { if a.OTKeys[curIndex].Key.PublicKey.Equal(toFind) { //Remove and return a.OTKeys[curIndex] = a.OTKeys[len(a.OTKeys)-1] a.OTKeys = a.OTKeys[:len(a.OTKeys)-1] - return + return nil } } + return nil //if the key is a fallback or prevFallback, don't remove it } -// GenFallbackKey generates a new fallback key. The old fallback key is stored in a.PrevFallbackKey overwriting any previous PrevFallbackKey. If reader is nil, crypto/rand is used for the key creation. -func (a *Account) GenFallbackKey(reader io.Reader) error { +// GenFallbackKey generates a new fallback key. The old fallback key is stored +// in a.PrevFallbackKey overwriting any previous PrevFallbackKey. +func (a *Account) GenFallbackKey() error { a.PrevFallbackKey = a.CurrentFallbackKey key := crypto.OneTimeKey{ Published: false, ID: a.NextOneTimeKeyID, } - newKP, err := crypto.Curve25519GenerateKey(reader) + newKP, err := crypto.Curve25519GenerateKey() if err != nil { return err } @@ -279,10 +257,10 @@ func (a *Account) GenFallbackKey(reader io.Reader) error { // FallbackKey returns the public part of the current fallback key of the Account. // The returned data is a map with the mapping of key id to base64-encoded Curve25519 key. -func (a Account) FallbackKey() map[string]id.Curve25519 { +func (a *Account) FallbackKey() map[string]id.Curve25519 { keys := make(map[string]id.Curve25519) if a.NumFallbackKeys >= 1 { - keys[a.CurrentFallbackKey.KeyIDEncoded()] = id.Curve25519(a.CurrentFallbackKey.PublicKeyEncoded()) + keys[a.CurrentFallbackKey.KeyIDEncoded()] = a.CurrentFallbackKey.Key.PublicKey.B64Encoded() } return keys } @@ -297,7 +275,7 @@ func (a Account) FallbackKey() map[string]id.Curve25519 { } } */ -func (a Account) FallbackKeyJSON() ([]byte, error) { +func (a *Account) FallbackKeyJSON() ([]byte, error) { res := make(map[string]map[string]id.Curve25519) fbk := a.FallbackKey() res["curve25519"] = fbk @@ -306,10 +284,10 @@ func (a Account) FallbackKeyJSON() ([]byte, error) { // FallbackKeyUnpublished returns the public part of the current fallback key of the Account only if it is unpublished. // The returned data is a map with the mapping of key id to base64-encoded Curve25519 key. -func (a Account) FallbackKeyUnpublished() map[string]id.Curve25519 { +func (a *Account) FallbackKeyUnpublished() map[string]id.Curve25519 { keys := make(map[string]id.Curve25519) if a.NumFallbackKeys >= 1 && !a.CurrentFallbackKey.Published { - keys[a.CurrentFallbackKey.KeyIDEncoded()] = id.Curve25519(a.CurrentFallbackKey.PublicKeyEncoded()) + keys[a.CurrentFallbackKey.KeyIDEncoded()] = a.CurrentFallbackKey.Key.PublicKey.B64Encoded() } return keys } @@ -324,7 +302,7 @@ func (a Account) FallbackKeyUnpublished() map[string]id.Curve25519 { } } */ -func (a Account) FallbackKeyUnpublishedJSON() ([]byte, error) { +func (a *Account) FallbackKeyUnpublishedJSON() ([]byte, error) { res := make(map[string]map[string]id.Curve25519) fbk := a.FallbackKeyUnpublished() res["curve25519"] = fbk @@ -342,69 +320,50 @@ func (a *Account) ForgetOldFallbackKey() { // Unpickle decodes the base64 encoded string and decrypts the result with the key. // The decrypted value is then passed to UnpickleLibOlm. func (a *Account) Unpickle(pickled, key []byte) error { - decrypted, err := cipher.Unpickle(key, pickled) + decrypted, err := libolmpickle.Unpickle(key, pickled) if err != nil { return err } - _, err = a.UnpickleLibOlm(decrypted) - return err + return a.UnpickleLibOlm(decrypted) } -// UnpickleLibOlm decodes the unencryted value and populates the Account accordingly. It returns the number of bytes read. -func (a *Account) UnpickleLibOlm(value []byte) (int, error) { - //First 4 bytes are the accountPickleVersion - pickledVersion, curPos, err := libolmpickle.UnpickleUInt32(value) +// UnpickleLibOlm unpickles the unencryted value and populates the [Account] accordingly. +func (a *Account) UnpickleLibOlm(buf []byte) error { + decoder := libolmpickle.NewDecoder(buf) + pickledVersion, err := decoder.ReadUInt32() if err != nil { - return 0, err + return err + } else if pickledVersion != accountPickleVersionLibOLM && pickledVersion != 3 && pickledVersion != 2 { + return fmt.Errorf("unpickle account: %w (found version %d)", olm.ErrUnknownOlmPickleVersion, pickledVersion) + } else if err = a.IdKeys.Ed25519.UnpickleLibOlm(decoder); err != nil { // read the ed25519 key pair + return err + } else if err = a.IdKeys.Curve25519.UnpickleLibOlm(decoder); err != nil { // read curve25519 key pair + return err } - switch pickledVersion { - case accountPickleVersionLibOLM, 3, 2: - default: - return 0, fmt.Errorf("unpickle account: %w", goolm.ErrBadVersion) - } - //read ed25519 key pair - readBytes, err := a.IdKeys.Ed25519.UnpickleLibOlm(value[curPos:]) + + otkCount, err := decoder.ReadUInt32() if err != nil { - return 0, err + return err } - curPos += readBytes - //read curve25519 key pair - readBytes, err = a.IdKeys.Curve25519.UnpickleLibOlm(value[curPos:]) - if err != nil { - return 0, err - } - curPos += readBytes - //Read number of onetimeKeys - numberOTKeys, readBytes, err := libolmpickle.UnpickleUInt32(value[curPos:]) - if err != nil { - return 0, err - } - curPos += readBytes - //Read i one time keys - a.OTKeys = make([]crypto.OneTimeKey, numberOTKeys) - for i := uint32(0); i < numberOTKeys; i++ { - readBytes, err := a.OTKeys[i].UnpickleLibOlm(value[curPos:]) - if err != nil { - return 0, err + + a.OTKeys = make([]crypto.OneTimeKey, otkCount) + for i := uint32(0); i < otkCount; i++ { + if err := a.OTKeys[i].UnpickleLibOlm(decoder); err != nil { + return err } - curPos += readBytes } + if pickledVersion <= 2 { // version 2 did not have fallback keys a.NumFallbackKeys = 0 } else if pickledVersion == 3 { // version 3 used the published flag to indicate how many fallback keys // were present (we'll have to assume that the keys were published) - readBytes, err := a.CurrentFallbackKey.UnpickleLibOlm(value[curPos:]) - if err != nil { - return 0, err + if err = a.CurrentFallbackKey.UnpickleLibOlm(decoder); err != nil { + return err + } else if err = a.PrevFallbackKey.UnpickleLibOlm(decoder); err != nil { + return err } - curPos += readBytes - readBytes, err = a.PrevFallbackKey.UnpickleLibOlm(value[curPos:]) - if err != nil { - return 0, err - } - curPos += readBytes if a.CurrentFallbackKey.Published { if a.PrevFallbackKey.Published { a.NumFallbackKeys = 2 @@ -415,109 +374,70 @@ func (a *Account) UnpickleLibOlm(value []byte) (int, error) { a.NumFallbackKeys = 0 } } else { - //Read number of fallback keys - numFallbackKeys, readBytes, err := libolmpickle.UnpickleUInt8(value[curPos:]) + // Read number of fallback keys + a.NumFallbackKeys, err = decoder.ReadUInt8() if err != nil { - return 0, err + return err } - curPos += readBytes - a.NumFallbackKeys = numFallbackKeys - if a.NumFallbackKeys >= 1 { - readBytes, err := a.CurrentFallbackKey.UnpickleLibOlm(value[curPos:]) - if err != nil { - return 0, err - } - curPos += readBytes - if a.NumFallbackKeys >= 2 { - readBytes, err := a.PrevFallbackKey.UnpickleLibOlm(value[curPos:]) - if err != nil { - return 0, err + for i := 0; i < int(a.NumFallbackKeys); i++ { + switch i { + case 0: + if err = a.CurrentFallbackKey.UnpickleLibOlm(decoder); err != nil { + return err + } + case 1: + if err = a.PrevFallbackKey.UnpickleLibOlm(decoder); err != nil { + return err + } + default: + // Just drain any remaining fallback keys + if err = (&crypto.OneTimeKey{}).UnpickleLibOlm(decoder); err != nil { + return err } - curPos += readBytes } } } - //Read next onetime key id - nextOTKeyID, readBytes, err := libolmpickle.UnpickleUInt32(value[curPos:]) - if err != nil { - return 0, err - } - curPos += readBytes - a.NextOneTimeKeyID = nextOTKeyID - return curPos, nil + + //Read next onetime key ID + a.NextOneTimeKeyID, err = decoder.ReadUInt32() + return err } // Pickle returns a base64 encoded and with key encrypted pickled account using PickleLibOlm(). -func (a Account) Pickle(key []byte) ([]byte, error) { - pickeledBytes := make([]byte, a.PickleLen()) - written, err := a.PickleLibOlm(pickeledBytes) - if err != nil { - return nil, err +func (a *Account) Pickle(key []byte) ([]byte, error) { + if len(key) == 0 { + return nil, olm.ErrNoKeyProvided } - if written != len(pickeledBytes) { - return nil, errors.New("number of written bytes not correct") - } - encrypted, err := cipher.Pickle(key, pickeledBytes) - if err != nil { - return nil, err - } - return encrypted, nil + return libolmpickle.Pickle(key, a.PickleLibOlm()) } -// PickleLibOlm encodes the Account into target. target has to have a size of at least PickleLen() and is written to from index 0. -// It returns the number of bytes written. -func (a Account) PickleLibOlm(target []byte) (int, error) { - if len(target) < a.PickleLen() { - return 0, fmt.Errorf("pickle account: %w", goolm.ErrValueTooShort) - } - written := libolmpickle.PickleUInt32(accountPickleVersionLibOLM, target) - writtenEdKey, err := a.IdKeys.Ed25519.PickleLibOlm(target[written:]) - if err != nil { - return 0, fmt.Errorf("pickle account: %w", err) - } - written += writtenEdKey - writtenCurveKey, err := a.IdKeys.Curve25519.PickleLibOlm(target[written:]) - if err != nil { - return 0, fmt.Errorf("pickle account: %w", err) - } - written += writtenCurveKey - written += libolmpickle.PickleUInt32(uint32(len(a.OTKeys)), target[written:]) +// PickleLibOlm pickles the [Account] and returns the raw bytes. +func (a *Account) PickleLibOlm() []byte { + encoder := libolmpickle.NewEncoder() + encoder.WriteUInt32(accountPickleVersionLibOLM) + a.IdKeys.Ed25519.PickleLibOlm(encoder) + a.IdKeys.Curve25519.PickleLibOlm(encoder) + + // One-Time Keys + encoder.WriteUInt32(uint32(len(a.OTKeys))) for _, curOTKey := range a.OTKeys { - writtenOT, err := curOTKey.PickleLibOlm(target[written:]) - if err != nil { - return 0, fmt.Errorf("pickle account: %w", err) - } - written += writtenOT + curOTKey.PickleLibOlm(encoder) } - written += libolmpickle.PickleUInt8(a.NumFallbackKeys, target[written:]) + + // Fallback Keys + encoder.WriteUInt8(a.NumFallbackKeys) if a.NumFallbackKeys >= 1 { - writtenOT, err := a.CurrentFallbackKey.PickleLibOlm(target[written:]) - if err != nil { - return 0, fmt.Errorf("pickle account: %w", err) - } - written += writtenOT - + a.CurrentFallbackKey.PickleLibOlm(encoder) if a.NumFallbackKeys >= 2 { - writtenOT, err := a.PrevFallbackKey.PickleLibOlm(target[written:]) - if err != nil { - return 0, fmt.Errorf("pickle account: %w", err) - } - written += writtenOT + a.PrevFallbackKey.PickleLibOlm(encoder) } } - written += libolmpickle.PickleUInt32(a.NextOneTimeKeyID, target[written:]) - return written, nil + encoder.WriteUInt32(a.NextOneTimeKeyID) + return encoder.Bytes() } -// PickleLen returns the number of bytes the pickled Account will have. -func (a Account) PickleLen() int { - length := libolmpickle.PickleUInt32Len(accountPickleVersionLibOLM) - length += a.IdKeys.Ed25519.PickleLen() - length += a.IdKeys.Curve25519.PickleLen() - length += libolmpickle.PickleUInt32Len(uint32(len(a.OTKeys))) - length += (len(a.OTKeys) * (&crypto.OneTimeKey{}).PickleLen()) - length += libolmpickle.PickleUInt8Len(a.NumFallbackKeys) - length += (int(a.NumFallbackKeys) * (&crypto.OneTimeKey{}).PickleLen()) - length += libolmpickle.PickleUInt32Len(a.NextOneTimeKeyID) - return length +// MaxNumberOfOneTimeKeys returns the largest number of one time keys this +// Account can store. +func (a *Account) MaxNumberOfOneTimeKeys() uint { + return uint(MaxOneTimeKeys) } diff --git a/crypto/goolm/account/account_test.go b/crypto/goolm/account/account_test.go index 943d8570..d0dec5f0 100644 --- a/crypto/goolm/account/account_test.go +++ b/crypto/goolm/account/account_test.go @@ -1,92 +1,56 @@ package account_test import ( - "bytes" "encoding/base64" - "errors" "testing" "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" "maunium.net/go/mautrix/id" - "maunium.net/go/mautrix/crypto/goolm" "maunium.net/go/mautrix/crypto/goolm/account" + "maunium.net/go/mautrix/crypto/olm" "maunium.net/go/mautrix/crypto/signatures" ) func TestAccount(t *testing.T) { - firstAccount, err := account.NewAccount(nil) - if err != nil { - t.Fatal(err) - } - err = firstAccount.GenFallbackKey(nil) - if err != nil { - t.Fatal(err) - } - err = firstAccount.GenOneTimeKeys(nil, 2) - if err != nil { - t.Fatal(err) - } + firstAccount, err := account.NewAccount() + assert.NoError(t, err) + err = firstAccount.GenFallbackKey() + assert.NoError(t, err) + err = firstAccount.GenOneTimeKeys(2) + assert.NoError(t, err) encryptionKey := []byte("testkey") + //now pickle account in JSON format pickled, err := firstAccount.PickleAsJSON(encryptionKey) - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err) + //now unpickle into new Account unpickledAccount, err := account.AccountFromJSONPickled(pickled, encryptionKey) - if err != nil { - t.Fatal(err) - } - //check if accounts are the same - if firstAccount.NextOneTimeKeyID != unpickledAccount.NextOneTimeKeyID { - t.Fatal("NextOneTimeKeyID unequal") - } - if !firstAccount.CurrentFallbackKey.Equal(unpickledAccount.CurrentFallbackKey) { - t.Fatal("CurrentFallbackKey unequal") - } - if !firstAccount.PrevFallbackKey.Equal(unpickledAccount.PrevFallbackKey) { - t.Fatal("PrevFallbackKey unequal") - } - if len(firstAccount.OTKeys) != len(unpickledAccount.OTKeys) { - t.Fatal("OneTimeKeysunequal") - } - for i := range firstAccount.OTKeys { - if !firstAccount.OTKeys[i].Equal(unpickledAccount.OTKeys[i]) { - t.Fatalf("OneTimeKeys %d unequal", i) - } - } - if !firstAccount.IdKeys.Curve25519.PrivateKey.Equal(unpickledAccount.IdKeys.Curve25519.PrivateKey) { - t.Fatal("IdentityKeys Curve25519 private unequal") - } - if !firstAccount.IdKeys.Curve25519.PublicKey.Equal(unpickledAccount.IdKeys.Curve25519.PublicKey) { - t.Fatal("IdentityKeys Curve25519 public unequal") - } - if !firstAccount.IdKeys.Ed25519.PrivateKey.Equal(unpickledAccount.IdKeys.Ed25519.PrivateKey) { - t.Fatal("IdentityKeys Ed25519 private unequal") - } - if !firstAccount.IdKeys.Ed25519.PublicKey.Equal(unpickledAccount.IdKeys.Ed25519.PublicKey) { - t.Fatal("IdentityKeys Ed25519 public unequal") - } + assert.NoError(t, err) - if len(firstAccount.OneTimeKeys()) != 2 { - t.Fatal("should get 2 unpublished oneTimeKeys") - } - if len(firstAccount.FallbackKeyUnpublished()) == 0 { - t.Fatal("should get fallbackKey") - } + //check if accounts are the same + assert.Equal(t, firstAccount.NextOneTimeKeyID, unpickledAccount.NextOneTimeKeyID) + assert.Equal(t, firstAccount.CurrentFallbackKey, unpickledAccount.CurrentFallbackKey) + assert.Equal(t, firstAccount.PrevFallbackKey, unpickledAccount.PrevFallbackKey) + assert.Equal(t, firstAccount.OTKeys, unpickledAccount.OTKeys) + assert.Equal(t, firstAccount.IdKeys, unpickledAccount.IdKeys) + + // Ensure that all of the keys are unpublished right now + otks, err := firstAccount.OneTimeKeys() + assert.NoError(t, err) + assert.Len(t, otks, 2) + assert.Len(t, firstAccount.FallbackKeyUnpublished(), 1) + + // Now, publish the key and make sure that they are published firstAccount.MarkKeysAsPublished() - if len(firstAccount.FallbackKey()) == 0 { - t.Fatal("should get fallbackKey") - } - if len(firstAccount.FallbackKeyUnpublished()) != 0 { - t.Fatal("should get no fallbackKey") - } - if len(firstAccount.OneTimeKeys()) != 0 { - t.Fatal("should get no oneTimeKeys") - } + + assert.Len(t, firstAccount.FallbackKeyUnpublished(), 0) + assert.Len(t, firstAccount.FallbackKey(), 1) + otks, err = firstAccount.OneTimeKeys() + assert.NoError(t, err) + assert.Len(t, otks, 0) } func TestAccountPickleJSON(t *testing.T) { @@ -104,109 +68,49 @@ func TestAccountPickleJSON(t *testing.T) { pickledData := []byte("6POkBWwbNl20fwvZWsOu0jgbHy4jkA5h0Ji+XCag59+ifWIRPDrqtgQi9HmkLiSF6wUhhYaV4S73WM+Hh+dlCuZRuXhTQr8yGPTifjcjq8birdAhObbEqHrYEdqaQkrgBLr/rlS5sibXeDqbkhVu4LslvootU9DkcCbd4b/0Flh7iugxqkcCs5GDndTEx9IzTVJzmK82Y0Q1Z1Z9Vuc2Iw746PtBJLtZjite6fSMp2NigPX/ZWWJ3OnwcJo0Vvjy8hgptZEWkamOHdWbUtelbHyjDIZlvxOC25D3rFif0zzPkF9qdpBPqVCWPPzGFmgnqKau6CHrnPfq7GLsM3BrprD7sHN1Js28ex14gXQPjBT7KTUo6H0e4gQMTMRp4qb8btNXDeId8xIFIElTh2SXZBTDmSq/ziVNJinEvYV8mGPvJZjDQQU+SyoS/HZ8uMc41tH0BOGDbFMHbfLMiz61E429gOrx2klu5lqyoyet7//HKi0ed5w2dQ") account, err := account.AccountFromJSONPickled(pickledData, key) - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err) expectedJSON := `{"ed25519":"qWvNB6Ztov5/AOsP073op0O32KJ8/tgSNarT7MaYgQE","curve25519":"TFUB6M6zwgyWhBEp2m1aUodl2AsnsrIuBr8l9AvwGS8"}` jsonData, err := account.IdentityKeysJSON() - if err != nil { - t.Fatal(err) - } - if !bytes.Equal(jsonData, []byte(expectedJSON)) { - t.Fatalf("Expected '%s' but got '%s'", expectedJSON, jsonData) - } + assert.NoError(t, err) + assert.Equal(t, expectedJSON, string(jsonData)) } func TestSessions(t *testing.T) { - aliceAccount, err := account.NewAccount(nil) - if err != nil { - t.Fatal(err) - } - err = aliceAccount.GenOneTimeKeys(nil, 5) - if err != nil { - t.Fatal(err) - } - bobAccount, err := account.NewAccount(nil) - if err != nil { - t.Fatal(err) - } - err = bobAccount.GenOneTimeKeys(nil, 5) - if err != nil { - t.Fatal(err) - } + aliceAccount, err := account.NewAccount() + assert.NoError(t, err) + err = aliceAccount.GenOneTimeKeys(5) + assert.NoError(t, err) + bobAccount, err := account.NewAccount() + assert.NoError(t, err) + err = bobAccount.GenOneTimeKeys(5) + assert.NoError(t, err) aliceSession, err := aliceAccount.NewOutboundSession(bobAccount.IdKeys.Curve25519.B64Encoded(), bobAccount.OTKeys[2].Key.B64Encoded()) - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err) plaintext := []byte("test message") - msgType, crypttext, err := aliceSession.Encrypt(plaintext, nil) - if err != nil { - t.Fatal(err) - } - if msgType != id.OlmMsgTypePreKey { - t.Fatal("wrong message type") - } + msgType, crypttext, err := aliceSession.Encrypt(plaintext) + assert.NoError(t, err) + assert.Equal(t, id.OlmMsgTypePreKey, msgType) - bobSession, err := bobAccount.NewInboundSession(nil, crypttext) - if err != nil { - t.Fatal(err) - } - decodedText, err := bobSession.Decrypt(crypttext, msgType) - if err != nil { - t.Fatal(err) - } - if !bytes.Equal(plaintext, decodedText) { - t.Fatalf("expected '%s' but got '%s'", string(plaintext), string(decodedText)) - } + bobSession, err := bobAccount.NewInboundSession(string(crypttext)) + assert.NoError(t, err) + decodedText, err := bobSession.Decrypt(string(crypttext), msgType) + assert.NoError(t, err) + assert.Equal(t, plaintext, decodedText) } func TestAccountPickle(t *testing.T) { pickleKey := []byte("secret_key") account, err := account.AccountFromPickled(pickledDataFromLibOlm, pickleKey) - if err != nil { - t.Fatal(err) - } - if !expectedEd25519KeyPairPickleLibOLM.PrivateKey.Equal(account.IdKeys.Ed25519.PrivateKey) { - t.Fatal("keys not equal") - } - if !expectedEd25519KeyPairPickleLibOLM.PublicKey.Equal(account.IdKeys.Ed25519.PublicKey) { - t.Fatal("keys not equal") - } - if !expectedCurve25519KeyPairPickleLibOLM.PrivateKey.Equal(account.IdKeys.Curve25519.PrivateKey) { - t.Fatal("keys not equal") - } - if !expectedCurve25519KeyPairPickleLibOLM.PublicKey.Equal(account.IdKeys.Curve25519.PublicKey) { - t.Fatal("keys not equal") - } - if account.NextOneTimeKeyID != 42 { - t.Fatal("wrong next otKey id") - } - if len(account.OTKeys) != len(expectedOTKeysPickleLibOLM) { - t.Fatal("wrong number of otKeys") - } - if account.NumFallbackKeys != 0 { - t.Fatal("fallback keys set but not in pickle") - } - for curIndex, curValue := range account.OTKeys { - curExpected := expectedOTKeysPickleLibOLM[curIndex] - if curExpected.ID != curValue.ID { - t.Fatal("OTKey id not correct") - } - if !curExpected.Key.PublicKey.Equal(curValue.Key.PublicKey) { - t.Fatal("OTKey public key not correct") - } - if !curExpected.Key.PrivateKey.Equal(curValue.Key.PrivateKey) { - t.Fatal("OTKey private key not correct") - } - } + assert.NoError(t, err) + assert.Equal(t, expectedEd25519KeyPairPickleLibOLM, account.IdKeys.Ed25519) + assert.Equal(t, expectedCurve25519KeyPairPickleLibOLM, account.IdKeys.Curve25519) + assert.EqualValues(t, 42, account.NextOneTimeKeyID) + assert.Equal(t, account.OTKeys, expectedOTKeysPickleLibOLM) + assert.EqualValues(t, 0, account.NumFallbackKeys) targetPickled, err := account.Pickle(pickleKey) - if err != nil { - t.Fatal(err) - } - if !bytes.Equal(targetPickled, pickledDataFromLibOlm) { - t.Fatal("repickled value does not equal given value") - } + assert.NoError(t, err) + assert.Equal(t, pickledDataFromLibOlm, targetPickled) } func TestOldAccountPickle(t *testing.T) { @@ -217,356 +121,213 @@ func TestOldAccountPickle(t *testing.T) { "K/A/8TOu9iK2hDFszy6xETiousHnHgh2ZGbRUh4pQx+YMm8ZdNZeRnwFGLnrWyf9" + "O5TmXua1FcU") pickleKey := []byte("") - account, err := account.NewAccount(nil) - if err != nil { - t.Fatal(err) - } + account, err := account.NewAccount() + assert.NoError(t, err) err = account.Unpickle(pickled, pickleKey) - if err == nil { - t.Fatal("expected error") - } else { - if !errors.Is(err, goolm.ErrBadVersion) { - t.Fatal(err) - } - } + assert.ErrorIs(t, err, olm.ErrUnknownOlmPickleVersion) } func TestLoopback(t *testing.T) { - accountA, err := account.NewAccount(nil) - if err != nil { - t.Fatal(err) - } + accountA, err := account.NewAccount() + assert.NoError(t, err) - accountB, err := account.NewAccount(nil) - if err != nil { - t.Fatal(err) - } - err = accountB.GenOneTimeKeys(nil, 42) - if err != nil { - t.Fatal(err) - } + accountB, err := account.NewAccount() + assert.NoError(t, err) + err = accountB.GenOneTimeKeys(42) + assert.NoError(t, err) aliceSession, err := accountA.NewOutboundSession(accountB.IdKeys.Curve25519.B64Encoded(), accountB.OTKeys[0].Key.B64Encoded()) - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err) plainText := []byte("Hello, World") - msgType, message1, err := aliceSession.Encrypt(plainText, nil) - if err != nil { - t.Fatal(err) - } - if msgType != id.OlmMsgTypePreKey { - t.Fatal("wrong message type") - } + msgType, message1, err := aliceSession.Encrypt(plainText) + assert.NoError(t, err) + assert.Equal(t, id.OlmMsgTypePreKey, msgType) - bobSession, err := accountB.NewInboundSession(nil, message1) - if err != nil { - t.Fatal(err) - } + bobSession, err := accountB.NewInboundSession(string(message1)) + assert.NoError(t, err) // Check that the inbound session matches the message it was created from. - sessionIsOK, err := bobSession.MatchesInboundSessionFrom(nil, message1) - if err != nil { - t.Fatal(err) - } - if !sessionIsOK { - t.Fatal("session was not detected to be valid") - } + sessionIsOK, err := bobSession.MatchesInboundSessionFrom("", string(message1)) + assert.NoError(t, err) + assert.True(t, sessionIsOK, "session was not detected to be valid") + // Check that the inbound session matches the key this message is supposed to be from. aIDKey := accountA.IdKeys.Curve25519.PublicKey.B64Encoded() - sessionIsOK, err = bobSession.MatchesInboundSessionFrom(&aIDKey, message1) - if err != nil { - t.Fatal(err) - } - if !sessionIsOK { - t.Fatal("session is sad to be not from a but it should") - } + sessionIsOK, err = bobSession.MatchesInboundSessionFrom(string(aIDKey), string(message1)) + assert.NoError(t, err) + assert.True(t, sessionIsOK, "session is sad to be not from a but it should") + // Check that the inbound session isn't from a different user. bIDKey := accountB.IdKeys.Curve25519.PublicKey.B64Encoded() - sessionIsOK, err = bobSession.MatchesInboundSessionFrom(&bIDKey, message1) - if err != nil { - t.Fatal(err) - } - if sessionIsOK { - t.Fatal("session is sad to be from b but is from a") - } + sessionIsOK, err = bobSession.MatchesInboundSessionFrom(string(bIDKey), string(message1)) + assert.NoError(t, err) + assert.False(t, sessionIsOK, "session is sad to be from b but is from a") + // Check that we can decrypt the message. - decryptedMessage, err := bobSession.Decrypt(message1, msgType) - if err != nil { - t.Fatal(err) - } - if !bytes.Equal(decryptedMessage, plainText) { - t.Fatal("messages are not the same") - } + decryptedMessage, err := bobSession.Decrypt(string(message1), msgType) + assert.NoError(t, err) + assert.Equal(t, plainText, decryptedMessage) - msgTyp2, message2, err := bobSession.Encrypt(plainText, nil) - if err != nil { - t.Fatal(err) - } - if msgTyp2 == id.OlmMsgTypePreKey { - t.Fatal("wrong message type") - } + msgTyp2, message2, err := bobSession.Encrypt(plainText) + assert.NoError(t, err) + assert.Equal(t, id.OlmMsgTypeMsg, msgTyp2) - decryptedMessage2, err := aliceSession.Decrypt(message2, msgTyp2) - if err != nil { - t.Fatal(err) - } - if !bytes.Equal(decryptedMessage2, plainText) { - t.Fatal("messages are not the same") - } + decryptedMessage2, err := aliceSession.Decrypt(string(message2), msgTyp2) + assert.NoError(t, err) + assert.Equal(t, plainText, decryptedMessage2) //decrypting again should fail, as the chain moved on - _, err = aliceSession.Decrypt(message2, msgTyp2) - if err == nil { - t.Fatal("expected error") - } + _, err = aliceSession.Decrypt(string(message2), msgTyp2) + assert.Error(t, err) + assert.ErrorIs(t, err, olm.ErrMessageKeyNotFound) //compare sessionIDs - if aliceSession.ID() != bobSession.ID() { - t.Fatal("sessionIDs are not equal") - } + assert.Equal(t, aliceSession.ID(), bobSession.ID()) } func TestMoreMessages(t *testing.T) { - accountA, err := account.NewAccount(nil) - if err != nil { - t.Fatal(err) - } + accountA, err := account.NewAccount() + assert.NoError(t, err) - accountB, err := account.NewAccount(nil) - if err != nil { - t.Fatal(err) - } - err = accountB.GenOneTimeKeys(nil, 42) - if err != nil { - t.Fatal(err) - } + accountB, err := account.NewAccount() + assert.NoError(t, err) + err = accountB.GenOneTimeKeys(42) + assert.NoError(t, err) aliceSession, err := accountA.NewOutboundSession(accountB.IdKeys.Curve25519.B64Encoded(), accountB.OTKeys[0].Key.B64Encoded()) - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err) plainText := []byte("Hello, World") - msgType, message1, err := aliceSession.Encrypt(plainText, nil) - if err != nil { - t.Fatal(err) - } - if msgType != id.OlmMsgTypePreKey { - t.Fatal("wrong message type") - } + msgType, message1, err := aliceSession.Encrypt(plainText) + assert.NoError(t, err) + assert.Equal(t, id.OlmMsgTypePreKey, msgType) - bobSession, err := accountB.NewInboundSession(nil, message1) - if err != nil { - t.Fatal(err) - } - decryptedMessage, err := bobSession.Decrypt(message1, msgType) - if err != nil { - t.Fatal(err) - } - if !bytes.Equal(decryptedMessage, plainText) { - t.Fatal("messages are not the same") - } + bobSession, err := accountB.NewInboundSession(string(message1)) + assert.NoError(t, err) + decryptedMessage, err := bobSession.Decrypt(string(message1), msgType) + assert.NoError(t, err) + assert.Equal(t, plainText, decryptedMessage) for i := 0; i < 8; i++ { //alice sends, bob reveices - msgType, message, err := aliceSession.Encrypt(plainText, nil) - if err != nil { - t.Fatal(err) - } + msgType, message, err := aliceSession.Encrypt(plainText) + assert.NoError(t, err) if i == 0 { //The first time should still be a preKeyMessage as bob has not yet send a message to alice - if msgType != id.OlmMsgTypePreKey { - t.Fatal("wrong message type") - } + assert.Equal(t, id.OlmMsgTypePreKey, msgType) } else { - if msgType == id.OlmMsgTypePreKey { - t.Fatal("wrong message type") - } - } - decryptedMessage, err := bobSession.Decrypt(message, msgType) - if err != nil { - t.Fatal(err) - } - if !bytes.Equal(decryptedMessage, plainText) { - t.Fatal("messages are not the same") + assert.Equal(t, id.OlmMsgTypeMsg, msgType) } + decryptedMessage, err := bobSession.Decrypt(string(message), msgType) + assert.NoError(t, err) + assert.Equal(t, plainText, decryptedMessage) + //now bob sends, alice receives - msgType, message, err = bobSession.Encrypt(plainText, nil) - if err != nil { - t.Fatal(err) - } - if msgType == id.OlmMsgTypePreKey { - t.Fatal("wrong message type") - } - decryptedMessage, err = aliceSession.Decrypt(message, msgType) - if err != nil { - t.Fatal(err) - } - if !bytes.Equal(decryptedMessage, plainText) { - t.Fatal("messages are not the same") - } + msgType, message, err = bobSession.Encrypt(plainText) + assert.NoError(t, err) + assert.Equal(t, id.OlmMsgTypeMsg, msgType) + + decryptedMessage, err = aliceSession.Decrypt(string(message), msgType) + assert.NoError(t, err) + assert.Equal(t, plainText, decryptedMessage) } } func TestFallbackKey(t *testing.T) { - accountA, err := account.NewAccount(nil) - if err != nil { - t.Fatal(err) - } + accountA, err := account.NewAccount() + assert.NoError(t, err) - accountB, err := account.NewAccount(nil) - if err != nil { - t.Fatal(err) - } - err = accountB.GenFallbackKey(nil) - if err != nil { - t.Fatal(err) - } + accountB, err := account.NewAccount() + assert.NoError(t, err) + err = accountB.GenFallbackKey() + assert.NoError(t, err) fallBackKeys := accountB.FallbackKeyUnpublished() var fallbackKey id.Curve25519 for _, fbKey := range fallBackKeys { fallbackKey = fbKey } aliceSession, err := accountA.NewOutboundSession(accountB.IdKeys.Curve25519.B64Encoded(), fallbackKey) - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err) plainText := []byte("Hello, World") - msgType, message1, err := aliceSession.Encrypt(plainText, nil) - if err != nil { - t.Fatal(err) - } - if msgType != id.OlmMsgTypePreKey { - t.Fatal("wrong message type") - } + msgType, message1, err := aliceSession.Encrypt(plainText) + assert.NoError(t, err) + assert.Equal(t, id.OlmMsgTypePreKey, msgType) - bobSession, err := accountB.NewInboundSession(nil, message1) - if err != nil { - t.Fatal(err) - } + bobSession, err := accountB.NewInboundSession(string(message1)) + assert.NoError(t, err) // Check that the inbound session matches the message it was created from. - sessionIsOK, err := bobSession.MatchesInboundSessionFrom(nil, message1) - if err != nil { - t.Fatal(err) - } - if !sessionIsOK { - t.Fatal("session was not detected to be valid") - } + sessionIsOK, err := bobSession.MatchesInboundSessionFrom("", string(message1)) + assert.NoError(t, err) + assert.True(t, sessionIsOK, "session was not detected to be valid") + // Check that the inbound session matches the key this message is supposed to be from. aIDKey := accountA.IdKeys.Curve25519.PublicKey.B64Encoded() - sessionIsOK, err = bobSession.MatchesInboundSessionFrom(&aIDKey, message1) - if err != nil { - t.Fatal(err) - } - if !sessionIsOK { - t.Fatal("session is sad to be not from a but it should") - } + sessionIsOK, err = bobSession.MatchesInboundSessionFrom(string(aIDKey), string(message1)) + assert.NoError(t, err) + assert.True(t, sessionIsOK, "session is sad to be not from a but it should") + // Check that the inbound session isn't from a different user. bIDKey := accountB.IdKeys.Curve25519.PublicKey.B64Encoded() - sessionIsOK, err = bobSession.MatchesInboundSessionFrom(&bIDKey, message1) - if err != nil { - t.Fatal(err) - } - if sessionIsOK { - t.Fatal("session is sad to be from b but is from a") - } + sessionIsOK, err = bobSession.MatchesInboundSessionFrom(string(bIDKey), string(message1)) + assert.NoError(t, err) + assert.False(t, sessionIsOK, "session is sad to be from b but is from a") + // Check that we can decrypt the message. - decryptedMessage, err := bobSession.Decrypt(message1, msgType) - if err != nil { - t.Fatal(err) - } - if !bytes.Equal(decryptedMessage, plainText) { - t.Fatal("messages are not the same") - } + decryptedMessage, err := bobSession.Decrypt(string(message1), msgType) + assert.NoError(t, err) + assert.Equal(t, plainText, decryptedMessage) // create a new fallback key for B (the old fallback should still be usable) - err = accountB.GenFallbackKey(nil) - if err != nil { - t.Fatal(err) - } + err = accountB.GenFallbackKey() + assert.NoError(t, err) // start another session and encrypt a message aliceSession2, err := accountA.NewOutboundSession(accountB.IdKeys.Curve25519.B64Encoded(), fallbackKey) - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err) + + msgType2, message2, err := aliceSession2.Encrypt(plainText) + assert.NoError(t, err) + assert.Equal(t, id.OlmMsgTypePreKey, msgType2) - msgType2, message2, err := aliceSession2.Encrypt(plainText, nil) - if err != nil { - t.Fatal(err) - } - if msgType2 != id.OlmMsgTypePreKey { - t.Fatal("wrong message type") - } // bobSession should not be valid for the message2 // Check that the inbound session matches the message it was created from. - sessionIsOK, err = bobSession.MatchesInboundSessionFrom(nil, message2) - if err != nil { - t.Fatal(err) - } - if sessionIsOK { - t.Fatal("session was detected to be valid but should not") - } - bobSession2, err := accountB.NewInboundSession(nil, message2) - if err != nil { - t.Fatal(err) - } + sessionIsOK, err = bobSession.MatchesInboundSessionFrom("", string(message2)) + assert.NoError(t, err) + assert.False(t, sessionIsOK, "session was detected to be valid but should not") + + bobSession2, err := accountB.NewInboundSession(string(message2)) + assert.NoError(t, err) // Check that the inbound session matches the message it was created from. - sessionIsOK, err = bobSession2.MatchesInboundSessionFrom(nil, message2) - if err != nil { - t.Fatal(err) - } - if !sessionIsOK { - t.Fatal("session was not detected to be valid") - } + sessionIsOK, err = bobSession2.MatchesInboundSessionFrom("", string(message2)) + assert.NoError(t, err) + assert.True(t, sessionIsOK, "session was not detected to be valid") + // Check that the inbound session matches the key this message is supposed to be from. - sessionIsOK, err = bobSession2.MatchesInboundSessionFrom(&aIDKey, message2) - if err != nil { - t.Fatal(err) - } - if !sessionIsOK { - t.Fatal("session is sad to be not from a but it should") - } + sessionIsOK, err = bobSession2.MatchesInboundSessionFrom(string(aIDKey), string(message2)) + assert.NoError(t, err) + assert.True(t, sessionIsOK, "session is sad to be not from a but it should") + // Check that the inbound session isn't from a different user. - sessionIsOK, err = bobSession2.MatchesInboundSessionFrom(&bIDKey, message2) - if err != nil { - t.Fatal(err) - } - if sessionIsOK { - t.Fatal("session is sad to be from b but is from a") - } + sessionIsOK, err = bobSession2.MatchesInboundSessionFrom(string(bIDKey), string(message2)) + assert.NoError(t, err) + assert.False(t, sessionIsOK, "session is sad to be from b but is from a") + // Check that we can decrypt the message. - decryptedMessage2, err := bobSession2.Decrypt(message2, msgType2) - if err != nil { - t.Fatal(err) - } - if !bytes.Equal(decryptedMessage2, plainText) { - t.Fatal("messages are not the same") - } + decryptedMessage2, err := bobSession2.Decrypt(string(message2), msgType2) + assert.NoError(t, err) + assert.Equal(t, plainText, decryptedMessage2) //Forget the old fallback key -- creating a new session should fail now accountB.ForgetOldFallbackKey() // start another session and encrypt a message aliceSession3, err := accountA.NewOutboundSession(accountB.IdKeys.Curve25519.B64Encoded(), fallbackKey) - if err != nil { - t.Fatal(err) - } - msgType3, message3, err := aliceSession3.Encrypt(plainText, nil) - if err != nil { - t.Fatal(err) - } - if msgType3 != id.OlmMsgTypePreKey { - t.Fatal("wrong message type") - } - _, err = accountB.NewInboundSession(nil, message3) - if err == nil { - t.Fatal("expected error") - } - if !errors.Is(err, goolm.ErrBadMessageKeyID) { - t.Fatal(err) - } + assert.NoError(t, err) + msgType3, message3, err := aliceSession3.Encrypt(plainText) + assert.NoError(t, err) + assert.Equal(t, id.OlmMsgTypePreKey, msgType3) + _, err = accountB.NewInboundSession(string(message3)) + assert.ErrorIs(t, err, olm.ErrBadMessageKeyID) } func TestOldV3AccountPickle(t *testing.T) { @@ -582,33 +343,23 @@ func TestOldV3AccountPickle(t *testing.T) { expectedUnpublishedFallbackJSON := []byte("{\"curve25519\":{}}") account, err := account.AccountFromPickled(pickledData, pickleKey) - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err) fallbackJSON, err := account.FallbackKeyJSON() - if err != nil { - t.Fatal(err) - } - if !bytes.Equal(fallbackJSON, expectedFallbackJSON) { - t.Fatalf("expected not as result:\n%s\n%s\n", expectedFallbackJSON, fallbackJSON) - } + assert.NoError(t, err) + assert.Equal(t, expectedFallbackJSON, fallbackJSON) fallbackJSONUnpublished, err := account.FallbackKeyUnpublishedJSON() - if err != nil { - t.Fatal(err) - } - if !bytes.Equal(fallbackJSONUnpublished, expectedUnpublishedFallbackJSON) { - t.Fatalf("expected not as result:\n%s\n%s\n", expectedUnpublishedFallbackJSON, fallbackJSONUnpublished) - } + assert.NoError(t, err) + assert.Equal(t, expectedUnpublishedFallbackJSON, fallbackJSONUnpublished) } func TestAccountSign(t *testing.T) { - accountA, err := account.NewAccount(nil) - require.NoError(t, err) + accountA, err := account.NewAccount() + assert.NoError(t, err) plainText := []byte("Hello, World") signatureB64, err := accountA.Sign(plainText) - require.NoError(t, err) + assert.NoError(t, err) signature, err := base64.RawStdEncoding.DecodeString(string(signatureB64)) - require.NoError(t, err) + assert.NoError(t, err) verified, err := signatures.VerifySignature(plainText, accountA.IdKeys.Ed25519.B64Encoded(), signature) assert.NoError(t, err) diff --git a/crypto/goolm/account/register.go b/crypto/goolm/account/register.go new file mode 100644 index 00000000..ec392d7e --- /dev/null +++ b/crypto/goolm/account/register.go @@ -0,0 +1,23 @@ +// Copyright (c) 2024 Sumner Evans +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package account + +import ( + "maunium.net/go/mautrix/crypto/olm" +) + +func Register() { + olm.InitNewAccount = func() (olm.Account, error) { + return NewAccount() + } + olm.InitBlankAccount = func() olm.Account { + return &Account{} + } + olm.InitNewAccountFromPickled = func(pickled, key []byte) (olm.Account, error) { + return AccountFromPickled(pickled, key) + } +} diff --git a/crypto/goolm/aessha2/aessha2.go b/crypto/goolm/aessha2/aessha2.go new file mode 100644 index 00000000..42d9811b --- /dev/null +++ b/crypto/goolm/aessha2/aessha2.go @@ -0,0 +1,59 @@ +// Copyright (c) 2024 Sumner Evans +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +// Package aessha2 implements the m.megolm.v1.aes-sha2 encryption algorithm +// described in [Section 10.12.4.3] in the Spec +// +// [Section 10.12.4.3]: https://spec.matrix.org/v1.12/client-server-api/#mmegolmv1aes-sha2 +package aessha2 + +import ( + "crypto/hmac" + "crypto/sha256" + "crypto/subtle" + "io" + + "golang.org/x/crypto/hkdf" + + "maunium.net/go/mautrix/crypto/aescbc" +) + +type AESSHA2 struct { + aesKey, hmacKey, iv []byte +} + +func NewAESSHA2(secret, info []byte) (AESSHA2, error) { + kdf := hkdf.New(sha256.New, secret, nil, info) + keymatter := make([]byte, 80) + _, err := io.ReadFull(kdf, keymatter) + return AESSHA2{ + keymatter[:32], // AES Key + keymatter[32:64], // HMAC Key + keymatter[64:], // IV + }, err +} + +func (a *AESSHA2) Encrypt(plaintext []byte) ([]byte, error) { + return aescbc.Encrypt(a.aesKey, a.iv, plaintext) +} + +func (a *AESSHA2) Decrypt(ciphertext []byte) ([]byte, error) { + return aescbc.Decrypt(a.aesKey, a.iv, ciphertext) +} + +func (a *AESSHA2) MAC(ciphertext []byte) ([]byte, error) { + hash := hmac.New(sha256.New, a.hmacKey) + _, err := hash.Write(ciphertext) + return hash.Sum(nil), err +} + +func (a *AESSHA2) VerifyMAC(ciphertext, theirMAC []byte) (bool, error) { + if mac, err := a.MAC(ciphertext); err != nil { + return false, err + } else { + return subtle.ConstantTimeCompare(mac[:len(theirMAC)], theirMAC) == 1, nil + } +} diff --git a/crypto/goolm/aessha2/aessha2_test.go b/crypto/goolm/aessha2/aessha2_test.go new file mode 100644 index 00000000..b2cfe8aa --- /dev/null +++ b/crypto/goolm/aessha2/aessha2_test.go @@ -0,0 +1,33 @@ +package aessha2_test + +import ( + "crypto/aes" + "testing" + + "github.com/stretchr/testify/assert" + + "maunium.net/go/mautrix/crypto/goolm/aessha2" +) + +func TestCipherAESSha256(t *testing.T) { + key := []byte("test key") + cipher, err := aessha2.NewAESSHA2(key, []byte("testKDFinfo")) + assert.NoError(t, err) + message := []byte("this is a random message for testing the implementation") + //increase to next block size + for len(message)%aes.BlockSize != 0 { + message = append(message, []byte("-")...) + } + encrypted, err := cipher.Encrypt([]byte(message)) + assert.NoError(t, err) + mac, err := cipher.MAC(encrypted) + assert.NoError(t, err) + + verified, err := cipher.VerifyMAC(encrypted, mac[:8]) + assert.NoError(t, err) + assert.True(t, verified, "signature verification failed") + + resultPlainText, err := cipher.Decrypt(encrypted) + assert.NoError(t, err) + assert.Equal(t, message, resultPlainText) +} diff --git a/crypto/goolm/cipher/aes_sha256.go b/crypto/goolm/cipher/aes_sha256.go deleted file mode 100644 index 2d2d58d5..00000000 --- a/crypto/goolm/cipher/aes_sha256.go +++ /dev/null @@ -1,98 +0,0 @@ -package cipher - -import ( - "bytes" - "crypto/aes" - "io" - - "maunium.net/go/mautrix/crypto/aescbc" - "maunium.net/go/mautrix/crypto/goolm/crypto" -) - -// derivedAESKeys stores the derived keys for the AESSHA256 cipher -type derivedAESKeys struct { - key []byte - hmacKey []byte - iv []byte -} - -// deriveAESKeys derives three keys for the AESSHA256 cipher -func deriveAESKeys(kdfInfo []byte, key []byte) (*derivedAESKeys, error) { - hkdf := crypto.HKDFSHA256(key, nil, kdfInfo) - keys := &derivedAESKeys{ - key: make([]byte, 32), - hmacKey: make([]byte, 32), - iv: make([]byte, 16), - } - if _, err := io.ReadFull(hkdf, keys.key); err != nil { - return nil, err - } - if _, err := io.ReadFull(hkdf, keys.hmacKey); err != nil { - return nil, err - } - if _, err := io.ReadFull(hkdf, keys.iv); err != nil { - return nil, err - } - return keys, nil -} - -// AESSha512BlockSize resturns the blocksize of the cipher AESSHA256. -func AESSha512BlockSize() int { - return aes.BlockSize -} - -// AESSHA256 is a valid cipher using AES with CBC and HKDFSha256. -type AESSHA256 struct { - kdfInfo []byte -} - -// NewAESSHA256 returns a new AESSHA256 cipher with the key derive function info (kdfInfo). -func NewAESSHA256(kdfInfo []byte) *AESSHA256 { - return &AESSHA256{ - kdfInfo: kdfInfo, - } -} - -// Encrypt encrypts the plaintext with the key. The key is used to derive the actual encryption key (32 bytes) as well as the iv (16 bytes). -func (c AESSHA256) Encrypt(key, plaintext []byte) (ciphertext []byte, err error) { - keys, err := deriveAESKeys(c.kdfInfo, key) - if err != nil { - return nil, err - } - ciphertext, err = aescbc.Encrypt(keys.key, keys.iv, plaintext) - if err != nil { - return nil, err - } - return ciphertext, nil -} - -// Decrypt decrypts the ciphertext with the key. The key is used to derive the actual encryption key (32 bytes) as well as the iv (16 bytes). -func (c AESSHA256) Decrypt(key, ciphertext []byte) (plaintext []byte, err error) { - keys, err := deriveAESKeys(c.kdfInfo, key) - if err != nil { - return nil, err - } - plaintext, err = aescbc.Decrypt(keys.key, keys.iv, ciphertext) - if err != nil { - return nil, err - } - return plaintext, nil -} - -// MAC returns the MAC for the message using the key. The key is used to derive the actual mac key (32 bytes). -func (c AESSHA256) MAC(key, message []byte) ([]byte, error) { - keys, err := deriveAESKeys(c.kdfInfo, key) - if err != nil { - return nil, err - } - return crypto.HMACSHA256(keys.hmacKey, message), nil -} - -// Verify checks the MAC of the message using the key against the givenMAC. The key is used to derive the actual mac key (32 bytes). -func (c AESSHA256) Verify(key, message, givenMAC []byte) (bool, error) { - mac, err := c.MAC(key, message) - if err != nil { - return false, err - } - return bytes.Equal(givenMAC, mac[:len(givenMAC)]), nil -} diff --git a/crypto/goolm/cipher/aes_sha256_test.go b/crypto/goolm/cipher/aes_sha256_test.go deleted file mode 100644 index d2f49cb1..00000000 --- a/crypto/goolm/cipher/aes_sha256_test.go +++ /dev/null @@ -1,83 +0,0 @@ -package cipher - -import ( - "bytes" - "crypto/aes" - "testing" -) - -func TestDeriveAESKeys(t *testing.T) { - kdfInfo := []byte("test") - key := []byte("test key") - derivedKeys, err := deriveAESKeys(kdfInfo, key) - if err != nil { - t.Fatal(err) - } - derivedKeys2, err := deriveAESKeys(kdfInfo, key) - if err != nil { - t.Fatal(err) - } - //derivedKeys and derivedKeys2 should be identical - if !bytes.Equal(derivedKeys.key, derivedKeys2.key) || - !bytes.Equal(derivedKeys.iv, derivedKeys2.iv) || - !bytes.Equal(derivedKeys.hmacKey, derivedKeys2.hmacKey) { - t.Fail() - } - //changing kdfInfo - kdfInfo = []byte("other kdf") - derivedKeys2, err = deriveAESKeys(kdfInfo, key) - if err != nil { - t.Fatal(err) - } - //derivedKeys and derivedKeys2 should now be different - if bytes.Equal(derivedKeys.key, derivedKeys2.key) || - bytes.Equal(derivedKeys.iv, derivedKeys2.iv) || - bytes.Equal(derivedKeys.hmacKey, derivedKeys2.hmacKey) { - t.Fail() - } - //changing key - key = []byte("other test key") - derivedKeys, err = deriveAESKeys(kdfInfo, key) - if err != nil { - t.Fatal(err) - } - //derivedKeys and derivedKeys2 should now be different - if bytes.Equal(derivedKeys.key, derivedKeys2.key) || - bytes.Equal(derivedKeys.iv, derivedKeys2.iv) || - bytes.Equal(derivedKeys.hmacKey, derivedKeys2.hmacKey) { - t.Fail() - } -} - -func TestCipherAESSha256(t *testing.T) { - key := []byte("test key") - cipher := NewAESSHA256([]byte("testKDFinfo")) - message := []byte("this is a random message for testing the implementation") - //increase to next block size - for len(message)%aes.BlockSize != 0 { - message = append(message, []byte("-")...) - } - encrypted, err := cipher.Encrypt(key, []byte(message)) - if err != nil { - t.Fatal(err) - } - mac, err := cipher.MAC(key, encrypted) - if err != nil { - t.Fatal(err) - } - - verified, err := cipher.Verify(key, encrypted, mac[:8]) - if err != nil { - t.Fatal(err) - } - if !verified { - t.Fatal("signature verification failed") - } - resultPlainText, err := cipher.Decrypt(key, encrypted) - if err != nil { - t.Fatal(err) - } - if !bytes.Equal(message, resultPlainText) { - t.Fail() - } -} diff --git a/crypto/goolm/cipher/cipher.go b/crypto/goolm/cipher/cipher.go deleted file mode 100644 index 43580b0b..00000000 --- a/crypto/goolm/cipher/cipher.go +++ /dev/null @@ -1,18 +0,0 @@ -// Package cipher provides the methods and structs to do encryptions for -// olm/megolm. -package cipher - -// Cipher defines a valid cipher. -type Cipher interface { - // Encrypt encrypts the plaintext. - Encrypt(key, plaintext []byte) (ciphertext []byte, err error) - - // Decrypt decrypts the ciphertext. - Decrypt(key, ciphertext []byte) (plaintext []byte, err error) - - //MAC returns the MAC of the message calculated with the key. - MAC(key, message []byte) ([]byte, error) - - //Verify checks the MAC of the message calculated with the key against the givenMAC. - Verify(key, message, givenMAC []byte) (bool, error) -} diff --git a/crypto/goolm/cipher/pickle.go b/crypto/goolm/cipher/pickle.go deleted file mode 100644 index 670ff6ff..00000000 --- a/crypto/goolm/cipher/pickle.go +++ /dev/null @@ -1,58 +0,0 @@ -package cipher - -import ( - "fmt" - - "maunium.net/go/mautrix/crypto/goolm" -) - -const ( - kdfPickle = "Pickle" //used to derive the keys for encryption - pickleMACLength = 8 -) - -// PickleBlockSize returns the blocksize of the used cipher. -func PickleBlockSize() int { - return AESSha512BlockSize() -} - -// Pickle encrypts the input with the key and the cipher AESSHA256. The result is then encoded in base64. -func Pickle(key, input []byte) ([]byte, error) { - pickleCipher := NewAESSHA256([]byte(kdfPickle)) - ciphertext, err := pickleCipher.Encrypt(key, input) - if err != nil { - return nil, err - } - mac, err := pickleCipher.MAC(key, ciphertext) - if err != nil { - return nil, err - } - ciphertext = append(ciphertext, mac[:pickleMACLength]...) - encoded := goolm.Base64Encode(ciphertext) - return encoded, nil -} - -// Unpickle decodes the input from base64 and decrypts the decoded input with the key and the cipher AESSHA256. -func Unpickle(key, input []byte) ([]byte, error) { - pickleCipher := NewAESSHA256([]byte(kdfPickle)) - ciphertext, err := goolm.Base64Decode(input) - if err != nil { - return nil, err - } - //remove mac and check - verified, err := pickleCipher.Verify(key, ciphertext[:len(ciphertext)-pickleMACLength], ciphertext[len(ciphertext)-pickleMACLength:]) - if err != nil { - return nil, err - } - if !verified { - return nil, fmt.Errorf("decrypt pickle: %w", goolm.ErrBadMAC) - } - //Set to next block size - targetCipherText := make([]byte, int(len(ciphertext)/PickleBlockSize())*PickleBlockSize()) - copy(targetCipherText, ciphertext) - plaintext, err := pickleCipher.Decrypt(key, targetCipherText) - if err != nil { - return nil, err - } - return plaintext, nil -} diff --git a/crypto/goolm/cipher/pickle_test.go b/crypto/goolm/cipher/pickle_test.go deleted file mode 100644 index b47bf3ea..00000000 --- a/crypto/goolm/cipher/pickle_test.go +++ /dev/null @@ -1,33 +0,0 @@ -package cipher_test - -import ( - "bytes" - "crypto/aes" - "testing" - - "maunium.net/go/mautrix/crypto/goolm/cipher" -) - -func TestEncoding(t *testing.T) { - key := []byte("test key") - input := []byte("test") - //pad marshaled to get block size - toEncrypt := input - if len(input)%aes.BlockSize != 0 { - padding := aes.BlockSize - len(input)%aes.BlockSize - toEncrypt = make([]byte, len(input)+padding) - copy(toEncrypt, input) - } - encoded, err := cipher.Pickle(key, toEncrypt) - if err != nil { - t.Fatal(err) - } - - decoded, err := cipher.Unpickle(key, encoded) - if err != nil { - t.Fatal(err) - } - if !bytes.Equal(decoded, toEncrypt) { - t.Fatalf("Expected '%s' but got '%s'", toEncrypt, decoded) - } -} diff --git a/crypto/goolm/crypto/curve25519.go b/crypto/goolm/crypto/curve25519.go index 125e1bfd..6e42d886 100644 --- a/crypto/goolm/crypto/curve25519.go +++ b/crypto/goolm/crypto/curve25519.go @@ -1,69 +1,51 @@ package crypto import ( - "bytes" "crypto/rand" + "crypto/subtle" "encoding/base64" - "fmt" - "io" "golang.org/x/crypto/curve25519" - "maunium.net/go/mautrix/crypto/goolm" "maunium.net/go/mautrix/crypto/goolm/libolmpickle" "maunium.net/go/mautrix/id" ) const ( - Curve25519KeyLength = curve25519.ScalarSize //The length of the private key. - curve25519PubKeyLength = 32 + Curve25519PrivateKeyLength = curve25519.ScalarSize //The length of the private key. + Curve25519PublicKeyLength = 32 ) -// Curve25519GenerateKey creates a new curve25519 key pair. If reader is nil, the random data is taken from crypto/rand. -func Curve25519GenerateKey(reader io.Reader) (Curve25519KeyPair, error) { - privateKeyByte := make([]byte, Curve25519KeyLength) - if reader == nil { - _, err := rand.Read(privateKeyByte) - if err != nil { - return Curve25519KeyPair{}, err - } - } else { - _, err := reader.Read(privateKeyByte) - if err != nil { - return Curve25519KeyPair{}, err - } - } - - privateKey := Curve25519PrivateKey(privateKeyByte) - - publicKey, err := privateKey.PubKey() - if err != nil { - return Curve25519KeyPair{}, err - } - return Curve25519KeyPair{ - PrivateKey: Curve25519PrivateKey(privateKey), - PublicKey: Curve25519PublicKey(publicKey), - }, nil -} - -// Curve25519GenerateFromPrivate creates a new curve25519 key pair with the private key given. -func Curve25519GenerateFromPrivate(private Curve25519PrivateKey) (Curve25519KeyPair, error) { - publicKey, err := private.PubKey() - if err != nil { - return Curve25519KeyPair{}, err - } - return Curve25519KeyPair{ - PrivateKey: private, - PublicKey: Curve25519PublicKey(publicKey), - }, nil -} - // Curve25519KeyPair stores both parts of a curve25519 key. type Curve25519KeyPair struct { PrivateKey Curve25519PrivateKey `json:"private,omitempty"` PublicKey Curve25519PublicKey `json:"public,omitempty"` } +// Curve25519GenerateKey creates a new curve25519 key pair. +func Curve25519GenerateKey() (Curve25519KeyPair, error) { + privateKeyByte := make([]byte, Curve25519PrivateKeyLength) + if _, err := rand.Read(privateKeyByte); err != nil { + return Curve25519KeyPair{}, err + } + + privateKey := Curve25519PrivateKey(privateKeyByte) + publicKey, err := privateKey.PubKey() + return Curve25519KeyPair{ + PrivateKey: Curve25519PrivateKey(privateKey), + PublicKey: Curve25519PublicKey(publicKey), + }, err +} + +// Curve25519GenerateFromPrivate creates a new curve25519 key pair with the private key given. +func Curve25519GenerateFromPrivate(private Curve25519PrivateKey) (Curve25519KeyPair, error) { + publicKey, err := private.PubKey() + return Curve25519KeyPair{ + PrivateKey: private, + PublicKey: Curve25519PublicKey(publicKey), + }, err +} + // B64Encoded returns a base64 encoded string of the public key. func (c Curve25519KeyPair) B64Encoded() id.Curve25519 { return c.PublicKey.B64Encoded() @@ -71,53 +53,30 @@ func (c Curve25519KeyPair) B64Encoded() id.Curve25519 { // SharedSecret returns the shared secret between the key pair and the given public key. func (c Curve25519KeyPair) SharedSecret(pubKey Curve25519PublicKey) ([]byte, error) { + // Note: the standard library checks that the output is non-zero return c.PrivateKey.SharedSecret(pubKey) } -// PickleLibOlm encodes the key pair into target. target has to have a size of at least PickleLen() and is written to from index 0. -// It returns the number of bytes written. -func (c Curve25519KeyPair) PickleLibOlm(target []byte) (int, error) { - if len(target) < c.PickleLen() { - return 0, fmt.Errorf("pickle curve25519 key pair: %w", goolm.ErrValueTooShort) - } - written, err := c.PublicKey.PickleLibOlm(target) - if err != nil { - return 0, fmt.Errorf("pickle curve25519 key pair: %w", err) - } - if len(c.PrivateKey) != Curve25519KeyLength { - written += libolmpickle.PickleBytes(make([]byte, Curve25519KeyLength), target[written:]) +// PickleLibOlm pickles the key pair into the encoder. +func (c Curve25519KeyPair) PickleLibOlm(encoder *libolmpickle.Encoder) { + c.PublicKey.PickleLibOlm(encoder) + if len(c.PrivateKey) == Curve25519PrivateKeyLength { + encoder.Write(c.PrivateKey) } else { - written += libolmpickle.PickleBytes(c.PrivateKey, target[written:]) + encoder.WriteEmptyBytes(Curve25519PrivateKeyLength) } - return written, nil } // UnpickleLibOlm decodes the unencryted value and populates the key pair accordingly. It returns the number of bytes read. -func (c *Curve25519KeyPair) UnpickleLibOlm(value []byte) (int, error) { - //unpickle PubKey - read, err := c.PublicKey.UnpickleLibOlm(value) - if err != nil { - return 0, err - } - //unpickle PrivateKey - privKey, readPriv, err := libolmpickle.UnpickleBytes(value[read:], Curve25519KeyLength) - if err != nil { - return read, err - } - c.PrivateKey = privKey - return read + readPriv, nil -} - -// PickleLen returns the number of bytes the pickled key pair will have. -func (c Curve25519KeyPair) PickleLen() int { - lenPublic := c.PublicKey.PickleLen() - var lenPrivate int - if len(c.PrivateKey) != Curve25519KeyLength { - lenPrivate = libolmpickle.PickleBytesLen(make([]byte, Curve25519KeyLength)) +func (c *Curve25519KeyPair) UnpickleLibOlm(decoder *libolmpickle.Decoder) error { + if err := c.PublicKey.UnpickleLibOlm(decoder); err != nil { + return err + } else if privKey, err := decoder.ReadBytes(Curve25519PrivateKeyLength); err != nil { + return err } else { - lenPrivate = libolmpickle.PickleBytesLen(c.PrivateKey) + c.PrivateKey = privKey + return nil } - return lenPublic + lenPrivate } // Curve25519PrivateKey represents the private key for curve25519 usage @@ -125,16 +84,12 @@ type Curve25519PrivateKey []byte // Equal compares the private key to the given private key. func (c Curve25519PrivateKey) Equal(x Curve25519PrivateKey) bool { - return bytes.Equal(c, x) + return subtle.ConstantTimeCompare(c, x) == 1 } // PubKey returns the public key derived from the private key. func (c Curve25519PrivateKey) PubKey() (Curve25519PublicKey, error) { - publicKey, err := curve25519.X25519(c, curve25519.Basepoint) - if err != nil { - return nil, err - } - return publicKey, nil + return curve25519.X25519(c, curve25519.Basepoint) } // SharedSecret returns the shared secret between the private key and the given public key. @@ -147,7 +102,7 @@ type Curve25519PublicKey []byte // Equal compares the public key to the given public key. func (c Curve25519PublicKey) Equal(x Curve25519PublicKey) bool { - return bytes.Equal(c, x) + return subtle.ConstantTimeCompare(c, x) == 1 } // B64Encoded returns a base64 encoded string of the public key. @@ -155,32 +110,18 @@ func (c Curve25519PublicKey) B64Encoded() id.Curve25519 { return id.Curve25519(base64.RawStdEncoding.EncodeToString(c)) } -// PickleLibOlm encodes the public key into target. target has to have a size of at least PickleLen() and is written to from index 0. -// It returns the number of bytes written. -func (c Curve25519PublicKey) PickleLibOlm(target []byte) (int, error) { - if len(target) < c.PickleLen() { - return 0, fmt.Errorf("pickle curve25519 public key: %w", goolm.ErrValueTooShort) +// PickleLibOlm pickles the public key into the encoder. +func (c Curve25519PublicKey) PickleLibOlm(encoder *libolmpickle.Encoder) { + if len(c) == Curve25519PublicKeyLength { + encoder.Write(c) + } else { + encoder.WriteEmptyBytes(Curve25519PublicKeyLength) } - if len(c) != curve25519PubKeyLength { - return libolmpickle.PickleBytes(make([]byte, curve25519PubKeyLength), target), nil - } - return libolmpickle.PickleBytes(c, target), nil } // UnpickleLibOlm decodes the unencryted value and populates the public key accordingly. It returns the number of bytes read. -func (c *Curve25519PublicKey) UnpickleLibOlm(value []byte) (int, error) { - unpickled, readBytes, err := libolmpickle.UnpickleBytes(value, curve25519PubKeyLength) - if err != nil { - return 0, err - } - *c = unpickled - return readBytes, nil -} - -// PickleLen returns the number of bytes the pickled public key will have. -func (c Curve25519PublicKey) PickleLen() int { - if len(c) != curve25519PubKeyLength { - return libolmpickle.PickleBytesLen(make([]byte, curve25519PubKeyLength)) - } - return libolmpickle.PickleBytesLen(c) +func (c *Curve25519PublicKey) UnpickleLibOlm(decoder *libolmpickle.Decoder) error { + pubkey, err := decoder.ReadBytes(Curve25519PublicKeyLength) + *c = pubkey + return err } diff --git a/crypto/goolm/crypto/curve25519_test.go b/crypto/goolm/crypto/curve25519_test.go index f7df5edc..2550f15e 100644 --- a/crypto/goolm/crypto/curve25519_test.go +++ b/crypto/goolm/crypto/curve25519_test.go @@ -1,39 +1,32 @@ package crypto_test import ( - "bytes" "testing" + "github.com/stretchr/testify/assert" + "maunium.net/go/mautrix/crypto/goolm/crypto" + "maunium.net/go/mautrix/crypto/goolm/libolmpickle" ) +const curve25519KeyPairPickleLength = crypto.Curve25519PublicKeyLength + // Public Key + crypto.Curve25519PrivateKeyLength // Private Key + func TestCurve25519(t *testing.T) { - firstKeypair, err := crypto.Curve25519GenerateKey(nil) - if err != nil { - t.Fatal(err) - } - secondKeypair, err := crypto.Curve25519GenerateKey(nil) - if err != nil { - t.Fatal(err) - } + firstKeypair, err := crypto.Curve25519GenerateKey() + assert.NoError(t, err) + secondKeypair, err := crypto.Curve25519GenerateKey() + assert.NoError(t, err) sharedSecretFromFirst, err := firstKeypair.SharedSecret(secondKeypair.PublicKey) - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err) sharedSecretFromSecond, err := secondKeypair.SharedSecret(firstKeypair.PublicKey) - if err != nil { - t.Fatal(err) - } - if !bytes.Equal(sharedSecretFromFirst, sharedSecretFromSecond) { - t.Fatal("shared secret not equal") - } + assert.NoError(t, err) + assert.Equal(t, sharedSecretFromFirst, sharedSecretFromSecond, "shared secret not equal") fromPrivate, err := crypto.Curve25519GenerateFromPrivate(firstKeypair.PrivateKey) - if err != nil { - t.Fatal(err) - } - if !bytes.Equal(fromPrivate.PublicKey, firstKeypair.PublicKey) { - t.Fatal("public keys not equal") - } + assert.NoError(t, err) + assert.Equal(t, fromPrivate, firstKeypair) + _, err = secondKeypair.SharedSecret(make([]byte, crypto.Curve25519PublicKeyLength)) + assert.Error(t, err) } func TestCurve25519Case1(t *testing.T) { @@ -76,112 +69,57 @@ func TestCurve25519Case1(t *testing.T) { PublicKey: bobPublic, } agreementFromAlice, err := aliceKeyPair.SharedSecret(bobKeyPair.PublicKey) - if err != nil { - t.Fatal(err) - } - if !bytes.Equal(agreementFromAlice, expectedAgreement) { - t.Fatal("expected agreement does not match agreement from Alice's view") - } + assert.NoError(t, err) + assert.Equal(t, expectedAgreement, agreementFromAlice, "expected agreement does not match agreement from Alice's view") agreementFromBob, err := bobKeyPair.SharedSecret(aliceKeyPair.PublicKey) - if err != nil { - t.Fatal(err) - } - if !bytes.Equal(agreementFromBob, expectedAgreement) { - t.Fatal("expected agreement does not match agreement from Bob's view") - } + assert.NoError(t, err) + assert.Equal(t, expectedAgreement, agreementFromBob, "expected agreement does not match agreement from Bob's view") } func TestCurve25519Pickle(t *testing.T) { //create keypair - keyPair, err := crypto.Curve25519GenerateKey(nil) - if err != nil { - t.Fatal(err) - } - target := make([]byte, keyPair.PickleLen()) - writtenBytes, err := keyPair.PickleLibOlm(target) - if err != nil { - t.Fatal(err) - } - if writtenBytes != len(target) { - t.Fatal("written bytes not correct") - } + keyPair, err := crypto.Curve25519GenerateKey() + assert.NoError(t, err) + + encoder := libolmpickle.NewEncoder() + keyPair.PickleLibOlm(encoder) + assert.Len(t, encoder.Bytes(), curve25519KeyPairPickleLength) unpickledKeyPair := crypto.Curve25519KeyPair{} - readBytes, err := unpickledKeyPair.UnpickleLibOlm(target) - if err != nil { - t.Fatal(err) - } - if readBytes != len(target) { - t.Fatal("read bytes not correct") - } - if !bytes.Equal(keyPair.PrivateKey, unpickledKeyPair.PrivateKey) { - t.Fatal("private keys not correct") - } - if !bytes.Equal(keyPair.PublicKey, unpickledKeyPair.PublicKey) { - t.Fatal("public keys not correct") - } + err = unpickledKeyPair.UnpickleLibOlm(libolmpickle.NewDecoder(encoder.Bytes())) + assert.NoError(t, err) + assert.Equal(t, keyPair, unpickledKeyPair) } func TestCurve25519PicklePubKeyOnly(t *testing.T) { //create keypair - keyPair, err := crypto.Curve25519GenerateKey(nil) - if err != nil { - t.Fatal(err) - } + keyPair, err := crypto.Curve25519GenerateKey() + assert.NoError(t, err) + //Remove privateKey keyPair.PrivateKey = nil - target := make([]byte, keyPair.PickleLen()) - writtenBytes, err := keyPair.PickleLibOlm(target) - if err != nil { - t.Fatal(err) - } - if writtenBytes != len(target) { - t.Fatal("written bytes not correct") - } + + encoder := libolmpickle.NewEncoder() + keyPair.PickleLibOlm(encoder) + assert.Len(t, encoder.Bytes(), curve25519KeyPairPickleLength) + unpickledKeyPair := crypto.Curve25519KeyPair{} - readBytes, err := unpickledKeyPair.UnpickleLibOlm(target) - if err != nil { - t.Fatal(err) - } - if readBytes != len(target) { - t.Fatal("read bytes not correct") - } - if !bytes.Equal(keyPair.PrivateKey, unpickledKeyPair.PrivateKey) { - t.Fatal("private keys not correct") - } - if !bytes.Equal(keyPair.PublicKey, unpickledKeyPair.PublicKey) { - t.Fatal("public keys not correct") - } + err = unpickledKeyPair.UnpickleLibOlm(libolmpickle.NewDecoder(encoder.Bytes())) + assert.NoError(t, err) + assert.Equal(t, keyPair, unpickledKeyPair) } func TestCurve25519PicklePrivKeyOnly(t *testing.T) { //create keypair - keyPair, err := crypto.Curve25519GenerateKey(nil) - if err != nil { - t.Fatal(err) - } + keyPair, err := crypto.Curve25519GenerateKey() + assert.NoError(t, err) //Remove public keyPair.PublicKey = nil - target := make([]byte, keyPair.PickleLen()) - writtenBytes, err := keyPair.PickleLibOlm(target) - if err != nil { - t.Fatal(err) - } - if writtenBytes != len(target) { - t.Fatal("written bytes not correct") - } + encoder := libolmpickle.NewEncoder() + keyPair.PickleLibOlm(encoder) + assert.Len(t, encoder.Bytes(), curve25519KeyPairPickleLength) unpickledKeyPair := crypto.Curve25519KeyPair{} - readBytes, err := unpickledKeyPair.UnpickleLibOlm(target) - if err != nil { - t.Fatal(err) - } - if readBytes != len(target) { - t.Fatal("read bytes not correct") - } - if !bytes.Equal(keyPair.PrivateKey, unpickledKeyPair.PrivateKey) { - t.Fatal("private keys not correct") - } - if !bytes.Equal(keyPair.PublicKey, unpickledKeyPair.PublicKey) { - t.Fatal("public keys not correct") - } + err = unpickledKeyPair.UnpickleLibOlm(libolmpickle.NewDecoder(encoder.Bytes())) + assert.NoError(t, err) + assert.Equal(t, keyPair, unpickledKeyPair) } diff --git a/crypto/goolm/crypto/ed25519.go b/crypto/goolm/crypto/ed25519.go index bc21300c..a3345ba9 100644 --- a/crypto/goolm/crypto/ed25519.go +++ b/crypto/goolm/crypto/ed25519.go @@ -1,31 +1,24 @@ package crypto import ( - "bytes" - "crypto/ed25519" "encoding/base64" - "fmt" - "io" - "maunium.net/go/mautrix/crypto/goolm" + "maunium.net/go/mautrix/crypto/ed25519" "maunium.net/go/mautrix/crypto/goolm/libolmpickle" "maunium.net/go/mautrix/id" ) const ( - ED25519SignatureSize = ed25519.SignatureSize //The length of a signature + Ed25519SignatureSize = ed25519.SignatureSize //The length of a signature ) -// Ed25519GenerateKey creates a new ed25519 key pair. If reader is nil, the random data is taken from crypto/rand. -func Ed25519GenerateKey(reader io.Reader) (Ed25519KeyPair, error) { - publicKey, privateKey, err := ed25519.GenerateKey(reader) - if err != nil { - return Ed25519KeyPair{}, err - } +// Ed25519GenerateKey creates a new ed25519 key pair. +func Ed25519GenerateKey() (Ed25519KeyPair, error) { + publicKey, privateKey, err := ed25519.GenerateKey(nil) return Ed25519KeyPair{ PrivateKey: Ed25519PrivateKey(privateKey), PublicKey: Ed25519PublicKey(publicKey), - }, nil + }, err } // Ed25519GenerateFromPrivate creates a new ed25519 key pair with the private key given. @@ -57,7 +50,7 @@ func (c Ed25519KeyPair) B64Encoded() id.Ed25519 { } // Sign returns the signature for the message. -func (c Ed25519KeyPair) Sign(message []byte) []byte { +func (c Ed25519KeyPair) Sign(message []byte) ([]byte, error) { return c.PrivateKey.Sign(message) } @@ -66,51 +59,26 @@ func (c Ed25519KeyPair) Verify(message, givenSignature []byte) bool { return c.PublicKey.Verify(message, givenSignature) } -// PickleLibOlm encodes the key pair into target. target has to have a size of at least PickleLen() and is written to from index 0. -// It returns the number of bytes written. -func (c Ed25519KeyPair) PickleLibOlm(target []byte) (int, error) { - if len(target) < c.PickleLen() { - return 0, fmt.Errorf("pickle ed25519 key pair: %w", goolm.ErrValueTooShort) - } - written, err := c.PublicKey.PickleLibOlm(target) - if err != nil { - return 0, fmt.Errorf("pickle ed25519 key pair: %w", err) - } - - if len(c.PrivateKey) != ed25519.PrivateKeySize { - written += libolmpickle.PickleBytes(make([]byte, ed25519.PrivateKeySize), target[written:]) +// PickleLibOlm pickles the key pair into the encoder. +func (c Ed25519KeyPair) PickleLibOlm(encoder *libolmpickle.Encoder) { + c.PublicKey.PickleLibOlm(encoder) + if len(c.PrivateKey) == ed25519.PrivateKeySize { + encoder.Write(c.PrivateKey) } else { - written += libolmpickle.PickleBytes(c.PrivateKey, target[written:]) + encoder.WriteEmptyBytes(ed25519.PrivateKeySize) } - return written, nil } -// UnpickleLibOlm decodes the unencryted value and populates the key pair accordingly. It returns the number of bytes read. -func (c *Ed25519KeyPair) UnpickleLibOlm(value []byte) (int, error) { - //unpickle PubKey - read, err := c.PublicKey.UnpickleLibOlm(value) - if err != nil { - return 0, err - } - //unpickle PrivateKey - privKey, readPriv, err := libolmpickle.UnpickleBytes(value[read:], ed25519.PrivateKeySize) - if err != nil { - return read, err - } - c.PrivateKey = privKey - return read + readPriv, nil -} - -// PickleLen returns the number of bytes the pickled key pair will have. -func (c Ed25519KeyPair) PickleLen() int { - lenPublic := c.PublicKey.PickleLen() - var lenPrivate int - if len(c.PrivateKey) != ed25519.PrivateKeySize { - lenPrivate = libolmpickle.PickleBytesLen(make([]byte, ed25519.PrivateKeySize)) +// UnpickleLibOlm unpickles the unencryted value and populates the key pair accordingly. +func (c *Ed25519KeyPair) UnpickleLibOlm(decoder *libolmpickle.Decoder) error { + if err := c.PublicKey.UnpickleLibOlm(decoder); err != nil { + return err + } else if privKey, err := decoder.ReadBytes(ed25519.PrivateKeySize); err != nil { + return err } else { - lenPrivate = libolmpickle.PickleBytesLen(c.PrivateKey) + c.PrivateKey = privKey + return nil } - return lenPublic + lenPrivate } // Curve25519PrivateKey represents the private key for ed25519 usage. This is just a wrapper. @@ -118,18 +86,18 @@ type Ed25519PrivateKey ed25519.PrivateKey // Equal compares the private key to the given private key. func (c Ed25519PrivateKey) Equal(x Ed25519PrivateKey) bool { - return bytes.Equal(c, x) + return ed25519.PrivateKey(c).Equal(ed25519.PrivateKey(x)) } // PubKey returns the public key derived from the private key. func (c Ed25519PrivateKey) PubKey() Ed25519PublicKey { publicKey := ed25519.PrivateKey(c).Public() - return Ed25519PublicKey(publicKey.(ed25519.PublicKey)) + return Ed25519PublicKey(publicKey.([]byte)) } // Sign returns the signature for the message. -func (c Ed25519PrivateKey) Sign(message []byte) []byte { - return ed25519.Sign(ed25519.PrivateKey(c), message) +func (c Ed25519PrivateKey) Sign(message []byte) ([]byte, error) { + return ed25519.PrivateKey(c).Sign(nil, message, &ed25519.Options{}) } // Ed25519PublicKey represents the public key for ed25519 usage. This is just a wrapper. @@ -137,7 +105,7 @@ type Ed25519PublicKey ed25519.PublicKey // Equal compares the public key to the given public key. func (c Ed25519PublicKey) Equal(x Ed25519PublicKey) bool { - return bytes.Equal(c, x) + return ed25519.PublicKey(c).Equal(ed25519.PublicKey(x)) } // B64Encoded returns a base64 encoded string of the public key. @@ -150,32 +118,19 @@ func (c Ed25519PublicKey) Verify(message, givenSignature []byte) bool { return ed25519.Verify(ed25519.PublicKey(c), message, givenSignature) } -// PickleLibOlm encodes the public key into target. target has to have a size of at least PickleLen() and is written to from index 0. -// It returns the number of bytes written. -func (c Ed25519PublicKey) PickleLibOlm(target []byte) (int, error) { - if len(target) < c.PickleLen() { - return 0, fmt.Errorf("pickle ed25519 public key: %w", goolm.ErrValueTooShort) +// PickleLibOlm pickles the public key into the encoder. +func (c Ed25519PublicKey) PickleLibOlm(encoder *libolmpickle.Encoder) { + if len(c) == ed25519.PublicKeySize { + encoder.Write(c) + } else { + encoder.WriteEmptyBytes(ed25519.PublicKeySize) } - if len(c) != ed25519.PublicKeySize { - return libolmpickle.PickleBytes(make([]byte, ed25519.PublicKeySize), target), nil - } - return libolmpickle.PickleBytes(c, target), nil } -// UnpickleLibOlm decodes the unencryted value and populates the public key accordingly. It returns the number of bytes read. -func (c *Ed25519PublicKey) UnpickleLibOlm(value []byte) (int, error) { - unpickled, readBytes, err := libolmpickle.UnpickleBytes(value, ed25519.PublicKeySize) - if err != nil { - return 0, err - } - *c = unpickled - return readBytes, nil -} - -// PickleLen returns the number of bytes the pickled public key will have. -func (c Ed25519PublicKey) PickleLen() int { - if len(c) != ed25519.PublicKeySize { - return libolmpickle.PickleBytesLen(make([]byte, ed25519.PublicKeySize)) - } - return libolmpickle.PickleBytesLen(c) +// UnpickleLibOlm unpickles the unencryted value and populates the public key +// accordingly. +func (c *Ed25519PublicKey) UnpickleLibOlm(decoder *libolmpickle.Decoder) error { + key, err := decoder.ReadBytes(ed25519.PublicKeySize) + *c = key + return err } diff --git a/crypto/goolm/crypto/ed25519_test.go b/crypto/goolm/crypto/ed25519_test.go index 391de912..610b8f3e 100644 --- a/crypto/goolm/crypto/ed25519_test.go +++ b/crypto/goolm/crypto/ed25519_test.go @@ -1,140 +1,89 @@ package crypto_test import ( - "bytes" "testing" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "maunium.net/go/mautrix/crypto/ed25519" "maunium.net/go/mautrix/crypto/goolm/crypto" + "maunium.net/go/mautrix/crypto/goolm/libolmpickle" ) +const ed25519KeyPairPickleLength = ed25519.PublicKeySize + // PublicKey + ed25519.PrivateKeySize // Private Key + func TestEd25519(t *testing.T) { - keypair, err := crypto.Ed25519GenerateKey(nil) - if err != nil { - t.Fatal(err) - } + keypair, err := crypto.Ed25519GenerateKey() + assert.NoError(t, err) message := []byte("test message") - signature := keypair.Sign(message) - if !keypair.Verify(message, signature) { - t.Fail() - } + signature, err := keypair.Sign(message) + require.NoError(t, err) + assert.True(t, keypair.Verify(message, signature)) } func TestEd25519Case1(t *testing.T) { //64 bytes for ed25519 package - keyPair, err := crypto.Ed25519GenerateKey(nil) - if err != nil { - t.Fatal(err) - } + keyPair, err := crypto.Ed25519GenerateKey() + assert.NoError(t, err) message := []byte("Hello, World") keyPair2 := crypto.Ed25519GenerateFromPrivate(keyPair.PrivateKey) - if !bytes.Equal(keyPair.PublicKey, keyPair2.PublicKey) { - t.Fatal("not equal key pairs") - } - signature := keyPair.Sign(message) + assert.Equal(t, keyPair, keyPair2, "not equal key pairs") + signature, err := keyPair.Sign(message) + require.NoError(t, err) verified := keyPair.Verify(message, signature) - if !verified { - t.Fatal("message did not verify although it should") - } + assert.True(t, verified, "message did not verify although it should") + //Now change the message and verify again message = append(message, []byte("a")...) verified = keyPair.Verify(message, signature) - if verified { - t.Fatal("message did verify although it should not") - } + assert.False(t, verified, "message did verify although it should not") } func TestEd25519Pickle(t *testing.T) { //create keypair - keyPair, err := crypto.Ed25519GenerateKey(nil) - if err != nil { - t.Fatal(err) - } - target := make([]byte, keyPair.PickleLen()) - writtenBytes, err := keyPair.PickleLibOlm(target) - if err != nil { - t.Fatal(err) - } - if writtenBytes != len(target) { - t.Fatal("written bytes not correct") - } + keyPair, err := crypto.Ed25519GenerateKey() + assert.NoError(t, err) + encoder := libolmpickle.NewEncoder() + keyPair.PickleLibOlm(encoder) + assert.Len(t, encoder.Bytes(), ed25519KeyPairPickleLength) unpickledKeyPair := crypto.Ed25519KeyPair{} - readBytes, err := unpickledKeyPair.UnpickleLibOlm(target) - if err != nil { - t.Fatal(err) - } - if readBytes != len(target) { - t.Fatal("read bytes not correct") - } - if !bytes.Equal(keyPair.PrivateKey, unpickledKeyPair.PrivateKey) { - t.Fatal("private keys not correct") - } - if !bytes.Equal(keyPair.PublicKey, unpickledKeyPair.PublicKey) { - t.Fatal("public keys not correct") - } + err = unpickledKeyPair.UnpickleLibOlm(libolmpickle.NewDecoder(encoder.Bytes())) + assert.NoError(t, err) + assert.Equal(t, keyPair, unpickledKeyPair) } func TestEd25519PicklePubKeyOnly(t *testing.T) { //create keypair - keyPair, err := crypto.Ed25519GenerateKey(nil) - if err != nil { - t.Fatal(err) - } + keyPair, err := crypto.Ed25519GenerateKey() + assert.NoError(t, err) //Remove privateKey keyPair.PrivateKey = nil - target := make([]byte, keyPair.PickleLen()) - writtenBytes, err := keyPair.PickleLibOlm(target) - if err != nil { - t.Fatal(err) - } - if writtenBytes != len(target) { - t.Fatal("written bytes not correct") - } + encoder := libolmpickle.NewEncoder() + keyPair.PickleLibOlm(encoder) + assert.Len(t, encoder.Bytes(), ed25519KeyPairPickleLength) + unpickledKeyPair := crypto.Ed25519KeyPair{} - readBytes, err := unpickledKeyPair.UnpickleLibOlm(target) - if err != nil { - t.Fatal(err) - } - if readBytes != len(target) { - t.Fatal("read bytes not correct") - } - if !bytes.Equal(keyPair.PrivateKey, unpickledKeyPair.PrivateKey) { - t.Fatal("private keys not correct") - } - if !bytes.Equal(keyPair.PublicKey, unpickledKeyPair.PublicKey) { - t.Fatal("public keys not correct") - } + err = unpickledKeyPair.UnpickleLibOlm(libolmpickle.NewDecoder(encoder.Bytes())) + assert.NoError(t, err) + assert.Equal(t, keyPair, unpickledKeyPair) } func TestEd25519PicklePrivKeyOnly(t *testing.T) { //create keypair - keyPair, err := crypto.Ed25519GenerateKey(nil) - if err != nil { - t.Fatal(err) - } + keyPair, err := crypto.Ed25519GenerateKey() + assert.NoError(t, err) //Remove public keyPair.PublicKey = nil - target := make([]byte, keyPair.PickleLen()) - writtenBytes, err := keyPair.PickleLibOlm(target) - if err != nil { - t.Fatal(err) - } - if writtenBytes != len(target) { - t.Fatal("written bytes not correct") - } + encoder := libolmpickle.NewEncoder() + keyPair.PickleLibOlm(encoder) + assert.Len(t, encoder.Bytes(), ed25519KeyPairPickleLength) + unpickledKeyPair := crypto.Ed25519KeyPair{} - readBytes, err := unpickledKeyPair.UnpickleLibOlm(target) - if err != nil { - t.Fatal(err) - } - if readBytes != len(target) { - t.Fatal("read bytes not correct") - } - if !bytes.Equal(keyPair.PrivateKey, unpickledKeyPair.PrivateKey) { - t.Fatal("private keys not correct") - } - if !bytes.Equal(keyPair.PublicKey, unpickledKeyPair.PublicKey) { - t.Fatal("public keys not correct") - } + err = unpickledKeyPair.UnpickleLibOlm(libolmpickle.NewDecoder(encoder.Bytes())) + assert.NoError(t, err) + assert.Equal(t, keyPair, unpickledKeyPair) } diff --git a/crypto/goolm/crypto/hmac.go b/crypto/goolm/crypto/hmac.go deleted file mode 100644 index 8542f7cb..00000000 --- a/crypto/goolm/crypto/hmac.go +++ /dev/null @@ -1,29 +0,0 @@ -package crypto - -import ( - "crypto/hmac" - "crypto/sha256" - "io" - - "golang.org/x/crypto/hkdf" -) - -// HMACSHA256 returns the hash message authentication code with SHA-256 of the input with the key. -func HMACSHA256(key, input []byte) []byte { - hash := hmac.New(sha256.New, key) - hash.Write(input) - return hash.Sum(nil) -} - -// SHA256 return the SHA-256 of the value. -func SHA256(value []byte) []byte { - hash := sha256.New() - hash.Write(value) - return hash.Sum(nil) -} - -// HKDFSHA256 is the key deivation function based on HMAC and returns a reader based on input. salt and info can both be nil. -// The reader can be used to read an arbitary length of bytes which are based on all parameters. -func HKDFSHA256(input, salt, info []byte) io.Reader { - return hkdf.New(sha256.New, input, salt, info) -} diff --git a/crypto/goolm/crypto/hmac_test.go b/crypto/goolm/crypto/hmac_test.go deleted file mode 100644 index 95c0bfd5..00000000 --- a/crypto/goolm/crypto/hmac_test.go +++ /dev/null @@ -1,114 +0,0 @@ -package crypto_test - -import ( - "bytes" - "encoding/base64" - "io" - "testing" - - "maunium.net/go/mautrix/crypto/goolm/crypto" -) - -func TestHMACSha256(t *testing.T) { - key := []byte("test key") - message := []byte("test message") - hash := crypto.HMACSHA256(key, message) - if !bytes.Equal(hash, crypto.HMACSHA256(key, message)) { - t.Fail() - } - str := "A4M0ovdiWHaZ5msdDFbrvtChFwZIoIaRSVGmv8bmPtc" - result, err := base64.RawStdEncoding.DecodeString(str) - if err != nil { - t.Fatal(err) - } - if !bytes.Equal(result, hash) { - t.Fail() - } -} - -func TestHKDFSha256(t *testing.T) { - message := []byte("test content") - hkdf := crypto.HKDFSHA256(message, nil, nil) - hkdf2 := crypto.HKDFSHA256(message, nil, nil) - result := make([]byte, 32) - if _, err := io.ReadFull(hkdf, result); err != nil { - t.Fatal(err) - } - result2 := make([]byte, 32) - if _, err := io.ReadFull(hkdf2, result2); err != nil { - t.Fatal(err) - } - if !bytes.Equal(result, result2) { - t.Fail() - } -} - -func TestSha256Case1(t *testing.T) { - input := make([]byte, 0) - expected := []byte{ - 0xE3, 0xB0, 0xC4, 0x42, 0x98, 0xFC, 0x1C, 0x14, - 0x9A, 0xFB, 0xF4, 0xC8, 0x99, 0x6F, 0xB9, 0x24, - 0x27, 0xAE, 0x41, 0xE4, 0x64, 0x9B, 0x93, 0x4C, - 0xA4, 0x95, 0x99, 0x1B, 0x78, 0x52, 0xB8, 0x55, - } - result := crypto.SHA256(input) - if !bytes.Equal(expected, result) { - t.Fatalf("result not as expected:\n%v\n%v\n", result, expected) - } -} - -func TestHMACCase1(t *testing.T) { - input := make([]byte, 0) - expected := []byte{ - 0xb6, 0x13, 0x67, 0x9a, 0x08, 0x14, 0xd9, 0xec, - 0x77, 0x2f, 0x95, 0xd7, 0x78, 0xc3, 0x5f, 0xc5, - 0xff, 0x16, 0x97, 0xc4, 0x93, 0x71, 0x56, 0x53, - 0xc6, 0xc7, 0x12, 0x14, 0x42, 0x92, 0xc5, 0xad, - } - result := crypto.HMACSHA256(input, input) - if !bytes.Equal(expected, result) { - t.Fatalf("result not as expected:\n%v\n%v\n", result, expected) - } -} - -func TestHDKFCase1(t *testing.T) { - input := []byte{ - 0x0b, 0x0b, 0x0b, 0x0b, 0x0b, 0x0b, 0x0b, 0x0b, - 0x0b, 0x0b, 0x0b, 0x0b, 0x0b, 0x0b, 0x0b, 0x0b, - 0x0b, 0x0b, 0x0b, 0x0b, 0x0b, 0x0b, - } - salt := []byte{ - 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, - 0x08, 0x09, 0x0a, 0x0b, 0x0c, - } - info := []byte{ - 0xf0, 0xf1, 0xf2, 0xf3, 0xf4, 0xf5, 0xf6, 0xf7, - 0xf8, 0xf9, - } - expectedHMAC := []byte{ - 0x07, 0x77, 0x09, 0x36, 0x2c, 0x2e, 0x32, 0xdf, - 0x0d, 0xdc, 0x3f, 0x0d, 0xc4, 0x7b, 0xba, 0x63, - 0x90, 0xb6, 0xc7, 0x3b, 0xb5, 0x0f, 0x9c, 0x31, - 0x22, 0xec, 0x84, 0x4a, 0xd7, 0xc2, 0xb3, 0xe5, - } - result := crypto.HMACSHA256(salt, input) - if !bytes.Equal(expectedHMAC, result) { - t.Fatalf("result not as expected:\n%v\n%v\n", result, expectedHMAC) - } - expectedHDKF := []byte{ - 0x3c, 0xb2, 0x5f, 0x25, 0xfa, 0xac, 0xd5, 0x7a, - 0x90, 0x43, 0x4f, 0x64, 0xd0, 0x36, 0x2f, 0x2a, - 0x2d, 0x2d, 0x0a, 0x90, 0xcf, 0x1a, 0x5a, 0x4c, - 0x5d, 0xb0, 0x2d, 0x56, 0xec, 0xc4, 0xc5, 0xbf, - 0x34, 0x00, 0x72, 0x08, 0xd5, 0xb8, 0x87, 0x18, - 0x58, 0x65, - } - resultReader := crypto.HKDFSHA256(input, salt, info) - result = make([]byte, len(expectedHDKF)) - if _, err := io.ReadFull(resultReader, result); err != nil { - t.Fatal(err) - } - if !bytes.Equal(expectedHDKF, result) { - t.Fatalf("result not as expected:\n%v\n%v\n", result, expectedHDKF) - } -} diff --git a/crypto/goolm/crypto/one_time_key.go b/crypto/goolm/crypto/one_time_key.go index 67465563..888b1749 100644 --- a/crypto/goolm/crypto/one_time_key.go +++ b/crypto/goolm/crypto/one_time_key.go @@ -3,11 +3,8 @@ package crypto import ( "encoding/base64" "encoding/binary" - "fmt" - "maunium.net/go/mautrix/crypto/goolm" "maunium.net/go/mautrix/crypto/goolm/libolmpickle" - "maunium.net/go/mautrix/id" ) // OneTimeKey stores the information about a one time key. @@ -18,78 +15,32 @@ type OneTimeKey struct { } // Equal compares the one time key to the given one. -func (otk OneTimeKey) Equal(s OneTimeKey) bool { - if otk.ID != s.ID { - return false - } - if otk.Published != s.Published { - return false - } - if !otk.Key.PrivateKey.Equal(s.Key.PrivateKey) { - return false - } - if !otk.Key.PublicKey.Equal(s.Key.PublicKey) { - return false - } - return true +func (otk OneTimeKey) Equal(other OneTimeKey) bool { + return otk.ID == other.ID && + otk.Published == other.Published && + otk.Key.PrivateKey.Equal(other.Key.PrivateKey) && + otk.Key.PublicKey.Equal(other.Key.PublicKey) } -// PickleLibOlm encodes the key pair into target. target has to have a size of at least PickleLen() and is written to from index 0. -// It returns the number of bytes written. -func (c OneTimeKey) PickleLibOlm(target []byte) (int, error) { - if len(target) < c.PickleLen() { - return 0, fmt.Errorf("pickle one time key: %w", goolm.ErrValueTooShort) - } - written := libolmpickle.PickleUInt32(uint32(c.ID), target) - written += libolmpickle.PickleBool(c.Published, target[written:]) - writtenKey, err := c.Key.PickleLibOlm(target[written:]) - if err != nil { - return 0, fmt.Errorf("pickle one time key: %w", err) - } - written += writtenKey - return written, nil +// PickleLibOlm pickles the key pair into the encoder. +func (c OneTimeKey) PickleLibOlm(encoder *libolmpickle.Encoder) { + encoder.WriteUInt32(c.ID) + encoder.WriteBool(c.Published) + c.Key.PickleLibOlm(encoder) } -// UnpickleLibOlm decodes the unencryted value and populates the OneTimeKey accordingly. It returns the number of bytes read. -func (c *OneTimeKey) UnpickleLibOlm(value []byte) (int, error) { - totalReadBytes := 0 - id, readBytes, err := libolmpickle.UnpickleUInt32(value) - if err != nil { - return 0, err +// UnpickleLibOlm unpickles the unencryted value and populates the [OneTimeKey] +// accordingly. +func (c *OneTimeKey) UnpickleLibOlm(decoder *libolmpickle.Decoder) (err error) { + if c.ID, err = decoder.ReadUInt32(); err != nil { + return + } else if c.Published, err = decoder.ReadBool(); err != nil { + return } - totalReadBytes += readBytes - c.ID = id - published, readBytes, err := libolmpickle.UnpickleBool(value[totalReadBytes:]) - if err != nil { - return 0, err - } - totalReadBytes += readBytes - c.Published = published - readBytes, err = c.Key.UnpickleLibOlm(value[totalReadBytes:]) - if err != nil { - return 0, err - } - totalReadBytes += readBytes - return totalReadBytes, nil + return c.Key.UnpickleLibOlm(decoder) } -// PickleLen returns the number of bytes the pickled OneTimeKey will have. -func (c OneTimeKey) PickleLen() int { - length := 0 - length += libolmpickle.PickleUInt32Len(c.ID) - length += libolmpickle.PickleBoolLen(c.Published) - length += c.Key.PickleLen() - return length -} - -// KeyIDEncoded returns the base64 encoded id. +// KeyIDEncoded returns the base64 encoded key ID. func (c OneTimeKey) KeyIDEncoded() string { - resSlice := make([]byte, 4) - binary.BigEndian.PutUint32(resSlice, c.ID) - return base64.RawStdEncoding.EncodeToString(resSlice) -} - -// PublicKeyEncoded returns the base64 encoded public key -func (c OneTimeKey) PublicKeyEncoded() id.Curve25519 { - return c.Key.PublicKey.B64Encoded() + return base64.RawStdEncoding.EncodeToString(binary.BigEndian.AppendUint32(nil, c.ID)) } diff --git a/crypto/goolm/errors.go b/crypto/goolm/errors.go deleted file mode 100644 index 6539b0f1..00000000 --- a/crypto/goolm/errors.go +++ /dev/null @@ -1,28 +0,0 @@ -package goolm - -import ( - "errors" -) - -// Those are the most common used errors -var ( - ErrBadSignature = errors.New("bad signature") - ErrBadMAC = errors.New("bad mac") - ErrBadMessageFormat = errors.New("bad message format") - ErrBadVerification = errors.New("bad verification") - ErrWrongProtocolVersion = errors.New("wrong protocol version") - ErrEmptyInput = errors.New("empty input") - ErrNoKeyProvided = errors.New("no key") - ErrBadMessageKeyID = errors.New("bad message key id") - ErrRatchetNotAvailable = errors.New("ratchet not available: attempt to decode a message whose index is earlier than our earliest known session key") - ErrMsgIndexTooHigh = errors.New("message index too high") - ErrProtocolViolation = errors.New("not protocol message order") - ErrMessageKeyNotFound = errors.New("message key not found") - ErrChainTooHigh = errors.New("chain index too high") - ErrBadInput = errors.New("bad input") - ErrBadVersion = errors.New("wrong version") - ErrWrongPickleVersion = errors.New("wrong pickle version") - ErrValueTooShort = errors.New("value too short") - ErrInputToSmall = errors.New("input too small (truncated?)") - ErrOverflow = errors.New("overflow") -) diff --git a/crypto/goolm/base64.go b/crypto/goolm/goolmbase64/base64.go similarity index 62% rename from crypto/goolm/base64.go rename to crypto/goolm/goolmbase64/base64.go index 229008cf..58ee26f7 100644 --- a/crypto/goolm/base64.go +++ b/crypto/goolm/goolmbase64/base64.go @@ -1,11 +1,12 @@ -package goolm +package goolmbase64 import ( "encoding/base64" ) -// Deprecated: base64.RawStdEncoding should be used directly -func Base64Decode(input []byte) ([]byte, error) { +// These methods should only be used for raw byte operations, never with string conversion + +func Decode(input []byte) ([]byte, error) { decoded := make([]byte, base64.RawStdEncoding.DecodedLen(len(input))) writtenBytes, err := base64.RawStdEncoding.Decode(decoded, input) if err != nil { @@ -14,8 +15,7 @@ func Base64Decode(input []byte) ([]byte, error) { return decoded[:writtenBytes], nil } -// Deprecated: base64.RawStdEncoding should be used directly -func Base64Encode(input []byte) []byte { +func Encode(input []byte) []byte { encoded := make([]byte, base64.RawStdEncoding.EncodedLen(len(input))) base64.RawStdEncoding.Encode(encoded, input) return encoded diff --git a/crypto/goolm/libolmpickle/encoder.go b/crypto/goolm/libolmpickle/encoder.go new file mode 100644 index 00000000..63e7b09b --- /dev/null +++ b/crypto/goolm/libolmpickle/encoder.go @@ -0,0 +1,40 @@ +package libolmpickle + +import ( + "bytes" + "encoding/binary" + + "go.mau.fi/util/exerrors" +) + +const ( + PickleBoolLength = 1 + PickleUInt8Length = 1 + PickleUInt32Length = 4 +) + +type Encoder struct { + bytes.Buffer +} + +func NewEncoder() *Encoder { return &Encoder{} } + +func (p *Encoder) WriteUInt8(value uint8) { + exerrors.PanicIfNotNil(p.WriteByte(value)) +} + +func (p *Encoder) WriteBool(value bool) { + if value { + exerrors.PanicIfNotNil(p.WriteByte(0x01)) + } else { + exerrors.PanicIfNotNil(p.WriteByte(0x00)) + } +} + +func (p *Encoder) WriteEmptyBytes(count int) { + exerrors.Must(p.Write(make([]byte, count))) +} + +func (p *Encoder) WriteUInt32(value uint32) { + exerrors.PanicIfNotNil(binary.Write(&p.Buffer, binary.BigEndian, value)) +} diff --git a/crypto/goolm/libolmpickle/encoder_test.go b/crypto/goolm/libolmpickle/encoder_test.go new file mode 100644 index 00000000..c7811225 --- /dev/null +++ b/crypto/goolm/libolmpickle/encoder_test.go @@ -0,0 +1,99 @@ +package libolmpickle_test + +import ( + "testing" + + "github.com/stretchr/testify/assert" + + "maunium.net/go/mautrix/crypto/goolm/libolmpickle" +) + +func TestEncoder(t *testing.T) { + var encoder libolmpickle.Encoder + encoder.WriteUInt32(4) + encoder.WriteUInt8(8) + encoder.WriteBool(false) + encoder.WriteEmptyBytes(10) + encoder.WriteBool(true) + encoder.Write([]byte("test")) + encoder.WriteUInt32(420_000) + assert.Equal(t, []byte{ + 0x00, 0x00, 0x00, 0x04, // 4 + 0x08, // 8 + 0x00, // false + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // ten empty bytes + 0x01, //true + 0x74, 0x65, 0x73, 0x74, // "test" (ASCII) + 0x00, 0x06, 0x68, 0xa0, // 420,000 + }, encoder.Bytes()) +} + +func TestPickleUInt32(t *testing.T) { + values := []uint32{ + 0xffffffff, + 0x00ff00ff, + 0xf0000000, + 0xf00f0000, + } + expected := [][]byte{ + {0xff, 0xff, 0xff, 0xff}, + {0x00, 0xff, 0x00, 0xff}, + {0xf0, 0x00, 0x00, 0x00}, + {0xf0, 0x0f, 0x00, 0x00}, + } + for i, value := range values { + var encoder libolmpickle.Encoder + encoder.WriteUInt32(value) + assert.Equal(t, expected[i], encoder.Bytes()) + } +} + +func TestPickleBool(t *testing.T) { + values := []bool{ + true, + false, + } + expected := [][]byte{ + {0x01}, + {0x00}, + } + for i, value := range values { + var encoder libolmpickle.Encoder + encoder.WriteBool(value) + assert.Equal(t, expected[i], encoder.Bytes()) + } +} + +func TestPickleUInt8(t *testing.T) { + values := []uint8{ + 0xff, + 0x1a, + } + expected := [][]byte{ + {0xff}, + {0x1a}, + } + for i, value := range values { + var encoder libolmpickle.Encoder + encoder.WriteUInt8(value) + assert.Equal(t, expected[i], encoder.Bytes()) + } +} + +func TestPickleBytes(t *testing.T) { + values := [][]byte{ + {0xff, 0xff, 0xff, 0xff}, + {0x00, 0xff, 0x00, 0xff}, + {0xf0, 0x00, 0x00, 0x00}, + } + expected := [][]byte{ + {0xff, 0xff, 0xff, 0xff}, + {0x00, 0xff, 0x00, 0xff}, + {0xf0, 0x00, 0x00, 0x00}, + } + for i, value := range values { + var encoder libolmpickle.Encoder + encoder.Write(value) + assert.Equal(t, expected[i], encoder.Bytes()) + } +} diff --git a/crypto/goolm/libolmpickle/pickle.go b/crypto/goolm/libolmpickle/pickle.go index ec125a34..d15358fd 100644 --- a/crypto/goolm/libolmpickle/pickle.go +++ b/crypto/goolm/libolmpickle/pickle.go @@ -1,41 +1,48 @@ package libolmpickle import ( - "encoding/binary" + "crypto/aes" + "fmt" + + "maunium.net/go/mautrix/crypto/goolm/aessha2" + "maunium.net/go/mautrix/crypto/goolm/goolmbase64" + "maunium.net/go/mautrix/crypto/olm" ) -func PickleUInt8(value uint8, target []byte) int { - target[0] = value - return 1 -} -func PickleUInt8Len(value uint8) int { - return 1 -} +const pickleMACLength = 8 -func PickleBool(value bool, target []byte) int { - if value { - target[0] = 0x01 +var kdfPickle = []byte("Pickle") //used to derive the keys for encryption + +// Pickle encrypts the input with the key and the cipher AESSHA256. The result is then encoded in base64. +func Pickle(key, plaintext []byte) ([]byte, error) { + if c, err := aessha2.NewAESSHA2(key, kdfPickle); err != nil { + return nil, err + } else if ciphertext, err := c.Encrypt(plaintext); err != nil { + return nil, err + } else if mac, err := c.MAC(ciphertext); err != nil { + return nil, err } else { - target[0] = 0x00 + return goolmbase64.Encode(append(ciphertext, mac[:pickleMACLength]...)), nil } - return 1 -} -func PickleBoolLen(value bool) int { - return 1 } -func PickleBytes(value, target []byte) int { - return copy(target, value) -} -func PickleBytesLen(value []byte) int { - return len(value) -} - -func PickleUInt32(value uint32, target []byte) int { - res := make([]byte, 4) //4 bytes for int32 - binary.BigEndian.PutUint32(res, value) - return copy(target, res) -} -func PickleUInt32Len(value uint32) int { - return 4 +// Unpickle decodes the input from base64 and decrypts the decoded input with the key and the cipher AESSHA256. +func Unpickle(key, input []byte) ([]byte, error) { + ciphertext, err := goolmbase64.Decode(input) + if err != nil { + return nil, err + } + ciphertext, mac := ciphertext[:len(ciphertext)-pickleMACLength], ciphertext[len(ciphertext)-pickleMACLength:] + if c, err := aessha2.NewAESSHA2(key, kdfPickle); err != nil { + return nil, err + } else if verified, err := c.VerifyMAC(ciphertext, mac); err != nil { + return nil, err + } else if !verified { + return nil, fmt.Errorf("decrypt pickle: %w", olm.ErrBadMAC) + } else { + // Set to next block size + targetCipherText := make([]byte, int(len(ciphertext)/aes.BlockSize)*aes.BlockSize) + copy(targetCipherText, ciphertext) + return c.Decrypt(targetCipherText) + } } diff --git a/crypto/goolm/libolmpickle/pickle_test.go b/crypto/goolm/libolmpickle/pickle_test.go index ce118428..0720e008 100644 --- a/crypto/goolm/libolmpickle/pickle_test.go +++ b/crypto/goolm/libolmpickle/pickle_test.go @@ -1,98 +1,26 @@ -package libolmpickle_test +package libolmpickle import ( - "bytes" + "crypto/aes" "testing" - "maunium.net/go/mautrix/crypto/goolm/libolmpickle" + "github.com/stretchr/testify/assert" ) -func TestPickleUInt32(t *testing.T) { - values := []uint32{ - 0xffffffff, - 0x00ff00ff, - 0xf0000000, - 0xf00f0000, +func TestEncoding(t *testing.T) { + key := []byte("test key") + input := []byte("test") + //pad marshaled to get block size + toEncrypt := input + if len(input)%aes.BlockSize != 0 { + padding := aes.BlockSize - len(input)%aes.BlockSize + toEncrypt = make([]byte, len(input)+padding) + copy(toEncrypt, input) } - expected := [][]byte{ - {0xff, 0xff, 0xff, 0xff}, - {0x00, 0xff, 0x00, 0xff}, - {0xf0, 0x00, 0x00, 0x00}, - {0xf0, 0x0f, 0x00, 0x00}, - } - for curIndex := range values { - response := make([]byte, 4) - resPLen := libolmpickle.PickleUInt32(values[curIndex], response) - if resPLen != libolmpickle.PickleUInt32Len(values[curIndex]) { - t.Fatal("written bytes not correct") - } - if !bytes.Equal(response, expected[curIndex]) { - t.Fatalf("response not as expected:\n%v\n%v\n", response, expected[curIndex]) - } - } -} + encoded, err := Pickle(key, toEncrypt) + assert.NoError(t, err) -func TestPickleBool(t *testing.T) { - values := []bool{ - true, - false, - } - expected := [][]byte{ - {0x01}, - {0x00}, - } - for curIndex := range values { - response := make([]byte, 1) - resPLen := libolmpickle.PickleBool(values[curIndex], response) - if resPLen != libolmpickle.PickleBoolLen(values[curIndex]) { - t.Fatal("written bytes not correct") - } - if !bytes.Equal(response, expected[curIndex]) { - t.Fatalf("response not as expected:\n%v\n%v\n", response, expected[curIndex]) - } - } -} - -func TestPickleUInt8(t *testing.T) { - values := []uint8{ - 0xff, - 0x1a, - } - expected := [][]byte{ - {0xff}, - {0x1a}, - } - for curIndex := range values { - response := make([]byte, 1) - resPLen := libolmpickle.PickleUInt8(values[curIndex], response) - if resPLen != libolmpickle.PickleUInt8Len(values[curIndex]) { - t.Fatal("written bytes not correct") - } - if !bytes.Equal(response, expected[curIndex]) { - t.Fatalf("response not as expected:\n%v\n%v\n", response, expected[curIndex]) - } - } -} - -func TestPickleBytes(t *testing.T) { - values := [][]byte{ - {0xff, 0xff, 0xff, 0xff}, - {0x00, 0xff, 0x00, 0xff}, - {0xf0, 0x00, 0x00, 0x00}, - } - expected := [][]byte{ - {0xff, 0xff, 0xff, 0xff}, - {0x00, 0xff, 0x00, 0xff}, - {0xf0, 0x00, 0x00, 0x00}, - } - for curIndex := range values { - response := make([]byte, len(values[curIndex])) - resPLen := libolmpickle.PickleBytes(values[curIndex], response) - if resPLen != libolmpickle.PickleBytesLen(values[curIndex]) { - t.Fatal("written bytes not correct") - } - if !bytes.Equal(response, expected[curIndex]) { - t.Fatalf("response not as expected:\n%v\n%v\n", response, expected[curIndex]) - } - } + decoded, err := Unpickle(key, encoded) + assert.NoError(t, err) + assert.Equal(t, toEncrypt, decoded) } diff --git a/crypto/goolm/utilities/pickle.go b/crypto/goolm/libolmpickle/picklejson.go similarity index 72% rename from crypto/goolm/utilities/pickle.go rename to crypto/goolm/libolmpickle/picklejson.go index 993366c8..f765391f 100644 --- a/crypto/goolm/utilities/pickle.go +++ b/crypto/goolm/libolmpickle/picklejson.go @@ -1,17 +1,17 @@ -package utilities +package libolmpickle import ( + "crypto/aes" "encoding/json" "fmt" - "maunium.net/go/mautrix/crypto/goolm" - "maunium.net/go/mautrix/crypto/goolm/cipher" + "maunium.net/go/mautrix/crypto/olm" ) // PickleAsJSON returns an object as a base64 string encrypted using the supplied key. The unencrypted representation of the object is in JSON format. func PickleAsJSON(object any, pickleVersion byte, key []byte) ([]byte, error) { if len(key) == 0 { - return nil, fmt.Errorf("pickle: %w", goolm.ErrNoKeyProvided) + return nil, fmt.Errorf("pickle: %w", olm.ErrNoKeyProvided) } marshaled, err := json.Marshal(object) if err != nil { @@ -21,12 +21,12 @@ func PickleAsJSON(object any, pickleVersion byte, key []byte) ([]byte, error) { toEncrypt := make([]byte, len(marshaled)) copy(toEncrypt, marshaled) //pad marshaled to get block size - if len(marshaled)%cipher.PickleBlockSize() != 0 { - padding := cipher.PickleBlockSize() - len(marshaled)%cipher.PickleBlockSize() + if len(marshaled)%aes.BlockSize != 0 { + padding := aes.BlockSize - len(marshaled)%aes.BlockSize toEncrypt = make([]byte, len(marshaled)+padding) copy(toEncrypt, marshaled) } - encrypted, err := cipher.Pickle(key, toEncrypt) + encrypted, err := Pickle(key, toEncrypt) if err != nil { return nil, fmt.Errorf("pickle encrypt: %w", err) } @@ -36,9 +36,9 @@ func PickleAsJSON(object any, pickleVersion byte, key []byte) ([]byte, error) { // UnpickleAsJSON updates the object by a base64 encrypted string using the supplied key. The unencrypted representation has to be in JSON format. func UnpickleAsJSON(object any, pickled, key []byte, pickleVersion byte) error { if len(key) == 0 { - return fmt.Errorf("unpickle: %w", goolm.ErrNoKeyProvided) + return fmt.Errorf("unpickle: %w", olm.ErrNoKeyProvided) } - decrypted, err := cipher.Unpickle(key, pickled) + decrypted, err := Unpickle(key, pickled) if err != nil { return fmt.Errorf("unpickle decrypt: %w", err) } @@ -50,7 +50,7 @@ func UnpickleAsJSON(object any, pickled, key []byte, pickleVersion byte) error { } } if decrypted[0] != pickleVersion { - return fmt.Errorf("unpickle: %w", goolm.ErrWrongPickleVersion) + return fmt.Errorf("unpickle: %w", olm.ErrUnknownJSONPickleVersion) } err = json.Unmarshal(decrypted[1:], object) if err != nil { diff --git a/crypto/goolm/libolmpickle/unpickle.go b/crypto/goolm/libolmpickle/unpickle.go index 9a6a4b62..d13be315 100644 --- a/crypto/goolm/libolmpickle/unpickle.go +++ b/crypto/goolm/libolmpickle/unpickle.go @@ -1,53 +1,52 @@ package libolmpickle import ( + "bytes" + "encoding/binary" "fmt" - - "maunium.net/go/mautrix/crypto/goolm" ) -func isZeroByteSlice(bytes []byte) bool { - b := byte(0) - for _, s := range bytes { - b |= s +func isZeroByteSlice(data []byte) bool { + for _, b := range data { + if b != 0 { + return false + } } - return b == 0 + return true } -func UnpickleUInt8(value []byte) (uint8, int, error) { - if len(value) < 1 { - return 0, 0, fmt.Errorf("unpickle uint8: %w", goolm.ErrValueTooShort) - } - return value[0], 1, nil +type Decoder struct { + buf bytes.Buffer } -func UnpickleBool(value []byte) (bool, int, error) { - if len(value) < 1 { - return false, 0, fmt.Errorf("unpickle bool: %w", goolm.ErrValueTooShort) - } - return value[0] != uint8(0x00), 1, nil +func NewDecoder(buf []byte) *Decoder { + return &Decoder{buf: *bytes.NewBuffer(buf)} } -func UnpickleBytes(value []byte, length int) ([]byte, int, error) { - if len(value) < length { - return nil, 0, fmt.Errorf("unpickle bytes: %w", goolm.ErrValueTooShort) - } - resp := value[:length] - if isZeroByteSlice(resp) { - return nil, length, nil - } - return resp, length, nil +func (d *Decoder) ReadUInt8() (uint8, error) { + return d.buf.ReadByte() } -func UnpickleUInt32(value []byte) (uint32, int, error) { - if len(value) < 4 { - return 0, 0, fmt.Errorf("unpickle uint32: %w", goolm.ErrValueTooShort) - } - var res uint32 - count := 0 - for i := 3; i >= 0; i-- { - res |= uint32(value[count]) << (8 * i) - count++ - } - return res, 4, nil +func (d *Decoder) ReadBool() (bool, error) { + val, err := d.buf.ReadByte() + return val != 0x00, err +} + +func (d *Decoder) ReadBytes(length int) (data []byte, err error) { + data = d.buf.Next(length) + if len(data) != length { + return nil, fmt.Errorf("only %d in buffer, expected %d", len(data), length) + } else if isZeroByteSlice(data) { + return nil, nil + } + return +} + +func (d *Decoder) ReadUInt32() (uint32, error) { + data := d.buf.Next(4) + if len(data) != 4 { + return 0, fmt.Errorf("only %d bytes is buffer, expected 4 for uint32", len(data)) + } else { + return binary.BigEndian.Uint32(data), nil + } } diff --git a/crypto/goolm/libolmpickle/unpickle_test.go b/crypto/goolm/libolmpickle/unpickle_test.go index 937630e5..30355a76 100644 --- a/crypto/goolm/libolmpickle/unpickle_test.go +++ b/crypto/goolm/libolmpickle/unpickle_test.go @@ -1,9 +1,10 @@ package libolmpickle_test import ( - "bytes" "testing" + "github.com/stretchr/testify/assert" + "maunium.net/go/mautrix/crypto/goolm/libolmpickle" ) @@ -19,16 +20,10 @@ func TestUnpickleUInt32(t *testing.T) { {0xf0, 0x00, 0x00, 0x00}, } for curIndex := range values { - response, readLength, err := libolmpickle.UnpickleUInt32(values[curIndex]) - if err != nil { - t.Fatal(err) - } - if readLength != 4 { - t.Fatal("read bytes not correct") - } - if response != expected[curIndex] { - t.Fatalf("response not as expected:\n%v\n%v\n", response, expected[curIndex]) - } + decoder := libolmpickle.NewDecoder(values[curIndex]) + response, err := decoder.ReadUInt32() + assert.NoError(t, err) + assert.Equal(t, expected[curIndex], response) } } @@ -44,16 +39,10 @@ func TestUnpickleBool(t *testing.T) { {0x02}, } for curIndex := range values { - response, readLength, err := libolmpickle.UnpickleBool(values[curIndex]) - if err != nil { - t.Fatal(err) - } - if readLength != 1 { - t.Fatal("read bytes not correct") - } - if response != expected[curIndex] { - t.Fatalf("response not as expected:\n%v\n%v\n", response, expected[curIndex]) - } + decoder := libolmpickle.NewDecoder(values[curIndex]) + response, err := decoder.ReadBool() + assert.NoError(t, err) + assert.Equal(t, expected[curIndex], response) } } @@ -67,16 +56,10 @@ func TestUnpickleUInt8(t *testing.T) { {0x1a}, } for curIndex := range values { - response, readLength, err := libolmpickle.UnpickleUInt8(values[curIndex]) - if err != nil { - t.Fatal(err) - } - if readLength != 1 { - t.Fatal("read bytes not correct") - } - if response != expected[curIndex] { - t.Fatalf("response not as expected:\n%v\n%v\n", response, expected[curIndex]) - } + decoder := libolmpickle.NewDecoder(values[curIndex]) + response, err := decoder.ReadUInt8() + assert.NoError(t, err) + assert.Equal(t, expected[curIndex], response) } } @@ -92,15 +75,9 @@ func TestUnpickleBytes(t *testing.T) { {0xf0, 0x00, 0x00, 0x00}, } for curIndex := range values { - response, readLength, err := libolmpickle.UnpickleBytes(values[curIndex], 4) - if err != nil { - t.Fatal(err) - } - if readLength != 4 { - t.Fatal("read bytes not correct") - } - if !bytes.Equal(response, expected[curIndex]) { - t.Fatalf("response not as expected:\n%v\n%v\n", response, expected[curIndex]) - } + decoder := libolmpickle.NewDecoder(values[curIndex]) + response, err := decoder.ReadBytes(4) + assert.NoError(t, err) + assert.Equal(t, expected[curIndex], response) } } diff --git a/crypto/goolm/megolm/megolm.go b/crypto/goolm/megolm/megolm.go index c3493f7b..3b5f1e4a 100644 --- a/crypto/goolm/megolm/megolm.go +++ b/crypto/goolm/megolm/megolm.go @@ -2,15 +2,17 @@ package megolm import ( + "crypto/hmac" "crypto/rand" + "crypto/sha256" "fmt" - "maunium.net/go/mautrix/crypto/goolm" - "maunium.net/go/mautrix/crypto/goolm/cipher" + "maunium.net/go/mautrix/crypto/goolm/aessha2" "maunium.net/go/mautrix/crypto/goolm/crypto" + "maunium.net/go/mautrix/crypto/goolm/goolmbase64" "maunium.net/go/mautrix/crypto/goolm/libolmpickle" "maunium.net/go/mautrix/crypto/goolm/message" - "maunium.net/go/mautrix/crypto/goolm/utilities" + "maunium.net/go/mautrix/crypto/olm" ) const ( @@ -23,7 +25,7 @@ const ( RatchetPartLength = 256 / 8 // length of each ratchet part in bytes ) -var RatchetCipher = cipher.NewAESSHA256([]byte("MEGOLM_KEYS")) +var megolmKeysKDFInfo = []byte("MEGOLM_KEYS") // hasKeySeed are the seed for the different ratchet parts var hashKeySeeds [RatchetParts][]byte = [RatchetParts][]byte{ @@ -62,8 +64,9 @@ func NewWithRandom(counter uint32) (*Ratchet, error) { // rehashPart rehases the part of the ratchet data with the base defined as from storing into the target to. func (m *Ratchet) rehashPart(from, to int) { - newData := crypto.HMACSHA256(m.Data[from*RatchetPartLength:from*RatchetPartLength+RatchetPartLength], hashKeySeeds[to]) - copy(m.Data[to*RatchetPartLength:], newData[:RatchetPartLength]) + hash := hmac.New(sha256.New, m.Data[from*RatchetPartLength:from*RatchetPartLength+RatchetPartLength]) + hash.Write(hashKeySeeds[to]) + copy(m.Data[to*RatchetPartLength:], hash.Sum(nil)) } // Advance advances the ratchet one step. @@ -132,9 +135,8 @@ func (m *Ratchet) AdvanceTo(target uint32) { // Encrypt encrypts the message in a message.GroupMessage with MAC and signature. // The output is base64 encoded. -func (r *Ratchet) Encrypt(plaintext []byte, key *crypto.Ed25519KeyPair) ([]byte, error) { - var err error - encryptedText, err := RatchetCipher.Encrypt(r.Data[:], plaintext) +func (r *Ratchet) Encrypt(plaintext []byte, key crypto.Ed25519KeyPair) ([]byte, error) { + cipher, err := aessha2.NewAESSHA2(r.Data[:], megolmKeysKDFInfo) if err != nil { return nil, fmt.Errorf("cipher encrypt: %w", err) } @@ -142,9 +144,12 @@ func (r *Ratchet) Encrypt(plaintext []byte, key *crypto.Ed25519KeyPair) ([]byte, message := &message.GroupMessage{} message.Version = protocolVersion message.MessageIndex = r.Counter - message.Ciphertext = encryptedText - //creating the mac and signing is done in encode - output, err := message.EncodeAndMacAndSign(r.Data[:], RatchetCipher, key) + message.Ciphertext, err = cipher.Encrypt(plaintext) + if err != nil { + return nil, err + } + //creating the MAC and signing is done in encode + output, err := message.EncodeAndMACAndSign(cipher, key) if err != nil { return nil, err } @@ -157,8 +162,8 @@ func (r Ratchet) SessionSharingMessage(key crypto.Ed25519KeyPair) ([]byte, error m := message.MegolmSessionSharing{} m.Counter = r.Counter m.RatchetData = r.Data - encoded := m.EncodeAndSign(key) - return goolm.Base64Encode(encoded), nil + encoded, err := m.EncodeAndSign(key) + return goolmbase64.Encode(encoded), err } // SessionExportMessage creates a message in the session export format. @@ -168,67 +173,51 @@ func (r Ratchet) SessionExportMessage(key crypto.Ed25519PublicKey) ([]byte, erro m.RatchetData = r.Data m.PublicKey = key encoded := m.Encode() - return goolm.Base64Encode(encoded), nil + return goolmbase64.Encode(encoded), nil } // Decrypt decrypts the ciphertext and verifies the MAC but not the signature. func (r Ratchet) Decrypt(ciphertext []byte, signingkey *crypto.Ed25519PublicKey, msg *message.GroupMessage) ([]byte, error) { //verify mac - verifiedMAC, err := msg.VerifyMACInline(r.Data[:], RatchetCipher, ciphertext) + cipher, err := aessha2.NewAESSHA2(r.Data[:], megolmKeysKDFInfo) + if err != nil { + return nil, err + } + verifiedMAC, err := msg.VerifyMACInline(cipher, ciphertext) if err != nil { return nil, err } if !verifiedMAC { - return nil, fmt.Errorf("decrypt: %w", goolm.ErrBadMAC) + return nil, fmt.Errorf("decrypt: %w", olm.ErrBadMAC) } - return RatchetCipher.Decrypt(r.Data[:], msg.Ciphertext) + return cipher.Decrypt(msg.Ciphertext) } // PickleAsJSON returns a ratchet as a base64 string encrypted using the supplied key. The unencrypted representation of the Account is in JSON format. func (r Ratchet) PickleAsJSON(key []byte) ([]byte, error) { - return utilities.PickleAsJSON(r, megolmPickleVersion, key) + return libolmpickle.PickleAsJSON(r, megolmPickleVersion, key) } // UnpickleAsJSON updates a ratchet by a base64 encrypted string using the supplied key. The unencrypted representation has to be in JSON format. func (r *Ratchet) UnpickleAsJSON(pickled, key []byte) error { - return utilities.UnpickleAsJSON(r, pickled, key, megolmPickleVersion) + return libolmpickle.UnpickleAsJSON(r, pickled, key, megolmPickleVersion) } // UnpickleLibOlm decodes the unencryted value and populates the Ratchet accordingly. It returns the number of bytes read. -func (r *Ratchet) UnpickleLibOlm(unpickled []byte) (int, error) { - //read ratchet data - curPos := 0 - ratchetData, readBytes, err := libolmpickle.UnpickleBytes(unpickled, RatchetParts*RatchetPartLength) +func (r *Ratchet) UnpickleLibOlm(decoder *libolmpickle.Decoder) error { + ratchetData, err := decoder.ReadBytes(RatchetParts * RatchetPartLength) if err != nil { - return 0, err + return err } copy(r.Data[:], ratchetData) - curPos += readBytes - //Read counter - counter, readBytes, err := libolmpickle.UnpickleUInt32(unpickled[curPos:]) - if err != nil { - return 0, err - } - curPos += readBytes - r.Counter = counter - return curPos, nil + + r.Counter, err = decoder.ReadUInt32() + return err } -// PickleLibOlm encodes the ratchet into target. target has to have a size of at least PickleLen() and is written to from index 0. -// It returns the number of bytes written. -func (r Ratchet) PickleLibOlm(target []byte) (int, error) { - if len(target) < r.PickleLen() { - return 0, fmt.Errorf("pickle account: %w", goolm.ErrValueTooShort) - } - written := libolmpickle.PickleBytes(r.Data[:], target) - written += libolmpickle.PickleUInt32(r.Counter, target[written:]) - return written, nil -} - -// PickleLen returns the number of bytes the pickled ratchet will have. -func (r Ratchet) PickleLen() int { - length := libolmpickle.PickleBytesLen(r.Data[:]) - length += libolmpickle.PickleUInt32Len(r.Counter) - return length +// PickleLibOlm pickles the ratchet into the encoder. +func (r Ratchet) PickleLibOlm(encoder *libolmpickle.Encoder) { + encoder.Write(r.Data[:]) + encoder.WriteUInt32(r.Counter) } diff --git a/crypto/goolm/megolm/megolm_test.go b/crypto/goolm/megolm/megolm_test.go index 40289eaf..a6f7c1a7 100644 --- a/crypto/goolm/megolm/megolm_test.go +++ b/crypto/goolm/megolm/megolm_test.go @@ -1,9 +1,10 @@ package megolm_test import ( - "bytes" "testing" + "github.com/stretchr/testify/assert" + "maunium.net/go/mautrix/crypto/goolm/megolm" ) @@ -19,9 +20,7 @@ func init() { func TestAdvance(t *testing.T) { m, err := megolm.New(0, startData) - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err) expectedData := [megolm.RatchetParts * megolm.RatchetPartLength]byte{ 0x30, 0x31, 0x32, 0x33, 0x34, 0x35, 0x36, 0x37, 0x38, 0x39, 0x41, 0x42, 0x43, 0x44, 0x45, 0x46, @@ -34,9 +33,7 @@ func TestAdvance(t *testing.T) { 0x89, 0xbb, 0xb4, 0x23, 0xa1, 0x8f, 0x23, 0x82, 0x8f, 0xb2, 0x09, 0x0d, 0x6e, 0x2a, 0xf8, 0x6a, } m.Advance() - if !bytes.Equal(m.Data[:], expectedData[:]) { - t.Fatal("result after advancing the ratchet is not as expected") - } + assert.Equal(t, m.Data[:], expectedData[:], "result after advancing the ratchet is not as expected") //repeat with complex advance m.Data = startData @@ -51,9 +48,8 @@ func TestAdvance(t *testing.T) { 0x89, 0xbb, 0xb4, 0x23, 0xa1, 0x8f, 0x23, 0x82, 0x8f, 0xb2, 0x09, 0x0d, 0x6e, 0x2a, 0xf8, 0x6a, } m.AdvanceTo(0x1000000) - if !bytes.Equal(m.Data[:], expectedData[:]) { - t.Fatal("result after advancing the ratchet is not as expected") - } + assert.Equal(t, m.Data[:], expectedData[:], "result after advancing the ratchet is not as expected") + expectedData = [megolm.RatchetParts * megolm.RatchetPartLength]byte{ 0x54, 0x02, 0x2d, 0x7d, 0xc0, 0x29, 0x8e, 0x16, 0x37, 0xe2, 0x1c, 0x97, 0x15, 0x30, 0x92, 0xf9, 0x33, 0xc0, 0x56, 0xff, 0x74, 0xfe, 0x1b, 0x92, 0x2d, 0x97, 0x1f, 0x24, 0x82, 0xc2, 0x85, 0x9c, @@ -65,77 +61,45 @@ func TestAdvance(t *testing.T) { 0xd5, 0x6f, 0x03, 0xe2, 0x44, 0x16, 0xb9, 0x8e, 0x1c, 0xfd, 0x97, 0xc2, 0x06, 0xaa, 0x90, 0x7a, } m.AdvanceTo(0x1041506) - if !bytes.Equal(m.Data[:], expectedData[:]) { - t.Fatal("result after advancing the ratchet is not as expected") - } + assert.Equal(t, m.Data[:], expectedData[:], "result after advancing the ratchet is not as expected") } func TestAdvanceWraparound(t *testing.T) { m, err := megolm.New(0xffffffff, startData) - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err) m.AdvanceTo(0x1000000) - if m.Counter != 0x1000000 { - t.Fatal("counter not correct") - } + assert.EqualValues(t, 0x1000000, m.Counter, "counter not correct") m2, err := megolm.New(0, startData) - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err) m2.AdvanceTo(0x2000000) - if m2.Counter != 0x2000000 { - t.Fatal("counter not correct") - } - if !bytes.Equal(m.Data[:], m2.Data[:]) { - t.Fatal("result after wrapping the ratchet is not as expected") - } + assert.EqualValues(t, 0x2000000, m2.Counter, "counter not correct") + assert.Equal(t, m.Data, m2.Data, "result after wrapping the ratchet is not as expected") } func TestAdvanceOverflowByOne(t *testing.T) { m, err := megolm.New(0xffffffff, startData) - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err) m.AdvanceTo(0x0) - if m.Counter != 0x0 { - t.Fatal("counter not correct") - } + assert.EqualValues(t, 0x0, m.Counter, "counter not correct") m2, err := megolm.New(0xffffffff, startData) - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err) m2.Advance() - if m2.Counter != 0x0 { - t.Fatal("counter not correct") - } - if !bytes.Equal(m.Data[:], m2.Data[:]) { - t.Fatal("result after wrapping the ratchet is not as expected") - } + assert.EqualValues(t, 0x0, m2.Counter, "counter not correct") + assert.Equal(t, m.Data, m2.Data, "result after wrapping the ratchet is not as expected") } func TestAdvanceOverflow(t *testing.T) { m, err := megolm.New(0x1, startData) - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err) m.AdvanceTo(0x80000000) m.AdvanceTo(0x0) - if m.Counter != 0x0 { - t.Fatal("counter not correct") - } + assert.EqualValues(t, 0x0, m.Counter, "counter not correct") m2, err := megolm.New(0x1, startData) - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err) m2.AdvanceTo(0x0) - if m2.Counter != 0x0 { - t.Fatal("counter not correct") - } - if !bytes.Equal(m.Data[:], m2.Data[:]) { - t.Fatal("result after wrapping the ratchet is not as expected") - } + assert.EqualValues(t, 0x0, m2.Counter, "counter not correct") + assert.Equal(t, m.Data, m2.Data, "result after wrapping the ratchet is not as expected") } diff --git a/crypto/goolm/message/decoder.go b/crypto/goolm/message/decoder.go index ba49f011..b06756a9 100644 --- a/crypto/goolm/message/decoder.go +++ b/crypto/goolm/message/decoder.go @@ -1,70 +1,33 @@ package message import ( + "bytes" "encoding/binary" + "fmt" - "maunium.net/go/mautrix/crypto/goolm" + "maunium.net/go/mautrix/crypto/olm" ) -// checkDecodeErr checks if there was an error during decode. -func checkDecodeErr(readBytes int) error { - if readBytes == 0 { - //end reached - return goolm.ErrInputToSmall +type Decoder struct { + *bytes.Buffer +} + +func NewDecoder(buf []byte) *Decoder { + return &Decoder{bytes.NewBuffer(buf)} +} + +func (d *Decoder) ReadVarInt() (uint64, error) { + return binary.ReadUvarint(d) +} + +func (d *Decoder) ReadVarBytes() ([]byte, error) { + if n, err := d.ReadVarInt(); err != nil { + return nil, err + } else if n > uint64(d.Len()) { + return nil, fmt.Errorf("%w: var bytes length says %d, but only %d bytes left", olm.ErrInputToSmall, n, d.Available()) + } else { + out := make([]byte, n) + _, err = d.Read(out) + return out, err } - if readBytes < 0 { - return goolm.ErrOverflow - } - return nil -} - -// decodeVarInt decodes a single big-endian encoded varint. -func decodeVarInt(input []byte) (uint32, int) { - value, readBytes := binary.Uvarint(input) - return uint32(value), readBytes -} - -// decodeVarString decodes the length of the string (varint) and returns the actual string -func decodeVarString(input []byte) ([]byte, int) { - stringLen, readBytes := decodeVarInt(input) - if readBytes <= 0 { - return nil, readBytes - } - input = input[readBytes:] - value := input[:stringLen] - readBytes += int(stringLen) - return value, readBytes -} - -// encodeVarIntByteLength returns the number of bytes needed to encode the uint32. -func encodeVarIntByteLength(input uint32) int { - result := 1 - for input >= 128 { - result++ - input >>= 7 - } - return result -} - -// encodeVarStringByteLength returns the number of bytes needed to encode the input. -func encodeVarStringByteLength(input []byte) int { - result := encodeVarIntByteLength(uint32(len(input))) - result += len(input) - return result -} - -// encodeVarInt encodes a single uint32 -func encodeVarInt(input uint32) []byte { - out := make([]byte, encodeVarIntByteLength(input)) - binary.PutUvarint(out, uint64(input)) - return out -} - -// encodeVarString encodes the length of the input (varint) and appends the actual input -func encodeVarString(input []byte) []byte { - out := make([]byte, encodeVarStringByteLength(input)) - length := encodeVarInt(uint32(len(input))) - copy(out, length) - copy(out[len(length):], input) - return out } diff --git a/crypto/goolm/message/encoder.go b/crypto/goolm/message/encoder.go new file mode 100644 index 00000000..95ab6d41 --- /dev/null +++ b/crypto/goolm/message/encoder.go @@ -0,0 +1,24 @@ +package message + +import "encoding/binary" + +type Encoder struct { + buf []byte +} + +func (e *Encoder) Bytes() []byte { + return e.buf +} + +func (e *Encoder) PutByte(val byte) { + e.buf = append(e.buf, val) +} + +func (e *Encoder) PutVarInt(val uint64) { + e.buf = binary.AppendUvarint(e.buf, val) +} + +func (e *Encoder) PutVarBytes(data []byte) { + e.PutVarInt(uint64(len(data))) + e.buf = append(e.buf, data...) +} diff --git a/crypto/goolm/message/decoder_test.go b/crypto/goolm/message/encoder_test.go similarity index 50% rename from crypto/goolm/message/decoder_test.go rename to crypto/goolm/message/encoder_test.go index 39503e3e..1fe2ebdb 100644 --- a/crypto/goolm/message/decoder_test.go +++ b/crypto/goolm/message/encoder_test.go @@ -1,36 +1,13 @@ -package message +package message_test import ( - "bytes" "testing" + + "github.com/stretchr/testify/assert" + + "maunium.net/go/mautrix/crypto/goolm/message" ) -func TestEncodeLengthInt(t *testing.T) { - numbers := []uint32{127, 128, 16383, 16384, 32767} - expected := []int{1, 2, 2, 3, 3} - for curIndex := range numbers { - if result := encodeVarIntByteLength(numbers[curIndex]); result != expected[curIndex] { - t.Fatalf("expected byte length of %d but got %d", expected[curIndex], result) - } - } -} - -func TestEncodeLengthString(t *testing.T) { - var strings [][]byte - var expected []int - strings = append(strings, []byte("test")) - expected = append(expected, 1+4) - strings = append(strings, []byte("this is a long message with a length of 127 so that the varint of the length is just one byte. just needs some padding---------")) - expected = append(expected, 1+127) - strings = append(strings, []byte("this is an even longer message with a length between 128 and 16383 so that the varint of the length needs two byte. just needs some padding again ---------")) - expected = append(expected, 2+155) - for curIndex := range strings { - if result := encodeVarStringByteLength(strings[curIndex]); result != expected[curIndex] { - t.Fatalf("expected byte length of %d but got %d", expected[curIndex], result) - } - } -} - func TestEncodeInt(t *testing.T) { var ints []uint32 var expected [][]byte @@ -43,9 +20,9 @@ func TestEncodeInt(t *testing.T) { ints = append(ints, 16383) expected = append(expected, []byte{0b11111111, 0b01111111}) for curIndex := range ints { - if result := encodeVarInt(ints[curIndex]); !bytes.Equal(result, expected[curIndex]) { - t.Fatalf("expected byte of %b but got %b", expected[curIndex], result) - } + var encoder message.Encoder + encoder.PutVarInt(uint64(ints[curIndex])) + assert.Equal(t, expected[curIndex], encoder.Bytes()) } } @@ -75,8 +52,8 @@ func TestEncodeString(t *testing.T) { res = append(res, curTest...) //Add string itself expected = append(expected, res) for curIndex := range strings { - if result := encodeVarString(strings[curIndex]); !bytes.Equal(result, expected[curIndex]) { - t.Fatalf("expected byte of %b but got %b", expected[curIndex], result) - } + var encoder message.Encoder + encoder.PutVarBytes(strings[curIndex]) + assert.Equal(t, expected[curIndex], encoder.Bytes()) } } diff --git a/crypto/goolm/message/group_message.go b/crypto/goolm/message/group_message.go index ebd5b77e..c83540c1 100644 --- a/crypto/goolm/message/group_message.go +++ b/crypto/goolm/message/group_message.go @@ -2,9 +2,12 @@ package message import ( "bytes" + "fmt" + "io" - "maunium.net/go/mautrix/crypto/goolm/cipher" + "maunium.net/go/mautrix/crypto/goolm/aessha2" "maunium.net/go/mautrix/crypto/goolm/crypto" + "maunium.net/go/mautrix/crypto/olm" ) const ( @@ -22,112 +25,87 @@ type GroupMessage struct { } // Decodes decodes the input and populates the corresponding fileds. MAC and signature are ignored but have to be present. -func (r *GroupMessage) Decode(input []byte) error { +func (r *GroupMessage) Decode(input []byte) (err error) { r.Version = 0 r.MessageIndex = 0 r.Ciphertext = nil if len(input) == 0 { return nil } - //first Byte is always version - r.Version = input[0] - curPos := 1 - for curPos < len(input)-countMACBytesGroupMessage-crypto.ED25519SignatureSize { - //Read Key - curKey, readBytes := decodeVarInt(input[curPos:]) - if err := checkDecodeErr(readBytes); err != nil { - return err - } - curPos += readBytes - if (curKey & 0b111) == 0 { - //The value is of type varint - value, readBytes := decodeVarInt(input[curPos:]) - if err := checkDecodeErr(readBytes); err != nil { - return err + + decoder := NewDecoder(input[:len(input)-countMACBytesGroupMessage-crypto.Ed25519SignatureSize]) + r.Version, err = decoder.ReadByte() // First byte is the version + if err != nil { + return + } + if r.Version != protocolVersion { + return fmt.Errorf("GroupMessage.Decode: %w (got %d, expected %d)", olm.ErrWrongProtocolVersion, r.Version, protocolVersion) + } + + for { + // Read Key + if curKey, err := decoder.ReadVarInt(); err != nil { + if err == io.EOF { + // No more keys to read + return nil } - curPos += readBytes - switch curKey { - case messageIndexTag: - r.MessageIndex = value + return err + } else if (curKey & 0b111) == 0 { + // The value is of type varint + if value, err := decoder.ReadVarInt(); err != nil { + return err + } else if curKey == messageIndexTag { + r.MessageIndex = uint32(value) r.HasMessageIndex = true } } else if (curKey & 0b111) == 2 { - //The value is of type string - value, readBytes := decodeVarString(input[curPos:]) - if err := checkDecodeErr(readBytes); err != nil { + // The value is of type string + if value, err := decoder.ReadVarBytes(); err != nil { return err - } - curPos += readBytes - switch curKey { - case cipherTextTag: + } else if curKey == cipherTextTag { r.Ciphertext = value } } } - - return nil } -// EncodeAndMacAndSign encodes the message, creates the mac with the key and the cipher and signs the message. +// EncodeAndMACAndSign encodes the message, creates the mac with the key and the cipher and signs the message. // If macKey or cipher is nil, no mac is appended. If signKey is nil, no signature is appended. -func (r *GroupMessage) EncodeAndMacAndSign(macKey []byte, cipher cipher.Cipher, signKey *crypto.Ed25519KeyPair) ([]byte, error) { - var lengthOfMessage int - lengthOfMessage += 1 //Version - lengthOfMessage += encodeVarIntByteLength(messageIndexTag) + encodeVarIntByteLength(r.MessageIndex) - lengthOfMessage += encodeVarIntByteLength(cipherTextTag) + encodeVarStringByteLength(r.Ciphertext) - out := make([]byte, lengthOfMessage) - out[0] = r.Version - curPos := 1 - encodedTag := encodeVarInt(messageIndexTag) - copy(out[curPos:], encodedTag) - curPos += len(encodedTag) - encodedValue := encodeVarInt(r.MessageIndex) - copy(out[curPos:], encodedValue) - curPos += len(encodedValue) - encodedTag = encodeVarInt(cipherTextTag) - copy(out[curPos:], encodedTag) - curPos += len(encodedTag) - encodedValue = encodeVarString(r.Ciphertext) - copy(out[curPos:], encodedValue) - curPos += len(encodedValue) - if len(macKey) != 0 && cipher != nil { - mac, err := r.MAC(macKey, cipher, out) - if err != nil { - return nil, err - } - out = append(out, mac[:countMACBytesGroupMessage]...) +func (r *GroupMessage) EncodeAndMACAndSign(cipher aessha2.AESSHA2, signKey crypto.Ed25519KeyPair) ([]byte, error) { + var encoder Encoder + encoder.PutByte(r.Version) + encoder.PutVarInt(messageIndexTag) + encoder.PutVarInt(uint64(r.MessageIndex)) + encoder.PutVarInt(cipherTextTag) + encoder.PutVarBytes(r.Ciphertext) + mac, err := r.MAC(cipher, encoder.Bytes()) + if err != nil { + return nil, err } - if signKey != nil { - signature := signKey.Sign(out) - out = append(out, signature...) - } - return out, nil + ciphertextWithMAC := append(encoder.Bytes(), mac[:countMACBytesGroupMessage]...) + signature, err := signKey.Sign(ciphertextWithMAC) + return append(ciphertextWithMAC, signature...), err } // MAC returns the MAC of the message calculated with cipher and key. The length of the MAC is truncated to the correct length. -func (r *GroupMessage) MAC(key []byte, cipher cipher.Cipher, message []byte) ([]byte, error) { - mac, err := cipher.MAC(key, message) +func (r *GroupMessage) MAC(cipher aessha2.AESSHA2, ciphertext []byte) ([]byte, error) { + mac, err := cipher.MAC(ciphertext) if err != nil { return nil, err } return mac[:countMACBytesGroupMessage], nil } -// VerifySignature verifies the givenSignature to the calculated signature of the message. -func (r *GroupMessage) VerifySignature(key crypto.Ed25519PublicKey, message, givenSignature []byte) bool { - return key.Verify(message, givenSignature) -} - // VerifySignature verifies the signature taken from the message to the calculated signature of the message. func (r *GroupMessage) VerifySignatureInline(key crypto.Ed25519PublicKey, message []byte) bool { - signature := message[len(message)-crypto.ED25519SignatureSize:] - message = message[:len(message)-crypto.ED25519SignatureSize] + signature := message[len(message)-crypto.Ed25519SignatureSize:] + message = message[:len(message)-crypto.Ed25519SignatureSize] return key.Verify(message, signature) } // VerifyMAC verifies the givenMAC to the calculated MAC of the message. -func (r *GroupMessage) VerifyMAC(key []byte, cipher cipher.Cipher, message, givenMAC []byte) (bool, error) { - checkMac, err := r.MAC(key, cipher, message) +func (r *GroupMessage) VerifyMAC(cipher aessha2.AESSHA2, ciphertext, givenMAC []byte) (bool, error) { + checkMac, err := r.MAC(cipher, ciphertext) if err != nil { return false, err } @@ -135,10 +113,10 @@ func (r *GroupMessage) VerifyMAC(key []byte, cipher cipher.Cipher, message, give } // VerifyMACInline verifies the MAC taken from the message to the calculated MAC of the message. -func (r *GroupMessage) VerifyMACInline(key []byte, cipher cipher.Cipher, message []byte) (bool, error) { - startMAC := len(message) - countMACBytesGroupMessage - crypto.ED25519SignatureSize +func (r *GroupMessage) VerifyMACInline(cipher aessha2.AESSHA2, message []byte) (bool, error) { + startMAC := len(message) - countMACBytesGroupMessage - crypto.Ed25519SignatureSize endMAC := startMAC + countMACBytesGroupMessage suplMac := message[startMAC:endMAC] message = message[:startMAC] - return r.VerifyMAC(key, cipher, message, suplMac) + return r.VerifyMAC(cipher, message, suplMac) } diff --git a/crypto/goolm/message/group_message_test.go b/crypto/goolm/message/group_message_test.go index 4ae1f830..272138c4 100644 --- a/crypto/goolm/message/group_message_test.go +++ b/crypto/goolm/message/group_message_test.go @@ -1,9 +1,13 @@ package message_test import ( - "bytes" "testing" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "maunium.net/go/mautrix/crypto/goolm/aessha2" + "maunium.net/go/mautrix/crypto/goolm/crypto" "maunium.net/go/mautrix/crypto/goolm/message" ) @@ -16,22 +20,13 @@ func TestGroupMessageDecode(t *testing.T) { msg := message.GroupMessage{} err := msg.Decode(messageRaw) - if err != nil { - t.Fatal(err) - } - if msg.Version != 3 { - t.Fatalf("Expected Version to be 3 but go %d", msg.Version) - } - if msg.MessageIndex != expectedMessageIndex { - t.Fatalf("Expected message index to be %d but got %d", expectedMessageIndex, msg.MessageIndex) - } - if !bytes.Equal(msg.Ciphertext, expectedCipherText) { - t.Fatalf("expected '%s' but got '%s'", expectedCipherText, msg.Ciphertext) - } + assert.NoError(t, err) + assert.EqualValues(t, 3, msg.Version) + assert.Equal(t, expectedMessageIndex, msg.MessageIndex) + assert.Equal(t, expectedCipherText, msg.Ciphertext) } func TestGroupMessageEncode(t *testing.T) { - expectedRaw := []byte("\x03\x08\xC8\x01\x12\x0aciphertexthmacsha2signature") hmacsha256 := []byte("hmacsha2") sign := []byte("signature") msg := message.GroupMessage{ @@ -39,13 +34,29 @@ func TestGroupMessageEncode(t *testing.T) { MessageIndex: 200, Ciphertext: []byte("ciphertext"), } - encoded, err := msg.EncodeAndMacAndSign(nil, nil, nil) - if err != nil { - t.Fatal(err) - } + + cipher, err := aessha2.NewAESSHA2(nil, nil) + require.NoError(t, err) + encoded, err := msg.EncodeAndMACAndSign(cipher, crypto.Ed25519GenerateFromSeed(make([]byte, 32))) + assert.NoError(t, err) encoded = append(encoded, hmacsha256...) encoded = append(encoded, sign...) - if !bytes.Equal(encoded, expectedRaw) { - t.Fatalf("expected '%s' but got '%s'", expectedRaw, encoded) + expected := []byte{ + 0x03, // Version + 0x08, + 0xC8, // 200 + 0x01, + 0x12, + 0x0a, } + expected = append(expected, []byte("ciphertext")...) + expected = append(expected, []byte{ + 0x6f, 0x95, 0x35, 0x51, 0xdc, 0xdb, 0xcb, 0x03, 0x0b, 0x22, 0xa2, 0xa7, 0xa1, 0xb7, 0x4f, 0x1a, + 0xa3, 0xe9, 0x5c, 0x05, 0x5d, 0x56, 0xdc, 0x5b, 0x87, 0x73, 0x05, 0x42, 0x2a, 0x59, 0x9a, 0x9a, + 0x26, 0x7a, 0x8d, 0xba, 0x65, 0xb2, 0x17, 0x65, 0x51, 0x6f, 0x37, 0xf3, 0x8f, 0xa1, 0x70, 0xd0, + 0xc4, 0x06, 0x05, 0xdc, 0x17, 0x71, 0x5e, 0x63, 0x84, 0xbe, 0xec, 0x7b, 0xa0, 0xc4, 0x08, 0xb8, + 0x9b, 0xc5, 0x08, 0x16, 0xad, 0xe5, 0x43, 0x0c, + }...) + expected = append(expected, []byte("hmacsha2signature")...) + assert.Equal(t, expected, encoded) } diff --git a/crypto/goolm/message/message.go b/crypto/goolm/message/message.go index 8b721aeb..b161a2d1 100644 --- a/crypto/goolm/message/message.go +++ b/crypto/goolm/message/message.go @@ -2,9 +2,12 @@ package message import ( "bytes" + "fmt" + "io" - "maunium.net/go/mautrix/crypto/goolm/cipher" + "maunium.net/go/mautrix/crypto/goolm/aessha2" "maunium.net/go/mautrix/crypto/goolm/crypto" + "maunium.net/go/mautrix/crypto/olm" ) const ( @@ -24,7 +27,7 @@ type Message struct { } // Decodes decodes the input and populates the corresponding fileds. MAC is ignored but has to be present. -func (r *Message) Decode(input []byte) error { +func (r *Message) Decode(input []byte) (err error) { r.Version = 0 r.HasCounter = false r.Counter = 0 @@ -33,89 +36,63 @@ func (r *Message) Decode(input []byte) error { if len(input) == 0 { return nil } - //first Byte is always version - r.Version = input[0] - curPos := 1 - for curPos < len(input)-countMACBytesMessage { - //Read Key - curKey, readBytes := decodeVarInt(input[curPos:]) - if err := checkDecodeErr(readBytes); err != nil { - return err - } - curPos += readBytes - if (curKey & 0b111) == 0 { - //The value is of type varint - value, readBytes := decodeVarInt(input[curPos:]) - if err := checkDecodeErr(readBytes); err != nil { - return err + + decoder := NewDecoder(input[:len(input)-countMACBytesMessage]) + r.Version, err = decoder.ReadByte() // first byte is always version + if err != nil { + return + } + if r.Version != protocolVersion { + return fmt.Errorf("Message.Decode: %w (got %d, expected %d)", olm.ErrWrongProtocolVersion, r.Version, protocolVersion) + } + + for { + // Read Key + if curKey, err := decoder.ReadVarInt(); err != nil { + if err == io.EOF { + // No more keys to read + return nil } - curPos += readBytes - switch curKey { - case counterTag: + return err + } else if (curKey & 0b111) == 0 { + // The value is of type varint + if value, err := decoder.ReadVarInt(); err != nil { + return err + } else if curKey == counterTag { + r.Counter = uint32(value) r.HasCounter = true - r.Counter = value } } else if (curKey & 0b111) == 2 { - //The value is of type string - value, readBytes := decodeVarString(input[curPos:]) - if err := checkDecodeErr(readBytes); err != nil { + // The value is of type string + if value, err := decoder.ReadVarBytes(); err != nil { return err - } - curPos += readBytes - switch curKey { - case ratchetKeyTag: + } else if curKey == ratchetKeyTag { r.RatchetKey = value - case cipherTextKeyTag: + } else if curKey == cipherTextKeyTag { r.Ciphertext = value } } } - - return nil } // EncodeAndMAC encodes the message and creates the MAC with the key and the cipher. // If key or cipher is nil, no MAC is appended. -func (r *Message) EncodeAndMAC(key []byte, cipher cipher.Cipher) ([]byte, error) { - var lengthOfMessage int - lengthOfMessage += 1 //Version - lengthOfMessage += encodeVarIntByteLength(ratchetKeyTag) + encodeVarStringByteLength(r.RatchetKey) - lengthOfMessage += encodeVarIntByteLength(counterTag) + encodeVarIntByteLength(r.Counter) - lengthOfMessage += encodeVarIntByteLength(cipherTextKeyTag) + encodeVarStringByteLength(r.Ciphertext) - out := make([]byte, lengthOfMessage) - out[0] = r.Version - curPos := 1 - encodedTag := encodeVarInt(ratchetKeyTag) - copy(out[curPos:], encodedTag) - curPos += len(encodedTag) - encodedValue := encodeVarString(r.RatchetKey) - copy(out[curPos:], encodedValue) - curPos += len(encodedValue) - encodedTag = encodeVarInt(counterTag) - copy(out[curPos:], encodedTag) - curPos += len(encodedTag) - encodedValue = encodeVarInt(r.Counter) - copy(out[curPos:], encodedValue) - curPos += len(encodedValue) - encodedTag = encodeVarInt(cipherTextKeyTag) - copy(out[curPos:], encodedTag) - curPos += len(encodedTag) - encodedValue = encodeVarString(r.Ciphertext) - copy(out[curPos:], encodedValue) - curPos += len(encodedValue) - if len(key) != 0 && cipher != nil { - mac, err := cipher.MAC(key, out) - if err != nil { - return nil, err - } - out = append(out, mac[:countMACBytesMessage]...) - } - return out, nil +func (r *Message) EncodeAndMAC(cipher aessha2.AESSHA2) ([]byte, error) { + var encoder Encoder + encoder.PutByte(r.Version) + encoder.PutVarInt(ratchetKeyTag) + encoder.PutVarBytes(r.RatchetKey) + encoder.PutVarInt(counterTag) + encoder.PutVarInt(uint64(r.Counter)) + encoder.PutVarInt(cipherTextKeyTag) + encoder.PutVarBytes(r.Ciphertext) + mac, err := cipher.MAC(encoder.Bytes()) + return append(encoder.Bytes(), mac[:countMACBytesMessage]...), err } // VerifyMAC verifies the givenMAC to the calculated MAC of the message. -func (r *Message) VerifyMAC(key []byte, cipher cipher.Cipher, message, givenMAC []byte) (bool, error) { - checkMAC, err := cipher.MAC(key, message) +func (r *Message) VerifyMAC(key []byte, cipher aessha2.AESSHA2, ciphertext, givenMAC []byte) (bool, error) { + checkMAC, err := cipher.MAC(ciphertext) if err != nil { return false, err } @@ -123,7 +100,7 @@ func (r *Message) VerifyMAC(key []byte, cipher cipher.Cipher, message, givenMAC } // VerifyMACInline verifies the MAC taken from the message to the calculated MAC of the message. -func (r *Message) VerifyMACInline(key []byte, cipher cipher.Cipher, message []byte) (bool, error) { +func (r *Message) VerifyMACInline(key []byte, cipher aessha2.AESSHA2, message []byte) (bool, error) { givenMAC := message[len(message)-countMACBytesMessage:] return r.VerifyMAC(key, cipher, message[:len(message)-countMACBytesMessage], givenMAC) } diff --git a/crypto/goolm/message/message_test.go b/crypto/goolm/message/message_test.go index 4a9f29fb..f3aa7108 100644 --- a/crypto/goolm/message/message_test.go +++ b/crypto/goolm/message/message_test.go @@ -1,9 +1,11 @@ package message_test import ( - "bytes" "testing" + "github.com/stretchr/testify/assert" + + "maunium.net/go/mautrix/crypto/goolm/aessha2" "maunium.net/go/mautrix/crypto/goolm/message" ) @@ -14,28 +16,16 @@ func TestMessageDecode(t *testing.T) { msg := message.Message{} err := msg.Decode(messageRaw) - if err != nil { - t.Fatal(err) - } - if msg.Version != 3 { - t.Fatalf("Expected Version to be 3 but go %d", msg.Version) - } - if !msg.HasCounter { - t.Fatal("Expected to have counter") - } - if msg.Counter != 1 { - t.Fatalf("Expected counter to be 1 but got %d", msg.Counter) - } - if !bytes.Equal(msg.Ciphertext, expectedCipherText) { - t.Fatalf("expected '%s' but got '%s'", expectedCipherText, msg.Ciphertext) - } - if !bytes.Equal(msg.RatchetKey, expectedRatchetKey) { - t.Fatalf("expected '%s' but got '%s'", expectedRatchetKey, msg.RatchetKey) - } + assert.NoError(t, err) + assert.EqualValues(t, 3, msg.Version) + assert.True(t, msg.HasCounter) + assert.EqualValues(t, 1, msg.Counter) + assert.Equal(t, expectedCipherText, msg.Ciphertext) + assert.EqualValues(t, expectedRatchetKey, msg.RatchetKey) } func TestMessageEncode(t *testing.T) { - expectedRaw := []byte("\x03\n\nratchetkey\x10\x01\"\nciphertexthmacsha2") + expectedRaw := []byte("\x03\n\nratchetkey\x10\x01\"\nciphertext\x95\x95\x92\x72\x04\x70\x56\xcdhmacsha2") hmacsha256 := []byte("hmacsha2") msg := message.Message{ Version: 3, @@ -43,12 +33,10 @@ func TestMessageEncode(t *testing.T) { RatchetKey: []byte("ratchetkey"), Ciphertext: []byte("ciphertext"), } - encoded, err := msg.EncodeAndMAC(nil, nil) - if err != nil { - t.Fatal(err) - } + cipher, err := aessha2.NewAESSHA2(nil, nil) + assert.NoError(t, err) + encoded, err := msg.EncodeAndMAC(cipher) + assert.NoError(t, err) encoded = append(encoded, hmacsha256...) - if !bytes.Equal(encoded, expectedRaw) { - t.Fatalf("expected '%s' but got '%s'", expectedRaw, encoded) - } + assert.Equal(t, expectedRaw, encoded) } diff --git a/crypto/goolm/message/prekey_message.go b/crypto/goolm/message/prekey_message.go index 6e007e06..4e3d495d 100644 --- a/crypto/goolm/message/prekey_message.go +++ b/crypto/goolm/message/prekey_message.go @@ -1,11 +1,15 @@ package message import ( + "fmt" + "io" + "maunium.net/go/mautrix/crypto/goolm/crypto" + "maunium.net/go/mautrix/crypto/olm" ) const ( - oneTimeKeyIdTag = 0x0A + oneTimeKeyIDTag = 0x0A baseKeyTag = 0x12 identityKeyTag = 0x1A messageTag = 0x22 @@ -19,8 +23,13 @@ type PreKeyMessage struct { Message []byte `json:"message"` } +// TODO deduplicate constant with one in session/olm_session.go +const ( + protocolVersion = 0x3 +) + // Decodes decodes the input and populates the corresponding fileds. -func (r *PreKeyMessage) Decode(input []byte) error { +func (r *PreKeyMessage) Decode(input []byte) (err error) { r.Version = 0 r.IdentityKey = nil r.BaseKey = nil @@ -29,44 +38,55 @@ func (r *PreKeyMessage) Decode(input []byte) error { if len(input) == 0 { return nil } - //first Byte is always version - r.Version = input[0] - curPos := 1 - for curPos < len(input) { - //Read Key - curKey, readBytes := decodeVarInt(input[curPos:]) - if err := checkDecodeErr(readBytes); err != nil { - return err + + decoder := NewDecoder(input) + r.Version, err = decoder.ReadByte() // first byte is always version + if err != nil { + if err == io.EOF { + return olm.ErrInputToSmall } - curPos += readBytes - if (curKey & 0b111) == 0 { - //The value is of type varint - _, readBytes := decodeVarInt(input[curPos:]) - if err := checkDecodeErr(readBytes); err != nil { + return + } + if r.Version != protocolVersion { + return fmt.Errorf("PreKeyMessage.Decode: %w (got %d, expected %d)", olm.ErrWrongProtocolVersion, r.Version, protocolVersion) + } + + for { + // Read Key + if curKey, err := decoder.ReadVarInt(); err != nil { + if err == io.EOF { + return nil + } + return err + } else if (curKey & 0b111) == 0 { + // The value is of type varint + if _, err = decoder.ReadVarInt(); err != nil { + if err == io.EOF { + return olm.ErrInputToSmall + } return err } - curPos += readBytes } else if (curKey & 0b111) == 2 { - //The value is of type string - value, readBytes := decodeVarString(input[curPos:]) - if err := checkDecodeErr(readBytes); err != nil { + // The value is of type string + if value, err := decoder.ReadVarBytes(); err != nil { + if err == io.EOF { + return olm.ErrInputToSmall + } return err - } - curPos += readBytes - switch curKey { - case oneTimeKeyIdTag: - r.OneTimeKey = value - case baseKeyTag: - r.BaseKey = value - case identityKeyTag: - r.IdentityKey = value - case messageTag: - r.Message = value + } else { + switch curKey { + case oneTimeKeyIDTag: + r.OneTimeKey = value + case baseKeyTag: + r.BaseKey = value + case identityKeyTag: + r.IdentityKey = value + case messageTag: + r.Message = value + } } } } - - return nil } // CheckField verifies the fields. If theirIdentityKey is nil, it is not compared to the key in the message. @@ -74,47 +94,25 @@ func (r *PreKeyMessage) CheckFields(theirIdentityKey *crypto.Curve25519PublicKey ok := true ok = ok && (theirIdentityKey != nil || r.IdentityKey != nil) if r.IdentityKey != nil { - ok = ok && (len(r.IdentityKey) == crypto.Curve25519KeyLength) + ok = ok && (len(r.IdentityKey) == crypto.Curve25519PrivateKeyLength) } ok = ok && len(r.Message) != 0 - ok = ok && len(r.BaseKey) == crypto.Curve25519KeyLength - ok = ok && len(r.OneTimeKey) == crypto.Curve25519KeyLength + ok = ok && len(r.BaseKey) == crypto.Curve25519PrivateKeyLength + ok = ok && len(r.OneTimeKey) == crypto.Curve25519PrivateKeyLength return ok } // Encode encodes the message. func (r *PreKeyMessage) Encode() ([]byte, error) { - var lengthOfMessage int - lengthOfMessage += 1 //Version - lengthOfMessage += encodeVarIntByteLength(oneTimeKeyIdTag) + encodeVarStringByteLength(r.OneTimeKey) - lengthOfMessage += encodeVarIntByteLength(identityKeyTag) + encodeVarStringByteLength(r.IdentityKey) - lengthOfMessage += encodeVarIntByteLength(baseKeyTag) + encodeVarStringByteLength(r.BaseKey) - lengthOfMessage += encodeVarIntByteLength(messageTag) + encodeVarStringByteLength(r.Message) - out := make([]byte, lengthOfMessage) - out[0] = r.Version - curPos := 1 - encodedTag := encodeVarInt(oneTimeKeyIdTag) - copy(out[curPos:], encodedTag) - curPos += len(encodedTag) - encodedValue := encodeVarString(r.OneTimeKey) - copy(out[curPos:], encodedValue) - curPos += len(encodedValue) - encodedTag = encodeVarInt(identityKeyTag) - copy(out[curPos:], encodedTag) - curPos += len(encodedTag) - encodedValue = encodeVarString(r.IdentityKey) - copy(out[curPos:], encodedValue) - curPos += len(encodedValue) - encodedTag = encodeVarInt(baseKeyTag) - copy(out[curPos:], encodedTag) - curPos += len(encodedTag) - encodedValue = encodeVarString(r.BaseKey) - copy(out[curPos:], encodedValue) - curPos += len(encodedValue) - encodedTag = encodeVarInt(messageTag) - copy(out[curPos:], encodedTag) - curPos += len(encodedTag) - encodedValue = encodeVarString(r.Message) - copy(out[curPos:], encodedValue) - return out, nil + var encoder Encoder + encoder.PutByte(r.Version) + encoder.PutVarInt(oneTimeKeyIDTag) + encoder.PutVarBytes(r.OneTimeKey) + encoder.PutVarInt(identityKeyTag) + encoder.PutVarBytes(r.IdentityKey) + encoder.PutVarInt(baseKeyTag) + encoder.PutVarBytes(r.BaseKey) + encoder.PutVarInt(messageTag) + encoder.PutVarBytes(r.Message) + return encoder.Bytes(), nil } diff --git a/crypto/goolm/message/prekey_message_test.go b/crypto/goolm/message/prekey_message_test.go index 431d27d5..fe196e31 100644 --- a/crypto/goolm/message/prekey_message_test.go +++ b/crypto/goolm/message/prekey_message_test.go @@ -1,9 +1,10 @@ package message_test import ( - "bytes" "testing" + "github.com/stretchr/testify/assert" + "maunium.net/go/mautrix/crypto/goolm/crypto" "maunium.net/go/mautrix/crypto/goolm/message" ) @@ -19,29 +20,14 @@ func TestPreKeyMessageDecode(t *testing.T) { msg := message.PreKeyMessage{} err := msg.Decode(messageRaw) - if err != nil { - t.Fatal(err) - } - if msg.Version != 3 { - t.Fatalf("Expected Version to be 3 but go %d", msg.Version) - } - if !bytes.Equal(msg.OneTimeKey, expectedOneTimeKey) { - t.Fatalf("expected '%s' but got '%s'", expectedOneTimeKey, msg.OneTimeKey) - } - if !bytes.Equal(msg.IdentityKey, expectedIdKey) { - t.Fatalf("expected '%s' but got '%s'", expectedIdKey, msg.IdentityKey) - } - if !bytes.Equal(msg.BaseKey, expectedbaseKey) { - t.Fatalf("expected '%s' but got '%s'", expectedbaseKey, msg.BaseKey) - } - if !bytes.Equal(msg.Message, expectedmessage) { - t.Fatalf("expected '%s' but got '%s'", expectedmessage, msg.Message) - } + assert.NoError(t, err) + assert.EqualValues(t, 3, msg.Version) + assert.EqualValues(t, expectedOneTimeKey, msg.OneTimeKey) + assert.EqualValues(t, expectedIdKey, msg.IdentityKey) + assert.EqualValues(t, expectedbaseKey, msg.BaseKey) + assert.Equal(t, expectedmessage, msg.Message) theirIDKey := crypto.Curve25519PublicKey(expectedIdKey) - checked := msg.CheckFields(&theirIDKey) - if !checked { - t.Fatal("field check failed") - } + assert.True(t, msg.CheckFields(&theirIDKey), "field check failed") } func TestPreKeyMessageEncode(t *testing.T) { @@ -54,10 +40,6 @@ func TestPreKeyMessageEncode(t *testing.T) { Message: []byte("message"), } encoded, err := msg.Encode() - if err != nil { - t.Fatal(err) - } - if !bytes.Equal(encoded, expectedRaw) { - t.Fatalf("got other than expected:\nExpected:\n%v\nGot:\n%v", expectedRaw, encoded) - } + assert.NoError(t, err) + assert.Equal(t, expectedRaw, encoded) } diff --git a/crypto/goolm/message/session_export.go b/crypto/goolm/message/session_export.go index f539cce5..d58dbb21 100644 --- a/crypto/goolm/message/session_export.go +++ b/crypto/goolm/message/session_export.go @@ -4,8 +4,8 @@ import ( "encoding/binary" "fmt" - "maunium.net/go/mautrix/crypto/goolm" "maunium.net/go/mautrix/crypto/goolm/crypto" + "maunium.net/go/mautrix/crypto/olm" ) const ( @@ -32,10 +32,10 @@ func (s MegolmSessionExport) Encode() []byte { // Decode populates the struct with the data encoded in input. func (s *MegolmSessionExport) Decode(input []byte) error { if len(input) != 165 { - return fmt.Errorf("decrypt: %w", goolm.ErrBadInput) + return fmt.Errorf("decrypt: %w", olm.ErrBadInput) } if input[0] != sessionExportVersion { - return fmt.Errorf("decrypt: %w", goolm.ErrBadVersion) + return fmt.Errorf("decrypt: %w", olm.ErrUnknownOlmPickleVersion) } s.Counter = binary.BigEndian.Uint32(input[1:5]) copy(s.RatchetData[:], input[5:133]) diff --git a/crypto/goolm/message/session_sharing.go b/crypto/goolm/message/session_sharing.go index c5393f50..d04ef15a 100644 --- a/crypto/goolm/message/session_sharing.go +++ b/crypto/goolm/message/session_sharing.go @@ -4,8 +4,8 @@ import ( "encoding/binary" "fmt" - "maunium.net/go/mautrix/crypto/goolm" "maunium.net/go/mautrix/crypto/goolm/crypto" + "maunium.net/go/mautrix/crypto/olm" ) const ( @@ -20,29 +20,29 @@ type MegolmSessionSharing struct { } // Encode returns the encoded message in the correct format with the signature by key appended. -func (s MegolmSessionSharing) EncodeAndSign(key crypto.Ed25519KeyPair) []byte { +func (s MegolmSessionSharing) EncodeAndSign(key crypto.Ed25519KeyPair) ([]byte, error) { output := make([]byte, 229) output[0] = sessionSharingVersion binary.BigEndian.PutUint32(output[1:], s.Counter) copy(output[5:], s.RatchetData[:]) copy(output[133:], key.PublicKey) - signature := key.Sign(output[:165]) + signature, err := key.Sign(output[:165]) copy(output[165:], signature) - return output + return output, err } // VerifyAndDecode verifies the input and populates the struct with the data encoded in input. func (s *MegolmSessionSharing) VerifyAndDecode(input []byte) error { if len(input) != 229 { - return fmt.Errorf("verify: %w", goolm.ErrBadInput) + return fmt.Errorf("verify: %w", olm.ErrBadInput) } publicKey := crypto.Ed25519PublicKey(input[133:165]) if !publicKey.Verify(input[:165], input[165:]) { - return fmt.Errorf("verify: %w", goolm.ErrBadVerification) + return fmt.Errorf("verify: %w", olm.ErrBadVerification) } s.PublicKey = publicKey if input[0] != sessionSharingVersion { - return fmt.Errorf("verify: %w", goolm.ErrBadVersion) + return fmt.Errorf("verify: %w", olm.ErrUnknownOlmPickleVersion) } s.Counter = binary.BigEndian.Uint32(input[1:5]) copy(s.RatchetData[:], input[5:133]) diff --git a/crypto/goolm/olm/chain.go b/crypto/goolm/olm/chain.go deleted file mode 100644 index 403637a4..00000000 --- a/crypto/goolm/olm/chain.go +++ /dev/null @@ -1,258 +0,0 @@ -package olm - -import ( - "fmt" - - "maunium.net/go/mautrix/crypto/goolm" - "maunium.net/go/mautrix/crypto/goolm/crypto" - "maunium.net/go/mautrix/crypto/goolm/libolmpickle" -) - -const ( - chainKeySeed = 0x02 - messageKeyLength = 32 -) - -// chainKey wraps the index and the public key -type chainKey struct { - Index uint32 `json:"index"` - Key crypto.Curve25519PublicKey `json:"key"` -} - -// advance advances the chain -func (c *chainKey) advance() { - c.Key = crypto.HMACSHA256(c.Key, []byte{chainKeySeed}) - c.Index++ -} - -// UnpickleLibOlm decodes the unencryted value and populates the chain key accordingly. It returns the number of bytes read. -func (r *chainKey) UnpickleLibOlm(value []byte) (int, error) { - curPos := 0 - readBytes, err := r.Key.UnpickleLibOlm(value) - if err != nil { - return 0, err - } - curPos += readBytes - r.Index, readBytes, err = libolmpickle.UnpickleUInt32(value[curPos:]) - if err != nil { - return 0, err - } - curPos += readBytes - return curPos, nil -} - -// PickleLibOlm encodes the chain key into target. target has to have a size of at least PickleLen() and is written to from index 0. -// It returns the number of bytes written. -func (r chainKey) PickleLibOlm(target []byte) (int, error) { - if len(target) < r.PickleLen() { - return 0, fmt.Errorf("pickle chain key: %w", goolm.ErrValueTooShort) - } - written, err := r.Key.PickleLibOlm(target) - if err != nil { - return 0, fmt.Errorf("pickle chain key: %w", err) - } - written += libolmpickle.PickleUInt32(r.Index, target[written:]) - return written, nil -} - -// PickleLen returns the number of bytes the pickled chain key will have. -func (r chainKey) PickleLen() int { - length := r.Key.PickleLen() - length += libolmpickle.PickleUInt32Len(r.Index) - return length -} - -// senderChain is a chain for sending messages -type senderChain struct { - RKey crypto.Curve25519KeyPair `json:"ratchet_key"` - CKey chainKey `json:"chain_key"` - IsSet bool `json:"set"` -} - -// newSenderChain returns a sender chain initialized with chainKey and ratchet key pair. -func newSenderChain(key crypto.Curve25519PublicKey, ratchet crypto.Curve25519KeyPair) *senderChain { - return &senderChain{ - RKey: ratchet, - CKey: chainKey{ - Index: 0, - Key: key, - }, - IsSet: true, - } -} - -// advance advances the chain -func (s *senderChain) advance() { - s.CKey.advance() -} - -// ratchetKey returns the ratchet key pair. -func (s senderChain) ratchetKey() crypto.Curve25519KeyPair { - return s.RKey -} - -// chainKey returns the current chainKey. -func (s senderChain) chainKey() chainKey { - return s.CKey -} - -// UnpickleLibOlm decodes the unencryted value and populates the chain accordingly. It returns the number of bytes read. -func (r *senderChain) UnpickleLibOlm(value []byte) (int, error) { - curPos := 0 - readBytes, err := r.RKey.UnpickleLibOlm(value) - if err != nil { - return 0, err - } - curPos += readBytes - readBytes, err = r.CKey.UnpickleLibOlm(value[curPos:]) - if err != nil { - return 0, err - } - curPos += readBytes - return curPos, nil -} - -// PickleLibOlm encodes the chain into target. target has to have a size of at least PickleLen() and is written to from index 0. -// It returns the number of bytes written. -func (r senderChain) PickleLibOlm(target []byte) (int, error) { - if len(target) < r.PickleLen() { - return 0, fmt.Errorf("pickle sender chain: %w", goolm.ErrValueTooShort) - } - written, err := r.RKey.PickleLibOlm(target) - if err != nil { - return 0, fmt.Errorf("pickle sender chain: %w", err) - } - writtenChain, err := r.CKey.PickleLibOlm(target[written:]) - if err != nil { - return 0, fmt.Errorf("pickle sender chain: %w", err) - } - written += writtenChain - return written, nil -} - -// PickleLen returns the number of bytes the pickled chain will have. -func (r senderChain) PickleLen() int { - length := r.RKey.PickleLen() - length += r.CKey.PickleLen() - return length -} - -// senderChain is a chain for receiving messages -type receiverChain struct { - RKey crypto.Curve25519PublicKey `json:"ratchet_key"` - CKey chainKey `json:"chain_key"` -} - -// newReceiverChain returns a receiver chain initialized with chainKey and ratchet public key. -func newReceiverChain(chain crypto.Curve25519PublicKey, ratchet crypto.Curve25519PublicKey) *receiverChain { - return &receiverChain{ - RKey: ratchet, - CKey: chainKey{ - Index: 0, - Key: chain, - }, - } -} - -// advance advances the chain -func (s *receiverChain) advance() { - s.CKey.advance() -} - -// ratchetKey returns the ratchet public key. -func (s receiverChain) ratchetKey() crypto.Curve25519PublicKey { - return s.RKey -} - -// chainKey returns the current chainKey. -func (s receiverChain) chainKey() chainKey { - return s.CKey -} - -// UnpickleLibOlm decodes the unencryted value and populates the chain accordingly. It returns the number of bytes read. -func (r *receiverChain) UnpickleLibOlm(value []byte) (int, error) { - curPos := 0 - readBytes, err := r.RKey.UnpickleLibOlm(value) - if err != nil { - return 0, err - } - curPos += readBytes - readBytes, err = r.CKey.UnpickleLibOlm(value[curPos:]) - if err != nil { - return 0, err - } - curPos += readBytes - return curPos, nil -} - -// PickleLibOlm encodes the chain into target. target has to have a size of at least PickleLen() and is written to from index 0. -// It returns the number of bytes written. -func (r receiverChain) PickleLibOlm(target []byte) (int, error) { - if len(target) < r.PickleLen() { - return 0, fmt.Errorf("pickle sender chain: %w", goolm.ErrValueTooShort) - } - written, err := r.RKey.PickleLibOlm(target) - if err != nil { - return 0, fmt.Errorf("pickle sender chain: %w", err) - } - writtenChain, err := r.CKey.PickleLibOlm(target) - if err != nil { - return 0, fmt.Errorf("pickle sender chain: %w", err) - } - written += writtenChain - return written, nil -} - -// PickleLen returns the number of bytes the pickled chain will have. -func (r receiverChain) PickleLen() int { - length := r.RKey.PickleLen() - length += r.CKey.PickleLen() - return length -} - -// messageKey wraps the index and the key of a message -type messageKey struct { - Index uint32 `json:"index"` - Key []byte `json:"key"` -} - -// UnpickleLibOlm decodes the unencryted value and populates the message key accordingly. It returns the number of bytes read. -func (m *messageKey) UnpickleLibOlm(value []byte) (int, error) { - curPos := 0 - ratchetKey, readBytes, err := libolmpickle.UnpickleBytes(value, messageKeyLength) - if err != nil { - return 0, err - } - m.Key = ratchetKey - curPos += readBytes - keyID, readBytes, err := libolmpickle.UnpickleUInt32(value[:curPos]) - if err != nil { - return 0, err - } - curPos += readBytes - m.Index = keyID - return curPos, nil -} - -// PickleLibOlm encodes the message key into target. target has to have a size of at least PickleLen() and is written to from index 0. -// It returns the number of bytes written. -func (m messageKey) PickleLibOlm(target []byte) (int, error) { - if len(target) < m.PickleLen() { - return 0, fmt.Errorf("pickle message key: %w", goolm.ErrValueTooShort) - } - written := 0 - if len(m.Key) != messageKeyLength { - written += libolmpickle.PickleBytes(make([]byte, messageKeyLength), target) - } else { - written += libolmpickle.PickleBytes(m.Key, target) - } - written += libolmpickle.PickleUInt32(m.Index, target[written:]) - return written, nil -} - -// PickleLen returns the number of bytes the pickled message key will have. -func (r messageKey) PickleLen() int { - length := libolmpickle.PickleBytesLen(make([]byte, messageKeyLength)) - length += libolmpickle.PickleUInt32Len(r.Index) - return length -} diff --git a/crypto/goolm/olm/olm_test.go b/crypto/goolm/olm/olm_test.go deleted file mode 100644 index 974ffc5e..00000000 --- a/crypto/goolm/olm/olm_test.go +++ /dev/null @@ -1,186 +0,0 @@ -package olm_test - -import ( - "bytes" - "encoding/json" - "testing" - - "maunium.net/go/mautrix/crypto/goolm/cipher" - "maunium.net/go/mautrix/crypto/goolm/crypto" - "maunium.net/go/mautrix/crypto/goolm/olm" -) - -var ( - sharedSecret = []byte("A secret") -) - -func initializeRatchets() (*olm.Ratchet, *olm.Ratchet, error) { - olm.KdfInfo = struct { - Root []byte - Ratchet []byte - }{ - Root: []byte("Olm"), - Ratchet: []byte("OlmRatchet"), - } - olm.RatchetCipher = cipher.NewAESSHA256([]byte("OlmMessageKeys")) - aliceRatchet := olm.New() - bobRatchet := olm.New() - - aliceKey, err := crypto.Curve25519GenerateKey(nil) - if err != nil { - return nil, nil, err - } - - aliceRatchet.InitializeAsAlice(sharedSecret, aliceKey) - bobRatchet.InitializeAsBob(sharedSecret, aliceKey.PublicKey) - return aliceRatchet, bobRatchet, nil -} - -func TestSendReceive(t *testing.T) { - aliceRatchet, bobRatchet, err := initializeRatchets() - if err != nil { - t.Fatal(err) - } - - plainText := []byte("Hello Bob") - - //Alice sends Bob a message - encryptedMessage, err := aliceRatchet.Encrypt(plainText, nil) - if err != nil { - t.Fatal(err) - } - - decrypted, err := bobRatchet.Decrypt(encryptedMessage) - if err != nil { - t.Fatal(err) - } - if !bytes.Equal(plainText, decrypted) { - t.Fatalf("expected '%v' from decryption but got '%v'", plainText, decrypted) - } - - //Bob sends Alice a message - plainText = []byte("Hello Alice") - encryptedMessage, err = bobRatchet.Encrypt(plainText, nil) - if err != nil { - t.Fatal(err) - } - decrypted, err = aliceRatchet.Decrypt(encryptedMessage) - if err != nil { - t.Fatal(err) - } - if !bytes.Equal(plainText, decrypted) { - t.Fatalf("expected '%v' from decryption but got '%v'", plainText, decrypted) - } -} - -func TestOutOfOrder(t *testing.T) { - aliceRatchet, bobRatchet, err := initializeRatchets() - if err != nil { - t.Fatal(err) - } - - plainText1 := []byte("First Message") - plainText2 := []byte("Second Messsage. A bit longer than the first.") - - /* Alice sends Bob two messages and they arrive out of order */ - message1Encrypted, err := aliceRatchet.Encrypt(plainText1, nil) - if err != nil { - t.Fatal(err) - } - message2Encrypted, err := aliceRatchet.Encrypt(plainText2, nil) - if err != nil { - t.Fatal(err) - } - - decrypted2, err := bobRatchet.Decrypt(message2Encrypted) - if err != nil { - t.Fatal(err) - } - decrypted1, err := bobRatchet.Decrypt(message1Encrypted) - if err != nil { - t.Fatal(err) - } - if !bytes.Equal(plainText1, decrypted1) { - t.Fatalf("expected '%v' from decryption but got '%v'", plainText1, decrypted1) - } - if !bytes.Equal(plainText2, decrypted2) { - t.Fatalf("expected '%v' from decryption but got '%v'", plainText2, decrypted2) - } -} - -func TestMoreMessages(t *testing.T) { - aliceRatchet, bobRatchet, err := initializeRatchets() - if err != nil { - t.Fatal(err) - } - plainText := []byte("These 15 bytes") - for i := 0; i < 8; i++ { - messageEncrypted, err := aliceRatchet.Encrypt(plainText, nil) - if err != nil { - t.Fatal(err) - } - decrypted, err := bobRatchet.Decrypt(messageEncrypted) - if err != nil { - t.Fatal(err) - } - if !bytes.Equal(plainText, decrypted) { - t.Fatalf("expected '%v' from decryption but got '%v'", plainText, decrypted) - } - } - for i := 0; i < 8; i++ { - messageEncrypted, err := bobRatchet.Encrypt(plainText, nil) - if err != nil { - t.Fatal(err) - } - decrypted, err := aliceRatchet.Decrypt(messageEncrypted) - if err != nil { - t.Fatal(err) - } - if !bytes.Equal(plainText, decrypted) { - t.Fatalf("expected '%v' from decryption but got '%v'", plainText, decrypted) - } - } - messageEncrypted, err := aliceRatchet.Encrypt(plainText, nil) - if err != nil { - t.Fatal(err) - } - decrypted, err := bobRatchet.Decrypt(messageEncrypted) - if err != nil { - t.Fatal(err) - } - if !bytes.Equal(plainText, decrypted) { - t.Fatalf("expected '%v' from decryption but got '%v'", plainText, decrypted) - } -} - -func TestJSONEncoding(t *testing.T) { - aliceRatchet, bobRatchet, err := initializeRatchets() - if err != nil { - t.Fatal(err) - } - marshaled, err := json.Marshal(aliceRatchet) - if err != nil { - t.Fatal(err) - } - - newRatcher := olm.Ratchet{} - err = json.Unmarshal(marshaled, &newRatcher) - if err != nil { - t.Fatal(err) - } - - plainText := []byte("These 15 bytes") - - messageEncrypted, err := newRatcher.Encrypt(plainText, nil) - if err != nil { - t.Fatal(err) - } - decrypted, err := bobRatchet.Decrypt(messageEncrypted) - if err != nil { - t.Fatal(err) - } - if !bytes.Equal(plainText, decrypted) { - t.Fatalf("expected '%v' from decryption but got '%v'", plainText, decrypted) - } - -} diff --git a/crypto/goolm/olm/skipped_message.go b/crypto/goolm/olm/skipped_message.go deleted file mode 100644 index 944337f6..00000000 --- a/crypto/goolm/olm/skipped_message.go +++ /dev/null @@ -1,55 +0,0 @@ -package olm - -import ( - "fmt" - - "maunium.net/go/mautrix/crypto/goolm" - "maunium.net/go/mautrix/crypto/goolm/crypto" -) - -// skippedMessageKey stores a skipped message key -type skippedMessageKey struct { - RKey crypto.Curve25519PublicKey `json:"ratchet_key"` - MKey messageKey `json:"message_key"` -} - -// UnpickleLibOlm decodes the unencryted value and populates the chain accordingly. It returns the number of bytes read. -func (r *skippedMessageKey) UnpickleLibOlm(value []byte) (int, error) { - curPos := 0 - readBytes, err := r.RKey.UnpickleLibOlm(value) - if err != nil { - return 0, err - } - curPos += readBytes - readBytes, err = r.MKey.UnpickleLibOlm(value[curPos:]) - if err != nil { - return 0, err - } - curPos += readBytes - return curPos, nil -} - -// PickleLibOlm encodes the chain into target. target has to have a size of at least PickleLen() and is written to from index 0. -// It returns the number of bytes written. -func (r skippedMessageKey) PickleLibOlm(target []byte) (int, error) { - if len(target) < r.PickleLen() { - return 0, fmt.Errorf("pickle sender chain: %w", goolm.ErrValueTooShort) - } - written, err := r.RKey.PickleLibOlm(target) - if err != nil { - return 0, fmt.Errorf("pickle sender chain: %w", err) - } - writtenChain, err := r.MKey.PickleLibOlm(target) - if err != nil { - return 0, fmt.Errorf("pickle sender chain: %w", err) - } - written += writtenChain - return written, nil -} - -// PickleLen returns the number of bytes the pickled chain will have. -func (r skippedMessageKey) PickleLen() int { - length := r.RKey.PickleLen() - length += r.MKey.PickleLen() - return length -} diff --git a/crypto/goolm/pk/decryption.go b/crypto/goolm/pk/decryption.go index d08e09f4..cdb20eb1 100644 --- a/crypto/goolm/pk/decryption.go +++ b/crypto/goolm/pk/decryption.go @@ -2,14 +2,13 @@ package pk import ( "encoding/base64" - "errors" "fmt" - "maunium.net/go/mautrix/crypto/goolm" - "maunium.net/go/mautrix/crypto/goolm/cipher" + "maunium.net/go/mautrix/crypto/goolm/aessha2" "maunium.net/go/mautrix/crypto/goolm/crypto" + "maunium.net/go/mautrix/crypto/goolm/goolmbase64" "maunium.net/go/mautrix/crypto/goolm/libolmpickle" - "maunium.net/go/mautrix/crypto/goolm/utilities" + "maunium.net/go/mautrix/crypto/olm" "maunium.net/go/mautrix/id" ) @@ -25,7 +24,7 @@ type Decryption struct { // NewDecryption returns a new Decryption with a new generated key pair. func NewDecryption() (*Decryption, error) { - keyPair, err := crypto.Curve25519GenerateKey(nil) + keyPair, err := crypto.Curve25519GenerateKey() if err != nil { return nil, err } @@ -56,110 +55,67 @@ func (s Decryption) PrivateKey() crypto.Curve25519PrivateKey { } // Decrypt decrypts the ciphertext and verifies the MAC. The base64 encoded key is used to construct the shared secret. -func (s Decryption) Decrypt(ciphertext, mac []byte, key id.Curve25519) ([]byte, error) { - keyDecoded, err := base64.RawStdEncoding.DecodeString(string(key)) - if err != nil { +func (s Decryption) Decrypt(ephemeralKey, mac, ciphertext []byte) ([]byte, error) { + if keyDecoded, err := base64.RawStdEncoding.DecodeString(string(ephemeralKey)); err != nil { return nil, err - } - sharedSecret, err := s.KeyPair.SharedSecret(keyDecoded) - if err != nil { + } else if sharedSecret, err := s.KeyPair.SharedSecret(keyDecoded); err != nil { return nil, err - } - decodedMAC, err := goolm.Base64Decode(mac) - if err != nil { + } else if decodedMAC, err := goolmbase64.Decode(mac); err != nil { return nil, err - } - cipher := cipher.NewAESSHA256(nil) - verified, err := cipher.Verify(sharedSecret, ciphertext, decodedMAC) - if err != nil { + } else if cipher, err := aessha2.NewAESSHA2(sharedSecret, nil); err != nil { return nil, err - } - if !verified { - return nil, fmt.Errorf("decrypt: %w", goolm.ErrBadMAC) - } - plaintext, err := cipher.Decrypt(sharedSecret, ciphertext) - if err != nil { + } else if verified, err := cipher.VerifyMAC(ciphertext, decodedMAC); err != nil { return nil, err + } else if !verified { + return nil, fmt.Errorf("decrypt: %w", olm.ErrBadMAC) + } else { + return cipher.Decrypt(ciphertext) } - return plaintext, nil } // PickleAsJSON returns an Decryption as a base64 string encrypted using the supplied key. The unencrypted representation of the Account is in JSON format. func (a Decryption) PickleAsJSON(key []byte) ([]byte, error) { - return utilities.PickleAsJSON(a, decryptionPickleVersionJSON, key) + return libolmpickle.PickleAsJSON(a, decryptionPickleVersionJSON, key) } // UnpickleAsJSON updates an Decryption by a base64 encrypted string using the supplied key. The unencrypted representation has to be in JSON format. func (a *Decryption) UnpickleAsJSON(pickled, key []byte) error { - return utilities.UnpickleAsJSON(a, pickled, key, decryptionPickleVersionJSON) + return libolmpickle.UnpickleAsJSON(a, pickled, key, decryptionPickleVersionJSON) } // Unpickle decodes the base64 encoded string and decrypts the result with the key. // The decrypted value is then passed to UnpickleLibOlm. func (a *Decryption) Unpickle(pickled, key []byte) error { - decrypted, err := cipher.Unpickle(key, pickled) + decrypted, err := libolmpickle.Unpickle(key, pickled) if err != nil { return err } - _, err = a.UnpickleLibOlm(decrypted) - return err + return a.UnpickleLibOlm(decrypted) } // UnpickleLibOlm decodes the unencryted value and populates the Decryption accordingly. It returns the number of bytes read. -func (a *Decryption) UnpickleLibOlm(value []byte) (int, error) { - //First 4 bytes are the accountPickleVersion - pickledVersion, curPos, err := libolmpickle.UnpickleUInt32(value) +func (a *Decryption) UnpickleLibOlm(unpickled []byte) error { + decoder := libolmpickle.NewDecoder(unpickled) + pickledVersion, err := decoder.ReadUInt32() if err != nil { - return 0, err + return err } - switch pickledVersion { - case decryptionPickleVersionLibOlm: - default: - return 0, fmt.Errorf("unpickle olmSession: %w", goolm.ErrBadVersion) + if pickledVersion == decryptionPickleVersionLibOlm { + return a.KeyPair.UnpickleLibOlm(decoder) + } else { + return fmt.Errorf("unpickle olmSession: %w (found %d, expected %d)", olm.ErrUnknownOlmPickleVersion, pickledVersion, decryptionPickleVersionLibOlm) } - readBytes, err := a.KeyPair.UnpickleLibOlm(value[curPos:]) - if err != nil { - return 0, err - } - curPos += readBytes - return curPos, nil } // Pickle returns a base64 encoded and with key encrypted pickled Decryption using PickleLibOlm(). func (a Decryption) Pickle(key []byte) ([]byte, error) { - pickeledBytes := make([]byte, a.PickleLen()) - written, err := a.PickleLibOlm(pickeledBytes) - if err != nil { - return nil, err - } - if written != len(pickeledBytes) { - return nil, errors.New("number of written bytes not correct") - } - encrypted, err := cipher.Pickle(key, pickeledBytes) - if err != nil { - return nil, err - } - return encrypted, nil + return libolmpickle.Pickle(key, a.PickleLibOlm()) } -// PickleLibOlm encodes the Decryption into target. target has to have a size of at least PickleLen() and is written to from index 0. -// It returns the number of bytes written. -func (a Decryption) PickleLibOlm(target []byte) (int, error) { - if len(target) < a.PickleLen() { - return 0, fmt.Errorf("pickle Decryption: %w", goolm.ErrValueTooShort) - } - written := libolmpickle.PickleUInt32(decryptionPickleVersionLibOlm, target) - writtenKey, err := a.KeyPair.PickleLibOlm(target[written:]) - if err != nil { - return 0, fmt.Errorf("pickle Decryption: %w", err) - } - written += writtenKey - return written, nil -} - -// PickleLen returns the number of bytes the pickled Decryption will have. -func (a Decryption) PickleLen() int { - length := libolmpickle.PickleUInt32Len(decryptionPickleVersionLibOlm) - length += a.KeyPair.PickleLen() - return length +// PickleLibOlm pickles the [Decryption] into the encoder. +func (a Decryption) PickleLibOlm() []byte { + encoder := libolmpickle.NewEncoder() + encoder.WriteUInt32(decryptionPickleVersionLibOlm) + a.KeyPair.PickleLibOlm(encoder) + return encoder.Bytes() } diff --git a/crypto/goolm/pk/encryption.go b/crypto/goolm/pk/encryption.go index dc50a6bb..2897d9b0 100644 --- a/crypto/goolm/pk/encryption.go +++ b/crypto/goolm/pk/encryption.go @@ -5,9 +5,9 @@ import ( "maunium.net/go/mautrix/id" - "maunium.net/go/mautrix/crypto/goolm" - "maunium.net/go/mautrix/crypto/goolm/cipher" + "maunium.net/go/mautrix/crypto/goolm/aessha2" "maunium.net/go/mautrix/crypto/goolm/crypto" + "maunium.net/go/mautrix/crypto/goolm/goolmbase64" ) // Encryption is used to encrypt pk messages @@ -36,14 +36,14 @@ func (e Encryption) Encrypt(plaintext []byte, privateKey crypto.Curve25519Privat if err != nil { return nil, nil, err } - cipher := cipher.NewAESSHA256(nil) - ciphertext, err = cipher.Encrypt(sharedSecret, plaintext) + cipher, err := aessha2.NewAESSHA2(sharedSecret, nil) if err != nil { return nil, nil, err } - mac, err = cipher.MAC(sharedSecret, ciphertext) + ciphertext, err = cipher.Encrypt(plaintext) if err != nil { return nil, nil, err } - return ciphertext, goolm.Base64Encode(mac), nil + mac, err = cipher.MAC(ciphertext) + return ciphertext, goolmbase64.Encode(mac), err } diff --git a/crypto/goolm/pk/pk_test.go b/crypto/goolm/pk/pk_test.go index 7ac524be..4b247430 100644 --- a/crypto/goolm/pk/pk_test.go +++ b/crypto/goolm/pk/pk_test.go @@ -1,14 +1,13 @@ package pk_test import ( - "bytes" "encoding/base64" "testing" - "maunium.net/go/mautrix/crypto/goolm" + "github.com/stretchr/testify/assert" + "maunium.net/go/mautrix/crypto/goolm/crypto" "maunium.net/go/mautrix/crypto/goolm/pk" - "maunium.net/go/mautrix/id" ) func TestEncryptionDecryption(t *testing.T) { @@ -27,34 +26,20 @@ func TestEncryptionDecryption(t *testing.T) { } bobPublic := []byte("3p7bfXt9wbTTW2HC7OQ1Nz+DQ8hbeGdNrfx+FG+IK08") decryption, err := pk.NewDecryptionFromPrivate(alicePrivate) - if err != nil { - t.Fatal(err) - } - if !bytes.Equal([]byte(decryption.PublicKey()), alicePublic) { - t.Fatal("public key not correct") - } - if !bytes.Equal(decryption.PrivateKey(), alicePrivate) { - t.Fatal("private key not correct") - } + assert.NoError(t, err) + assert.EqualValues(t, alicePublic, decryption.PublicKey(), "public key not correct") + assert.EqualValues(t, alicePrivate, decryption.PrivateKey(), "private key not correct") encryption, err := pk.NewEncryption(decryption.PublicKey()) - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err) plaintext := []byte("This is a test") ciphertext, mac, err := encryption.Encrypt(plaintext, bobPrivate) - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err) - decrypted, err := decryption.Decrypt(ciphertext, mac, id.Curve25519(bobPublic)) - if err != nil { - t.Fatal(err) - } - if !bytes.Equal(decrypted, plaintext) { - t.Fatal("message not equal") - } + decrypted, err := decryption.Decrypt(bobPublic, mac, ciphertext) + assert.NoError(t, err) + assert.EqualValues(t, plaintext, decrypted, "message not equal") } func TestSigning(t *testing.T) { @@ -67,29 +52,20 @@ func TestSigning(t *testing.T) { message := []byte("We hold these truths to be self-evident, that all men are created equal, that they are endowed by their Creator with certain unalienable Rights, that among these are Life, Liberty and the pursuit of Happiness.") signing, _ := pk.NewSigningFromSeed(seed) signature, err := signing.Sign(message) - if err != nil { - t.Fatal(err) - } - signatureDecoded, err := goolm.Base64Decode(signature) - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err) + signatureDecoded, err := base64.RawStdEncoding.DecodeString(string(signature)) + assert.NoError(t, err) pubKeyEncoded := signing.PublicKey() pubKeyDecoded, err := base64.RawStdEncoding.DecodeString(string(pubKeyEncoded)) - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err) pubKey := crypto.Ed25519PublicKey(pubKeyDecoded) verified := pubKey.Verify(message, signatureDecoded) - if !verified { - t.Fatal("signature did not verify") - } + assert.True(t, verified, "signature did not verify") + copy(signatureDecoded[0:], []byte("m")) verified = pubKey.Verify(message, signatureDecoded) - if verified { - t.Fatal("signature did verify") - } + assert.False(t, verified, "signature verified with wrong message") } func TestDecryptionPickling(t *testing.T) { @@ -101,37 +77,19 @@ func TestDecryptionPickling(t *testing.T) { } alicePublic := []byte("hSDwCYkwp1R0i33ctD73Wg2/Og0mOBr066SpjqqbTmo") decryption, err := pk.NewDecryptionFromPrivate(alicePrivate) - if err != nil { - t.Fatal(err) - } - if !bytes.Equal([]byte(decryption.PublicKey()), alicePublic) { - t.Fatal("public key not correct") - } - if !bytes.Equal(decryption.PrivateKey(), alicePrivate) { - t.Fatal("private key not correct") - } + assert.NoError(t, err) + assert.EqualValues(t, alicePublic, decryption.PublicKey(), "public key not correct") + assert.EqualValues(t, alicePrivate, decryption.PrivateKey(), "private key not correct") pickleKey := []byte("secret_key") expectedPickle := []byte("qx37WTQrjZLz5tId/uBX9B3/okqAbV1ofl9UnHKno1eipByCpXleAAlAZoJgYnCDOQZDQWzo3luTSfkF9pU1mOILCbbouubs6TVeDyPfgGD9i86J8irHjA") pickled, err := decryption.Pickle(pickleKey) - if err != nil { - t.Fatal(err) - } - if !bytes.Equal(expectedPickle, pickled) { - t.Fatalf("pickle not as expected:\n%v\n%v\n", pickled, expectedPickle) - } + assert.NoError(t, err) + assert.EqualValues(t, expectedPickle, pickled, "pickle not as expected") newDecription, err := pk.NewDecryption() - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err) err = newDecription.Unpickle(pickled, pickleKey) - if err != nil { - t.Fatal(err) - } - if !bytes.Equal([]byte(newDecription.PublicKey()), alicePublic) { - t.Fatal("public key not correct") - } - if !bytes.Equal(newDecription.PrivateKey(), alicePrivate) { - t.Fatal("private key not correct") - } + assert.NoError(t, err) + assert.EqualValues(t, alicePublic, newDecription.PublicKey(), "public key not correct") + assert.EqualValues(t, alicePrivate, newDecription.PrivateKey(), "private key not correct") } diff --git a/crypto/goolm/pk/register.go b/crypto/goolm/pk/register.go new file mode 100644 index 00000000..0e27b568 --- /dev/null +++ b/crypto/goolm/pk/register.go @@ -0,0 +1,21 @@ +// Copyright (c) 2024 Sumner Evans +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package pk + +import "maunium.net/go/mautrix/crypto/olm" + +func Register() { + olm.InitNewPKSigningFromSeed = func(seed []byte) (olm.PKSigning, error) { + return NewSigningFromSeed(seed) + } + olm.InitNewPKSigning = func() (olm.PKSigning, error) { + return NewSigning() + } + olm.InitNewPKDecryptionFromPrivateKey = func(privateKey []byte) (olm.PKDecryption, error) { + return NewDecryptionFromPrivate(privateKey) + } +} diff --git a/crypto/goolm/pk/signing.go b/crypto/goolm/pk/signing.go index a98330d5..61b31b6f 100644 --- a/crypto/goolm/pk/signing.go +++ b/crypto/goolm/pk/signing.go @@ -7,8 +7,8 @@ import ( "github.com/tidwall/sjson" "maunium.net/go/mautrix/crypto/canonicaljson" - "maunium.net/go/mautrix/crypto/goolm" "maunium.net/go/mautrix/crypto/goolm/crypto" + "maunium.net/go/mautrix/crypto/goolm/goolmbase64" "maunium.net/go/mautrix/id" ) @@ -48,8 +48,8 @@ func (s Signing) PublicKey() id.Ed25519 { // Sign returns the signature of the message base64 encoded. func (s Signing) Sign(message []byte) ([]byte, error) { - signature := s.keyPair.Sign(message) - return goolm.Base64Encode(signature), nil + signature, err := s.keyPair.Sign(message) + return goolmbase64.Encode(signature), err } // SignJSON creates a signature for the given object after encoding it to @@ -62,8 +62,5 @@ func (s Signing) SignJSON(obj any) (string, error) { objJSON, _ = sjson.DeleteBytes(objJSON, "unsigned") objJSON, _ = sjson.DeleteBytes(objJSON, "signatures") signature, err := s.Sign(canonicaljson.CanonicalJSONAssumeValid(objJSON)) - if err != nil { - return "", err - } - return string(signature), nil + return string(signature), err } diff --git a/crypto/goolm/ratchet/chain.go b/crypto/goolm/ratchet/chain.go new file mode 100644 index 00000000..5deb90f5 --- /dev/null +++ b/crypto/goolm/ratchet/chain.go @@ -0,0 +1,170 @@ +package ratchet + +import ( + "crypto/hmac" + "crypto/sha256" + + "maunium.net/go/mautrix/crypto/goolm/crypto" + "maunium.net/go/mautrix/crypto/goolm/libolmpickle" +) + +const ( + chainKeySeed = 0x02 + messageKeyLength = 32 +) + +// chainKey wraps the index and the public key +type chainKey struct { + Index uint32 `json:"index"` + Key crypto.Curve25519PublicKey `json:"key"` +} + +// advance advances the chain +func (c *chainKey) advance() { + hash := hmac.New(sha256.New, c.Key) + hash.Write([]byte{chainKeySeed}) + c.Key = hash.Sum(nil) + c.Index++ +} + +// UnpickleLibOlm unpickles the unencryted value and populates the chain key accordingly. +func (r *chainKey) UnpickleLibOlm(decoder *libolmpickle.Decoder) error { + err := r.Key.UnpickleLibOlm(decoder) + if err != nil { + return err + } + r.Index, err = decoder.ReadUInt32() + return err +} + +// PickleLibOlm pickles the chain key into the encoder. +func (r chainKey) PickleLibOlm(encoder *libolmpickle.Encoder) { + r.Key.PickleLibOlm(encoder) + encoder.WriteUInt32(r.Index) +} + +// senderChain is a chain for sending messages +type senderChain struct { + RKey crypto.Curve25519KeyPair `json:"ratchet_key"` + CKey chainKey `json:"chain_key"` + IsSet bool `json:"set"` +} + +// newSenderChain returns a sender chain initialized with chainKey and ratchet key pair. +func newSenderChain(key crypto.Curve25519PublicKey, ratchet crypto.Curve25519KeyPair) *senderChain { + return &senderChain{ + RKey: ratchet, + CKey: chainKey{ + Index: 0, + Key: key, + }, + IsSet: true, + } +} + +// advance advances the chain +func (s *senderChain) advance() { + s.CKey.advance() +} + +// ratchetKey returns the ratchet key pair. +func (s senderChain) ratchetKey() crypto.Curve25519KeyPair { + return s.RKey +} + +// chainKey returns the current chainKey. +func (s senderChain) chainKey() chainKey { + return s.CKey +} + +// UnpickleLibOlm unpickles the unencryted value and populates the sender chain +// accordingly. +func (r *senderChain) UnpickleLibOlm(decoder *libolmpickle.Decoder) error { + if err := r.RKey.UnpickleLibOlm(decoder); err != nil { + return err + } + return r.CKey.UnpickleLibOlm(decoder) +} + +// PickleLibOlm pickles the sender chain into the encoder. +func (r senderChain) PickleLibOlm(encoder *libolmpickle.Encoder) { + if r.IsSet { + encoder.WriteUInt32(1) // Length of the sender chain (1 if set) + r.RKey.PickleLibOlm(encoder) + r.CKey.PickleLibOlm(encoder) + } else { + encoder.WriteUInt32(0) + } +} + +// senderChain is a chain for receiving messages +type receiverChain struct { + RKey crypto.Curve25519PublicKey `json:"ratchet_key"` + CKey chainKey `json:"chain_key"` +} + +// newReceiverChain returns a receiver chain initialized with chainKey and ratchet public key. +func newReceiverChain(chain crypto.Curve25519PublicKey, ratchet crypto.Curve25519PublicKey) *receiverChain { + return &receiverChain{ + RKey: ratchet, + CKey: chainKey{ + Index: 0, + Key: chain, + }, + } +} + +// advance advances the chain +func (s *receiverChain) advance() { + s.CKey.advance() +} + +// ratchetKey returns the ratchet public key. +func (s receiverChain) ratchetKey() crypto.Curve25519PublicKey { + return s.RKey +} + +// chainKey returns the current chainKey. +func (s receiverChain) chainKey() chainKey { + return s.CKey +} + +// UnpickleLibOlm unpickles the unencryted value and populates the chain accordingly. +func (r *receiverChain) UnpickleLibOlm(decoder *libolmpickle.Decoder) error { + if err := r.RKey.UnpickleLibOlm(decoder); err != nil { + return err + } + return r.CKey.UnpickleLibOlm(decoder) +} + +// PickleLibOlm pickles the receiver chain into the encoder. +func (r receiverChain) PickleLibOlm(encoder *libolmpickle.Encoder) { + r.RKey.PickleLibOlm(encoder) + r.CKey.PickleLibOlm(encoder) +} + +// messageKey wraps the index and the key of a message +type messageKey struct { + Index uint32 `json:"index"` + Key []byte `json:"key"` +} + +// UnpickleLibOlm unpickles the unencryted value and populates the message key +// accordingly. +func (m *messageKey) UnpickleLibOlm(decoder *libolmpickle.Decoder) (err error) { + if m.Key, err = decoder.ReadBytes(messageKeyLength); err != nil { + return + } + m.Index, err = decoder.ReadUInt32() + return +} + +// PickleLibOlm pickles the message key into the encoder. +func (m messageKey) PickleLibOlm(encoder *libolmpickle.Encoder) { + if len(m.Key) == messageKeyLength { + encoder.Write(m.Key) + } else { + encoder.WriteEmptyBytes(messageKeyLength) + } + encoder.WriteUInt32(m.Index) +} diff --git a/crypto/goolm/olm/olm.go b/crypto/goolm/ratchet/olm.go similarity index 51% rename from crypto/goolm/olm/olm.go rename to crypto/goolm/ratchet/olm.go index 299ec7c4..9901ada8 100644 --- a/crypto/goolm/olm/olm.go +++ b/crypto/goolm/ratchet/olm.go @@ -1,16 +1,19 @@ -// olm provides the ratchet used by the olm protocol -package olm +// Package ratchet provides the ratchet used by the olm protocol +package ratchet import ( + "crypto/hmac" + "crypto/sha256" "fmt" "io" - "maunium.net/go/mautrix/crypto/goolm" - "maunium.net/go/mautrix/crypto/goolm/cipher" + "golang.org/x/crypto/hkdf" + + "maunium.net/go/mautrix/crypto/goolm/aessha2" "maunium.net/go/mautrix/crypto/goolm/crypto" "maunium.net/go/mautrix/crypto/goolm/libolmpickle" "maunium.net/go/mautrix/crypto/goolm/message" - "maunium.net/go/mautrix/crypto/goolm/utilities" + "maunium.net/go/mautrix/crypto/olm" ) const ( @@ -27,6 +30,8 @@ const ( sharedKeyLength = 32 ) +var olmKeysKDFInfo = []byte("OLM_KEYS") + // KdfInfo has the infos used for the kdf var KdfInfo = struct { Root []byte @@ -36,8 +41,6 @@ var KdfInfo = struct { Ratchet: []byte("OLM_RATCHET"), } -var RatchetCipher = cipher.NewAESSHA256([]byte("OLM_KEYS")) - // Ratchet represents the olm ratchet as described in // // https://gitlab.matrix.org/matrix-org/olm/-/blob/master/docs/olm.md @@ -64,13 +67,12 @@ type Ratchet struct { // New creates a new ratchet, setting the kdfInfos and cipher. func New() *Ratchet { - r := &Ratchet{} - return r + return &Ratchet{} } // InitializeAsBob initializes this ratchet from a receiving point of view (only first message). func (r *Ratchet) InitializeAsBob(sharedSecret []byte, theirRatchetKey crypto.Curve25519PublicKey) error { - derivedSecretsReader := crypto.HKDFSHA256(sharedSecret, nil, KdfInfo.Root) + derivedSecretsReader := hkdf.New(sha256.New, sharedSecret, nil, KdfInfo.Root) derivedSecrets := make([]byte, 2*sharedKeyLength) if _, err := io.ReadFull(derivedSecretsReader, derivedSecrets); err != nil { return err @@ -83,7 +85,7 @@ func (r *Ratchet) InitializeAsBob(sharedSecret []byte, theirRatchetKey crypto.Cu // InitializeAsAlice initializes this ratchet from a sending point of view (only first message). func (r *Ratchet) InitializeAsAlice(sharedSecret []byte, ourRatchetKey crypto.Curve25519KeyPair) error { - derivedSecretsReader := crypto.HKDFSHA256(sharedSecret, nil, KdfInfo.Root) + derivedSecretsReader := hkdf.New(sha256.New, sharedSecret, nil, KdfInfo.Root) derivedSecrets := make([]byte, 2*sharedKeyLength) if _, err := io.ReadFull(derivedSecretsReader, derivedSecrets); err != nil { return err @@ -94,11 +96,11 @@ func (r *Ratchet) InitializeAsAlice(sharedSecret []byte, ourRatchetKey crypto.Cu return nil } -// Encrypt encrypts the message in a message.Message with MAC. If reader is nil, crypto/rand is used for key generations. -func (r *Ratchet) Encrypt(plaintext []byte, reader io.Reader) ([]byte, error) { +// Encrypt encrypts the message in a message.Message with MAC. +func (r *Ratchet) Encrypt(plaintext []byte) ([]byte, error) { var err error if !r.SenderChains.IsSet { - newRatchetKey, err := crypto.Curve25519GenerateKey(reader) + newRatchetKey, err := crypto.Curve25519GenerateKey() if err != nil { return nil, err } @@ -113,7 +115,11 @@ func (r *Ratchet) Encrypt(plaintext []byte, reader io.Reader) ([]byte, error) { messageKey := r.createMessageKeys(r.SenderChains.chainKey()) r.SenderChains.advance() - encryptedText, err := RatchetCipher.Encrypt(messageKey.Key, plaintext) + cipher, err := aessha2.NewAESSHA2(messageKey.Key, olmKeysKDFInfo) + if err != nil { + return nil, err + } + encryptedText, err := cipher.Encrypt(plaintext) if err != nil { return nil, fmt.Errorf("cipher encrypt: %w", err) } @@ -124,15 +130,10 @@ func (r *Ratchet) Encrypt(plaintext []byte, reader io.Reader) ([]byte, error) { message.RatchetKey = r.SenderChains.ratchetKey().PublicKey message.Ciphertext = encryptedText //creating the mac is done in encode - output, err := message.EncodeAndMAC(messageKey.Key, RatchetCipher) - if err != nil { - return nil, err - } - - return output, nil + return message.EncodeAndMAC(cipher) } -// Decrypt decrypts the ciphertext and verifies the MAC. If reader is nil, crypto/rand is used for key generations. +// Decrypt decrypts the ciphertext and verifies the MAC. func (r *Ratchet) Decrypt(input []byte) ([]byte, error) { message := &message.Message{} //The mac is not verified here, as we do not know the key yet @@ -141,10 +142,10 @@ func (r *Ratchet) Decrypt(input []byte) ([]byte, error) { return nil, err } if message.Version != protocolVersion { - return nil, fmt.Errorf("decrypt: %w", goolm.ErrWrongProtocolVersion) + return nil, fmt.Errorf("decrypt: %w (got %d, expected %d)", olm.ErrWrongProtocolVersion, message.Version, protocolVersion) } if !message.HasCounter || len(message.RatchetKey) == 0 || len(message.Ciphertext) == 0 { - return nil, fmt.Errorf("decrypt: %w", goolm.ErrBadMessageFormat) + return nil, fmt.Errorf("decrypt: %w", olm.ErrBadMessageFormat) } var receiverChainFromMessage *receiverChain for curChainIndex := range r.ReceiverChains { @@ -153,53 +154,40 @@ func (r *Ratchet) Decrypt(input []byte) ([]byte, error) { break } } - var result []byte if receiverChainFromMessage == nil { //Advancing the chain is done in this method - result, err = r.decryptForNewChain(message, input) - if err != nil { - return nil, err - } + return r.decryptForNewChain(message, input) } else if receiverChainFromMessage.chainKey().Index > message.Counter { // No need to advance the chain // Chain already advanced beyond the key for this message // Check if the message keys are in the skipped key list. - foundSkippedKey := false for curSkippedIndex := range r.SkippedMessageKeys { - if message.Counter == r.SkippedMessageKeys[curSkippedIndex].MKey.Index { - // Found the key for this message. Check the MAC. - verified, err := message.VerifyMACInline(r.SkippedMessageKeys[curSkippedIndex].MKey.Key, RatchetCipher, input) - if err != nil { - return nil, err - } - if !verified { - return nil, fmt.Errorf("decrypt from skipped message keys: %w", goolm.ErrBadMAC) - } - result, err = RatchetCipher.Decrypt(r.SkippedMessageKeys[curSkippedIndex].MKey.Key, message.Ciphertext) - if err != nil { - return nil, fmt.Errorf("cipher decrypt: %w", err) - } - if len(result) != 0 { - // Remove the key from the skipped keys now that we've - // decoded the message it corresponds to. - r.SkippedMessageKeys[curSkippedIndex] = r.SkippedMessageKeys[len(r.SkippedMessageKeys)-1] - r.SkippedMessageKeys = r.SkippedMessageKeys[:len(r.SkippedMessageKeys)-1] - } - foundSkippedKey = true + if message.Counter != r.SkippedMessageKeys[curSkippedIndex].MKey.Index { + continue + } + + // Found the key for this message. Check the MAC. + if cipher, err := aessha2.NewAESSHA2(r.SkippedMessageKeys[curSkippedIndex].MKey.Key, olmKeysKDFInfo); err != nil { + return nil, err + } else if verified, err := message.VerifyMACInline(r.SkippedMessageKeys[curSkippedIndex].MKey.Key, cipher, input); err != nil { + return nil, err + } else if !verified { + return nil, fmt.Errorf("decrypt from skipped message keys: %w", olm.ErrBadMAC) + } else if result, err := cipher.Decrypt(message.Ciphertext); err != nil { + return nil, fmt.Errorf("cipher decrypt: %w", err) + } else if len(result) != 0 { + // Remove the key from the skipped keys now that we've + // decoded the message it corresponds to. + r.SkippedMessageKeys[curSkippedIndex] = r.SkippedMessageKeys[len(r.SkippedMessageKeys)-1] + r.SkippedMessageKeys = r.SkippedMessageKeys[:len(r.SkippedMessageKeys)-1] + return result, nil } } - if !foundSkippedKey { - return nil, fmt.Errorf("decrypt: %w", goolm.ErrMessageKeyNotFound) - } + return nil, fmt.Errorf("decrypt: %w", olm.ErrMessageKeyNotFound) } else { //Advancing the chain is done in this method - result, err = r.decryptForExistingChain(receiverChainFromMessage, message, input) - if err != nil { - return nil, err - } + return r.decryptForExistingChain(receiverChainFromMessage, message, input) } - - return result, nil } // advanceRootKey created the next root key and returns the next chainKey @@ -208,7 +196,7 @@ func (r *Ratchet) advanceRootKey(newRatchetKey crypto.Curve25519KeyPair, oldRatc if err != nil { return nil, err } - derivedSecretsReader := crypto.HKDFSHA256(sharedSecret, r.RootKey, KdfInfo.Ratchet) + derivedSecretsReader := hkdf.New(sha256.New, sharedSecret, r.RootKey, KdfInfo.Ratchet) derivedSecrets := make([]byte, 2*sharedKeyLength) if _, err := io.ReadFull(derivedSecretsReader, derivedSecrets); err != nil { return nil, err @@ -219,20 +207,22 @@ func (r *Ratchet) advanceRootKey(newRatchetKey crypto.Curve25519KeyPair, oldRatc // createMessageKeys returns the messageKey derived from the chainKey func (r Ratchet) createMessageKeys(chainKey chainKey) messageKey { - res := messageKey{} - res.Key = crypto.HMACSHA256(chainKey.Key, []byte{messageKeySeed}) - res.Index = chainKey.Index - return res + hash := hmac.New(sha256.New, chainKey.Key) + hash.Write([]byte{messageKeySeed}) + return messageKey{ + Key: hash.Sum(nil), + Index: chainKey.Index, + } } // decryptForExistingChain returns the decrypted message by using the chain. The MAC of the rawMessage is verified. func (r *Ratchet) decryptForExistingChain(chain *receiverChain, message *message.Message, rawMessage []byte) ([]byte, error) { if message.Counter < chain.CKey.Index { - return nil, fmt.Errorf("decrypt: %w", goolm.ErrChainTooHigh) + return nil, fmt.Errorf("decrypt: %w", olm.ErrChainTooHigh) } // Limit the number of hashes we're prepared to compute if message.Counter-chain.CKey.Index > maxMessageGap { - return nil, fmt.Errorf("decrypt from existing chain: %w", goolm.ErrMsgIndexTooHigh) + return nil, fmt.Errorf("decrypt from existing chain: %w", olm.ErrMsgIndexTooHigh) } for chain.CKey.Index < message.Counter { messageKey := r.createMessageKeys(chain.chainKey()) @@ -245,14 +235,18 @@ func (r *Ratchet) decryptForExistingChain(chain *receiverChain, message *message } messageKey := r.createMessageKeys(chain.chainKey()) chain.advance() - verified, err := message.VerifyMACInline(messageKey.Key, RatchetCipher, rawMessage) + cipher, err := aessha2.NewAESSHA2(messageKey.Key, olmKeysKDFInfo) + if err != nil { + return nil, err + } + verified, err := message.VerifyMACInline(messageKey.Key, cipher, rawMessage) if err != nil { return nil, err } if !verified { - return nil, fmt.Errorf("decrypt from existing chain: %w", goolm.ErrBadMAC) + return nil, fmt.Errorf("decrypt from existing chain: %w", olm.ErrBadMAC) } - return RatchetCipher.Decrypt(messageKey.Key, message.Ciphertext) + return cipher.Decrypt(message.Ciphertext) } // decryptForNewChain returns the decrypted message by creating a new chain and advancing the root key. @@ -260,11 +254,11 @@ func (r *Ratchet) decryptForNewChain(message *message.Message, rawMessage []byte // They shouldn't move to a new chain until we've sent them a message // acknowledging the last one if !r.SenderChains.IsSet { - return nil, fmt.Errorf("decrypt for new chain: %w", goolm.ErrProtocolViolation) + return nil, fmt.Errorf("decrypt for new chain: %w", olm.ErrProtocolViolation) } // Limit the number of hashes we're prepared to compute if message.Counter > maxMessageGap { - return nil, fmt.Errorf("decrypt for new chain: %w", goolm.ErrMsgIndexTooHigh) + return nil, fmt.Errorf("decrypt for new chain: %w", olm.ErrMsgIndexTooHigh) } newChainKey, err := r.advanceRootKey(r.SenderChains.ratchetKey(), message.RatchetKey) @@ -281,152 +275,88 @@ func (r *Ratchet) decryptForNewChain(message *message.Message, rawMessage []byte */ r.SenderChains = senderChain{} - decrypted, err := r.decryptForExistingChain(&r.ReceiverChains[0], message, rawMessage) - if err != nil { - return nil, err - } - return decrypted, nil + return r.decryptForExistingChain(&r.ReceiverChains[0], message, rawMessage) } // PickleAsJSON returns a ratchet as a base64 string encrypted using the supplied key. The unencrypted representation of the Account is in JSON format. func (r Ratchet) PickleAsJSON(key []byte) ([]byte, error) { - return utilities.PickleAsJSON(r, olmPickleVersion, key) + return libolmpickle.PickleAsJSON(r, olmPickleVersion, key) } // UnpickleAsJSON updates a ratchet by a base64 encrypted string using the supplied key. The unencrypted representation has to be in JSON format. func (r *Ratchet) UnpickleAsJSON(pickled, key []byte) error { - return utilities.UnpickleAsJSON(r, pickled, key, olmPickleVersion) + return libolmpickle.UnpickleAsJSON(r, pickled, key, olmPickleVersion) } -// UnpickleLibOlm decodes the unencryted value and populates the Ratchet accordingly. It returns the number of bytes read. -func (r *Ratchet) UnpickleLibOlm(value []byte, includesChainIndex bool) (int, error) { - //read ratchet data - curPos := 0 - readBytes, err := r.RootKey.UnpickleLibOlm(value) - if err != nil { - return 0, err +// UnpickleLibOlm unpickles the unencryted value and populates the [Ratchet] +// accordingly. +func (r *Ratchet) UnpickleLibOlm(decoder *libolmpickle.Decoder, includesChainIndex bool) error { + if err := r.RootKey.UnpickleLibOlm(decoder); err != nil { + return err } - curPos += readBytes - countSenderChains, readBytes, err := libolmpickle.UnpickleUInt32(value[curPos:]) //Length of sender chain + senderChainsCount, err := decoder.ReadUInt32() if err != nil { - return 0, err + return err } - curPos += readBytes - for i := uint32(0); i < countSenderChains; i++ { + + for i := uint32(0); i < senderChainsCount; i++ { if i == 0 { - //only first is stored - readBytes, err := r.SenderChains.UnpickleLibOlm(value[curPos:]) - if err != nil { - return 0, err - } - curPos += readBytes + // only the first sender key is stored + err = r.SenderChains.UnpickleLibOlm(decoder) r.SenderChains.IsSet = true } else { - dummy := senderChain{} - readBytes, err := dummy.UnpickleLibOlm(value[curPos:]) - if err != nil { - return 0, err - } - curPos += readBytes + // just eat the values + err = (&senderChain{}).UnpickleLibOlm(decoder) } - } - countReceivChains, readBytes, err := libolmpickle.UnpickleUInt32(value[curPos:]) //Length of recevier chain - if err != nil { - return 0, err - } - curPos += readBytes - r.ReceiverChains = make([]receiverChain, countReceivChains) - for i := uint32(0); i < countReceivChains; i++ { - readBytes, err := r.ReceiverChains[i].UnpickleLibOlm(value[curPos:]) if err != nil { - return 0, err + return err } - curPos += readBytes } - countSkippedMessageKeys, readBytes, err := libolmpickle.UnpickleUInt32(value[curPos:]) //Length of skippedMessageKeys + + receiverChainCount, err := decoder.ReadUInt32() if err != nil { - return 0, err + return err } - curPos += readBytes - r.SkippedMessageKeys = make([]skippedMessageKey, countSkippedMessageKeys) - for i := uint32(0); i < countSkippedMessageKeys; i++ { - readBytes, err := r.SkippedMessageKeys[i].UnpickleLibOlm(value[curPos:]) - if err != nil { - return 0, err + r.ReceiverChains = make([]receiverChain, receiverChainCount) + for i := uint32(0); i < receiverChainCount; i++ { + if err := r.ReceiverChains[i].UnpickleLibOlm(decoder); err != nil { + return err } - curPos += readBytes } - // pickle v 0x80000001 includes a chain index; pickle v1 does not. + + skippedMessageKeysCount, err := decoder.ReadUInt32() + if err != nil { + return err + } + r.SkippedMessageKeys = make([]skippedMessageKey, skippedMessageKeysCount) + for i := uint32(0); i < skippedMessageKeysCount; i++ { + if err := r.SkippedMessageKeys[i].UnpickleLibOlm(decoder); err != nil { + return err + } + } + + // pickle version 0x80000001 includes a chain index; pickle version 1 does not. if includesChainIndex { - _, readBytes, err := libolmpickle.UnpickleUInt32(value[curPos:]) - if err != nil { - return 0, err - } - curPos += readBytes + _, err = decoder.ReadUInt32() + return err } - return curPos, nil + return nil } -// PickleLibOlm encodes the ratchet into target. target has to have a size of at least PickleLen() and is written to from index 0. -// It returns the number of bytes written. -func (r Ratchet) PickleLibOlm(target []byte) (int, error) { - if len(target) < r.PickleLen() { - return 0, fmt.Errorf("pickle ratchet: %w", goolm.ErrValueTooShort) - } - written, err := r.RootKey.PickleLibOlm(target) - if err != nil { - return 0, fmt.Errorf("pickle ratchet: %w", err) - } - if r.SenderChains.IsSet { - written += libolmpickle.PickleUInt32(1, target[written:]) //Length of sender chain, always 1 - writtenSender, err := r.SenderChains.PickleLibOlm(target[written:]) - if err != nil { - return 0, fmt.Errorf("pickle ratchet: %w", err) - } - written += writtenSender - } else { - written += libolmpickle.PickleUInt32(0, target[written:]) //Length of sender chain - } - written += libolmpickle.PickleUInt32(uint32(len(r.ReceiverChains)), target[written:]) +// PickleLibOlm pickles the ratchet into the encoder. +func (r Ratchet) PickleLibOlm(encoder *libolmpickle.Encoder) { + r.RootKey.PickleLibOlm(encoder) + r.SenderChains.PickleLibOlm(encoder) + + // Receiver Chains + encoder.WriteUInt32(uint32(len(r.ReceiverChains))) for _, curChain := range r.ReceiverChains { - writtenChain, err := curChain.PickleLibOlm(target[written:]) - if err != nil { - return 0, fmt.Errorf("pickle ratchet: %w", err) - } - written += writtenChain + curChain.PickleLibOlm(encoder) } - written += libolmpickle.PickleUInt32(uint32(len(r.SkippedMessageKeys)), target[written:]) + + // Skipped Message Keys + encoder.WriteUInt32(uint32(len(r.SkippedMessageKeys))) for _, curChain := range r.SkippedMessageKeys { - writtenChain, err := curChain.PickleLibOlm(target[written:]) - if err != nil { - return 0, fmt.Errorf("pickle ratchet: %w", err) - } - written += writtenChain + curChain.PickleLibOlm(encoder) } - return written, nil -} - -// PickleLen returns the actual number of bytes the pickled ratchet will have. -func (r Ratchet) PickleLen() int { - length := r.RootKey.PickleLen() - if r.SenderChains.IsSet { - length += libolmpickle.PickleUInt32Len(1) - length += r.SenderChains.PickleLen() - } else { - length += libolmpickle.PickleUInt32Len(0) - } - length += libolmpickle.PickleUInt32Len(uint32(len(r.ReceiverChains))) - length += len(r.ReceiverChains) * receiverChain{}.PickleLen() - length += libolmpickle.PickleUInt32Len(uint32(len(r.SkippedMessageKeys))) - length += len(r.SkippedMessageKeys) * skippedMessageKey{}.PickleLen() - return length -} - -// PickleLen returns the minimum number of bytes the pickled ratchet must have. -func (r Ratchet) PickleLenMin() int { - length := r.RootKey.PickleLen() - length += libolmpickle.PickleUInt32Len(0) - length += libolmpickle.PickleUInt32Len(0) - length += libolmpickle.PickleUInt32Len(0) - return length } diff --git a/crypto/goolm/ratchet/olm_test.go b/crypto/goolm/ratchet/olm_test.go new file mode 100644 index 00000000..2bf7ea0a --- /dev/null +++ b/crypto/goolm/ratchet/olm_test.go @@ -0,0 +1,126 @@ +package ratchet_test + +import ( + "encoding/json" + "testing" + + "github.com/stretchr/testify/assert" + + "maunium.net/go/mautrix/crypto/goolm/crypto" + "maunium.net/go/mautrix/crypto/goolm/ratchet" +) + +var ( + sharedSecret = []byte("A secret") +) + +func initializeRatchets() (*ratchet.Ratchet, *ratchet.Ratchet, error) { + ratchet.KdfInfo = struct { + Root []byte + Ratchet []byte + }{ + Root: []byte("Olm"), + Ratchet: []byte("OlmRatchet"), + } + aliceRatchet := ratchet.New() + bobRatchet := ratchet.New() + + aliceKey, err := crypto.Curve25519GenerateKey() + if err != nil { + return nil, nil, err + } + + aliceRatchet.InitializeAsAlice(sharedSecret, aliceKey) + bobRatchet.InitializeAsBob(sharedSecret, aliceKey.PublicKey) + return aliceRatchet, bobRatchet, nil +} + +func TestSendReceive(t *testing.T) { + aliceRatchet, bobRatchet, err := initializeRatchets() + assert.NoError(t, err) + + plainText := []byte("Hello Bob") + + //Alice sends Bob a message + encryptedMessage, err := aliceRatchet.Encrypt(plainText) + assert.NoError(t, err) + + decrypted, err := bobRatchet.Decrypt(encryptedMessage) + assert.NoError(t, err) + assert.Equal(t, plainText, decrypted) + + //Bob sends Alice a message + plainText = []byte("Hello Alice") + encryptedMessage, err = bobRatchet.Encrypt(plainText) + assert.NoError(t, err) + decrypted, err = aliceRatchet.Decrypt(encryptedMessage) + assert.NoError(t, err) + assert.Equal(t, plainText, decrypted) +} + +func TestOutOfOrder(t *testing.T) { + aliceRatchet, bobRatchet, err := initializeRatchets() + assert.NoError(t, err) + + plainText1 := []byte("First Message") + plainText2 := []byte("Second Messsage. A bit longer than the first.") + + /* Alice sends Bob two messages and they arrive out of order */ + message1Encrypted, err := aliceRatchet.Encrypt(plainText1) + assert.NoError(t, err) + message2Encrypted, err := aliceRatchet.Encrypt(plainText2) + assert.NoError(t, err) + + decrypted2, err := bobRatchet.Decrypt(message2Encrypted) + assert.NoError(t, err) + decrypted1, err := bobRatchet.Decrypt(message1Encrypted) + assert.NoError(t, err) + assert.Equal(t, plainText1, decrypted1) + assert.Equal(t, plainText2, decrypted2) +} + +func TestMoreMessages(t *testing.T) { + aliceRatchet, bobRatchet, err := initializeRatchets() + assert.NoError(t, err) + plainText := []byte("These 15 bytes") + for i := 0; i < 8; i++ { + messageEncrypted, err := aliceRatchet.Encrypt(plainText) + assert.NoError(t, err) + + decrypted, err := bobRatchet.Decrypt(messageEncrypted) + assert.NoError(t, err) + assert.Equal(t, plainText, decrypted) + } + for i := 0; i < 8; i++ { + messageEncrypted, err := bobRatchet.Encrypt(plainText) + assert.NoError(t, err) + + decrypted, err := aliceRatchet.Decrypt(messageEncrypted) + assert.NoError(t, err) + assert.Equal(t, plainText, decrypted) + } + messageEncrypted, err := aliceRatchet.Encrypt(plainText) + assert.NoError(t, err) + decrypted, err := bobRatchet.Decrypt(messageEncrypted) + assert.NoError(t, err) + assert.Equal(t, plainText, decrypted) +} + +func TestJSONEncoding(t *testing.T) { + aliceRatchet, bobRatchet, err := initializeRatchets() + assert.NoError(t, err) + marshaled, err := json.Marshal(aliceRatchet) + assert.NoError(t, err) + + newRatcher := ratchet.Ratchet{} + err = json.Unmarshal(marshaled, &newRatcher) + assert.NoError(t, err) + + plainText := []byte("These 15 bytes") + + messageEncrypted, err := newRatcher.Encrypt(plainText) + assert.NoError(t, err) + decrypted, err := bobRatchet.Decrypt(messageEncrypted) + assert.NoError(t, err) + assert.Equal(t, plainText, decrypted) +} diff --git a/crypto/goolm/ratchet/skipped_message.go b/crypto/goolm/ratchet/skipped_message.go new file mode 100644 index 00000000..2ffaee7b --- /dev/null +++ b/crypto/goolm/ratchet/skipped_message.go @@ -0,0 +1,27 @@ +package ratchet + +import ( + "maunium.net/go/mautrix/crypto/goolm/crypto" + "maunium.net/go/mautrix/crypto/goolm/libolmpickle" +) + +// skippedMessageKey stores a skipped message key +type skippedMessageKey struct { + RKey crypto.Curve25519PublicKey `json:"ratchet_key"` + MKey messageKey `json:"message_key"` +} + +// UnpickleLibOlm unpickles the unencryted value and populates the skipped +// message keys accordingly. +func (r *skippedMessageKey) UnpickleLibOlm(decoder *libolmpickle.Decoder) (err error) { + if err = r.RKey.UnpickleLibOlm(decoder); err != nil { + return + } + return r.MKey.UnpickleLibOlm(decoder) +} + +// PickleLibOlm pickles the skipped message key into the encoder. +func (r skippedMessageKey) PickleLibOlm(encoder *libolmpickle.Encoder) { + r.RKey.PickleLibOlm(encoder) + r.MKey.PickleLibOlm(encoder) +} diff --git a/crypto/goolm/register.go b/crypto/goolm/register.go new file mode 100644 index 00000000..800f567f --- /dev/null +++ b/crypto/goolm/register.go @@ -0,0 +1,29 @@ +// Copyright (c) 2024 Sumner Evans +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package goolm + +import ( + "maunium.net/go/mautrix/crypto/goolm/account" + "maunium.net/go/mautrix/crypto/goolm/pk" + "maunium.net/go/mautrix/crypto/goolm/session" + "maunium.net/go/mautrix/crypto/olm" +) + +func Register() { + olm.Driver = "goolm" + + olm.GetVersion = func() (major, minor, patch uint8) { + return 3, 2, 15 + } + olm.SetPickleKeyImpl = func(key []byte) { + panic("gob and json encoding is deprecated and not supported with goolm") + } + + account.Register() + pk.Register() + session.Register() +} diff --git a/crypto/goolm/session/megolm_inbound_session.go b/crypto/goolm/session/megolm_inbound_session.go index 165f7f16..7ccbd26d 100644 --- a/crypto/goolm/session/megolm_inbound_session.go +++ b/crypto/goolm/session/megolm_inbound_session.go @@ -2,16 +2,14 @@ package session import ( "encoding/base64" - "errors" "fmt" - "maunium.net/go/mautrix/crypto/goolm" - "maunium.net/go/mautrix/crypto/goolm/cipher" "maunium.net/go/mautrix/crypto/goolm/crypto" + "maunium.net/go/mautrix/crypto/goolm/goolmbase64" "maunium.net/go/mautrix/crypto/goolm/libolmpickle" "maunium.net/go/mautrix/crypto/goolm/megolm" "maunium.net/go/mautrix/crypto/goolm/message" - "maunium.net/go/mautrix/crypto/goolm/utilities" + "maunium.net/go/mautrix/crypto/olm" "maunium.net/go/mautrix/id" ) @@ -28,10 +26,14 @@ type MegolmInboundSession struct { SigningKeyVerified bool `json:"signing_key_verified"` //not used for now } +// Ensure that MegolmInboundSession implements the [olm.InboundGroupSession] +// interface. +var _ olm.InboundGroupSession = (*MegolmInboundSession)(nil) + // NewMegolmInboundSession creates a new MegolmInboundSession from a base64 encoded session sharing message. func NewMegolmInboundSession(input []byte) (*MegolmInboundSession, error) { var err error - input, err = goolm.Base64Decode(input) + input, err = goolmbase64.Decode(input) if err != nil { return nil, err } @@ -55,7 +57,7 @@ func NewMegolmInboundSession(input []byte) (*MegolmInboundSession, error) { // NewMegolmInboundSessionFromExport creates a new MegolmInboundSession from a base64 encoded session export message. func NewMegolmInboundSessionFromExport(input []byte) (*MegolmInboundSession, error) { var err error - input, err = goolm.Base64Decode(input) + input, err = goolmbase64.Decode(input) if err != nil { return nil, err } @@ -78,7 +80,7 @@ func NewMegolmInboundSessionFromExport(input []byte) (*MegolmInboundSession, err // MegolmInboundSessionFromPickled loads the MegolmInboundSession details from a pickled base64 string. The input is decrypted with the supplied key. func MegolmInboundSessionFromPickled(pickled, key []byte) (*MegolmInboundSession, error) { if len(pickled) == 0 { - return nil, fmt.Errorf("megolmInboundSessionFromPickled: %w", goolm.ErrEmptyInput) + return nil, fmt.Errorf("megolmInboundSessionFromPickled: %w", olm.ErrEmptyInput) } a := &MegolmInboundSession{} err := a.Unpickle(pickled, key) @@ -89,7 +91,7 @@ func MegolmInboundSessionFromPickled(pickled, key []byte) (*MegolmInboundSession } // getRatchet tries to find the correct ratchet for a messageIndex. -func (o MegolmInboundSession) getRatchet(messageIndex uint32) (*megolm.Ratchet, error) { +func (o *MegolmInboundSession) getRatchet(messageIndex uint32) (*megolm.Ratchet, error) { // pick a megolm instance to use. if we are at or beyond the latest ratchet value, use that if (messageIndex - o.Ratchet.Counter) < uint32(1<<31) { o.Ratchet.AdvanceTo(messageIndex) @@ -97,7 +99,7 @@ func (o MegolmInboundSession) getRatchet(messageIndex uint32) (*megolm.Ratchet, } if (messageIndex - o.InitialRatchet.Counter) >= uint32(1<<31) { // the counter is before our initial ratchet - we can't decode this - return nil, fmt.Errorf("decrypt: %w", goolm.ErrRatchetNotAvailable) + return nil, fmt.Errorf("decrypt: %w", olm.ErrUnknownMessageIndex) } // otherwise, start from the initial ratchet. Take a copy so that we don't overwrite the initial ratchet copiedRatchet := o.InitialRatchet @@ -107,11 +109,14 @@ func (o MegolmInboundSession) getRatchet(messageIndex uint32) (*megolm.Ratchet, } // Decrypt decrypts a base64 encoded group message. -func (o *MegolmInboundSession) Decrypt(ciphertext []byte) ([]byte, uint32, error) { - if o.SigningKey == nil { - return nil, 0, fmt.Errorf("decrypt: %w", goolm.ErrBadMessageFormat) +func (o *MegolmInboundSession) Decrypt(ciphertext []byte) ([]byte, uint, error) { + if len(ciphertext) == 0 { + return nil, 0, olm.ErrEmptyInput } - decoded, err := goolm.Base64Decode(ciphertext) + if o.SigningKey == nil { + return nil, 0, fmt.Errorf("decrypt: %w", olm.ErrBadMessageFormat) + } + decoded, err := goolmbase64.Decode(ciphertext) if err != nil { return nil, 0, err } @@ -121,16 +126,16 @@ func (o *MegolmInboundSession) Decrypt(ciphertext []byte) ([]byte, uint32, error return nil, 0, err } if msg.Version != protocolVersion { - return nil, 0, fmt.Errorf("decrypt: %w", goolm.ErrWrongProtocolVersion) + return nil, 0, fmt.Errorf("decrypt: %w (got %d, expected %d)", olm.ErrWrongProtocolVersion, msg.Version, protocolVersion) } if msg.Ciphertext == nil || !msg.HasMessageIndex { - return nil, 0, fmt.Errorf("decrypt: %w", goolm.ErrBadMessageFormat) + return nil, 0, fmt.Errorf("decrypt: %w", olm.ErrBadMessageFormat) } // verify signature verifiedSignature := msg.VerifySignatureInline(o.SigningKey, decoded) if !verifiedSignature { - return nil, 0, fmt.Errorf("decrypt: %w", goolm.ErrBadSignature) + return nil, 0, fmt.Errorf("decrypt: %w", olm.ErrBadSignature) } targetRatch, err := o.getRatchet(msg.MessageIndex) @@ -143,27 +148,33 @@ func (o *MegolmInboundSession) Decrypt(ciphertext []byte) ([]byte, uint32, error return nil, 0, err } o.SigningKeyVerified = true - return decrypted, msg.MessageIndex, nil + return decrypted, uint(msg.MessageIndex), nil } -// SessionID returns the base64 endoded signing key -func (o MegolmInboundSession) SessionID() id.SessionID { +// ID returns the base64 endoded signing key +func (o *MegolmInboundSession) ID() id.SessionID { return id.SessionID(base64.RawStdEncoding.EncodeToString(o.SigningKey)) } // PickleAsJSON returns an MegolmInboundSession as a base64 string encrypted using the supplied key. The unencrypted representation of the Account is in JSON format. -func (o MegolmInboundSession) PickleAsJSON(key []byte) ([]byte, error) { - return utilities.PickleAsJSON(o, megolmInboundSessionPickleVersionJSON, key) +func (o *MegolmInboundSession) PickleAsJSON(key []byte) ([]byte, error) { + return libolmpickle.PickleAsJSON(o, megolmInboundSessionPickleVersionJSON, key) } // UnpickleAsJSON updates an MegolmInboundSession by a base64 encrypted string using the supplied key. The unencrypted representation has to be in JSON format. func (o *MegolmInboundSession) UnpickleAsJSON(pickled, key []byte) error { - return utilities.UnpickleAsJSON(o, pickled, key, megolmInboundSessionPickleVersionJSON) + return libolmpickle.UnpickleAsJSON(o, pickled, key, megolmInboundSessionPickleVersionJSON) } -// SessionExportMessage creates an base64 encoded export of the session. -func (o MegolmInboundSession) SessionExportMessage(messageIndex uint32) ([]byte, error) { +// Export returns the base64-encoded ratchet key for this session, at the given +// index, in a format which can be used by +// InboundGroupSession.InboundGroupSessionImport(). Encrypts the +// InboundGroupSession using the supplied key. Returns error on failure. +// if we do not have a session key corresponding to the given index (ie, it was +// sent before the session key was shared with us) the error will be +// returned. +func (o *MegolmInboundSession) Export(messageIndex uint32) ([]byte, error) { ratchet, err := o.getRatchet(messageIndex) if err != nil { return nil, err @@ -174,103 +185,75 @@ func (o MegolmInboundSession) SessionExportMessage(messageIndex uint32) ([]byte, // Unpickle decodes the base64 encoded string and decrypts the result with the key. // The decrypted value is then passed to UnpickleLibOlm. func (o *MegolmInboundSession) Unpickle(pickled, key []byte) error { - decrypted, err := cipher.Unpickle(key, pickled) + if len(key) == 0 { + return olm.ErrNoKeyProvided + } else if len(pickled) == 0 { + return olm.ErrEmptyInput + } + decrypted, err := libolmpickle.Unpickle(key, pickled) if err != nil { return err } - _, err = o.UnpickleLibOlm(decrypted) - return err + return o.UnpickleLibOlm(decrypted) } -// UnpickleLibOlm decodes the unencryted value and populates the Session accordingly. It returns the number of bytes read. -func (o *MegolmInboundSession) UnpickleLibOlm(value []byte) (int, error) { - //First 4 bytes are the accountPickleVersion - pickledVersion, curPos, err := libolmpickle.UnpickleUInt32(value) +// UnpickleLibOlm unpickles the unencryted value and populates the [Session] +// accordingly. +func (o *MegolmInboundSession) UnpickleLibOlm(value []byte) error { + decoder := libolmpickle.NewDecoder(value) + pickledVersion, err := decoder.ReadUInt32() if err != nil { - return 0, err + return err } - switch pickledVersion { - case megolmInboundSessionPickleVersionLibOlm, 1: - default: - return 0, fmt.Errorf("unpickle MegolmInboundSession: %w", goolm.ErrBadVersion) + if pickledVersion != megolmInboundSessionPickleVersionLibOlm && pickledVersion != 1 { + return fmt.Errorf("unpickle MegolmInboundSession: %w (found version %d)", olm.ErrUnknownOlmPickleVersion, pickledVersion) } - readBytes, err := o.InitialRatchet.UnpickleLibOlm(value[curPos:]) - if err != nil { - return 0, err + + if err = o.InitialRatchet.UnpickleLibOlm(decoder); err != nil { + return err + } else if err = o.Ratchet.UnpickleLibOlm(decoder); err != nil { + return err + } else if err = o.SigningKey.UnpickleLibOlm(decoder); err != nil { + return err } - curPos += readBytes - readBytes, err = o.Ratchet.UnpickleLibOlm(value[curPos:]) - if err != nil { - return 0, err - } - curPos += readBytes - readBytes, err = o.SigningKey.UnpickleLibOlm(value[curPos:]) - if err != nil { - return 0, err - } - curPos += readBytes + if pickledVersion == 1 { // pickle v1 had no signing_key_verified field (all keyshares were verified at import time) o.SigningKeyVerified = true } else { - o.SigningKeyVerified, readBytes, err = libolmpickle.UnpickleBool(value[curPos:]) - if err != nil { - return 0, err - } - curPos += readBytes + o.SigningKeyVerified, err = decoder.ReadBool() + return err } - return curPos, nil + return nil } // Pickle returns a base64 encoded and with key encrypted pickled MegolmInboundSession using PickleLibOlm(). -func (o MegolmInboundSession) Pickle(key []byte) ([]byte, error) { - pickeledBytes := make([]byte, o.PickleLen()) - written, err := o.PickleLibOlm(pickeledBytes) - if err != nil { - return nil, err +func (o *MegolmInboundSession) Pickle(key []byte) ([]byte, error) { + if len(key) == 0 { + return nil, olm.ErrNoKeyProvided } - if written != len(pickeledBytes) { - return nil, errors.New("number of written bytes not correct") - } - encrypted, err := cipher.Pickle(key, pickeledBytes) - if err != nil { - return nil, err - } - return encrypted, nil + return libolmpickle.Pickle(key, o.PickleLibOlm()) } -// PickleLibOlm encodes the session into target. target has to have a size of at least PickleLen() and is written to from index 0. -// It returns the number of bytes written. -func (o MegolmInboundSession) PickleLibOlm(target []byte) (int, error) { - if len(target) < o.PickleLen() { - return 0, fmt.Errorf("pickle MegolmInboundSession: %w", goolm.ErrValueTooShort) - } - written := libolmpickle.PickleUInt32(megolmInboundSessionPickleVersionLibOlm, target) - writtenInitRatchet, err := o.InitialRatchet.PickleLibOlm(target[written:]) - if err != nil { - return 0, fmt.Errorf("pickle MegolmInboundSession: %w", err) - } - written += writtenInitRatchet - writtenRatchet, err := o.Ratchet.PickleLibOlm(target[written:]) - if err != nil { - return 0, fmt.Errorf("pickle MegolmInboundSession: %w", err) - } - written += writtenRatchet - writtenPubKey, err := o.SigningKey.PickleLibOlm(target[written:]) - if err != nil { - return 0, fmt.Errorf("pickle MegolmInboundSession: %w", err) - } - written += writtenPubKey - written += libolmpickle.PickleBool(o.SigningKeyVerified, target[written:]) - return written, nil +// PickleLibOlm pickles the session returning the raw bytes. +func (o *MegolmInboundSession) PickleLibOlm() []byte { + encoder := libolmpickle.NewEncoder() + encoder.WriteUInt32(megolmInboundSessionPickleVersionLibOlm) + o.InitialRatchet.PickleLibOlm(encoder) + o.Ratchet.PickleLibOlm(encoder) + o.SigningKey.PickleLibOlm(encoder) + encoder.WriteBool(o.SigningKeyVerified) + return encoder.Bytes() } -// PickleLen returns the number of bytes the pickled session will have. -func (o MegolmInboundSession) PickleLen() int { - length := libolmpickle.PickleUInt32Len(megolmInboundSessionPickleVersionLibOlm) - length += o.InitialRatchet.PickleLen() - length += o.Ratchet.PickleLen() - length += o.SigningKey.PickleLen() - length += libolmpickle.PickleBoolLen(o.SigningKeyVerified) - return length +// FirstKnownIndex returns the first message index we know how to decrypt. +func (s *MegolmInboundSession) FirstKnownIndex() uint32 { + return s.InitialRatchet.Counter +} + +// IsVerified check if the session has been verified as a valid session. (A +// session is verified either because the original session share was signed, or +// because we have subsequently successfully decrypted a message.) +func (s *MegolmInboundSession) IsVerified() bool { + return s.SigningKeyVerified } diff --git a/crypto/goolm/session/megolm_outbound_session.go b/crypto/goolm/session/megolm_outbound_session.go index e594258d..7f923534 100644 --- a/crypto/goolm/session/megolm_outbound_session.go +++ b/crypto/goolm/session/megolm_outbound_session.go @@ -3,17 +3,16 @@ package session import ( "crypto/rand" "encoding/base64" - "errors" "fmt" - "maunium.net/go/mautrix/id" + "go.mau.fi/util/exerrors" - "maunium.net/go/mautrix/crypto/goolm" - "maunium.net/go/mautrix/crypto/goolm/cipher" "maunium.net/go/mautrix/crypto/goolm/crypto" + "maunium.net/go/mautrix/crypto/goolm/goolmbase64" "maunium.net/go/mautrix/crypto/goolm/libolmpickle" "maunium.net/go/mautrix/crypto/goolm/megolm" - "maunium.net/go/mautrix/crypto/goolm/utilities" + "maunium.net/go/mautrix/crypto/olm" + "maunium.net/go/mautrix/id" ) const ( @@ -27,11 +26,13 @@ type MegolmOutboundSession struct { SigningKey crypto.Ed25519KeyPair `json:"signing_key"` } +var _ olm.OutboundGroupSession = (*MegolmOutboundSession)(nil) + // NewMegolmOutboundSession creates a new MegolmOutboundSession. func NewMegolmOutboundSession() (*MegolmOutboundSession, error) { o := &MegolmOutboundSession{} var err error - o.SigningKey, err = crypto.Ed25519GenerateKey(nil) + o.SigningKey, err = crypto.Ed25519GenerateKey() if err != nil { return nil, err } @@ -51,121 +52,94 @@ func NewMegolmOutboundSession() (*MegolmOutboundSession, error) { // MegolmOutboundSessionFromPickled loads the MegolmOutboundSession details from a pickled base64 string. The input is decrypted with the supplied key. func MegolmOutboundSessionFromPickled(pickled, key []byte) (*MegolmOutboundSession, error) { if len(pickled) == 0 { - return nil, fmt.Errorf("megolmOutboundSessionFromPickled: %w", goolm.ErrEmptyInput) + return nil, fmt.Errorf("megolmOutboundSessionFromPickled: %w", olm.ErrEmptyInput) } a := &MegolmOutboundSession{} err := a.Unpickle(pickled, key) - if err != nil { - return nil, err - } - return a, nil + return a, err } // Encrypt encrypts the plaintext as a base64 encoded group message. func (o *MegolmOutboundSession) Encrypt(plaintext []byte) ([]byte, error) { - encrypted, err := o.Ratchet.Encrypt(plaintext, &o.SigningKey) - if err != nil { - return nil, err + if len(plaintext) == 0 { + return nil, olm.ErrEmptyInput } - return goolm.Base64Encode(encrypted), nil + encrypted, err := o.Ratchet.Encrypt(plaintext, o.SigningKey) + return goolmbase64.Encode(encrypted), err } // SessionID returns the base64 endoded public signing key -func (o MegolmOutboundSession) SessionID() id.SessionID { +func (o *MegolmOutboundSession) ID() id.SessionID { return id.SessionID(base64.RawStdEncoding.EncodeToString(o.SigningKey.PublicKey)) } // PickleAsJSON returns an Session as a base64 string encrypted using the supplied key. The unencrypted representation of the Account is in JSON format. -func (o MegolmOutboundSession) PickleAsJSON(key []byte) ([]byte, error) { - return utilities.PickleAsJSON(o, megolmOutboundSessionPickleVersion, key) +func (o *MegolmOutboundSession) PickleAsJSON(key []byte) ([]byte, error) { + return libolmpickle.PickleAsJSON(o, megolmOutboundSessionPickleVersion, key) } // UnpickleAsJSON updates an Session by a base64 encrypted string with the key. The unencrypted representation has to be in JSON format. func (o *MegolmOutboundSession) UnpickleAsJSON(pickled, key []byte) error { - return utilities.UnpickleAsJSON(o, pickled, key, megolmOutboundSessionPickleVersion) + return libolmpickle.UnpickleAsJSON(o, pickled, key, megolmOutboundSessionPickleVersion) } // Unpickle decodes the base64 encoded string and decrypts the result with the key. // The decrypted value is then passed to UnpickleLibOlm. func (o *MegolmOutboundSession) Unpickle(pickled, key []byte) error { - decrypted, err := cipher.Unpickle(key, pickled) + if len(key) == 0 { + return olm.ErrNoKeyProvided + } + decrypted, err := libolmpickle.Unpickle(key, pickled) if err != nil { return err } - _, err = o.UnpickleLibOlm(decrypted) - return err + return o.UnpickleLibOlm(decrypted) } -// UnpickleLibOlm decodes the unencryted value and populates the Session accordingly. It returns the number of bytes read. -func (o *MegolmOutboundSession) UnpickleLibOlm(value []byte) (int, error) { - //First 4 bytes are the accountPickleVersion - pickledVersion, curPos, err := libolmpickle.UnpickleUInt32(value) +// UnpickleLibOlm unpickles the unencryted value and populates the +// [MegolmOutboundSession] accordingly. +func (o *MegolmOutboundSession) UnpickleLibOlm(buf []byte) error { + decoder := libolmpickle.NewDecoder(buf) + pickledVersion, err := decoder.ReadUInt32() if err != nil { - return 0, err + return fmt.Errorf("unpickle MegolmOutboundSession: failed to read version: %w", err) + } else if pickledVersion != megolmOutboundSessionPickleVersionLibOlm { + return fmt.Errorf("unpickle MegolmInboundSession: %w (found version %d)", olm.ErrUnknownOlmPickleVersion, pickledVersion) } - switch pickledVersion { - case megolmOutboundSessionPickleVersionLibOlm: - default: - return 0, fmt.Errorf("unpickle MegolmInboundSession: %w", goolm.ErrBadVersion) + if err = o.Ratchet.UnpickleLibOlm(decoder); err != nil { + return err } - readBytes, err := o.Ratchet.UnpickleLibOlm(value[curPos:]) - if err != nil { - return 0, err - } - curPos += readBytes - readBytes, err = o.SigningKey.UnpickleLibOlm(value[curPos:]) - if err != nil { - return 0, err - } - curPos += readBytes - return curPos, nil + return o.SigningKey.UnpickleLibOlm(decoder) } // Pickle returns a base64 encoded and with key encrypted pickled MegolmOutboundSession using PickleLibOlm(). -func (o MegolmOutboundSession) Pickle(key []byte) ([]byte, error) { - pickeledBytes := make([]byte, o.PickleLen()) - written, err := o.PickleLibOlm(pickeledBytes) - if err != nil { - return nil, err +func (o *MegolmOutboundSession) Pickle(key []byte) ([]byte, error) { + if len(key) == 0 { + return nil, olm.ErrNoKeyProvided } - if written != len(pickeledBytes) { - return nil, errors.New("number of written bytes not correct") - } - encrypted, err := cipher.Pickle(key, pickeledBytes) - if err != nil { - return nil, err - } - return encrypted, nil + return libolmpickle.Pickle(key, o.PickleLibOlm()) } -// PickleLibOlm encodes the session into target. target has to have a size of at least PickleLen() and is written to from index 0. -// It returns the number of bytes written. -func (o MegolmOutboundSession) PickleLibOlm(target []byte) (int, error) { - if len(target) < o.PickleLen() { - return 0, fmt.Errorf("pickle MegolmOutboundSession: %w", goolm.ErrValueTooShort) - } - written := libolmpickle.PickleUInt32(megolmOutboundSessionPickleVersionLibOlm, target) - writtenRatchet, err := o.Ratchet.PickleLibOlm(target[written:]) - if err != nil { - return 0, fmt.Errorf("pickle MegolmOutboundSession: %w", err) - } - written += writtenRatchet - writtenPubKey, err := o.SigningKey.PickleLibOlm(target[written:]) - if err != nil { - return 0, fmt.Errorf("pickle MegolmOutboundSession: %w", err) - } - written += writtenPubKey - return written, nil +// PickleLibOlm pickles the session returning the raw bytes. +func (o *MegolmOutboundSession) PickleLibOlm() []byte { + encoder := libolmpickle.NewEncoder() + encoder.WriteUInt32(megolmOutboundSessionPickleVersionLibOlm) + o.Ratchet.PickleLibOlm(encoder) + o.SigningKey.PickleLibOlm(encoder) + return encoder.Bytes() } -// PickleLen returns the number of bytes the pickled session will have. -func (o MegolmOutboundSession) PickleLen() int { - length := libolmpickle.PickleUInt32Len(megolmOutboundSessionPickleVersionLibOlm) - length += o.Ratchet.PickleLen() - length += o.SigningKey.PickleLen() - return length -} - -func (o MegolmOutboundSession) SessionSharingMessage() ([]byte, error) { +func (o *MegolmOutboundSession) SessionSharingMessage() ([]byte, error) { return o.Ratchet.SessionSharingMessage(o.SigningKey) } + +// MessageIndex returns the message index for this session. Each message is +// sent with an increasing index; this returns the index for the next message. +func (s *MegolmOutboundSession) MessageIndex() uint { + return uint(s.Ratchet.Counter) +} + +// Key returns the base64-encoded current ratchet key for this session. +func (s *MegolmOutboundSession) Key() string { + return string(exerrors.Must(s.SessionSharingMessage())) +} diff --git a/crypto/goolm/session/megolm_session_test.go b/crypto/goolm/session/megolm_session_test.go index 9b3f56b5..72d8857b 100644 --- a/crypto/goolm/session/megolm_session_test.go +++ b/crypto/goolm/session/megolm_session_test.go @@ -1,92 +1,57 @@ package session_test import ( - "bytes" "crypto/rand" - "errors" + "encoding/base64" "testing" - "maunium.net/go/mautrix/crypto/goolm" + "github.com/stretchr/testify/assert" + "maunium.net/go/mautrix/crypto/goolm/crypto" "maunium.net/go/mautrix/crypto/goolm/megolm" "maunium.net/go/mautrix/crypto/goolm/session" + "maunium.net/go/mautrix/crypto/olm" ) func TestOutboundPickleJSON(t *testing.T) { pickleKey := []byte("secretKey") sess, err := session.NewMegolmOutboundSession() - if err != nil { - t.Fatal(err) - } - kp, err := crypto.Ed25519GenerateKey(nil) - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err) + kp, err := crypto.Ed25519GenerateKey() + assert.NoError(t, err) sess.SigningKey = kp pickled, err := sess.PickleAsJSON(pickleKey) - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err) newSession := session.MegolmOutboundSession{} err = newSession.UnpickleAsJSON(pickled, pickleKey) - if err != nil { - t.Fatal(err) - } - if sess.SessionID() != newSession.SessionID() { - t.Fatal("session ids not equal") - } - if !bytes.Equal(sess.SigningKey.PrivateKey, newSession.SigningKey.PrivateKey) { - t.Fatal("private keys not equal") - } - if !bytes.Equal(sess.Ratchet.Data[:], newSession.Ratchet.Data[:]) { - t.Fatal("ratchet data not equal") - } - if sess.Ratchet.Counter != newSession.Ratchet.Counter { - t.Fatal("ratchet counter not equal") - } + assert.NoError(t, err) + assert.Equal(t, sess.ID(), newSession.ID()) + assert.Equal(t, sess.SigningKey, newSession.SigningKey) + assert.Equal(t, sess.Ratchet, newSession.Ratchet) } func TestInboundPickleJSON(t *testing.T) { pickleKey := []byte("secretKey") sess := session.MegolmInboundSession{} - kp, err := crypto.Ed25519GenerateKey(nil) - if err != nil { - t.Fatal(err) - } + kp, err := crypto.Ed25519GenerateKey() + assert.NoError(t, err) sess.SigningKey = kp.PublicKey var randomData [megolm.RatchetParts * megolm.RatchetPartLength]byte _, err = rand.Read(randomData[:]) - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err) ratchet, err := megolm.New(0, randomData) - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err) sess.Ratchet = *ratchet pickled, err := sess.PickleAsJSON(pickleKey) - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err) newSession := session.MegolmInboundSession{} err = newSession.UnpickleAsJSON(pickled, pickleKey) - if err != nil { - t.Fatal(err) - } - if sess.SessionID() != newSession.SessionID() { - t.Fatal("sess ids not equal") - } - if !bytes.Equal(sess.SigningKey, newSession.SigningKey) { - t.Fatal("private keys not equal") - } - if !bytes.Equal(sess.Ratchet.Data[:], newSession.Ratchet.Data[:]) { - t.Fatal("ratchet data not equal") - } - if sess.Ratchet.Counter != newSession.Ratchet.Counter { - t.Fatal("ratchet counter not equal") - } + assert.NoError(t, err) + assert.Equal(t, sess.ID(), newSession.ID()) + assert.Equal(t, sess.SigningKey, newSession.SigningKey) + assert.Equal(t, sess.Ratchet, newSession.Ratchet) } func TestGroupSendReceive(t *testing.T) { @@ -100,46 +65,27 @@ func TestGroupSendReceive(t *testing.T) { ) outboundSession, err := session.NewMegolmOutboundSession() - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err) copy(outboundSession.Ratchet.Data[:], randomData) - if outboundSession.Ratchet.Counter != 0 { - t.Fatal("ratchet counter is not correkt") - } + assert.EqualValues(t, 0, outboundSession.Ratchet.Counter) + sessionSharing, err := outboundSession.SessionSharingMessage() - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err) plainText := []byte("Message") ciphertext, err := outboundSession.Encrypt(plainText) - if err != nil { - t.Fatal(err) - } - if outboundSession.Ratchet.Counter != 1 { - t.Fatal("ratchet counter is not correkt") - } + assert.NoError(t, err) + assert.EqualValues(t, 1, outboundSession.Ratchet.Counter) //build inbound session inboundSession, err := session.NewMegolmInboundSession(sessionSharing) - if err != nil { - t.Fatal(err) - } - if !inboundSession.SigningKeyVerified { - t.Fatal("key not verified") - } - if inboundSession.SessionID() != outboundSession.SessionID() { - t.Fatal("session ids not equal") - } + assert.NoError(t, err) + assert.True(t, inboundSession.SigningKeyVerified) + assert.Equal(t, outboundSession.ID(), inboundSession.ID()) //decode message decoded, _, err := inboundSession.Decrypt(ciphertext) - if err != nil { - t.Fatal(err) - } - if !bytes.Equal(plainText, decoded) { - t.Fatal("messages not equal") - } + assert.NoError(t, err) + assert.Equal(t, plainText, decoded) } func TestGroupSessionExportImport(t *testing.T) { @@ -158,45 +104,26 @@ func TestGroupSessionExportImport(t *testing.T) { //init inbound inboundSession, err := session.NewMegolmInboundSession(sessionKey) - if err != nil { - t.Fatal(err) - } - if !inboundSession.SigningKeyVerified { - t.Fatal("signing key not verified") - } + assert.NoError(t, err) + assert.True(t, inboundSession.SigningKeyVerified) decrypted, _, err := inboundSession.Decrypt(message) - if err != nil { - t.Fatal(err) - } - if !bytes.Equal(plaintext, decrypted) { - t.Fatal("message is not correct") - } + assert.NoError(t, err) + assert.Equal(t, plaintext, decrypted) //Export the keys - exported, err := inboundSession.SessionExportMessage(0) - if err != nil { - t.Fatal(err) - } + exported, err := inboundSession.Export(0) + assert.NoError(t, err) secondInboundSession, err := session.NewMegolmInboundSessionFromExport(exported) - if err != nil { - t.Fatal(err) - } - if secondInboundSession.SigningKeyVerified { - t.Fatal("signing key is verified") - } + assert.NoError(t, err) + assert.False(t, secondInboundSession.SigningKeyVerified) + //decrypt with new session decrypted, _, err = secondInboundSession.Decrypt(message) - if err != nil { - t.Fatal(err) - } - if !bytes.Equal(plaintext, decrypted) { - t.Fatal("message is not correct") - } - if !secondInboundSession.SigningKeyVerified { - t.Fatal("signing key not verified") - } + assert.NoError(t, err) + assert.Equal(t, plaintext, decrypted) + assert.True(t, secondInboundSession.SigningKeyVerified) } func TestBadSignatureGroupMessage(t *testing.T) { @@ -215,70 +142,43 @@ func TestBadSignatureGroupMessage(t *testing.T) { //init inbound inboundSession, err := session.NewMegolmInboundSession(sessionKey) - if err != nil { - t.Fatal(err) - } - if !inboundSession.SigningKeyVerified { - t.Fatal("signing key not verified") - } + assert.NoError(t, err) + assert.True(t, inboundSession.SigningKeyVerified) decrypted, _, err := inboundSession.Decrypt(message) - if err != nil { - t.Fatal(err) - } - if !bytes.Equal(plaintext, decrypted) { - t.Fatal("message is not correct") - } + assert.NoError(t, err) + assert.Equal(t, plaintext, decrypted) //Now twiddle the signature copy(message[len(message)-1:], []byte("E")) _, _, err = inboundSession.Decrypt(message) - if err == nil { - t.Fatal("Signature was changed but did not cause an error") - } - if !errors.Is(err, goolm.ErrBadSignature) { - t.Fatalf("wrong error %s", err.Error()) - } + assert.ErrorIs(t, err, olm.ErrBadSignature) } func TestOutbountPickle(t *testing.T) { pickledDataFromLibOlm := []byte("icDKYm0b4aO23WgUuOxdpPoxC0UlEOYPVeuduNH3IkpFsmnWx5KuEOpxGiZw5IuB/sSn2RZUCTiJ90IvgC7AClkYGHep9O8lpiqQX73XVKD9okZDCAkBc83eEq0DKYC7HBkGRAU/4T6QPIBBY3UK4QZwULLE/fLsi3j4YZBehMtnlsqgHK0q1bvX4cRznZItUO3TiOp5I+6PnQka6n8eHTyIEh3tCetilD+BKnHvtakE0eHHvG6pjEsMNN/vs7lkB5rV6XkoUKHLTE1dAfFunYEeHEZuKQpbG385dBwaMJXt4JrC0hU5jnv6jWNqAA0Ud9GxRDvkp04") pickleKey := []byte("secret_key") sess, err := session.MegolmOutboundSessionFromPickled(pickledDataFromLibOlm, pickleKey) - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err) newPickled, err := sess.Pickle(pickleKey) - if err != nil { - t.Fatal(err) - } - if !bytes.Equal(pickledDataFromLibOlm, newPickled) { - t.Fatal("pickled version does not equal libolm version") - } + assert.NoError(t, err) + assert.Equal(t, pickledDataFromLibOlm, newPickled) + pickledDataFromLibOlm = append(pickledDataFromLibOlm, []byte("a")...) _, err = session.MegolmOutboundSessionFromPickled(pickledDataFromLibOlm, pickleKey) - if err == nil { - t.Fatal("should have gotten an error") - } + assert.ErrorIs(t, err, olm.ErrBadMAC) } func TestInbountPickle(t *testing.T) { pickledDataFromLibOlm := []byte("1/IPCdtUoQxMba5XT7sjjUW0Hrs7no9duGFnhsEmxzFX2H3qtRc4eaFBRZYXxOBRTGZ6eMgy3IiSrgAQ1gUlSZf5Q4AVKeBkhvN4LZ6hdhQFv91mM+C2C55/4B9/gDjJEbDGiRgLoMqbWPDV+y0F4h0KaR1V1PiTCC7zCi4WdxJQ098nJLgDL4VSsDbnaLcSMO60FOYgRN4KsLaKUGkXiiUBWp4boFMCiuTTOiyH8XlH0e9uWc0vMLyGNUcO8kCbpAnx3v1JTIVan3WGsnGv4K8Qu4M8GAkZewpexrsb2BSNNeLclOV9/cR203Y5KlzXcpiWNXSs8XoB3TLEtHYMnjuakMQfyrcXKIQntg4xPD/+wvfqkcMg9i7pcplQh7X2OK5ylrMZQrZkJ1fAYBGbBz1tykWOjfrZ") pickleKey := []byte("secret_key") sess, err := session.MegolmInboundSessionFromPickled(pickledDataFromLibOlm, pickleKey) - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err) newPickled, err := sess.Pickle(pickleKey) - if err != nil { - t.Fatal(err) - } - if !bytes.Equal(pickledDataFromLibOlm, newPickled) { - t.Fatal("pickled version does not equal libolm version") - } + assert.NoError(t, err) + assert.Equal(t, pickledDataFromLibOlm, newPickled) + pickledDataFromLibOlm = append(pickledDataFromLibOlm, []byte("a")...) _, err = session.MegolmInboundSessionFromPickled(pickledDataFromLibOlm, pickleKey) - if err == nil { - t.Fatal("should have gotten an error") - } + assert.ErrorIs(t, err, base64.CorruptInputError(416)) } diff --git a/crypto/goolm/session/olm_session.go b/crypto/goolm/session/olm_session.go index 6655e0a5..a1cb8d66 100644 --- a/crypto/goolm/session/olm_session.go +++ b/crypto/goolm/session/olm_session.go @@ -2,18 +2,17 @@ package session import ( "bytes" + "crypto/sha256" "encoding/base64" - "errors" "fmt" - "io" + "strings" - "maunium.net/go/mautrix/crypto/goolm" - "maunium.net/go/mautrix/crypto/goolm/cipher" "maunium.net/go/mautrix/crypto/goolm/crypto" + "maunium.net/go/mautrix/crypto/goolm/goolmbase64" "maunium.net/go/mautrix/crypto/goolm/libolmpickle" "maunium.net/go/mautrix/crypto/goolm/message" - "maunium.net/go/mautrix/crypto/goolm/olm" - "maunium.net/go/mautrix/crypto/goolm/utilities" + "maunium.net/go/mautrix/crypto/goolm/ratchet" + "maunium.net/go/mautrix/crypto/olm" "maunium.net/go/mautrix/id" ) @@ -32,9 +31,11 @@ type OlmSession struct { AliceIdentityKey crypto.Curve25519PublicKey `json:"alice_id_key"` AliceBaseKey crypto.Curve25519PublicKey `json:"alice_base_key"` BobOneTimeKey crypto.Curve25519PublicKey `json:"bob_one_time_key"` - Ratchet olm.Ratchet `json:"ratchet"` + Ratchet ratchet.Ratchet `json:"ratchet"` } +var _ olm.Session = (*OlmSession)(nil) + // SearchOTKFunc is used to retrieve a crypto.OneTimeKey from a public key. type SearchOTKFunc = func(crypto.Curve25519PublicKey) *crypto.OneTimeKey @@ -42,33 +43,25 @@ type SearchOTKFunc = func(crypto.Curve25519PublicKey) *crypto.OneTimeKey // the Session using the supplied key. func OlmSessionFromJSONPickled(pickled, key []byte) (*OlmSession, error) { if len(pickled) == 0 { - return nil, fmt.Errorf("sessionFromPickled: %w", goolm.ErrEmptyInput) + return nil, fmt.Errorf("sessionFromPickled: %w", olm.ErrEmptyInput) } a := &OlmSession{} - err := a.UnpickleAsJSON(pickled, key) - if err != nil { - return nil, err - } - return a, nil + return a, a.UnpickleAsJSON(pickled, key) } // OlmSessionFromPickled loads the OlmSession details from a pickled base64 string. The input is decrypted with the supplied key. func OlmSessionFromPickled(pickled, key []byte) (*OlmSession, error) { if len(pickled) == 0 { - return nil, fmt.Errorf("sessionFromPickled: %w", goolm.ErrEmptyInput) + return nil, fmt.Errorf("sessionFromPickled: %w", olm.ErrEmptyInput) } a := &OlmSession{} - err := a.Unpickle(pickled, key) - if err != nil { - return nil, err - } - return a, nil + return a, a.Unpickle(pickled, key) } // NewOlmSession creates a new Session. func NewOlmSession() *OlmSession { s := &OlmSession{} - s.Ratchet = *olm.New() + s.Ratchet = *ratchet.New() return s } @@ -77,12 +70,12 @@ func NewOlmSession() *OlmSession { func NewOutboundOlmSession(identityKeyAlice crypto.Curve25519KeyPair, identityKeyBob crypto.Curve25519PublicKey, oneTimeKeyBob crypto.Curve25519PublicKey) (*OlmSession, error) { s := NewOlmSession() //generate E_A - baseKey, err := crypto.Curve25519GenerateKey(nil) + baseKey, err := crypto.Curve25519GenerateKey() if err != nil { return nil, err } //generate T_0 - ratchetKey, err := crypto.Curve25519GenerateKey(nil) + ratchetKey, err := crypto.Curve25519GenerateKey() if err != nil { return nil, err } @@ -117,7 +110,7 @@ func NewOutboundOlmSession(identityKeyAlice crypto.Curve25519KeyPair, identityKe // NewInboundOlmSession creates a new inbound session from receiving the first message. func NewInboundOlmSession(identityKeyAlice *crypto.Curve25519PublicKey, receivedOTKMsg []byte, searchBobOTK SearchOTKFunc, identityKeyBob crypto.Curve25519KeyPair) (*OlmSession, error) { - decodedOTKMsg, err := goolm.Base64Decode(receivedOTKMsg) + decodedOTKMsg, err := goolmbase64.Decode(receivedOTKMsg) if err != nil { return nil, err } @@ -130,7 +123,7 @@ func NewInboundOlmSession(identityKeyAlice *crypto.Curve25519PublicKey, received return nil, fmt.Errorf("OneTimeKeyMessage decode: %w", err) } if !oneTimeMsg.CheckFields(identityKeyAlice) { - return nil, fmt.Errorf("OneTimeKeyMessage check fields: %w", goolm.ErrBadMessageFormat) + return nil, fmt.Errorf("OneTimeKeyMessage check fields: %w", olm.ErrBadMessageFormat) } //Either the identityKeyAlice is set and/or the oneTimeMsg.IdentityKey is set, which is checked @@ -138,7 +131,7 @@ func NewInboundOlmSession(identityKeyAlice *crypto.Curve25519PublicKey, received if identityKeyAlice != nil && len(oneTimeMsg.IdentityKey) != 0 { //if both are set, compare them if !identityKeyAlice.Equal(oneTimeMsg.IdentityKey) { - return nil, fmt.Errorf("OneTimeKeyMessage identity keys: %w", goolm.ErrBadMessageKeyID) + return nil, fmt.Errorf("OneTimeKeyMessage identity keys: %w", olm.ErrBadMessageKeyID) } } if identityKeyAlice == nil { @@ -148,7 +141,7 @@ func NewInboundOlmSession(identityKeyAlice *crypto.Curve25519PublicKey, received oneTimeKeyBob := searchBobOTK(oneTimeMsg.OneTimeKey) if oneTimeKeyBob == nil { - return nil, fmt.Errorf("ourOneTimeKey: %w", goolm.ErrBadMessageKeyID) + return nil, fmt.Errorf("ourOneTimeKey: %w", olm.ErrBadMessageKeyID) } //Calculate shared secret via Triple Diffie-Hellman @@ -175,11 +168,11 @@ func NewInboundOlmSession(identityKeyAlice *crypto.Curve25519PublicKey, received msg := message.Message{} err = msg.Decode(oneTimeMsg.Message) if err != nil { - return nil, fmt.Errorf("Message decode: %w", err) + return nil, fmt.Errorf("message decode: %w", err) } if len(msg.RatchetKey) == 0 { - return nil, fmt.Errorf("Message missing ratchet key: %w", goolm.ErrBadMessageFormat) + return nil, fmt.Errorf("message missing ratchet key: %w", olm.ErrBadMessageFormat) } //Init Ratchet s.Ratchet.InitializeAsBob(secret, msg.RatchetKey) @@ -194,40 +187,64 @@ func NewInboundOlmSession(identityKeyAlice *crypto.Curve25519PublicKey, received // PickleAsJSON returns an Session as a base64 string encrypted using the supplied key. The unencrypted representation of the Account is in JSON format. func (a OlmSession) PickleAsJSON(key []byte) ([]byte, error) { - return utilities.PickleAsJSON(a, olmSessionPickleVersionJSON, key) + return libolmpickle.PickleAsJSON(a, olmSessionPickleVersionJSON, key) } // UnpickleAsJSON updates an Session by a base64 encrypted string with the key. The unencrypted representation has to be in JSON format. func (a *OlmSession) UnpickleAsJSON(pickled, key []byte) error { - return utilities.UnpickleAsJSON(a, pickled, key, olmSessionPickleVersionJSON) + return libolmpickle.UnpickleAsJSON(a, pickled, key, olmSessionPickleVersionJSON) } // ID returns an identifier for this Session. Will be the same for both ends of the conversation. // Generated by hashing the public keys used to create the session. -func (s OlmSession) ID() id.SessionID { - message := make([]byte, 3*crypto.Curve25519KeyLength) +func (s *OlmSession) ID() id.SessionID { + message := make([]byte, 3*crypto.Curve25519PrivateKeyLength) copy(message, s.AliceIdentityKey) - copy(message[crypto.Curve25519KeyLength:], s.AliceBaseKey) - copy(message[2*crypto.Curve25519KeyLength:], s.BobOneTimeKey) - hash := crypto.SHA256(message) - res := id.SessionID(goolm.Base64Encode(hash)) + copy(message[crypto.Curve25519PrivateKeyLength:], s.AliceBaseKey) + copy(message[2*crypto.Curve25519PrivateKeyLength:], s.BobOneTimeKey) + hash := sha256.Sum256(message) + res := id.SessionID(base64.RawStdEncoding.EncodeToString(hash[:])) return res } // HasReceivedMessage returns true if this session has received any message. -func (s OlmSession) HasReceivedMessage() bool { +func (s *OlmSession) HasReceivedMessage() bool { return s.ReceivedMessage } -// MatchesInboundSessionFrom checks if the oneTimeKeyMsg message is set for this inbound -// Session. This can happen if multiple messages are sent to this Account -// before this Account sends a message in reply. Returns true if the session -// matches. Returns false if the session does not match. -func (s OlmSession) MatchesInboundSessionFrom(theirIdentityKeyEncoded *id.Curve25519, receivedOTKMsg []byte) (bool, error) { - if len(receivedOTKMsg) == 0 { - return false, fmt.Errorf("inbound match: %w", goolm.ErrEmptyInput) +// MatchesInboundSession checks if the PRE_KEY message is for this in-bound +// Session. This can happen if multiple messages are sent to this Account +// before this Account sends a message in reply. Returns true if the session +// matches. Returns false if the session does not match. Returns error on +// failure. +func (s *OlmSession) MatchesInboundSession(oneTimeKeyMsg string) (bool, error) { + return s.matchesInboundSession(nil, []byte(oneTimeKeyMsg)) +} + +// MatchesInboundSessionFrom checks if the PRE_KEY message is for this in-bound +// Session. This can happen if multiple messages are sent to this Account +// before this Account sends a message in reply. Returns true if the session +// matches. Returns false if the session does not match. Returns error on +// failure. +func (s *OlmSession) MatchesInboundSessionFrom(theirIdentityKey, oneTimeKeyMsg string) (bool, error) { + var theirKey *id.Curve25519 + if theirIdentityKey != "" { + theirs := id.Curve25519(theirIdentityKey) + theirKey = &theirs } - decodedOTKMsg, err := goolm.Base64Decode(receivedOTKMsg) + + return s.matchesInboundSession(theirKey, []byte(oneTimeKeyMsg)) +} + +// matchesInboundSession checks if the oneTimeKeyMsg message is set for this +// inbound Session. This can happen if multiple messages are sent to this +// Account before this Account sends a message in reply. Returns true if the +// session matches. Returns false if the session does not match. +func (s *OlmSession) matchesInboundSession(theirIdentityKeyEncoded *id.Curve25519, receivedOTKMsg []byte) (bool, error) { + if len(receivedOTKMsg) == 0 { + return false, fmt.Errorf("inbound match: %w", olm.ErrEmptyInput) + } + decodedOTKMsg, err := goolmbase64.Decode(receivedOTKMsg) if err != nil { return false, err } @@ -266,20 +283,20 @@ func (s OlmSession) MatchesInboundSessionFrom(theirIdentityKeyEncoded *id.Curve2 // EncryptMsgType returns the type of the next message that Encrypt will // return. Returns MsgTypePreKey if the message will be a oneTimeKeyMsg. // Returns MsgTypeMsg if the message will be a normal message. -func (s OlmSession) EncryptMsgType() id.OlmMsgType { +func (s *OlmSession) EncryptMsgType() id.OlmMsgType { if s.ReceivedMessage { return id.OlmMsgTypeMsg } return id.OlmMsgTypePreKey } -// Encrypt encrypts a message using the Session. Returns the encrypted message base64 encoded. If reader is nil, crypto/rand is used for key generations. -func (s *OlmSession) Encrypt(plaintext []byte, reader io.Reader) (id.OlmMsgType, []byte, error) { +// Encrypt encrypts a message using the Session. Returns the encrypted message base64 encoded. +func (s *OlmSession) Encrypt(plaintext []byte) (id.OlmMsgType, []byte, error) { if len(plaintext) == 0 { - return 0, nil, fmt.Errorf("encrypt: %w", goolm.ErrEmptyInput) + return 0, nil, fmt.Errorf("encrypt: %w", olm.ErrEmptyInput) } messageType := s.EncryptMsgType() - encrypted, err := s.Ratchet.Encrypt(plaintext, reader) + encrypted, err := s.Ratchet.Encrypt(plaintext) if err != nil { return 0, nil, err } @@ -300,15 +317,15 @@ func (s *OlmSession) Encrypt(plaintext []byte, reader io.Reader) (id.OlmMsgType, result = messageBody } - return messageType, goolm.Base64Encode(result), nil + return messageType, goolmbase64.Encode(result), nil } // Decrypt decrypts a base64 encoded message using the Session. -func (s *OlmSession) Decrypt(crypttext []byte, msgType id.OlmMsgType) ([]byte, error) { +func (s *OlmSession) Decrypt(crypttext string, msgType id.OlmMsgType) ([]byte, error) { if len(crypttext) == 0 { - return nil, fmt.Errorf("decrypt: %w", goolm.ErrEmptyInput) + return nil, fmt.Errorf("decrypt: %w", olm.ErrEmptyInput) } - decodedCrypttext, err := goolm.Base64Decode(crypttext) + decodedCrypttext, err := base64.RawStdEncoding.DecodeString(crypttext) if err != nil { return nil, err } @@ -333,144 +350,80 @@ func (s *OlmSession) Decrypt(crypttext []byte, msgType id.OlmMsgType) ([]byte, e // Unpickle decodes the base64 encoded string and decrypts the result with the key. // The decrypted value is then passed to UnpickleLibOlm. func (o *OlmSession) Unpickle(pickled, key []byte) error { - decrypted, err := cipher.Unpickle(key, pickled) + if len(pickled) == 0 { + return olm.ErrEmptyInput + } + decrypted, err := libolmpickle.Unpickle(key, pickled) if err != nil { return err } - _, err = o.UnpickleLibOlm(decrypted) - return err + return o.UnpickleLibOlm(decrypted) } -// UnpickleLibOlm decodes the unencryted value and populates the Session accordingly. It returns the number of bytes read. -func (o *OlmSession) UnpickleLibOlm(value []byte) (int, error) { - //First 4 bytes are the accountPickleVersion - pickledVersion, curPos, err := libolmpickle.UnpickleUInt32(value) +// UnpickleLibOlm unpickles the unencryted value and populates the [OlmSession] +// accordingly. +func (o *OlmSession) UnpickleLibOlm(buf []byte) error { + decoder := libolmpickle.NewDecoder(buf) + pickledVersion, err := decoder.ReadUInt32() if err != nil { - return 0, err + return fmt.Errorf("unpickle olmSession: failed to read version: %w", err) } - includesChainIndex := true + + var includesChainIndex bool switch pickledVersion { case olmSessionPickleVersionLibOlm: includesChainIndex = false case uint32(0x80000001): includesChainIndex = true default: - return 0, fmt.Errorf("unpickle olmSession: %w", goolm.ErrBadVersion) + return fmt.Errorf("unpickle olmSession: %w (found version %d)", olm.ErrUnknownOlmPickleVersion, pickledVersion) } - var readBytes int - o.ReceivedMessage, readBytes, err = libolmpickle.UnpickleBool(value[curPos:]) - if err != nil { - return 0, err + + if o.ReceivedMessage, err = decoder.ReadBool(); err != nil { + return err + } else if err = o.AliceIdentityKey.UnpickleLibOlm(decoder); err != nil { + return err + } else if err = o.AliceBaseKey.UnpickleLibOlm(decoder); err != nil { + return err + } else if err = o.BobOneTimeKey.UnpickleLibOlm(decoder); err != nil { + return err } - curPos += readBytes - readBytes, err = o.AliceIdentityKey.UnpickleLibOlm(value[curPos:]) - if err != nil { - return 0, err - } - curPos += readBytes - readBytes, err = o.AliceBaseKey.UnpickleLibOlm(value[curPos:]) - if err != nil { - return 0, err - } - curPos += readBytes - readBytes, err = o.BobOneTimeKey.UnpickleLibOlm(value[curPos:]) - if err != nil { - return 0, err - } - curPos += readBytes - readBytes, err = o.Ratchet.UnpickleLibOlm(value[curPos:], includesChainIndex) - if err != nil { - return 0, err - } - curPos += readBytes - return curPos, nil + return o.Ratchet.UnpickleLibOlm(decoder, includesChainIndex) } -// Pickle returns a base64 encoded and with key encrypted pickled olmSession using PickleLibOlm(). -func (o OlmSession) Pickle(key []byte) ([]byte, error) { - pickeledBytes := make([]byte, o.PickleLen()) - written, err := o.PickleLibOlm(pickeledBytes) - if err != nil { - return nil, err +// Pickle returns a base64 encoded and with key encrypted pickled olmSession +// using PickleLibOlm(). +func (s *OlmSession) Pickle(key []byte) ([]byte, error) { + if len(key) == 0 { + return nil, olm.ErrNoKeyProvided } - if written != len(pickeledBytes) { - return nil, errors.New("number of written bytes not correct") - } - encrypted, err := cipher.Pickle(key, pickeledBytes) - if err != nil { - return nil, err - } - return encrypted, nil + return libolmpickle.Pickle(key, s.PickleLibOlm()) } -// PickleLibOlm encodes the session into target. target has to have a size of at least PickleLen() and is written to from index 0. -// It returns the number of bytes written. -func (o OlmSession) PickleLibOlm(target []byte) (int, error) { - if len(target) < o.PickleLen() { - return 0, fmt.Errorf("pickle MegolmOutboundSession: %w", goolm.ErrValueTooShort) - } - written := libolmpickle.PickleUInt32(olmSessionPickleVersionLibOlm, target) - written += libolmpickle.PickleBool(o.ReceivedMessage, target[written:]) - writtenRatchet, err := o.AliceIdentityKey.PickleLibOlm(target[written:]) - if err != nil { - return 0, fmt.Errorf("pickle MegolmOutboundSession: %w", err) - } - written += writtenRatchet - writtenRatchet, err = o.AliceBaseKey.PickleLibOlm(target[written:]) - if err != nil { - return 0, fmt.Errorf("pickle MegolmOutboundSession: %w", err) - } - written += writtenRatchet - writtenRatchet, err = o.BobOneTimeKey.PickleLibOlm(target[written:]) - if err != nil { - return 0, fmt.Errorf("pickle MegolmOutboundSession: %w", err) - } - written += writtenRatchet - writtenRatchet, err = o.Ratchet.PickleLibOlm(target[written:]) - if err != nil { - return 0, fmt.Errorf("pickle MegolmOutboundSession: %w", err) - } - written += writtenRatchet - return written, nil -} - -// PickleLen returns the actual number of bytes the pickled session will have. -func (o OlmSession) PickleLen() int { - length := libolmpickle.PickleUInt32Len(olmSessionPickleVersionLibOlm) - length += libolmpickle.PickleBoolLen(o.ReceivedMessage) - length += o.AliceIdentityKey.PickleLen() - length += o.AliceBaseKey.PickleLen() - length += o.BobOneTimeKey.PickleLen() - length += o.Ratchet.PickleLen() - return length -} - -// PickleLenMin returns the minimum number of bytes the pickled session must have. -func (o OlmSession) PickleLenMin() int { - length := libolmpickle.PickleUInt32Len(olmSessionPickleVersionLibOlm) - length += libolmpickle.PickleBoolLen(o.ReceivedMessage) - length += o.AliceIdentityKey.PickleLen() - length += o.AliceBaseKey.PickleLen() - length += o.BobOneTimeKey.PickleLen() - length += o.Ratchet.PickleLenMin() - return length +// PickleLibOlm pickles the session and returns the raw bytes. +func (o *OlmSession) PickleLibOlm() []byte { + encoder := libolmpickle.NewEncoder() + encoder.WriteUInt32(olmSessionPickleVersionLibOlm) + encoder.WriteBool(o.ReceivedMessage) + o.AliceIdentityKey.PickleLibOlm(encoder) + o.AliceBaseKey.PickleLibOlm(encoder) + o.BobOneTimeKey.PickleLibOlm(encoder) + o.Ratchet.PickleLibOlm(encoder) + return encoder.Bytes() } // Describe returns a string describing the current state of the session for debugging. -func (o OlmSession) Describe() string { - var res string - if o.Ratchet.SenderChains.IsSet { - res += fmt.Sprintf("sender chain index: %d ", o.Ratchet.SenderChains.CKey.Index) - } else { - res += "sender chain index: " - } - res += "receiver chain indicies:" +func (o *OlmSession) Describe() string { + var builder strings.Builder + builder.WriteString("sender chain index: ") + builder.WriteString(fmt.Sprint(o.Ratchet.SenderChains.CKey.Index)) + builder.WriteString(" receiver chain indices:") for _, curChain := range o.Ratchet.ReceiverChains { - res += fmt.Sprintf(" %d", curChain.CKey.Index) + builder.WriteString(fmt.Sprintf(" %d", curChain.CKey.Index)) } - res += " skipped message keys:" + builder.WriteString(" skipped message keys:") for _, curSkip := range o.Ratchet.SkippedMessageKeys { - res += fmt.Sprintf(" %d", curSkip.MKey.Index) + builder.WriteString(fmt.Sprintf(" %d", curSkip.MKey.Index)) } - return res + return builder.String() } diff --git a/crypto/goolm/session/olm_session_test.go b/crypto/goolm/session/olm_session_test.go index 11b13c32..f87c2e7e 100644 --- a/crypto/goolm/session/olm_session_test.go +++ b/crypto/goolm/session/olm_session_test.go @@ -1,44 +1,32 @@ package session_test import ( - "bytes" "encoding/base64" - "errors" "testing" - "maunium.net/go/mautrix/crypto/goolm" + "github.com/stretchr/testify/assert" + "maunium.net/go/mautrix/crypto/goolm/crypto" "maunium.net/go/mautrix/crypto/goolm/session" + "maunium.net/go/mautrix/crypto/olm" "maunium.net/go/mautrix/id" ) func TestOlmSession(t *testing.T) { pickleKey := []byte("secretKey") - aliceKeyPair, err := crypto.Curve25519GenerateKey(nil) - if err != nil { - t.Fatal(err) - } - bobKeyPair, err := crypto.Curve25519GenerateKey(nil) - if err != nil { - t.Fatal(err) - } - bobOneTimeKey, err := crypto.Curve25519GenerateKey(nil) - if err != nil { - t.Fatal(err) - } + aliceKeyPair, err := crypto.Curve25519GenerateKey() + assert.NoError(t, err) + bobKeyPair, err := crypto.Curve25519GenerateKey() + assert.NoError(t, err) + bobOneTimeKey, err := crypto.Curve25519GenerateKey() + assert.NoError(t, err) aliceSession, err := session.NewOutboundOlmSession(aliceKeyPair, bobKeyPair.PublicKey, bobOneTimeKey.PublicKey) - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err) //create a message so that there are more keys to marshal plaintext := []byte("Test message from Alice to Bob") - msgType, message, err := aliceSession.Encrypt(plaintext, nil) - if err != nil { - t.Fatal(err) - } - if msgType != id.OlmMsgTypePreKey { - t.Fatal("Wrong message type") - } + msgType, message, err := aliceSession.Encrypt(plaintext) + assert.NoError(t, err) + assert.Equal(t, id.OlmMsgTypePreKey, msgType) searchFunc := func(target crypto.Curve25519PublicKey) *crypto.OneTimeKey { if target.Equal(bobOneTimeKey.PublicKey) { @@ -52,92 +40,58 @@ func TestOlmSession(t *testing.T) { } //bob receives message bobSession, err := session.NewInboundOlmSession(nil, message, searchFunc, bobKeyPair) - if err != nil { - t.Fatal(err) - } - decryptedMsg, err := bobSession.Decrypt(message, msgType) - if err != nil { - t.Fatal(err) - } - if !bytes.Equal(plaintext, decryptedMsg) { - t.Fatalf("messages are not equal:\n%v\n%v\n", plaintext, decryptedMsg) - } + assert.NoError(t, err) + decryptedMsg, err := bobSession.Decrypt(string(message), msgType) + assert.NoError(t, err) + assert.Equal(t, plaintext, decryptedMsg) // Alice pickles session pickled, err := aliceSession.PickleAsJSON(pickleKey) - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err) //bob sends a message plaintext = []byte("A message from Bob to Alice") - msgType, message, err = bobSession.Encrypt(plaintext, nil) - if err != nil { - t.Fatal(err) - } - if msgType != id.OlmMsgTypeMsg { - t.Fatal("Wrong message type") - } + msgType, message, err = bobSession.Encrypt(plaintext) + assert.NoError(t, err) + assert.Equal(t, id.OlmMsgTypeMsg, msgType) //Alice unpickles session newAliceSession, err := session.OlmSessionFromJSONPickled(pickled, pickleKey) - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err) //Alice receives message - decryptedMsg, err = newAliceSession.Decrypt(message, msgType) - if err != nil { - t.Fatal(err) - } - if !bytes.Equal(plaintext, decryptedMsg) { - t.Fatalf("messages are not equal:\n%v\n%v\n", plaintext, decryptedMsg) - } + decryptedMsg, err = newAliceSession.Decrypt(string(message), msgType) + assert.NoError(t, err) + assert.Equal(t, plaintext, decryptedMsg) //Alice receives message again - _, err = newAliceSession.Decrypt(message, msgType) - if err == nil { - t.Fatal("should have gotten an error") - } + _, err = newAliceSession.Decrypt(string(message), msgType) + assert.ErrorIs(t, err, olm.ErrMessageKeyNotFound) //Alice sends another message plaintext = []byte("A second message to Bob") - msgType, message, err = newAliceSession.Encrypt(plaintext, nil) - if err != nil { - t.Fatal(err) - } - if msgType != id.OlmMsgTypeMsg { - t.Fatal("Wrong message type") - } + msgType, message, err = newAliceSession.Encrypt(plaintext) + assert.NoError(t, err) + assert.Equal(t, id.OlmMsgTypeMsg, msgType) + //bob receives message - decryptedMsg, err = bobSession.Decrypt(message, msgType) - if err != nil { - t.Fatal(err) - } - if !bytes.Equal(plaintext, decryptedMsg) { - t.Fatalf("messages are not equal:\n%v\n%v\n", plaintext, decryptedMsg) - } + decryptedMsg, err = bobSession.Decrypt(string(message), msgType) + assert.NoError(t, err) + assert.Equal(t, plaintext, decryptedMsg) } func TestSessionPickle(t *testing.T) { pickledDataFromLibOlm := []byte("icDKYm0b4aO23WgUuOxdpPoxC0UlEOYPVeuduNH3IkpFsmnWx5KuEOpxGiZw5IuB/sSn2RZUCTiJ90IvgC7AClkYGHep9O8lpiqQX73XVKD9okZDCAkBc83eEq0DKYC7HBkGRAU/4T6QPIBBY3UK4QZwULLE/fLsi3j4YZBehMtnlsqgHK0q1bvX4cRznZItVKR4ro0O9EAk6LLxJtSnRu5elSUk7YXT") pickleKey := []byte("secret_key") sess, err := session.OlmSessionFromPickled(pickledDataFromLibOlm, pickleKey) - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err) newPickled, err := sess.Pickle(pickleKey) - if err != nil { - t.Fatal(err) - } - if !bytes.Equal(pickledDataFromLibOlm, newPickled) { - t.Fatal("pickled version does not equal libolm version") - } + assert.NoError(t, err) + assert.Equal(t, pickledDataFromLibOlm, newPickled) + pickledDataFromLibOlm = append(pickledDataFromLibOlm, []byte("a")...) _, err = session.OlmSessionFromPickled(pickledDataFromLibOlm, pickleKey) - if err == nil { - t.Fatal("should have gotten an error") - } + assert.ErrorIs(t, err, base64.CorruptInputError(224)) } func TestDecrypts(t *testing.T) { @@ -148,7 +102,7 @@ func TestDecrypts(t *testing.T) { {0xe9, 0xe9, 0xc9, 0xc1, 0xe9, 0xe9, 0xc9, 0xe9, 0xc9, 0xc1, 0xe9, 0xe9, 0xc9, 0xc1}, } expectedErr := []error{ - goolm.ErrInputToSmall, + olm.ErrInputToSmall, // Why are these being tested 🤔 base64.CorruptInputError(0), base64.CorruptInputError(0), @@ -161,17 +115,9 @@ func TestDecrypts(t *testing.T) { "dGvPXeH8qLeNZA") pickleKey := []byte("") sess, err := session.OlmSessionFromPickled(sessionPickled, pickleKey) - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err) for curIndex, curMessage := range messages { - _, err := sess.Decrypt(curMessage, id.OlmMsgTypePreKey) - if err != nil { - if !errors.Is(err, expectedErr[curIndex]) { - t.Fatal(err) - } - } else { - t.Fatal("error expected") - } + _, err := sess.Decrypt(string(curMessage), id.OlmMsgTypePreKey) + assert.ErrorIs(t, err, expectedErr[curIndex]) } } diff --git a/crypto/goolm/session/register.go b/crypto/goolm/session/register.go new file mode 100644 index 00000000..b95a44ac --- /dev/null +++ b/crypto/goolm/session/register.go @@ -0,0 +1,61 @@ +// Copyright (c) 2024 Sumner Evans +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package session + +import ( + "maunium.net/go/mautrix/crypto/olm" +) + +func Register() { + // Inbound Session + olm.InitInboundGroupSessionFromPickled = func(pickled, key []byte) (olm.InboundGroupSession, error) { + if len(pickled) == 0 { + return nil, olm.ErrEmptyInput + } + if len(key) == 0 { + key = []byte(" ") + } + return MegolmInboundSessionFromPickled(pickled, key) + } + olm.InitNewInboundGroupSession = func(sessionKey []byte) (olm.InboundGroupSession, error) { + if len(sessionKey) == 0 { + return nil, olm.ErrEmptyInput + } + return NewMegolmInboundSession(sessionKey) + } + olm.InitInboundGroupSessionImport = func(sessionKey []byte) (olm.InboundGroupSession, error) { + if len(sessionKey) == 0 { + return nil, olm.ErrEmptyInput + } + return NewMegolmInboundSessionFromExport(sessionKey) + } + olm.InitBlankInboundGroupSession = func() olm.InboundGroupSession { + return &MegolmInboundSession{} + } + + // Outbound Session + olm.InitNewOutboundGroupSessionFromPickled = func(pickled, key []byte) (olm.OutboundGroupSession, error) { + if len(pickled) == 0 { + return nil, olm.ErrEmptyInput + } + lenKey := len(key) + if lenKey == 0 { + key = []byte(" ") + } + return MegolmOutboundSessionFromPickled(pickled, key) + } + olm.InitNewOutboundGroupSession = func() (olm.OutboundGroupSession, error) { return NewMegolmOutboundSession() } + olm.InitNewBlankOutboundGroupSession = func() olm.OutboundGroupSession { return &MegolmOutboundSession{} } + + // Olm Session + olm.InitSessionFromPickled = func(pickled, key []byte) (olm.Session, error) { + return OlmSessionFromPickled(pickled, key) + } + olm.InitNewBlankSession = func() olm.Session { + return NewOlmSession() + } +} diff --git a/crypto/keybackup.go b/crypto/keybackup.go index d3701e93..7b3c30db 100644 --- a/crypto/keybackup.go +++ b/crypto/keybackup.go @@ -2,6 +2,8 @@ package crypto import ( "context" + "encoding/base64" + "errors" "fmt" "time" @@ -11,6 +13,7 @@ import ( "maunium.net/go/mautrix/crypto/backup" "maunium.net/go/mautrix/crypto/olm" "maunium.net/go/mautrix/crypto/signatures" + "maunium.net/go/mautrix/event" "maunium.net/go/mautrix/id" ) @@ -21,7 +24,7 @@ func (mach *OlmMachine) DownloadAndStoreLatestKeyBackup(ctx context.Context, meg ctx = log.WithContext(ctx) - versionInfo, err := mach.GetAndVerifyLatestKeyBackupVersion(ctx) + versionInfo, err := mach.GetAndVerifyLatestKeyBackupVersion(ctx, megolmBackupKey) if err != nil { return "", err } else if versionInfo == nil { @@ -32,7 +35,7 @@ func (mach *OlmMachine) DownloadAndStoreLatestKeyBackup(ctx context.Context, meg return versionInfo.Version, err } -func (mach *OlmMachine) GetAndVerifyLatestKeyBackupVersion(ctx context.Context) (*mautrix.RespRoomKeysVersion[backup.MegolmAuthData], error) { +func (mach *OlmMachine) GetAndVerifyLatestKeyBackupVersion(ctx context.Context, megolmBackupKey *backup.MegolmBackupKey) (*mautrix.RespRoomKeysVersion[backup.MegolmAuthData], error) { versionInfo, err := mach.Client.GetKeyBackupLatestVersion(ctx) if err != nil { return nil, err @@ -48,12 +51,33 @@ func (mach *OlmMachine) GetAndVerifyLatestKeyBackupVersion(ctx context.Context) Stringer("key_backup_version", versionInfo.Version). Logger() + // https://spec.matrix.org/v1.10/client-server-api/#server-side-key-backups + // "Clients must only store keys in backups after they have ensured that the auth_data is trusted. This can be done either... + // ...by deriving the public key from a private key that it obtained from a trusted source. Trusted sources for the private + // key include the user entering the key, retrieving the key stored in secret storage, or obtaining the key via secret sharing + // from a verified device belonging to the same user." + if megolmBackupKey != nil { + megolmBackupDerivedPublicKey := id.Ed25519(base64.RawStdEncoding.EncodeToString(megolmBackupKey.PublicKey().Bytes())) + if versionInfo.AuthData.PublicKey == megolmBackupDerivedPublicKey { + log.Debug().Msg("Key backup is trusted based on derived public key") + return versionInfo, nil + } + log.Debug(). + Stringer("expected_key", megolmBackupDerivedPublicKey). + Stringer("actual_key", versionInfo.AuthData.PublicKey). + Msg("key backup public keys do not match, proceeding to check device signatures") + } + + // "...or checking that it is signed by the user’s master cross-signing key or by a verified device belonging to the same user" userSignatures, ok := versionInfo.AuthData.Signatures[mach.Client.UserID] if !ok { return nil, fmt.Errorf("no signature from user %s found in key backup", mach.Client.UserID) } crossSigningPubkeys := mach.GetOwnCrossSigningPublicKeys(ctx) + if crossSigningPubkeys == nil { + return nil, ErrCrossSigningPubkeysNotCached + } signatureVerified := false for keyID := range userSignatures { @@ -66,10 +90,12 @@ func (mach *OlmMachine) GetAndVerifyLatestKeyBackupVersion(ctx context.Context) var key id.Ed25519 if keyName == crossSigningPubkeys.MasterKey.String() { key = crossSigningPubkeys.MasterKey - } else if device, err := mach.GetOrFetchDevice(ctx, mach.Client.UserID, id.DeviceID(keyName)); err != nil { - log.Warn().Err(err).Msg("Failed to fetch device") + } else if device, err := mach.CryptoStore.GetDevice(ctx, mach.Client.UserID, id.DeviceID(keyName)); err != nil { + return nil, fmt.Errorf("failed to get device %s/%s from store: %w", mach.Client.UserID, keyName, err) + } else if device == nil { + log.Warn().Err(err).Msg("Device does not exist, ignoring signature") continue - } else if !mach.IsDeviceTrusted(device) { + } else if !mach.IsDeviceTrusted(ctx, device) { log.Warn().Err(err).Msg("Device is not trusted") continue } else { @@ -82,6 +108,7 @@ func (mach *OlmMachine) GetAndVerifyLatestKeyBackupVersion(ctx context.Context) continue } else { // One of the signatures is valid, break from the loop. + log.Debug().Stringer("key_id", keyID).Msg("key backup is trusted based on matching signature") signatureVerified = true break } @@ -112,7 +139,7 @@ func (mach *OlmMachine) GetAndStoreKeyBackup(ctx context.Context, version id.Key continue } - err = mach.ImportRoomKeyFromBackup(ctx, version, roomID, sessionID, sessionData) + _, err = mach.ImportRoomKeyFromBackup(ctx, version, roomID, sessionID, sessionData) if err != nil { log.Warn().Err(err).Msg("Failed to import room key from backup") failedCount++ @@ -130,55 +157,84 @@ func (mach *OlmMachine) GetAndStoreKeyBackup(ctx context.Context, version id.Key return nil } -func (mach *OlmMachine) ImportRoomKeyFromBackup(ctx context.Context, version id.KeyBackupVersion, roomID id.RoomID, sessionID id.SessionID, keyBackupData *backup.MegolmSessionData) error { - log := zerolog.Ctx(ctx).With(). - Str("room_id", roomID.String()). - Str("session_id", sessionID.String()). - Logger() +var ( + ErrUnknownAlgorithmInKeyBackup = errors.New("ignoring room key in backup with weird algorithm") + ErrMismatchingSessionIDInKeyBackup = errors.New("mismatched session ID while creating inbound group session from key backup") + ErrFailedToStoreNewInboundGroupSessionFromBackup = errors.New("failed to store new inbound group session from key backup") +) + +func (mach *OlmMachine) ImportRoomKeyFromBackupWithoutSaving( + ctx context.Context, + version id.KeyBackupVersion, + roomID id.RoomID, + config *event.EncryptionEventContent, + sessionID id.SessionID, + keyBackupData *backup.MegolmSessionData, +) (*InboundGroupSession, error) { + log := zerolog.Ctx(ctx) if keyBackupData.Algorithm != id.AlgorithmMegolmV1 { - return fmt.Errorf("ignoring room key in backup with weird algorithm %s", keyBackupData.Algorithm) + return nil, fmt.Errorf("%w %s", ErrUnknownAlgorithmInKeyBackup, keyBackupData.Algorithm) } igsInternal, err := olm.InboundGroupSessionImport([]byte(keyBackupData.SessionKey)) if err != nil { - return fmt.Errorf("failed to import inbound group session: %w", err) + return nil, fmt.Errorf("failed to import inbound group session: %w", err) } else if igsInternal.ID() != sessionID { log.Warn(). + Stringer("room_id", roomID). + Stringer("session_id", sessionID). Stringer("actual_session_id", igsInternal.ID()). Msg("Mismatched session ID while creating inbound group session from key backup") - return fmt.Errorf("mismatched session ID while creating inbound group session from key backup") + return nil, ErrMismatchingSessionIDInKeyBackup } var maxAge time.Duration var maxMessages int - if config, err := mach.StateStore.GetEncryptionEvent(ctx, roomID); err != nil { - log.Error().Err(err).Msg("Failed to get encryption event for room") - } else if config != nil { + if config != nil { maxAge = time.Duration(config.RotationPeriodMillis) * time.Millisecond maxMessages = config.RotationPeriodMessages } - if firstKnownIndex := igsInternal.FirstKnownIndex(); firstKnownIndex > 0 { - log.Warn().Uint32("first_known_index", firstKnownIndex).Msg("Importing partial session") - } - - igs := &InboundGroupSession{ - Internal: *igsInternal, + return &InboundGroupSession{ + Internal: igsInternal, SigningKey: keyBackupData.SenderClaimedKeys.Ed25519, SenderKey: keyBackupData.SenderKey, RoomID: roomID, - ForwardingChains: append(keyBackupData.ForwardingKeyChain, keyBackupData.SenderKey.String()), + ForwardingChains: keyBackupData.ForwardingKeyChain, id: sessionID, ReceivedAt: time.Now().UTC(), MaxAge: maxAge.Milliseconds(), MaxMessages: maxMessages, KeyBackupVersion: version, - } - err = mach.CryptoStore.PutGroupSession(ctx, roomID, keyBackupData.SenderKey, sessionID, igs) - if err != nil { - return fmt.Errorf("failed to store new inbound group session: %w", err) - } - mach.markSessionReceived(ctx, sessionID) - return nil + KeySource: id.KeySourceBackup, + }, nil +} + +func (mach *OlmMachine) ImportRoomKeyFromBackup(ctx context.Context, version id.KeyBackupVersion, roomID id.RoomID, sessionID id.SessionID, keyBackupData *backup.MegolmSessionData) (*InboundGroupSession, error) { + config, err := mach.StateStore.GetEncryptionEvent(ctx, roomID) + if err != nil { + zerolog.Ctx(ctx).Err(err). + Stringer("room_id", roomID). + Stringer("session_id", sessionID). + Msg("Failed to get encryption event for room") + } + imported, err := mach.ImportRoomKeyFromBackupWithoutSaving(ctx, version, roomID, config, sessionID, keyBackupData) + if err != nil { + return nil, err + } + firstKnownIndex := imported.Internal.FirstKnownIndex() + if firstKnownIndex > 0 { + zerolog.Ctx(ctx).Warn(). + Stringer("room_id", roomID). + Stringer("session_id", sessionID). + Uint32("first_known_index", firstKnownIndex). + Msg("Importing partial session") + } + err = mach.CryptoStore.PutGroupSession(ctx, imported) + if err != nil { + return nil, fmt.Errorf("%w: %w", ErrFailedToStoreNewInboundGroupSessionFromBackup, err) + } + mach.MarkSessionReceived(ctx, roomID, sessionID, firstKnownIndex) + return imported, nil } diff --git a/crypto/keyexport.go b/crypto/keyexport.go index 91bfb6c6..1904c8a5 100644 --- a/crypto/keyexport.go +++ b/crypto/keyexport.go @@ -1,4 +1,4 @@ -// Copyright (c) 2020 Tulir Asokan +// Copyright (c) 2025 Tulir Asokan // // This Source Code Form is subject to the terms of the Mozilla Public // License, v. 2.0. If a copy of the MPL was not distributed with this @@ -11,21 +11,26 @@ import ( "crypto/aes" "crypto/cipher" "crypto/hmac" - "crypto/rand" "crypto/sha256" "crypto/sha512" "encoding/base64" "encoding/binary" "encoding/json" + "errors" "fmt" "math" + "go.mau.fi/util/dbutil" + "go.mau.fi/util/exbytes" + "go.mau.fi/util/exerrors" + "go.mau.fi/util/random" "golang.org/x/crypto/pbkdf2" - "maunium.net/go/mautrix/crypto/olm" "maunium.net/go/mautrix/id" ) +var ErrNoSessionsForExport = errors.New("no sessions provided for export") + type SenderClaimedKeys struct { Ed25519 id.Ed25519 `json:"ed25519"` } @@ -66,45 +71,27 @@ func computeKey(passphrase string, salt []byte, rounds int) (encryptionKey, hash } func makeExportIV() []byte { - iv := make([]byte, 16) - _, err := rand.Read(iv) - if err != nil { - panic(olm.NotEnoughGoRandom) - } + iv := random.Bytes(16) // Set bit 63 to zero iv[7] &= 0b11111110 return iv } func makeExportKeys(passphrase string) (encryptionKey, hashKey, salt, iv []byte) { - salt = make([]byte, 16) - _, err := rand.Read(salt) - if err != nil { - panic(olm.NotEnoughGoRandom) - } - + salt = random.Bytes(16) encryptionKey, hashKey = computeKey(passphrase, salt, defaultPassphraseRounds) - iv = makeExportIV() return } -func exportSessions(sessions []*InboundGroupSession) ([]ExportedSession, error) { - export := make([]ExportedSession, len(sessions)) +func exportSessions(sessions []*InboundGroupSession) ([]*ExportedSession, error) { + export := make([]*ExportedSession, len(sessions)) + var err error for i, session := range sessions { - key, err := session.Internal.Export(session.Internal.FirstKnownIndex()) + export[i], err = session.export() if err != nil { return nil, fmt.Errorf("failed to export session: %w", err) } - export[i] = ExportedSession{ - Algorithm: id.AlgorithmMegolmV1, - ForwardingChains: session.ForwardingChains, - RoomID: session.RoomID, - SenderKey: session.SenderKey, - SenderClaimedKeys: SenderClaimedKeys{}, - SessionID: session.ID(), - SessionKey: string(key), - } } return export, nil } @@ -117,46 +104,74 @@ func exportSessionsJSON(sessions []*InboundGroupSession) ([]byte, error) { return json.Marshal(exportedSessions) } -func min(a, b int) int { - if a > b { - return b +func formatKeyExportData(data []byte) []byte { + encodedLen := base64.StdEncoding.EncodedLen(len(data)) + outputLength := len(exportPrefix) + + encodedLen + int(math.Ceil(float64(encodedLen)/exportLineLengthLimit)) + + len(exportSuffix) + output := make([]byte, 0, outputLength) + outputWriter := (*exbytes.Writer)(&output) + base64Writer := base64.NewEncoder(base64.StdEncoding, outputWriter) + lineByteCount := base64.StdEncoding.DecodedLen(exportLineLengthLimit) + exerrors.Must(outputWriter.WriteString(exportPrefix)) + for i := 0; i < len(data); i += lineByteCount { + exerrors.Must(base64Writer.Write(data[i:min(i+lineByteCount, len(data))])) + if i+lineByteCount >= len(data) { + exerrors.PanicIfNotNil(base64Writer.Close()) + } + exerrors.PanicIfNotNil(outputWriter.WriteByte('\n')) } - return a + exerrors.Must(outputWriter.WriteString(exportSuffix)) + if len(output) != outputLength { + panic(fmt.Errorf("unexpected length %d / %d", len(output), outputLength)) + } + return output } -func formatKeyExportData(data []byte) []byte { - base64Data := make([]byte, base64.StdEncoding.EncodedLen(len(data))) - base64.StdEncoding.Encode(base64Data, data) - - // Prefix + data and newline for each 76 characters of data + suffix - outputLength := len(exportPrefix) + - len(base64Data) + int(math.Ceil(float64(len(base64Data))/exportLineLengthLimit)) + - len(exportSuffix) - - var buf bytes.Buffer - buf.Grow(outputLength) - buf.WriteString(exportPrefix) - for ptr := 0; ptr < len(base64Data); ptr += exportLineLengthLimit { - buf.Write(base64Data[ptr:min(ptr+exportLineLengthLimit, len(base64Data))]) - buf.WriteRune('\n') +func ExportKeysIter(passphrase string, sessions dbutil.RowIter[*InboundGroupSession]) ([]byte, error) { + buf := bytes.NewBuffer(make([]byte, 0, 50*1024)) + enc := json.NewEncoder(buf) + buf.WriteByte('[') + err := sessions.Iter(func(session *InboundGroupSession) (bool, error) { + exported, err := session.export() + if err != nil { + return false, err + } + err = enc.Encode(exported) + if err != nil { + return false, err + } + buf.WriteByte(',') + return true, nil + }) + if err != nil { + return nil, err } - buf.WriteString(exportSuffix) - if buf.Len() != buf.Cap() || buf.Len() != outputLength { - panic(fmt.Errorf("unexpected length %d / %d / %d", buf.Len(), buf.Cap(), outputLength)) + output := buf.Bytes() + if len(output) == 1 { + return nil, ErrNoSessionsForExport } - return buf.Bytes() + output[len(output)-1] = ']' // Replace the last comma with a closing bracket + return EncryptKeyExport(passphrase, output) } // ExportKeys exports the given Megolm sessions with the format specified in the Matrix spec. // See https://spec.matrix.org/v1.2/client-server-api/#key-exports func ExportKeys(passphrase string, sessions []*InboundGroupSession) ([]byte, error) { - // Make all the keys necessary for exporting - encryptionKey, hashKey, salt, iv := makeExportKeys(passphrase) + if len(sessions) == 0 { + return nil, ErrNoSessionsForExport + } // Export all the given sessions and put them in JSON unencryptedData, err := exportSessionsJSON(sessions) if err != nil { return nil, err } + return EncryptKeyExport(passphrase, unencryptedData) +} + +func EncryptKeyExport(passphrase string, unencryptedData json.RawMessage) ([]byte, error) { + // Make all the keys necessary for exporting + encryptionKey, hashKey, salt, iv := makeExportKeys(passphrase) // The export data consists of: // 1 byte of export format version diff --git a/crypto/keyexport_test.go b/crypto/keyexport_test.go new file mode 100644 index 00000000..fd6f105d --- /dev/null +++ b/crypto/keyexport_test.go @@ -0,0 +1,35 @@ +// Copyright (c) 2024 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package crypto_test + +import ( + "testing" + + "github.com/stretchr/testify/assert" + + "go.mau.fi/util/exerrors" + "go.mau.fi/util/exfmt" + + "maunium.net/go/mautrix/crypto" + "maunium.net/go/mautrix/crypto/olm" +) + +func TestExportKeys(t *testing.T) { + acc := crypto.NewOlmAccount() + sess := exerrors.Must(crypto.NewInboundGroupSession( + acc.IdentityKey(), + acc.SigningKey(), + "!room:example.com", + exerrors.Must(olm.NewOutboundGroupSession()).Key(), + 7*exfmt.Day, + 100, + false, + )) + data, err := crypto.ExportKeys("meow", []*crypto.InboundGroupSession{sess}) + assert.NoError(t, err) + assert.Len(t, data, 893) +} diff --git a/crypto/keyimport.go b/crypto/keyimport.go index da51774f..3ffc74a5 100644 --- a/crypto/keyimport.go +++ b/crypto/keyimport.go @@ -36,6 +36,10 @@ var ( var exportPrefixBytes, exportSuffixBytes = []byte(exportPrefix), []byte(exportSuffix) func decodeKeyExport(data []byte) ([]byte, error) { + // Fix some types of corruption in the key export file before checking anything + if bytes.IndexByte(data, '\r') != -1 { + data = bytes.ReplaceAll(data, []byte{'\r', '\n'}, []byte{'\n'}) + } // If the valid prefix and suffix aren't there, it's probably not a Matrix key export if !bytes.HasPrefix(data, exportPrefixBytes) { return nil, ErrMissingExportPrefix @@ -104,25 +108,27 @@ func (mach *OlmMachine) importExportedRoomKey(ctx context.Context, session Expor return false, ErrMismatchingExportedSessionID } igs := &InboundGroupSession{ - Internal: *igsInternal, - SigningKey: session.SenderClaimedKeys.Ed25519, - SenderKey: session.SenderKey, - RoomID: session.RoomID, - // TODO should we add something here to mark the signing key as unverified like key requests do? + Internal: igsInternal, + SigningKey: session.SenderClaimedKeys.Ed25519, + SenderKey: session.SenderKey, + RoomID: session.RoomID, ForwardingChains: session.ForwardingChains, - - ReceivedAt: time.Now().UTC(), + KeySource: id.KeySourceImport, + ReceivedAt: time.Now().UTC(), } - existingIGS, _ := mach.CryptoStore.GetGroupSession(ctx, igs.RoomID, igs.SenderKey, igs.ID()) - if existingIGS != nil && existingIGS.Internal.FirstKnownIndex() <= igs.Internal.FirstKnownIndex() { - // We already have an equivalent or better session in the store, so don't override it. + existingIGS, _ := mach.CryptoStore.GetGroupSession(ctx, igs.RoomID, igs.ID()) + firstKnownIndex := igs.Internal.FirstKnownIndex() + if existingIGS != nil && existingIGS.Internal.FirstKnownIndex() <= firstKnownIndex { + // We already have an equivalent or better session in the store, so don't override it, + // but do notify the session received callback just in case. + mach.MarkSessionReceived(ctx, session.RoomID, igs.ID(), existingIGS.Internal.FirstKnownIndex()) return false, nil } - err = mach.CryptoStore.PutGroupSession(ctx, igs.RoomID, igs.SenderKey, igs.ID(), igs) + err = mach.CryptoStore.PutGroupSession(ctx, igs) if err != nil { return false, fmt.Errorf("failed to store imported session: %w", err) } - mach.markSessionReceived(ctx, igs.ID()) + mach.MarkSessionReceived(ctx, session.RoomID, igs.ID(), firstKnownIndex) return true, nil } diff --git a/crypto/keysharing.go b/crypto/keysharing.go index 05e7f894..19a68c87 100644 --- a/crypto/keysharing.go +++ b/crypto/keysharing.go @@ -33,6 +33,7 @@ var ( KeyShareRejectBlacklisted = KeyShareRejection{event.RoomKeyWithheldBlacklisted, "You have been blacklisted by this device"} KeyShareRejectUnverified = KeyShareRejection{event.RoomKeyWithheldUnverified, "This device does not share keys to unverified devices"} KeyShareRejectOtherUser = KeyShareRejection{event.RoomKeyWithheldUnauthorized, "This device does not share keys to other users"} + KeyShareRejectNotRecipient = KeyShareRejection{event.RoomKeyWithheldUnauthorized, "You were not in the original recipient list for that session, or that session didn't originate from this device"} KeyShareRejectUnavailable = KeyShareRejection{event.RoomKeyWithheldUnavailable, "Requested session ID not found on this device"} KeyShareRejectInternalError = KeyShareRejection{event.RoomKeyWithheldUnavailable, "An internal error occurred while trying to share the requested session"} ) @@ -58,11 +59,15 @@ func (mach *OlmMachine) RequestRoomKey(ctx context.Context, toUser id.UserID, to select { case <-keyResponseReceived: // key request successful - mach.Log.Debug().Msgf("Key for session %v was received, cancelling other key requests", sessionID) + mach.Log.Debug(). + Stringer("session_id", sessionID). + Msg("Key for session was received, cancelling other key requests") resChan <- true case <-ctx.Done(): // if the context is done, key request was unsuccessful - mach.Log.Debug().Msgf("Context closed (%v) before forwared key for session %v received, sending key request cancellation", ctx.Err(), sessionID) + mach.Log.Debug().Err(err). + Stringer("session_id", sessionID). + Msg("Context closed before forwarded key for session received, sending key request cancellation") resChan <- false } @@ -168,11 +173,12 @@ func (mach *OlmMachine) importForwardedRoomKey(ctx context.Context, evt *Decrypt if content.MaxMessages != 0 { maxMessages = content.MaxMessages } - if firstKnownIndex := igsInternal.FirstKnownIndex(); firstKnownIndex > 0 { + firstKnownIndex := igsInternal.FirstKnownIndex() + if firstKnownIndex > 0 { log.Warn().Uint32("first_known_index", firstKnownIndex).Msg("Importing partial session") } igs := &InboundGroupSession{ - Internal: *igsInternal, + Internal: igsInternal, SigningKey: evt.Keys.Ed25519, SenderKey: content.SenderKey, RoomID: content.RoomID, @@ -183,13 +189,19 @@ func (mach *OlmMachine) importForwardedRoomKey(ctx context.Context, evt *Decrypt MaxAge: maxAge.Milliseconds(), MaxMessages: maxMessages, IsScheduled: content.IsScheduled, + KeySource: id.KeySourceForward, } - err = mach.CryptoStore.PutGroupSession(ctx, content.RoomID, content.SenderKey, content.SessionID, igs) + existingIGS, _ := mach.CryptoStore.GetGroupSession(ctx, igs.RoomID, igs.ID()) + if existingIGS != nil && existingIGS.Internal.FirstKnownIndex() <= igs.Internal.FirstKnownIndex() { + // We already have an equivalent or better session in the store, so don't override it. + return false + } + err = mach.CryptoStore.PutGroupSession(ctx, igs) if err != nil { log.Error().Err(err).Msg("Failed to store new inbound group session") return false } - mach.markSessionReceived(ctx, content.SessionID) + mach.MarkSessionReceived(ctx, content.RoomID, content.SessionID, firstKnownIndex) log.Debug().Msg("Received forwarded inbound group session") return true } @@ -203,6 +215,7 @@ func (mach *OlmMachine) rejectKeyRequest(ctx context.Context, rejection KeyShare RoomID: request.RoomID, Algorithm: request.Algorithm, SessionID: request.SessionID, + //lint:ignore SA1019 This is just echoing back the deprecated field SenderKey: request.SenderKey, Code: rejection.Code, Reason: rejection.Reason, @@ -243,13 +256,23 @@ func (mach *OlmMachine) sendToOneDevice(ctx context.Context, userID id.UserID, d func (mach *OlmMachine) defaultAllowKeyShare(ctx context.Context, device *id.Device, evt event.RequestedKeyInfo) *KeyShareRejection { log := mach.machOrContextLog(ctx) if mach.Client.UserID != device.UserID { + if mach.DisableSharedGroupSessionTracking { + log.Debug().Msg("Rejecting key request from another user as recipient list tracking is disabled") + return &KeyShareRejectOtherUser + } isShared, err := mach.CryptoStore.IsOutboundGroupSessionShared(ctx, device.UserID, device.IdentityKey, evt.SessionID) if err != nil { log.Err(err).Msg("Rejecting key request due to internal error when checking session sharing") return &KeyShareRejectNoResponse } else if !isShared { - log.Debug().Msg("Rejecting key request for unshared session") - return &KeyShareRejectOtherUser + igs, _ := mach.CryptoStore.GetGroupSession(ctx, evt.RoomID, evt.SessionID) + if igs != nil && igs.SenderKey == mach.OwnIdentity().IdentityKey { + log.Debug().Msg("Rejecting key request for unshared session") + return &KeyShareRejectNotRecipient + } + // Note: this case will also happen for redacted sessions and database errors + log.Debug().Msg("Rejecting key request for session created by another device") + return &KeyShareRejectNoResponse } log.Debug().Msg("Accepting key request for shared session") return nil @@ -259,7 +282,7 @@ func (mach *OlmMachine) defaultAllowKeyShare(ctx context.Context, device *id.Dev } else if device.Trust == id.TrustStateBlacklisted { log.Debug().Msg("Rejecting key request from blacklisted device") return &KeyShareRejectBlacklisted - } else if trustState := mach.ResolveTrust(device); trustState >= mach.ShareKeysMinTrust { + } else if trustState, _ := mach.ResolveTrustContext(ctx, device); trustState >= mach.ShareKeysMinTrust { log.Debug(). Str("min_trust", mach.SendKeysMinTrust.String()). Str("device_trust", trustState.String()). @@ -303,11 +326,13 @@ func (mach *OlmMachine) HandleRoomKeyRequest(ctx context.Context, sender id.User return } - igs, err := mach.CryptoStore.GetGroupSession(ctx, content.Body.RoomID, content.Body.SenderKey, content.Body.SessionID) + igs, err := mach.CryptoStore.GetGroupSession(ctx, content.Body.RoomID, content.Body.SessionID) if err != nil { if errors.Is(err, ErrGroupSessionWithheld) { log.Debug().Err(err).Msg("Requested group session not available") - mach.rejectKeyRequest(ctx, KeyShareRejectUnavailable, device, content.Body) + if sender != mach.Client.UserID { + mach.rejectKeyRequest(ctx, KeyShareRejectUnavailable, device, content.Body) + } } else { log.Error().Err(err).Msg("Failed to get group session to forward") mach.rejectKeyRequest(ctx, KeyShareRejectInternalError, device, content.Body) @@ -315,12 +340,14 @@ func (mach *OlmMachine) HandleRoomKeyRequest(ctx context.Context, sender id.User return } else if igs == nil { log.Error().Msg("Didn't find group session to forward") - mach.rejectKeyRequest(ctx, KeyShareRejectUnavailable, device, content.Body) + if sender != mach.Client.UserID { + mach.rejectKeyRequest(ctx, KeyShareRejectUnavailable, device, content.Body) + } return } if internalID := igs.ID(); internalID != content.Body.SessionID { // Should this be an error? - log = log.With().Str("unexpected_session_id", internalID.String()).Logger() + log = log.With().Stringer("unexpected_session_id", internalID).Logger() } firstKnownIndex := igs.Internal.FirstKnownIndex() @@ -340,7 +367,7 @@ func (mach *OlmMachine) HandleRoomKeyRequest(ctx context.Context, sender id.User SessionID: igs.ID(), SessionKey: string(exportedKey), }, - SenderKey: content.Body.SenderKey, + SenderKey: igs.SenderKey, ForwardingKeyChain: igs.ForwardingChains, SenderClaimedKey: igs.SigningKey, }, @@ -360,7 +387,7 @@ func (mach *OlmMachine) HandleBeeperRoomKeyAck(ctx context.Context, sender id.Us Int("first_message_index", content.FirstMessageIndex). Logger() - sess, err := mach.CryptoStore.GetGroupSession(ctx, content.RoomID, "", content.SessionID) + sess, err := mach.CryptoStore.GetGroupSession(ctx, content.RoomID, content.SessionID) if err != nil { if errors.Is(err, ErrGroupSessionWithheld) { log.Debug().Err(err).Msg("Acked group session was already redacted") @@ -380,7 +407,7 @@ func (mach *OlmMachine) HandleBeeperRoomKeyAck(ctx context.Context, sender id.Us isInbound := sess.SenderKey == mach.OwnIdentity().IdentityKey if isInbound && mach.DeleteOutboundKeysOnAck && content.FirstMessageIndex == 0 { log.Debug().Msg("Redacting inbound copy of outbound group session after ack") - err = mach.CryptoStore.RedactGroupSession(ctx, content.RoomID, sess.SenderKey, content.SessionID, "outbound session acked") + err = mach.CryptoStore.RedactGroupSession(ctx, content.RoomID, content.SessionID, "outbound session acked") if err != nil { log.Err(err).Msg("Failed to redact group session") } diff --git a/crypto/libolm/account.go b/crypto/libolm/account.go new file mode 100644 index 00000000..0350f083 --- /dev/null +++ b/crypto/libolm/account.go @@ -0,0 +1,419 @@ +package libolm + +// #cgo LDFLAGS: -lolm -lstdc++ +// #include +import "C" + +import ( + "crypto/rand" + "encoding/base64" + "encoding/json" + "runtime" + "unsafe" + + "github.com/tidwall/gjson" + + "maunium.net/go/mautrix/crypto/olm" + "maunium.net/go/mautrix/id" +) + +// Account stores a device account for end to end encrypted messaging. +type Account struct { + int *C.OlmAccount + mem []byte +} + +// Ensure that [Account] implements [olm.Account]. +var _ olm.Account = (*Account)(nil) + +// AccountFromPickled loads an Account from a pickled base64 string. Decrypts +// the Account using the supplied key. Returns error on failure. If the key +// doesn't match the one used to encrypt the Account then the error will be +// "BAD_ACCOUNT_KEY". If the base64 couldn't be decoded then the error will be +// "INVALID_BASE64". +func AccountFromPickled(pickled, key []byte) (*Account, error) { + if len(pickled) == 0 { + return nil, olm.ErrEmptyInput + } + a := NewBlankAccount() + return a, a.Unpickle(pickled, key) +} + +func NewBlankAccount() *Account { + memory := make([]byte, accountSize()) + return &Account{ + int: C.olm_account(unsafe.Pointer(unsafe.SliceData(memory))), + mem: memory, + } +} + +// NewAccount creates a new [Account]. +func NewAccount() (*Account, error) { + a := NewBlankAccount() + random := make([]byte, a.createRandomLen()+1) + _, err := rand.Read(random) + if err != nil { + panic(olm.ErrNotEnoughGoRandom) + } + ret := C.olm_create_account( + (*C.OlmAccount)(a.int), + unsafe.Pointer(unsafe.SliceData(random)), + C.size_t(len(random))) + runtime.KeepAlive(random) + if ret == errorVal() { + return nil, a.lastError() + } else { + return a, nil + } +} + +// accountSize returns the size of an account object in bytes. +func accountSize() uint { + return uint(C.olm_account_size()) +} + +// lastError returns an error describing the most recent error to happen to an +// account. +func (a *Account) lastError() error { + return convertError(C.GoString(C.olm_account_last_error((*C.OlmAccount)(a.int)))) +} + +// Clear clears the memory used to back this Account. +func (a *Account) Clear() error { + r := C.olm_clear_account((*C.OlmAccount)(a.int)) + if r == errorVal() { + return a.lastError() + } else { + return nil + } +} + +// pickleLen returns the number of bytes needed to store an Account. +func (a *Account) pickleLen() uint { + return uint(C.olm_pickle_account_length((*C.OlmAccount)(a.int))) +} + +// createRandomLen returns the number of random bytes needed to create an +// Account. +func (a *Account) createRandomLen() uint { + return uint(C.olm_create_account_random_length((*C.OlmAccount)(a.int))) +} + +// identityKeysLen returns the size of the output buffer needed to hold the +// identity keys. +func (a *Account) identityKeysLen() uint { + return uint(C.olm_account_identity_keys_length((*C.OlmAccount)(a.int))) +} + +// signatureLen returns the length of an ed25519 signature encoded as base64. +func (a *Account) signatureLen() uint { + return uint(C.olm_account_signature_length((*C.OlmAccount)(a.int))) +} + +// oneTimeKeysLen returns the size of the output buffer needed to hold the one +// time keys. +func (a *Account) oneTimeKeysLen() uint { + return uint(C.olm_account_one_time_keys_length((*C.OlmAccount)(a.int))) +} + +// genOneTimeKeysRandomLen returns the number of random bytes needed to +// generate a given number of new one time keys. +func (a *Account) genOneTimeKeysRandomLen(num uint) uint { + return uint(C.olm_account_generate_one_time_keys_random_length( + (*C.OlmAccount)(a.int), + C.size_t(num))) +} + +// Pickle returns an Account as a base64 string. Encrypts the Account using the +// supplied key. +func (a *Account) Pickle(key []byte) ([]byte, error) { + if len(key) == 0 { + return nil, olm.ErrNoKeyProvided + } + pickled := make([]byte, a.pickleLen()) + r := C.olm_pickle_account( + (*C.OlmAccount)(a.int), + unsafe.Pointer(unsafe.SliceData(key)), + C.size_t(len(key)), + unsafe.Pointer(unsafe.SliceData(pickled)), + C.size_t(len(pickled))) + if r == errorVal() { + return nil, a.lastError() + } + return pickled[:r], nil +} + +func (a *Account) Unpickle(pickled, key []byte) error { + if len(key) == 0 { + return olm.ErrNoKeyProvided + } + r := C.olm_unpickle_account( + (*C.OlmAccount)(a.int), + unsafe.Pointer(unsafe.SliceData(key)), + C.size_t(len(key)), + unsafe.Pointer(unsafe.SliceData(pickled)), + C.size_t(len(pickled))) + if r == errorVal() { + return a.lastError() + } + return nil +} + +// Deprecated +func (a *Account) GobEncode() ([]byte, error) { + pickled, err := a.Pickle(pickleKey) + if err != nil { + return nil, err + } + length := base64.RawStdEncoding.DecodedLen(len(pickled)) + rawPickled := make([]byte, length) + _, err = base64.RawStdEncoding.Decode(rawPickled, pickled) + return rawPickled, err +} + +// Deprecated +func (a *Account) GobDecode(rawPickled []byte) error { + if a.int == nil { + *a = *NewBlankAccount() + } + length := base64.RawStdEncoding.EncodedLen(len(rawPickled)) + pickled := make([]byte, length) + base64.RawStdEncoding.Encode(pickled, rawPickled) + return a.Unpickle(pickled, pickleKey) +} + +// Deprecated +func (a *Account) MarshalJSON() ([]byte, error) { + pickled, err := a.Pickle(pickleKey) + if err != nil { + return nil, err + } + quotes := make([]byte, len(pickled)+2) + quotes[0] = '"' + quotes[len(quotes)-1] = '"' + copy(quotes[1:len(quotes)-1], pickled) + return quotes, nil +} + +// Deprecated +func (a *Account) UnmarshalJSON(data []byte) error { + if len(data) == 0 || data[0] != '"' || data[len(data)-1] != '"' { + return olm.ErrInputNotJSONString + } + if a.int == nil { + *a = *NewBlankAccount() + } + return a.Unpickle(data[1:len(data)-1], pickleKey) +} + +// IdentityKeysJSON returns the public parts of the identity keys for the Account. +func (a *Account) IdentityKeysJSON() ([]byte, error) { + identityKeys := make([]byte, a.identityKeysLen()) + r := C.olm_account_identity_keys( + (*C.OlmAccount)(a.int), + unsafe.Pointer(unsafe.SliceData(identityKeys)), + C.size_t(len(identityKeys))) + if r == errorVal() { + return nil, a.lastError() + } else { + return identityKeys, nil + } +} + +// IdentityKeys returns the public parts of the Ed25519 and Curve25519 identity +// keys for the Account. +func (a *Account) IdentityKeys() (id.Ed25519, id.Curve25519, error) { + identityKeysJSON, err := a.IdentityKeysJSON() + if err != nil { + return "", "", err + } + results := gjson.GetManyBytes(identityKeysJSON, "ed25519", "curve25519") + return id.Ed25519(results[0].Str), id.Curve25519(results[1].Str), nil +} + +// Sign returns the signature of a message using the ed25519 key for this +// Account. +func (a *Account) Sign(message []byte) ([]byte, error) { + if len(message) == 0 { + panic(olm.ErrEmptyInput) + } + signature := make([]byte, a.signatureLen()) + r := C.olm_account_sign( + (*C.OlmAccount)(a.int), + unsafe.Pointer(unsafe.SliceData(message)), + C.size_t(len(message)), + unsafe.Pointer(unsafe.SliceData(signature)), + C.size_t(len(signature))) + runtime.KeepAlive(message) + if r == errorVal() { + panic(a.lastError()) + } + return signature, nil +} + +// OneTimeKeys returns the public parts of the unpublished one time keys for +// the Account. +// +// The returned data is a struct with the single value "Curve25519", which is +// itself an object mapping key id to base64-encoded Curve25519 key. For +// example: +// +// { +// Curve25519: { +// "AAAAAA": "wo76WcYtb0Vk/pBOdmduiGJ0wIEjW4IBMbbQn7aSnTo", +// "AAAAAB": "LRvjo46L1X2vx69sS9QNFD29HWulxrmW11Up5AfAjgU" +// } +// } +func (a *Account) OneTimeKeys() (map[string]id.Curve25519, error) { + oneTimeKeysJSON := make([]byte, a.oneTimeKeysLen()) + r := C.olm_account_one_time_keys( + (*C.OlmAccount)(a.int), + unsafe.Pointer(unsafe.SliceData(oneTimeKeysJSON)), + C.size_t(len(oneTimeKeysJSON)), + ) + if r == errorVal() { + return nil, a.lastError() + } + var oneTimeKeys struct { + Curve25519 map[string]id.Curve25519 `json:"curve25519"` + } + return oneTimeKeys.Curve25519, json.Unmarshal(oneTimeKeysJSON, &oneTimeKeys) +} + +// MarkKeysAsPublished marks the current set of one time keys as being +// published. +func (a *Account) MarkKeysAsPublished() { + C.olm_account_mark_keys_as_published((*C.OlmAccount)(a.int)) +} + +// MaxNumberOfOneTimeKeys returns the largest number of one time keys this +// Account can store. +func (a *Account) MaxNumberOfOneTimeKeys() uint { + return uint(C.olm_account_max_number_of_one_time_keys((*C.OlmAccount)(a.int))) +} + +// GenOneTimeKeys generates a number of new one time keys. If the total number +// of keys stored by this Account exceeds MaxNumberOfOneTimeKeys then the old +// keys are discarded. +func (a *Account) GenOneTimeKeys(num uint) error { + random := make([]byte, a.genOneTimeKeysRandomLen(num)+1) + _, err := rand.Read(random) + if err != nil { + return olm.ErrNotEnoughGoRandom + } + r := C.olm_account_generate_one_time_keys( + (*C.OlmAccount)(a.int), + C.size_t(num), + unsafe.Pointer(unsafe.SliceData(random)), + C.size_t(len(random)), + ) + runtime.KeepAlive(random) + if r == errorVal() { + return a.lastError() + } + return nil +} + +// NewOutboundSession creates a new out-bound session for sending messages to a +// given curve25519 identityKey and oneTimeKey. Returns error on failure. If the +// keys couldn't be decoded as base64 then the error will be "INVALID_BASE64" +func (a *Account) NewOutboundSession(theirIdentityKey, theirOneTimeKey id.Curve25519) (olm.Session, error) { + if len(theirIdentityKey) == 0 || len(theirOneTimeKey) == 0 { + return nil, olm.ErrEmptyInput + } + s := NewBlankSession() + random := make([]byte, s.createOutboundRandomLen()+1) + _, err := rand.Read(random) + if err != nil { + panic(olm.ErrNotEnoughGoRandom) + } + theirIdentityKeyCopy := []byte(theirIdentityKey) + theirOneTimeKeyCopy := []byte(theirOneTimeKey) + r := C.olm_create_outbound_session( + (*C.OlmSession)(s.int), + (*C.OlmAccount)(a.int), + unsafe.Pointer(unsafe.SliceData(theirIdentityKeyCopy)), + C.size_t(len(theirIdentityKeyCopy)), + unsafe.Pointer(unsafe.SliceData(theirOneTimeKeyCopy)), + C.size_t(len(theirOneTimeKeyCopy)), + unsafe.Pointer(unsafe.SliceData(random)), + C.size_t(len(random)), + ) + runtime.KeepAlive(random) + runtime.KeepAlive(theirIdentityKeyCopy) + runtime.KeepAlive(theirOneTimeKeyCopy) + if r == errorVal() { + return nil, s.lastError() + } + return s, nil +} + +// NewInboundSession creates a new in-bound session for sending/receiving +// messages from an incoming PRE_KEY message. Returns error on failure. If +// the base64 couldn't be decoded then the error will be "INVALID_BASE64". If +// the message was for an unsupported protocol version then the error will be +// "BAD_MESSAGE_VERSION". If the message couldn't be decoded then then the +// error will be "BAD_MESSAGE_FORMAT". If the message refers to an unknown one +// time key then the error will be "BAD_MESSAGE_KEY_ID". +func (a *Account) NewInboundSession(oneTimeKeyMsg string) (olm.Session, error) { + if len(oneTimeKeyMsg) == 0 { + return nil, olm.ErrEmptyInput + } + s := NewBlankSession() + oneTimeKeyMsgCopy := []byte(oneTimeKeyMsg) + r := C.olm_create_inbound_session( + (*C.OlmSession)(s.int), + (*C.OlmAccount)(a.int), + unsafe.Pointer(unsafe.SliceData(oneTimeKeyMsgCopy)), + C.size_t(len(oneTimeKeyMsgCopy)), + ) + runtime.KeepAlive(oneTimeKeyMsgCopy) + if r == errorVal() { + return nil, s.lastError() + } + return s, nil +} + +// NewInboundSessionFrom creates a new in-bound session for sending/receiving +// messages from an incoming PRE_KEY message. Returns error on failure. If +// the base64 couldn't be decoded then the error will be "INVALID_BASE64". If +// the message was for an unsupported protocol version then the error will be +// "BAD_MESSAGE_VERSION". If the message couldn't be decoded then then the +// error will be "BAD_MESSAGE_FORMAT". If the message refers to an unknown one +// time key then the error will be "BAD_MESSAGE_KEY_ID". +func (a *Account) NewInboundSessionFrom(theirIdentityKey *id.Curve25519, oneTimeKeyMsg string) (olm.Session, error) { + if theirIdentityKey == nil || len(oneTimeKeyMsg) == 0 { + return nil, olm.ErrEmptyInput + } + theirIdentityKeyCopy := []byte(*theirIdentityKey) + oneTimeKeyMsgCopy := []byte(oneTimeKeyMsg) + s := NewBlankSession() + r := C.olm_create_inbound_session_from( + (*C.OlmSession)(s.int), + (*C.OlmAccount)(a.int), + unsafe.Pointer(unsafe.SliceData(theirIdentityKeyCopy)), + C.size_t(len(theirIdentityKeyCopy)), + unsafe.Pointer(unsafe.SliceData(oneTimeKeyMsgCopy)), + C.size_t(len(oneTimeKeyMsgCopy)), + ) + runtime.KeepAlive(theirIdentityKeyCopy) + runtime.KeepAlive(oneTimeKeyMsgCopy) + if r == errorVal() { + return nil, s.lastError() + } + return s, nil +} + +// RemoveOneTimeKeys removes the one time keys that the session used from the +// Account. Returns error on failure. If the Account doesn't have any +// matching one time keys then the error will be "BAD_MESSAGE_KEY_ID". +func (a *Account) RemoveOneTimeKeys(s olm.Session) error { + r := C.olm_remove_one_time_keys( + (*C.OlmAccount)(a.int), + (*C.OlmSession)(s.(*Session).int), + ) + if r == errorVal() { + return a.lastError() + } + return nil +} diff --git a/crypto/libolm/error.go b/crypto/libolm/error.go new file mode 100644 index 00000000..6fb5512b --- /dev/null +++ b/crypto/libolm/error.go @@ -0,0 +1,37 @@ +package libolm + +// #cgo LDFLAGS: -lolm -lstdc++ +// #include +import "C" + +import ( + "fmt" + + "maunium.net/go/mautrix/crypto/olm" +) + +var errorMap = map[string]error{ + "NOT_ENOUGH_RANDOM": olm.ErrLibolmNotEnoughRandom, + "OUTPUT_BUFFER_TOO_SMALL": olm.ErrLibolmOutputBufferTooSmall, + "BAD_MESSAGE_VERSION": olm.ErrWrongProtocolVersion, + "BAD_MESSAGE_FORMAT": olm.ErrBadMessageFormat, + "BAD_MESSAGE_MAC": olm.ErrBadMAC, + "BAD_MESSAGE_KEY_ID": olm.ErrBadMessageKeyID, + "INVALID_BASE64": olm.ErrLibolmInvalidBase64, + "BAD_ACCOUNT_KEY": olm.ErrLibolmBadAccountKey, + "UNKNOWN_PICKLE_VERSION": olm.ErrUnknownOlmPickleVersion, + "CORRUPTED_PICKLE": olm.ErrLibolmCorruptedPickle, + "BAD_SESSION_KEY": olm.ErrLibolmBadSessionKey, + "UNKNOWN_MESSAGE_INDEX": olm.ErrUnknownMessageIndex, + "BAD_LEGACY_ACCOUNT_PICKLE": olm.ErrLibolmBadLegacyAccountPickle, + "BAD_SIGNATURE": olm.ErrBadSignature, + "INPUT_BUFFER_TOO_SMALL": olm.ErrInputToSmall, +} + +func convertError(errCode string) error { + err, ok := errorMap[errCode] + if ok { + return err + } + return fmt.Errorf("unknown error: %s", errCode) +} diff --git a/crypto/libolm/inboundgroupsession.go b/crypto/libolm/inboundgroupsession.go new file mode 100644 index 00000000..8815ac32 --- /dev/null +++ b/crypto/libolm/inboundgroupsession.go @@ -0,0 +1,327 @@ +package libolm + +// #cgo LDFLAGS: -lolm -lstdc++ +// #include +import "C" + +import ( + "bytes" + "encoding/base64" + "runtime" + "unsafe" + + "maunium.net/go/mautrix/crypto/olm" + "maunium.net/go/mautrix/id" +) + +// InboundGroupSession stores an inbound encrypted messaging session for a +// group. +type InboundGroupSession struct { + int *C.OlmInboundGroupSession + mem []byte +} + +// Ensure that [InboundGroupSession] implements [olm.InboundGroupSession]. +var _ olm.InboundGroupSession = (*InboundGroupSession)(nil) + +// InboundGroupSessionFromPickled loads an InboundGroupSession from a pickled +// base64 string. Decrypts the InboundGroupSession using the supplied key. +// Returns error on failure. If the key doesn't match the one used to encrypt +// the InboundGroupSession then the error will be "BAD_SESSION_KEY". If the +// base64 couldn't be decoded then the error will be "INVALID_BASE64". +func InboundGroupSessionFromPickled(pickled, key []byte) (*InboundGroupSession, error) { + if len(pickled) == 0 { + return nil, olm.ErrEmptyInput + } + lenKey := len(key) + if lenKey == 0 { + key = []byte(" ") + } + s := NewBlankInboundGroupSession() + return s, s.Unpickle(pickled, key) +} + +// NewInboundGroupSession creates a new inbound group session from a key +// exported from OutboundGroupSession.Key(). Returns error on failure. +// If the sessionKey is not valid base64 the error will be +// "OLM_INVALID_BASE64". If the session_key is invalid the error will be +// "OLM_BAD_SESSION_KEY". +func NewInboundGroupSession(sessionKey []byte) (*InboundGroupSession, error) { + if len(sessionKey) == 0 { + return nil, olm.ErrEmptyInput + } + s := NewBlankInboundGroupSession() + r := C.olm_init_inbound_group_session( + (*C.OlmInboundGroupSession)(s.int), + (*C.uint8_t)(unsafe.Pointer(unsafe.SliceData(sessionKey))), + C.size_t(len(sessionKey)), + ) + runtime.KeepAlive(sessionKey) + if r == errorVal() { + return nil, s.lastError() + } + return s, nil +} + +// InboundGroupSessionImport imports an inbound group session from a previous +// export. Returns error on failure. If the sessionKey is not valid base64 +// the error will be "OLM_INVALID_BASE64". If the session_key is invalid the +// error will be "OLM_BAD_SESSION_KEY". +func InboundGroupSessionImport(sessionKey []byte) (*InboundGroupSession, error) { + if len(sessionKey) == 0 { + return nil, olm.ErrEmptyInput + } + s := NewBlankInboundGroupSession() + r := C.olm_import_inbound_group_session( + (*C.OlmInboundGroupSession)(s.int), + (*C.uint8_t)(unsafe.Pointer(unsafe.SliceData(sessionKey))), + C.size_t(len(sessionKey)), + ) + runtime.KeepAlive(sessionKey) + if r == errorVal() { + return nil, s.lastError() + } + return s, nil +} + +// inboundGroupSessionSize is the size of an inbound group session object in +// bytes. +func inboundGroupSessionSize() uint { + return uint(C.olm_inbound_group_session_size()) +} + +// newInboundGroupSession initialises an empty InboundGroupSession. +func NewBlankInboundGroupSession() *InboundGroupSession { + memory := make([]byte, inboundGroupSessionSize()) + return &InboundGroupSession{ + int: C.olm_inbound_group_session(unsafe.Pointer(unsafe.SliceData(memory))), + mem: memory, + } +} + +// lastError returns an error describing the most recent error to happen to an +// inbound group session. +func (s *InboundGroupSession) lastError() error { + return convertError(C.GoString(C.olm_inbound_group_session_last_error((*C.OlmInboundGroupSession)(s.int)))) +} + +// Clear clears the memory used to back this InboundGroupSession. +func (s *InboundGroupSession) Clear() error { + r := C.olm_clear_inbound_group_session((*C.OlmInboundGroupSession)(s.int)) + if r == errorVal() { + return s.lastError() + } + return nil +} + +// pickleLen returns the number of bytes needed to store an inbound group +// session. +func (s *InboundGroupSession) pickleLen() uint { + return uint(C.olm_pickle_inbound_group_session_length((*C.OlmInboundGroupSession)(s.int))) +} + +// Pickle returns an InboundGroupSession as a base64 string. Encrypts the +// InboundGroupSession using the supplied key. +func (s *InboundGroupSession) Pickle(key []byte) ([]byte, error) { + if len(key) == 0 { + return nil, olm.ErrNoKeyProvided + } + pickled := make([]byte, s.pickleLen()) + r := C.olm_pickle_inbound_group_session( + (*C.OlmInboundGroupSession)(s.int), + unsafe.Pointer(unsafe.SliceData(key)), + C.size_t(len(key)), + unsafe.Pointer(unsafe.SliceData(pickled)), + C.size_t(len(pickled)), + ) + runtime.KeepAlive(key) + if r == errorVal() { + return nil, s.lastError() + } + return pickled[:r], nil +} + +func (s *InboundGroupSession) Unpickle(pickled, key []byte) error { + if len(key) == 0 { + return olm.ErrNoKeyProvided + } else if len(pickled) == 0 { + return olm.ErrEmptyInput + } + r := C.olm_unpickle_inbound_group_session( + (*C.OlmInboundGroupSession)(s.int), + unsafe.Pointer(unsafe.SliceData(key)), + C.size_t(len(key)), + unsafe.Pointer(unsafe.SliceData(pickled)), + C.size_t(len(pickled)), + ) + runtime.KeepAlive(key) + if r == errorVal() { + return s.lastError() + } + return nil +} + +// Deprecated +func (s *InboundGroupSession) GobEncode() ([]byte, error) { + pickled, err := s.Pickle(pickleKey) + if err != nil { + return nil, err + } + length := base64.RawStdEncoding.DecodedLen(len(pickled)) + rawPickled := make([]byte, length) + _, err = base64.RawStdEncoding.Decode(rawPickled, pickled) + return rawPickled, err +} + +// Deprecated +func (s *InboundGroupSession) GobDecode(rawPickled []byte) error { + if s == nil || s.int == nil { + *s = *NewBlankInboundGroupSession() + } + length := base64.RawStdEncoding.EncodedLen(len(rawPickled)) + pickled := make([]byte, length) + base64.RawStdEncoding.Encode(pickled, rawPickled) + return s.Unpickle(pickled, pickleKey) +} + +// Deprecated +func (s *InboundGroupSession) MarshalJSON() ([]byte, error) { + pickled, err := s.Pickle(pickleKey) + if err != nil { + return nil, err + } + quotes := make([]byte, len(pickled)+2) + quotes[0] = '"' + quotes[len(quotes)-1] = '"' + copy(quotes[1:len(quotes)-1], pickled) + return quotes, nil +} + +// Deprecated +func (s *InboundGroupSession) UnmarshalJSON(data []byte) error { + if len(data) == 0 || data[0] != '"' || data[len(data)-1] != '"' { + return olm.ErrInputNotJSONString + } + if s == nil || s.int == nil { + *s = *NewBlankInboundGroupSession() + } + return s.Unpickle(data[1:len(data)-1], pickleKey) +} + +// decryptMaxPlaintextLen returns the maximum number of bytes of plain-text a +// given message could decode to. The actual size could be different due to +// padding. Returns error on failure. If the message base64 couldn't be +// decoded then the error will be "INVALID_BASE64". If the message is for an +// unsupported version of the protocol then the error will be +// "BAD_MESSAGE_VERSION". If the message couldn't be decoded then the error +// will be "BAD_MESSAGE_FORMAT". +func (s *InboundGroupSession) decryptMaxPlaintextLen(message []byte) (uint, error) { + if len(message) == 0 { + return 0, olm.ErrEmptyInput + } + // olm_group_decrypt_max_plaintext_length destroys the input, so we have to clone it + messageCopy := bytes.Clone(message) + r := C.olm_group_decrypt_max_plaintext_length( + (*C.OlmInboundGroupSession)(s.int), + (*C.uint8_t)(unsafe.Pointer(unsafe.SliceData(messageCopy))), + C.size_t(len(messageCopy)), + ) + runtime.KeepAlive(messageCopy) + if r == errorVal() { + return 0, s.lastError() + } + return uint(r), nil +} + +// Decrypt decrypts a message using the InboundGroupSession. Returns the the +// plain-text and message index on success. Returns error on failure. If the +// base64 couldn't be decoded then the error will be "INVALID_BASE64". If the +// message is for an unsupported version of the protocol then the error will be +// "BAD_MESSAGE_VERSION". If the message couldn't be decoded then the error +// will be BAD_MESSAGE_FORMAT". If the MAC on the message was invalid then the +// error will be "BAD_MESSAGE_MAC". If we do not have a session key +// corresponding to the message's index (ie, it was sent before the session key +// was shared with us) the error will be "OLM_UNKNOWN_MESSAGE_INDEX". +func (s *InboundGroupSession) Decrypt(message []byte) ([]byte, uint, error) { + if len(message) == 0 { + return nil, 0, olm.ErrEmptyInput + } + decryptMaxPlaintextLen, err := s.decryptMaxPlaintextLen(message) + if err != nil { + return nil, 0, err + } + messageCopy := bytes.Clone(message) + plaintext := make([]byte, decryptMaxPlaintextLen) + var messageIndex uint32 + r := C.olm_group_decrypt( + (*C.OlmInboundGroupSession)(s.int), + (*C.uint8_t)(unsafe.Pointer(unsafe.SliceData(messageCopy))), + C.size_t(len(messageCopy)), + (*C.uint8_t)(unsafe.Pointer(unsafe.SliceData(plaintext))), + C.size_t(len(plaintext)), + (*C.uint32_t)(unsafe.Pointer(&messageIndex)), + ) + runtime.KeepAlive(messageCopy) + if r == errorVal() { + return nil, 0, s.lastError() + } + return plaintext[:r], uint(messageIndex), nil +} + +// sessionIdLen returns the number of bytes needed to store a session ID. +func (s *InboundGroupSession) sessionIdLen() uint { + return uint(C.olm_inbound_group_session_id_length((*C.OlmInboundGroupSession)(s.int))) +} + +// ID returns a base64-encoded identifier for this session. +func (s *InboundGroupSession) ID() id.SessionID { + sessionID := make([]byte, s.sessionIdLen()) + r := C.olm_inbound_group_session_id( + (*C.OlmInboundGroupSession)(s.int), + (*C.uint8_t)(unsafe.Pointer(unsafe.SliceData(sessionID))), + C.size_t(len(sessionID)), + ) + if r == errorVal() { + panic(s.lastError()) + } + return id.SessionID(sessionID[:r]) +} + +// FirstKnownIndex returns the first message index we know how to decrypt. +func (s *InboundGroupSession) FirstKnownIndex() uint32 { + return uint32(C.olm_inbound_group_session_first_known_index((*C.OlmInboundGroupSession)(s.int))) +} + +// IsVerified check if the session has been verified as a valid session. (A +// session is verified either because the original session share was signed, or +// because we have subsequently successfully decrypted a message.) +func (s *InboundGroupSession) IsVerified() bool { + return uint(C.olm_inbound_group_session_is_verified((*C.OlmInboundGroupSession)(s.int))) == 1 +} + +// exportLen returns the number of bytes needed to export an inbound group +// session. +func (s *InboundGroupSession) exportLen() uint { + return uint(C.olm_export_inbound_group_session_length((*C.OlmInboundGroupSession)(s.int))) +} + +// Export returns the base64-encoded ratchet key for this session, at the given +// index, in a format which can be used by +// InboundGroupSession.InboundGroupSessionImport(). Encrypts the +// InboundGroupSession using the supplied key. Returns error on failure. +// if we do not have a session key corresponding to the given index (ie, it was +// sent before the session key was shared with us) the error will be +// "OLM_UNKNOWN_MESSAGE_INDEX". +func (s *InboundGroupSession) Export(messageIndex uint32) ([]byte, error) { + key := make([]byte, s.exportLen()) + r := C.olm_export_inbound_group_session( + (*C.OlmInboundGroupSession)(s.int), + (*C.uint8_t)(unsafe.Pointer(unsafe.SliceData(key))), + C.size_t(len(key)), + C.uint32_t(messageIndex), + ) + if r == errorVal() { + return nil, s.lastError() + } + return key[:r], nil +} diff --git a/crypto/libolm/libolm.go b/crypto/libolm/libolm.go new file mode 100644 index 00000000..18815767 --- /dev/null +++ b/crypto/libolm/libolm.go @@ -0,0 +1,10 @@ +package libolm + +// #cgo LDFLAGS: -lolm -lstdc++ +// #include +import "C" + +// errorVal returns the value that olm functions return if there was an error. +func errorVal() C.size_t { + return C.olm_error() +} diff --git a/crypto/libolm/outboundgroupsession.go b/crypto/libolm/outboundgroupsession.go new file mode 100644 index 00000000..ca5b68f7 --- /dev/null +++ b/crypto/libolm/outboundgroupsession.go @@ -0,0 +1,245 @@ +package libolm + +// #cgo LDFLAGS: -lolm -lstdc++ +// #include +import "C" + +import ( + "crypto/rand" + "encoding/base64" + "runtime" + "unsafe" + + "maunium.net/go/mautrix/crypto/olm" + "maunium.net/go/mautrix/id" +) + +// OutboundGroupSession stores an outbound encrypted messaging session +// for a group. +type OutboundGroupSession struct { + int *C.OlmOutboundGroupSession + mem []byte +} + +// Ensure that [OutboundGroupSession] implements [olm.OutboundGroupSession]. +var _ olm.OutboundGroupSession = (*OutboundGroupSession)(nil) + +func NewOutboundGroupSession() (*OutboundGroupSession, error) { + s := NewBlankOutboundGroupSession() + random := make([]byte, s.createRandomLen()+1) + _, err := rand.Read(random) + if err != nil { + return nil, err + } + r := C.olm_init_outbound_group_session( + (*C.OlmOutboundGroupSession)(s.int), + (*C.uint8_t)(unsafe.Pointer(unsafe.SliceData(random))), + C.size_t(len(random)), + ) + runtime.KeepAlive(random) + if r == errorVal() { + return nil, s.lastError() + } + return s, nil +} + +// outboundGroupSessionSize is the size of an outbound group session object in +// bytes. +func outboundGroupSessionSize() uint { + return uint(C.olm_outbound_group_session_size()) +} + +// NewBlankOutboundGroupSession initialises an empty [OutboundGroupSession]. +func NewBlankOutboundGroupSession() *OutboundGroupSession { + memory := make([]byte, outboundGroupSessionSize()) + return &OutboundGroupSession{ + int: C.olm_outbound_group_session(unsafe.Pointer(unsafe.SliceData(memory))), + mem: memory, + } +} + +// lastError returns an error describing the most recent error to happen to an +// outbound group session. +func (s *OutboundGroupSession) lastError() error { + return convertError(C.GoString(C.olm_outbound_group_session_last_error((*C.OlmOutboundGroupSession)(s.int)))) +} + +// Clear clears the memory used to back this OutboundGroupSession. +func (s *OutboundGroupSession) Clear() error { + r := C.olm_clear_outbound_group_session((*C.OlmOutboundGroupSession)(s.int)) + if r == errorVal() { + return s.lastError() + } else { + return nil + } +} + +// pickleLen returns the number of bytes needed to store an outbound group +// session. +func (s *OutboundGroupSession) pickleLen() uint { + return uint(C.olm_pickle_outbound_group_session_length((*C.OlmOutboundGroupSession)(s.int))) +} + +// Pickle returns an OutboundGroupSession as a base64 string. Encrypts the +// OutboundGroupSession using the supplied key. +func (s *OutboundGroupSession) Pickle(key []byte) ([]byte, error) { + if len(key) == 0 { + return nil, olm.ErrNoKeyProvided + } + pickled := make([]byte, s.pickleLen()) + r := C.olm_pickle_outbound_group_session( + (*C.OlmOutboundGroupSession)(s.int), + unsafe.Pointer(unsafe.SliceData(key)), + C.size_t(len(key)), + unsafe.Pointer(unsafe.SliceData(pickled)), + C.size_t(len(pickled)), + ) + runtime.KeepAlive(key) + if r == errorVal() { + return nil, s.lastError() + } + return pickled[:r], nil +} + +func (s *OutboundGroupSession) Unpickle(pickled, key []byte) error { + if len(key) == 0 { + return olm.ErrNoKeyProvided + } + r := C.olm_unpickle_outbound_group_session( + (*C.OlmOutboundGroupSession)(s.int), + unsafe.Pointer(unsafe.SliceData(key)), + C.size_t(len(key)), + unsafe.Pointer(unsafe.SliceData(pickled)), + C.size_t(len(pickled)), + ) + runtime.KeepAlive(pickled) + runtime.KeepAlive(key) + if r == errorVal() { + return s.lastError() + } + return nil +} + +// Deprecated +func (s *OutboundGroupSession) GobEncode() ([]byte, error) { + pickled, err := s.Pickle(pickleKey) + if err != nil { + return nil, err + } + length := base64.RawStdEncoding.DecodedLen(len(pickled)) + rawPickled := make([]byte, length) + _, err = base64.RawStdEncoding.Decode(rawPickled, pickled) + return rawPickled, err +} + +// Deprecated +func (s *OutboundGroupSession) GobDecode(rawPickled []byte) error { + if s == nil || s.int == nil { + *s = *NewBlankOutboundGroupSession() + } + length := base64.RawStdEncoding.EncodedLen(len(rawPickled)) + pickled := make([]byte, length) + base64.RawStdEncoding.Encode(pickled, rawPickled) + return s.Unpickle(pickled, pickleKey) +} + +// Deprecated +func (s *OutboundGroupSession) MarshalJSON() ([]byte, error) { + pickled, err := s.Pickle(pickleKey) + if err != nil { + return nil, err + } + quotes := make([]byte, len(pickled)+2) + quotes[0] = '"' + quotes[len(quotes)-1] = '"' + copy(quotes[1:len(quotes)-1], pickled) + return quotes, nil +} + +// Deprecated +func (s *OutboundGroupSession) UnmarshalJSON(data []byte) error { + if len(data) == 0 || data[0] != '"' || data[len(data)-1] != '"' { + return olm.ErrInputNotJSONString + } + if s == nil || s.int == nil { + *s = *NewBlankOutboundGroupSession() + } + return s.Unpickle(data[1:len(data)-1], pickleKey) +} + +// createRandomLen returns the number of random bytes needed to create an +// Account. +func (s *OutboundGroupSession) createRandomLen() uint { + return uint(C.olm_init_outbound_group_session_random_length((*C.OlmOutboundGroupSession)(s.int))) +} + +// encryptMsgLen returns the size of the next message in bytes for the given +// number of plain-text bytes. +func (s *OutboundGroupSession) encryptMsgLen(plainTextLen int) uint { + return uint(C.olm_group_encrypt_message_length((*C.OlmOutboundGroupSession)(s.int), C.size_t(plainTextLen))) +} + +// Encrypt encrypts a message using the Session. Returns the encrypted message +// as base64. +func (s *OutboundGroupSession) Encrypt(plaintext []byte) ([]byte, error) { + if len(plaintext) == 0 { + return nil, olm.ErrEmptyInput + } + message := make([]byte, s.encryptMsgLen(len(plaintext))) + r := C.olm_group_encrypt( + (*C.OlmOutboundGroupSession)(s.int), + (*C.uint8_t)(unsafe.Pointer(unsafe.SliceData(plaintext))), + C.size_t(len(plaintext)), + (*C.uint8_t)(unsafe.Pointer(unsafe.SliceData(message))), + C.size_t(len(message)), + ) + runtime.KeepAlive(plaintext) + if r == errorVal() { + return nil, s.lastError() + } + return message[:r], nil +} + +// sessionIdLen returns the number of bytes needed to store a session ID. +func (s *OutboundGroupSession) sessionIdLen() uint { + return uint(C.olm_outbound_group_session_id_length((*C.OlmOutboundGroupSession)(s.int))) +} + +// ID returns a base64-encoded identifier for this session. +func (s *OutboundGroupSession) ID() id.SessionID { + sessionID := make([]byte, s.sessionIdLen()) + r := C.olm_outbound_group_session_id( + (*C.OlmOutboundGroupSession)(s.int), + (*C.uint8_t)(unsafe.Pointer(unsafe.SliceData(sessionID))), + C.size_t(len(sessionID)), + ) + if r == errorVal() { + panic(s.lastError()) + } + return id.SessionID(sessionID[:r]) +} + +// MessageIndex returns the message index for this session. Each message is +// sent with an increasing index; this returns the index for the next message. +func (s *OutboundGroupSession) MessageIndex() uint { + return uint(C.olm_outbound_group_session_message_index((*C.OlmOutboundGroupSession)(s.int))) +} + +// sessionKeyLen returns the number of bytes needed to store a session key. +func (s *OutboundGroupSession) sessionKeyLen() uint { + return uint(C.olm_outbound_group_session_key_length((*C.OlmOutboundGroupSession)(s.int))) +} + +// Key returns the base64-encoded current ratchet key for this session. +func (s *OutboundGroupSession) Key() string { + sessionKey := make([]byte, s.sessionKeyLen()) + r := C.olm_outbound_group_session_key( + (*C.OlmOutboundGroupSession)(s.int), + (*C.uint8_t)(unsafe.Pointer(unsafe.SliceData(sessionKey))), + C.size_t(len(sessionKey)), + ) + if r == errorVal() { + panic(s.lastError()) + } + return string(sessionKey[:r]) +} diff --git a/crypto/olm/pk_libolm.go b/crypto/libolm/pk.go similarity index 52% rename from crypto/olm/pk_libolm.go rename to crypto/libolm/pk.go index 0854b4d1..2683cf15 100644 --- a/crypto/olm/pk_libolm.go +++ b/crypto/libolm/pk.go @@ -4,9 +4,7 @@ // License, v. 2.0. If a copy of the MPL was not distributed with this // file, You can obtain one at http://mozilla.org/MPL/2.0/. -//go:build !goolm - -package olm +package libolm // #cgo LDFLAGS: -lolm -lstdc++ // #include @@ -16,24 +14,26 @@ import "C" import ( "crypto/rand" "encoding/json" + "runtime" "unsafe" "github.com/tidwall/sjson" "maunium.net/go/mautrix/crypto/canonicaljson" + "maunium.net/go/mautrix/crypto/olm" "maunium.net/go/mautrix/id" ) -// LibOlmPKSigning stores a key pair for signing messages. -type LibOlmPKSigning struct { +// PKSigning stores a key pair for signing messages. +type PKSigning struct { int *C.OlmPkSigning mem []byte publicKey id.Ed25519 seed []byte } -// Ensure that LibOlmPKSigning implements PKSigning. -var _ PKSigning = (*LibOlmPKSigning)(nil) +// Ensure that [PKSigning] implements [olm.PKSigning]. +var _ olm.PKSigning = (*PKSigning)(nil) func pkSigningSize() uint { return uint(C.olm_pk_signing_size()) @@ -51,22 +51,27 @@ func pkSigningSignatureLength() uint { return uint(C.olm_pk_signature_length()) } -func newBlankPKSigning() *LibOlmPKSigning { +func newBlankPKSigning() *PKSigning { memory := make([]byte, pkSigningSize()) - return &LibOlmPKSigning{ - int: C.olm_pk_signing(unsafe.Pointer(&memory[0])), + return &PKSigning{ + int: C.olm_pk_signing(unsafe.Pointer(unsafe.SliceData(memory))), mem: memory, } } // NewPKSigningFromSeed creates a new [PKSigning] object using the given seed. -func NewPKSigningFromSeed(seed []byte) (PKSigning, error) { +func NewPKSigningFromSeed(seed []byte) (*PKSigning, error) { p := newBlankPKSigning() p.clear() pubKey := make([]byte, pkSigningPublicKeyLength()) - if C.olm_pk_signing_key_from_seed((*C.OlmPkSigning)(p.int), - unsafe.Pointer(&pubKey[0]), C.size_t(len(pubKey)), - unsafe.Pointer(&seed[0]), C.size_t(len(seed))) == errorVal() { + r := C.olm_pk_signing_key_from_seed( + (*C.OlmPkSigning)(p.int), + unsafe.Pointer(unsafe.SliceData(pubKey)), + C.size_t(len(pubKey)), + unsafe.Pointer(unsafe.SliceData(seed)), + C.size_t(len(seed)), + ) + if r == errorVal() { return nil, p.lastError() } p.publicKey = id.Ed25519(pubKey) @@ -74,44 +79,51 @@ func NewPKSigningFromSeed(seed []byte) (PKSigning, error) { return p, nil } -// NewPKSigning creates a new LibOlmPKSigning object, containing a key pair for +// NewPKSigning creates a new [PKSigning] object, containing a key pair for // signing messages. -func NewPKSigning() (PKSigning, error) { +func NewPKSigning() (*PKSigning, error) { // Generate the seed seed := make([]byte, pkSigningSeedLength()) _, err := rand.Read(seed) if err != nil { - panic(NotEnoughGoRandom) + panic(olm.ErrNotEnoughGoRandom) } pk, err := NewPKSigningFromSeed(seed) return pk, err } -func (p *LibOlmPKSigning) PublicKey() id.Ed25519 { +func (p *PKSigning) PublicKey() id.Ed25519 { return p.publicKey } -func (p *LibOlmPKSigning) Seed() []byte { +func (p *PKSigning) Seed() []byte { return p.seed } -// clear clears the underlying memory of a LibOlmPKSigning object. -func (p *LibOlmPKSigning) clear() { +// clear clears the underlying memory of a [PKSigning] object. +func (p *PKSigning) clear() { C.olm_clear_pk_signing((*C.OlmPkSigning)(p.int)) } // Sign creates a signature for the given message using this key. -func (p *LibOlmPKSigning) Sign(message []byte) ([]byte, error) { +func (p *PKSigning) Sign(message []byte) ([]byte, error) { signature := make([]byte, pkSigningSignatureLength()) - if C.olm_pk_sign((*C.OlmPkSigning)(p.int), (*C.uint8_t)(unsafe.Pointer(&message[0])), C.size_t(len(message)), - (*C.uint8_t)(unsafe.Pointer(&signature[0])), C.size_t(len(signature))) == errorVal() { + r := C.olm_pk_sign( + (*C.OlmPkSigning)(p.int), + (*C.uint8_t)(unsafe.Pointer(unsafe.SliceData(message))), + C.size_t(len(message)), + (*C.uint8_t)(unsafe.Pointer(unsafe.SliceData(signature))), + C.size_t(len(signature)), + ) + runtime.KeepAlive(message) + if r == errorVal() { return nil, p.lastError() } return signature, nil } // SignJSON creates a signature for the given object after encoding it to canonical JSON. -func (p *LibOlmPKSigning) SignJSON(obj interface{}) (string, error) { +func (p *PKSigning) SignJSON(obj interface{}) (string, error) { objJSON, err := json.Marshal(obj) if err != nil { return "", err @@ -126,15 +138,15 @@ func (p *LibOlmPKSigning) SignJSON(obj interface{}) (string, error) { } // lastError returns the last error that happened in relation to this -// LibOlmPKSigning object. -func (p *LibOlmPKSigning) lastError() error { +// [PKSigning] object. +func (p *PKSigning) lastError() error { return convertError(C.GoString(C.olm_pk_signing_last_error((*C.OlmPkSigning)(p.int)))) } -type LibOlmPKDecryption struct { +type PKDecryption struct { int *C.OlmPkDecryption mem []byte - PublicKey []byte + publicKey []byte } func pkDecryptionSize() uint { @@ -145,34 +157,56 @@ func pkDecryptionPublicKeySize() uint { return uint(C.olm_pk_key_length()) } -func NewPkDecryption(privateKey []byte) (*LibOlmPKDecryption, error) { +func NewPkDecryption(privateKey []byte) (*PKDecryption, error) { memory := make([]byte, pkDecryptionSize()) - p := &LibOlmPKDecryption{ - int: C.olm_pk_decryption(unsafe.Pointer(&memory[0])), + p := &PKDecryption{ + int: C.olm_pk_decryption(unsafe.Pointer(unsafe.SliceData(memory))), mem: memory, } p.clear() pubKey := make([]byte, pkDecryptionPublicKeySize()) - if C.olm_pk_key_from_private((*C.OlmPkDecryption)(p.int), - unsafe.Pointer(&pubKey[0]), C.size_t(len(pubKey)), - unsafe.Pointer(&privateKey[0]), C.size_t(len(privateKey))) == errorVal() { + r := C.olm_pk_key_from_private( + (*C.OlmPkDecryption)(p.int), + unsafe.Pointer(unsafe.SliceData(pubKey)), + C.size_t(len(pubKey)), + unsafe.Pointer(unsafe.SliceData(privateKey)), + C.size_t(len(privateKey)), + ) + runtime.KeepAlive(privateKey) + if r == errorVal() { return nil, p.lastError() } - p.PublicKey = pubKey + p.publicKey = pubKey return p, nil } -func (p *LibOlmPKDecryption) Decrypt(ephemeralKey []byte, mac []byte, ciphertext []byte) ([]byte, error) { - maxPlaintextLength := uint(C.olm_pk_max_plaintext_length((*C.OlmPkDecryption)(p.int), C.size_t(len(ciphertext)))) +func (p *PKDecryption) PublicKey() id.Curve25519 { + return id.Curve25519(p.publicKey) +} + +func (p *PKDecryption) Decrypt(ephemeralKey []byte, mac []byte, ciphertext []byte) ([]byte, error) { + maxPlaintextLength := uint(C.olm_pk_max_plaintext_length( + (*C.OlmPkDecryption)(p.int), + C.size_t(len(ciphertext)), + )) plaintext := make([]byte, maxPlaintextLength) - size := C.olm_pk_decrypt((*C.OlmPkDecryption)(p.int), - unsafe.Pointer(&ephemeralKey[0]), C.size_t(len(ephemeralKey)), - unsafe.Pointer(&mac[0]), C.size_t(len(mac)), - unsafe.Pointer(&ciphertext[0]), C.size_t(len(ciphertext)), - unsafe.Pointer(&plaintext[0]), C.size_t(len(plaintext))) + size := C.olm_pk_decrypt( + (*C.OlmPkDecryption)(p.int), + unsafe.Pointer(unsafe.SliceData(ephemeralKey)), + C.size_t(len(ephemeralKey)), + unsafe.Pointer(unsafe.SliceData(mac)), + C.size_t(len(mac)), + unsafe.Pointer(unsafe.SliceData(ciphertext)), + C.size_t(len(ciphertext)), + unsafe.Pointer(unsafe.SliceData(plaintext)), + C.size_t(len(plaintext)), + ) + runtime.KeepAlive(ephemeralKey) + runtime.KeepAlive(mac) + runtime.KeepAlive(ciphertext) if size == errorVal() { return nil, p.lastError() } @@ -181,12 +215,12 @@ func (p *LibOlmPKDecryption) Decrypt(ephemeralKey []byte, mac []byte, ciphertext } // Clear clears the underlying memory of a PkDecryption object. -func (p *LibOlmPKDecryption) clear() { +func (p *PKDecryption) clear() { C.olm_clear_pk_decryption((*C.OlmPkDecryption)(p.int)) } // lastError returns the last error that happened in relation to this -// LibOlmPKDecryption object. -func (p *LibOlmPKDecryption) lastError() error { +// [PKDecryption] object. +func (p *PKDecryption) lastError() error { return convertError(C.GoString(C.olm_pk_decryption_last_error((*C.OlmPkDecryption)(p.int)))) } diff --git a/crypto/libolm/register.go b/crypto/libolm/register.go new file mode 100644 index 00000000..ddf84613 --- /dev/null +++ b/crypto/libolm/register.go @@ -0,0 +1,75 @@ +package libolm + +// #cgo LDFLAGS: -lolm -lstdc++ +// #include +import "C" +import ( + "unsafe" + + "maunium.net/go/mautrix/crypto/olm" +) + +var pickleKey = []byte("maunium.net/go/mautrix/crypto/olm") + +func Register() { + olm.Driver = "libolm" + + olm.GetVersion = func() (major, minor, patch uint8) { + C.olm_get_library_version( + (*C.uint8_t)(unsafe.Pointer(&major)), + (*C.uint8_t)(unsafe.Pointer(&minor)), + (*C.uint8_t)(unsafe.Pointer(&patch))) + return 3, 2, 15 + } + olm.SetPickleKeyImpl = func(key []byte) { + pickleKey = key + } + + olm.InitNewAccount = func() (olm.Account, error) { + return NewAccount() + } + olm.InitBlankAccount = func() olm.Account { + return NewBlankAccount() + } + olm.InitNewAccountFromPickled = func(pickled, key []byte) (olm.Account, error) { + return AccountFromPickled(pickled, key) + } + + olm.InitSessionFromPickled = func(pickled, key []byte) (olm.Session, error) { + return SessionFromPickled(pickled, key) + } + olm.InitNewBlankSession = func() olm.Session { + return NewBlankSession() + } + + olm.InitNewPKSigning = func() (olm.PKSigning, error) { return NewPKSigning() } + olm.InitNewPKSigningFromSeed = func(seed []byte) (olm.PKSigning, error) { + return NewPKSigningFromSeed(seed) + } + olm.InitNewPKDecryptionFromPrivateKey = func(privateKey []byte) (olm.PKDecryption, error) { + return NewPkDecryption(privateKey) + } + + olm.InitInboundGroupSessionFromPickled = func(pickled, key []byte) (olm.InboundGroupSession, error) { + return InboundGroupSessionFromPickled(pickled, key) + } + olm.InitNewInboundGroupSession = func(sessionKey []byte) (olm.InboundGroupSession, error) { + return NewInboundGroupSession(sessionKey) + } + olm.InitInboundGroupSessionImport = func(sessionKey []byte) (olm.InboundGroupSession, error) { + return InboundGroupSessionImport(sessionKey) + } + olm.InitBlankInboundGroupSession = func() olm.InboundGroupSession { + return NewBlankInboundGroupSession() + } + + olm.InitNewOutboundGroupSessionFromPickled = func(pickled, key []byte) (olm.OutboundGroupSession, error) { + if len(pickled) == 0 { + return nil, olm.ErrEmptyInput + } + s := NewBlankOutboundGroupSession() + return s, s.Unpickle(pickled, key) + } + olm.InitNewOutboundGroupSession = func() (olm.OutboundGroupSession, error) { return NewOutboundGroupSession() } + olm.InitNewBlankOutboundGroupSession = func() olm.OutboundGroupSession { return NewBlankOutboundGroupSession() } +} diff --git a/crypto/libolm/session.go b/crypto/libolm/session.go new file mode 100644 index 00000000..1441df26 --- /dev/null +++ b/crypto/libolm/session.go @@ -0,0 +1,401 @@ +// Copyright (c) 2024 Sumner Evans +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package libolm + +// #cgo LDFLAGS: -lolm -lstdc++ +// #include +// #include +// #include +// void olm_session_describe(OlmSession * session, char *buf, size_t buflen) __attribute__((weak)); +// void meowlm_session_describe(OlmSession * session, char *buf, size_t buflen) { +// if (olm_session_describe) { +// olm_session_describe(session, buf, buflen); +// } else { +// sprintf(buf, "olm_session_describe not supported"); +// } +// } +import "C" + +import ( + "crypto/rand" + "encoding/base64" + "runtime" + "unsafe" + + "maunium.net/go/mautrix/crypto/olm" + "maunium.net/go/mautrix/id" +) + +// Session stores an end to end encrypted messaging session. +type Session struct { + int *C.OlmSession + mem []byte +} + +// Ensure that [Session] implements [olm.Session]. +var _ olm.Session = (*Session)(nil) + +// sessionSize is the size of a session object in bytes. +func sessionSize() uint { + return uint(C.olm_session_size()) +} + +// SessionFromPickled loads a Session from a pickled base64 string. Decrypts +// the Session using the supplied key. Returns error on failure. If the key +// doesn't match the one used to encrypt the Session then the error will be +// "BAD_SESSION_KEY". If the base64 couldn't be decoded then the error will be +// "INVALID_BASE64". +func SessionFromPickled(pickled, key []byte) (*Session, error) { + if len(pickled) == 0 { + return nil, olm.ErrEmptyInput + } + s := NewBlankSession() + return s, s.Unpickle(pickled, key) +} + +func NewBlankSession() *Session { + memory := make([]byte, sessionSize()) + return &Session{ + int: C.olm_session(unsafe.Pointer(unsafe.SliceData(memory))), + mem: memory, + } +} + +// lastError returns an error describing the most recent error to happen to a +// session. +func (s *Session) lastError() error { + return convertError(C.GoString(C.olm_session_last_error((*C.OlmSession)(s.int)))) +} + +// Clear clears the memory used to back this Session. +func (s *Session) Clear() error { + r := C.olm_clear_session((*C.OlmSession)(s.int)) + if r == errorVal() { + return s.lastError() + } + return nil +} + +// pickleLen returns the number of bytes needed to store a session. +func (s *Session) pickleLen() uint { + return uint(C.olm_pickle_session_length((*C.OlmSession)(s.int))) +} + +// createOutboundRandomLen returns the number of random bytes needed to create +// an outbound session. +func (s *Session) createOutboundRandomLen() uint { + return uint(C.olm_create_outbound_session_random_length((*C.OlmSession)(s.int))) +} + +// idLen returns the length of the buffer needed to return the id for this +// session. +func (s *Session) idLen() uint { + return uint(C.olm_session_id_length((*C.OlmSession)(s.int))) +} + +// encryptRandomLen returns the number of random bytes needed to encrypt the +// next message. +func (s *Session) encryptRandomLen() uint { + return uint(C.olm_encrypt_random_length((*C.OlmSession)(s.int))) +} + +// encryptMsgLen returns the size of the next message in bytes for the given +// number of plain-text bytes. +func (s *Session) encryptMsgLen(plainTextLen int) uint { + return uint(C.olm_encrypt_message_length((*C.OlmSession)(s.int), C.size_t(plainTextLen))) +} + +// decryptMaxPlaintextLen returns the maximum number of bytes of plain-text a +// given message could decode to. The actual size could be different due to +// padding. Returns error on failure. If the message base64 couldn't be +// decoded then the error will be "INVALID_BASE64". If the message is for an +// unsupported version of the protocol then the error will be +// "BAD_MESSAGE_VERSION". If the message couldn't be decoded then the error +// will be "BAD_MESSAGE_FORMAT". +func (s *Session) decryptMaxPlaintextLen(message string, msgType id.OlmMsgType) (uint, error) { + if len(message) == 0 { + return 0, olm.ErrEmptyInput + } + messageCopy := []byte(message) + r := C.olm_decrypt_max_plaintext_length( + (*C.OlmSession)(s.int), + C.size_t(msgType), + unsafe.Pointer(unsafe.SliceData((messageCopy))), + C.size_t(len(messageCopy)), + ) + runtime.KeepAlive(messageCopy) + if r == errorVal() { + return 0, s.lastError() + } + return uint(r), nil +} + +// Pickle returns a Session as a base64 string. Encrypts the Session using the +// supplied key. +func (s *Session) Pickle(key []byte) ([]byte, error) { + if len(key) == 0 { + return nil, olm.ErrNoKeyProvided + } + pickled := make([]byte, s.pickleLen()) + r := C.olm_pickle_session( + (*C.OlmSession)(s.int), + unsafe.Pointer(unsafe.SliceData(key)), + C.size_t(len(key)), + unsafe.Pointer(unsafe.SliceData(pickled)), + C.size_t(len(pickled))) + runtime.KeepAlive(key) + if r == errorVal() { + panic(s.lastError()) + } + return pickled[:r], nil +} + +// Unpickle unpickles the base64-encoded Olm session decrypting it with the +// provided key. This function mutates the input pickled data slice. +func (s *Session) Unpickle(pickled, key []byte) error { + if len(key) == 0 { + return olm.ErrNoKeyProvided + } + r := C.olm_unpickle_session( + (*C.OlmSession)(s.int), + unsafe.Pointer(unsafe.SliceData(key)), + C.size_t(len(key)), + unsafe.Pointer(unsafe.SliceData(pickled)), + C.size_t(len(pickled))) + runtime.KeepAlive(pickled) + runtime.KeepAlive(key) + if r == errorVal() { + return s.lastError() + } + return nil +} + +// Deprecated +func (s *Session) GobEncode() ([]byte, error) { + pickled, err := s.Pickle(pickleKey) + if err != nil { + return nil, err + } + length := base64.RawStdEncoding.DecodedLen(len(pickled)) + rawPickled := make([]byte, length) + _, err = base64.RawStdEncoding.Decode(rawPickled, pickled) + return rawPickled, err +} + +// Deprecated +func (s *Session) GobDecode(rawPickled []byte) error { + if s == nil || s.int == nil { + *s = *NewBlankSession() + } + length := base64.RawStdEncoding.EncodedLen(len(rawPickled)) + pickled := make([]byte, length) + base64.RawStdEncoding.Encode(pickled, rawPickled) + return s.Unpickle(pickled, pickleKey) +} + +// Deprecated +func (s *Session) MarshalJSON() ([]byte, error) { + pickled, err := s.Pickle(pickleKey) + if err != nil { + return nil, err + } + quotes := make([]byte, len(pickled)+2) + quotes[0] = '"' + quotes[len(quotes)-1] = '"' + copy(quotes[1:len(quotes)-1], pickled) + return quotes, nil +} + +// Deprecated +func (s *Session) UnmarshalJSON(data []byte) error { + if len(data) == 0 || data[0] != '"' || data[len(data)-1] != '"' { + return olm.ErrInputNotJSONString + } + if s == nil || s.int == nil { + *s = *NewBlankSession() + } + return s.Unpickle(data[1:len(data)-1], pickleKey) +} + +// Id returns an identifier for this Session. Will be the same for both ends +// of the conversation. +func (s *Session) ID() id.SessionID { + sessionID := make([]byte, s.idLen()) + r := C.olm_session_id( + (*C.OlmSession)(s.int), + unsafe.Pointer(unsafe.SliceData(sessionID)), + C.size_t(len(sessionID)), + ) + if r == errorVal() { + panic(s.lastError()) + } + return id.SessionID(sessionID) +} + +// HasReceivedMessage returns true if this session has received any message. +func (s *Session) HasReceivedMessage() bool { + switch C.olm_session_has_received_message((*C.OlmSession)(s.int)) { + case 0: + return false + default: + return true + } +} + +// MatchesInboundSession checks if the PRE_KEY message is for this in-bound +// Session. This can happen if multiple messages are sent to this Account +// before this Account sends a message in reply. Returns true if the session +// matches. Returns false if the session does not match. Returns error on +// failure. If the base64 couldn't be decoded then the error will be +// "INVALID_BASE64". If the message was for an unsupported protocol version +// then the error will be "BAD_MESSAGE_VERSION". If the message couldn't be +// decoded then then the error will be "BAD_MESSAGE_FORMAT". +func (s *Session) MatchesInboundSession(oneTimeKeyMsg string) (bool, error) { + if len(oneTimeKeyMsg) == 0 { + return false, olm.ErrEmptyInput + } + oneTimeKeyMsgCopy := []byte(oneTimeKeyMsg) + r := C.olm_matches_inbound_session( + (*C.OlmSession)(s.int), + unsafe.Pointer(unsafe.SliceData(oneTimeKeyMsgCopy)), + C.size_t(len(oneTimeKeyMsgCopy)), + ) + runtime.KeepAlive(oneTimeKeyMsgCopy) + if r == 1 { + return true, nil + } else if r == 0 { + return false, nil + } else { // if r == errorVal() + return false, s.lastError() + } +} + +// MatchesInboundSessionFrom checks if the PRE_KEY message is for this in-bound +// Session. This can happen if multiple messages are sent to this Account +// before this Account sends a message in reply. Returns true if the session +// matches. Returns false if the session does not match. Returns error on +// failure. If the base64 couldn't be decoded then the error will be +// "INVALID_BASE64". If the message was for an unsupported protocol version +// then the error will be "BAD_MESSAGE_VERSION". If the message couldn't be +// decoded then then the error will be "BAD_MESSAGE_FORMAT". +func (s *Session) MatchesInboundSessionFrom(theirIdentityKey, oneTimeKeyMsg string) (bool, error) { + if len(theirIdentityKey) == 0 || len(oneTimeKeyMsg) == 0 { + return false, olm.ErrEmptyInput + } + theirIdentityKeyCopy := []byte(theirIdentityKey) + oneTimeKeyMsgCopy := []byte(oneTimeKeyMsg) + r := C.olm_matches_inbound_session_from( + (*C.OlmSession)(s.int), + unsafe.Pointer(unsafe.SliceData(theirIdentityKeyCopy)), + C.size_t(len(theirIdentityKeyCopy)), + unsafe.Pointer(unsafe.SliceData(oneTimeKeyMsgCopy)), + C.size_t(len(oneTimeKeyMsgCopy)), + ) + runtime.KeepAlive(theirIdentityKeyCopy) + runtime.KeepAlive(oneTimeKeyMsgCopy) + if r == 1 { + return true, nil + } else if r == 0 { + return false, nil + } else { // if r == errorVal() + return false, s.lastError() + } +} + +// EncryptMsgType returns the type of the next message that Encrypt will +// return. Returns MsgTypePreKey if the message will be a PRE_KEY message. +// Returns MsgTypeMsg if the message will be a normal message. Returns error +// on failure. +func (s *Session) EncryptMsgType() id.OlmMsgType { + switch C.olm_encrypt_message_type((*C.OlmSession)(s.int)) { + case C.size_t(id.OlmMsgTypePreKey): + return id.OlmMsgTypePreKey + case C.size_t(id.OlmMsgTypeMsg): + return id.OlmMsgTypeMsg + default: + panic("olm_encrypt_message_type returned invalid result") + } +} + +// Encrypt encrypts a message using the Session. Returns the encrypted message +// as base64. +func (s *Session) Encrypt(plaintext []byte) (id.OlmMsgType, []byte, error) { + if len(plaintext) == 0 { + return 0, nil, olm.ErrEmptyInput + } + // Make the slice be at least length 1 + random := make([]byte, s.encryptRandomLen()+1) + _, err := rand.Read(random) + if err != nil { + // TODO can we just return err here? + return 0, nil, olm.ErrNotEnoughGoRandom + } + messageType := s.EncryptMsgType() + message := make([]byte, s.encryptMsgLen(len(plaintext))) + r := C.olm_encrypt( + (*C.OlmSession)(s.int), + unsafe.Pointer(unsafe.SliceData(plaintext)), + C.size_t(len(plaintext)), + unsafe.Pointer(unsafe.SliceData(random)), + C.size_t(len(random)), + unsafe.Pointer(unsafe.SliceData(message)), + C.size_t(len(message)), + ) + runtime.KeepAlive(plaintext) + runtime.KeepAlive(random) + if r == errorVal() { + return 0, nil, s.lastError() + } + return messageType, message[:r], nil +} + +// Decrypt decrypts a message using the Session. Returns the the plain-text on +// success. Returns error on failure. If the base64 couldn't be decoded then +// the error will be "INVALID_BASE64". If the message is for an unsupported +// version of the protocol then the error will be "BAD_MESSAGE_VERSION". If +// the message couldn't be decoded then the error will be BAD_MESSAGE_FORMAT". +// If the MAC on the message was invalid then the error will be +// "BAD_MESSAGE_MAC". +func (s *Session) Decrypt(message string, msgType id.OlmMsgType) ([]byte, error) { + if len(message) == 0 { + return nil, olm.ErrEmptyInput + } + decryptMaxPlaintextLen, err := s.decryptMaxPlaintextLen(message, msgType) + if err != nil { + return nil, err + } + messageCopy := []byte(message) + plaintext := make([]byte, decryptMaxPlaintextLen) + r := C.olm_decrypt( + (*C.OlmSession)(s.int), + C.size_t(msgType), + unsafe.Pointer(unsafe.SliceData(messageCopy)), + C.size_t(len(messageCopy)), + unsafe.Pointer(unsafe.SliceData(plaintext)), + C.size_t(len(plaintext)), + ) + runtime.KeepAlive(messageCopy) + if r == errorVal() { + return nil, s.lastError() + } + return plaintext[:r], nil +} + +// https://gitlab.matrix.org/matrix-org/olm/-/blob/3.2.8/include/olm/olm.h#L392-393 +const maxDescribeSize = 600 + +// Describe generates a string describing the internal state of an olm session for debugging and logging purposes. +func (s *Session) Describe() string { + desc := (*C.char)(C.malloc(C.size_t(maxDescribeSize))) + defer C.free(unsafe.Pointer(desc)) + C.meowlm_session_describe( + (*C.OlmSession)(s.int), + desc, + C.size_t(maxDescribeSize), + ) + return C.GoString(desc) +} diff --git a/crypto/machine.go b/crypto/machine.go index 4417faf3..fa051f94 100644 --- a/crypto/machine.go +++ b/crypto/machine.go @@ -11,15 +11,19 @@ import ( "errors" "fmt" "sync" + "sync/atomic" "time" "github.com/rs/zerolog" + "go.mau.fi/util/ptr" - "maunium.net/go/mautrix/crypto/ssss" - "maunium.net/go/mautrix/id" + "go.mau.fi/util/exzerolog" "maunium.net/go/mautrix" + "maunium.net/go/mautrix/crypto/olm" + "maunium.net/go/mautrix/crypto/ssss" "maunium.net/go/mautrix/event" + "maunium.net/go/mautrix/id" ) // OlmMachine is the main struct for handling Matrix end-to-end encryption. @@ -31,7 +35,12 @@ type OlmMachine struct { CryptoStore Store StateStore StateStore - PlaintextMentions bool + backgroundCtx context.Context + cancelBackgroundCtx context.CancelFunc + + PlaintextMentions bool + MSC4392Relations bool + AllowEncryptedState bool // Never ask the server for keys automatically as a side effect during Megolm decryption. DisableDecryptKeyFetching bool @@ -39,6 +48,8 @@ type OlmMachine struct { // Don't mark outbound Olm sessions as shared for devices they were initially sent to. DisableSharedGroupSessionTracking bool + IgnorePostDecryptionParseErrors bool + SendKeysMinTrust id.TrustState ShareKeysMinTrust id.TrustState @@ -53,19 +64,23 @@ type OlmMachine struct { keyWaitersLock sync.Mutex // Optional callback which is called when we save a session to store - SessionReceived func(context.Context, id.SessionID) + SessionReceived func(context.Context, id.RoomID, id.SessionID, uint32) devicesToUnwedge map[id.IdentityKey]bool devicesToUnwedgeLock sync.Mutex recentlyUnwedged map[id.IdentityKey]time.Time recentlyUnwedgedLock sync.Mutex + olmHashSavePoints []time.Time + lastHashDelete time.Time + olmHashSavePointLock sync.Mutex olmLock sync.Mutex megolmEncryptLock sync.Mutex megolmDecryptLock sync.Mutex - otkUploadLock sync.Mutex - lastOTKUpload time.Time + otkUploadLock sync.Mutex + lastOTKUpload time.Time + receivedOTKsForSelf atomic.Bool CrossSigningKeys *CrossSigningKeysCache crossSigningPubkeys *CrossSigningPublicKeysCache @@ -78,6 +93,7 @@ type OlmMachine struct { RatchetKeysOnDecrypt bool DeleteFullyUsedKeysOnDecrypt bool DeleteKeysOnDeviceDelete bool + DisableRatchetTracking bool DisableDeviceChangeKeyRotation bool @@ -120,6 +136,7 @@ func NewOlmMachine(client *mautrix.Client, log *zerolog.Logger, cryptoStore Stor recentlyUnwedged: make(map[id.IdentityKey]time.Time), secretListeners: make(map[string]chan<- string), } + mach.backgroundCtx, mach.cancelBackgroundCtx = context.WithCancel(context.Background()) mach.AllowKeyShare = mach.defaultAllowKeyShare return mach } @@ -132,6 +149,11 @@ func (mach *OlmMachine) machOrContextLog(ctx context.Context) *zerolog.Logger { return log } +func (mach *OlmMachine) SetBackgroundCtx(ctx context.Context) { + mach.cancelBackgroundCtx() + mach.backgroundCtx, mach.cancelBackgroundCtx = context.WithCancel(ctx) +} + // Load loads the Olm account information from the crypto store. If there's no olm account, a new one is created. // This must be called before using the machine. func (mach *OlmMachine) Load(ctx context.Context) (err error) { @@ -142,9 +164,23 @@ func (mach *OlmMachine) Load(ctx context.Context) (err error) { if mach.account == nil { mach.account = NewOlmAccount() } + zerolog.Ctx(ctx).Debug(). + Str("machine_ptr", fmt.Sprintf("%p", mach)). + Str("account_ptr", fmt.Sprintf("%p", mach.account.Internal)). + Str("olm_driver", olm.Driver). + Msg("Loaded olm account") return nil } +func (mach *OlmMachine) Destroy() { + mach.Log.Debug(). + Str("machine_ptr", fmt.Sprintf("%p", mach)). + Str("account_ptr", fmt.Sprintf("%p", ptr.Val(mach.account).Internal)). + Msg("Destroying olm machine") + mach.cancelBackgroundCtx() + // TODO actually destroy something? +} + func (mach *OlmMachine) saveAccount(ctx context.Context) error { err := mach.CryptoStore.PutAccount(ctx, mach.account) if err != nil { @@ -170,7 +206,7 @@ func (mach *OlmMachine) FlushStore(ctx context.Context) error { func (mach *OlmMachine) timeTrace(ctx context.Context, thing string, expectedDuration time.Duration) func() { start := time.Now() return func() { - duration := time.Now().Sub(start) + duration := time.Since(start) if duration > expectedDuration { zerolog.Ctx(ctx).Warn(). Str("action", thing). @@ -206,13 +242,14 @@ func (mach *OlmMachine) OwnIdentity() *id.Device { } } -type asEventProcessor interface { +type ASEventProcessor interface { On(evtType event.Type, handler func(ctx context.Context, evt *event.Event)) OnOTK(func(ctx context.Context, otk *mautrix.OTKCount)) OnDeviceList(func(ctx context.Context, lists *mautrix.DeviceLists, since string)) + Dispatch(ctx context.Context, evt *event.Event) } -func (mach *OlmMachine) AddAppserviceListener(ep asEventProcessor) { +func (mach *OlmMachine) AddAppserviceListener(ep ASEventProcessor) { // ToDeviceForwardedRoomKey and ToDeviceRoomKey should only be present inside encrypted to-device events ep.On(event.ToDeviceEncrypted, mach.HandleToDeviceEvent) ep.On(event.ToDeviceRoomKeyRequest, mach.HandleToDeviceEvent) @@ -242,14 +279,29 @@ func (mach *OlmMachine) HandleDeviceLists(ctx context.Context, dl *mautrix.Devic } } +func (mach *OlmMachine) otkCountIsForCrossSigningKey(otkCount *mautrix.OTKCount) bool { + if mach.crossSigningPubkeys == nil || otkCount.UserID != mach.Client.UserID { + return false + } + switch id.Ed25519(otkCount.DeviceID) { + case mach.crossSigningPubkeys.MasterKey, mach.crossSigningPubkeys.UserSigningKey, mach.crossSigningPubkeys.SelfSigningKey: + return true + } + return false +} + func (mach *OlmMachine) HandleOTKCounts(ctx context.Context, otkCount *mautrix.OTKCount) { + receivedOTKsForSelf := mach.receivedOTKsForSelf.Load() if (len(otkCount.UserID) > 0 && otkCount.UserID != mach.Client.UserID) || (len(otkCount.DeviceID) > 0 && otkCount.DeviceID != mach.Client.DeviceID) { - // TODO This log probably needs to be silence-able if someone wants to use encrypted appservices with multiple e2ee sessions - mach.Log.Warn(). - Str("target_user_id", otkCount.UserID.String()). - Str("target_device_id", otkCount.DeviceID.String()). - Msg("Dropping OTK counts targeted to someone else") + if otkCount.UserID != mach.Client.UserID || (!receivedOTKsForSelf && !mach.otkCountIsForCrossSigningKey(otkCount)) { + mach.Log.Warn(). + Str("target_user_id", otkCount.UserID.String()). + Str("target_device_id", otkCount.DeviceID.String()). + Msg("Dropping OTK counts targeted to someone else") + } return + } else if !receivedOTKsForSelf { + mach.receivedOTKsForSelf.Store(true) } minCount := mach.account.Internal.MaxNumberOfOneTimeKeys() / 2 @@ -258,7 +310,7 @@ func (mach *OlmMachine) HandleOTKCounts(ctx context.Context, otkCount *mautrix.O log := mach.Log.With().Str("trace_id", traceID).Logger() ctx = log.WithContext(ctx) log.Debug(). - Int("keys_left", otkCount.Curve25519). + Int("keys_left", otkCount.SignedCurve25519). Msg("Sync response said we have less than 50 signed curve25519 keys left, sharing new ones...") err := mach.ShareKeys(ctx, otkCount.SignedCurve25519) if err != nil { @@ -288,6 +340,7 @@ func (mach *OlmMachine) ProcessSyncResponse(ctx context.Context, resp *mautrix.R } mach.HandleOTKCounts(ctx, &resp.DeviceOTKCount) + mach.MarkOlmHashSavePoint(ctx) return true } @@ -330,20 +383,20 @@ func (mach *OlmMachine) HandleMemberEvent(ctx context.Context, evt *event.Event) Msg("Got membership state change, invalidating group session in room") err := mach.CryptoStore.RemoveOutboundGroupSession(ctx, evt.RoomID) if err != nil { - mach.Log.Warn().Str("room_id", evt.RoomID.String()).Msg("Failed to invalidate outbound group session") + mach.Log.Warn().Stringer("room_id", evt.RoomID).Msg("Failed to invalidate outbound group session") } } -func (mach *OlmMachine) HandleEncryptedEvent(ctx context.Context, evt *event.Event) { +func (mach *OlmMachine) HandleEncryptedEvent(ctx context.Context, evt *event.Event) *DecryptedOlmEvent { if _, ok := evt.Content.Parsed.(*event.EncryptedEventContent); !ok { mach.machOrContextLog(ctx).Warn().Msg("Passed invalid event to encrypted handler") - return + return nil } decryptedEvt, err := mach.decryptOlmEvent(ctx, evt) if err != nil { mach.machOrContextLog(ctx).Error().Err(err).Msg("Failed to decrypt to-device event") - return + return nil } log := mach.machOrContextLog(ctx).With(). @@ -372,6 +425,37 @@ func (mach *OlmMachine) HandleEncryptedEvent(ctx context.Context, evt *event.Eve log.Trace().Msg("Handled secret send event") default: log.Debug().Msg("Unhandled encrypted to-device event") + return decryptedEvt + } + return nil +} + +const olmHashSavePointCount = 5 +const olmHashDeleteMinInterval = 10 * time.Minute +const minSavePointInterval = 1 * time.Minute + +// MarkOlmHashSavePoint marks the current time as a save point for olm hashes and deletes old hashes if needed. +// +// This should be called after all to-device events in a sync have been processed. +// The function will then delete old olm hashes after enough syncs have happened +// (such that it's unlikely for the olm messages to repeat). +func (mach *OlmMachine) MarkOlmHashSavePoint(ctx context.Context) { + mach.olmHashSavePointLock.Lock() + defer mach.olmHashSavePointLock.Unlock() + if len(mach.olmHashSavePoints) > 0 && time.Since(mach.olmHashSavePoints[len(mach.olmHashSavePoints)-1]) < minSavePointInterval { + return + } + mach.olmHashSavePoints = append(mach.olmHashSavePoints, time.Now()) + if len(mach.olmHashSavePoints) > olmHashSavePointCount { + sp := mach.olmHashSavePoints[0] + mach.olmHashSavePoints = mach.olmHashSavePoints[1:] + if time.Since(mach.lastHashDelete) > olmHashDeleteMinInterval { + err := mach.CryptoStore.DeleteOldOlmHashes(ctx, sp) + mach.lastHashDelete = time.Now() + if err != nil { + zerolog.Ctx(ctx).Err(err).Msg("Failed to delete old olm hashes") + } + } } } @@ -505,25 +589,24 @@ func (mach *OlmMachine) SendEncryptedToDevice(ctx context.Context, device *id.De return err } -func (mach *OlmMachine) createGroupSession(ctx context.Context, senderKey id.SenderKey, signingKey id.Ed25519, roomID id.RoomID, sessionID id.SessionID, sessionKey string, maxAge time.Duration, maxMessages int, isScheduled bool) { +func (mach *OlmMachine) createGroupSession(ctx context.Context, senderKey id.SenderKey, signingKey id.Ed25519, roomID id.RoomID, sessionID id.SessionID, sessionKey string, maxAge time.Duration, maxMessages int, isScheduled bool) error { log := zerolog.Ctx(ctx) igs, err := NewInboundGroupSession(senderKey, signingKey, roomID, sessionKey, maxAge, maxMessages, isScheduled) if err != nil { - log.Error().Err(err).Msg("Failed to create inbound group session") - return + return fmt.Errorf("failed to create inbound group session: %w", err) } else if igs.ID() != sessionID { log.Warn(). Str("expected_session_id", sessionID.String()). Str("actual_session_id", igs.ID().String()). Msg("Mismatched session ID while creating inbound group session") - return + return fmt.Errorf("mismatched session ID while creating inbound group session") } - err = mach.CryptoStore.PutGroupSession(ctx, roomID, senderKey, sessionID, igs) + err = mach.CryptoStore.PutGroupSession(ctx, igs) if err != nil { - log.Error().Err(err).Str("session_id", sessionID.String()).Msg("Failed to store new inbound group session") - return + log.Err(err).Stringer("session_id", sessionID).Msg("Failed to store new inbound group session") + return fmt.Errorf("failed to store new inbound group session: %w", err) } - mach.markSessionReceived(ctx, sessionID) + mach.MarkSessionReceived(ctx, roomID, sessionID, igs.Internal.FirstKnownIndex()) log.Debug(). Str("session_id", sessionID.String()). Str("sender_key", senderKey.String()). @@ -531,11 +614,12 @@ func (mach *OlmMachine) createGroupSession(ctx context.Context, senderKey id.Sen Int("max_messages", maxMessages). Bool("is_scheduled", isScheduled). Msg("Received inbound group session") + return nil } -func (mach *OlmMachine) markSessionReceived(ctx context.Context, id id.SessionID) { +func (mach *OlmMachine) MarkSessionReceived(ctx context.Context, roomID id.RoomID, id id.SessionID, firstKnownIndex uint32) { if mach.SessionReceived != nil { - mach.SessionReceived(ctx, id) + mach.SessionReceived(ctx, roomID, id, firstKnownIndex) } mach.keyWaitersLock.Lock() @@ -557,7 +641,7 @@ func (mach *OlmMachine) WaitForSession(ctx context.Context, roomID id.RoomID, se } mach.keyWaitersLock.Unlock() // Handle race conditions where a session appears between the failed decryption and WaitForSession call. - sess, err := mach.CryptoStore.GetGroupSession(ctx, roomID, senderKey, sessionID) + sess, err := mach.CryptoStore.GetGroupSession(ctx, roomID, sessionID) if sess != nil || errors.Is(err, ErrGroupSessionWithheld) { return true } @@ -565,7 +649,7 @@ func (mach *OlmMachine) WaitForSession(ctx context.Context, roomID id.RoomID, se case <-ch: return true case <-time.After(timeout): - sess, err = mach.CryptoStore.GetGroupSession(ctx, roomID, senderKey, sessionID) + sess, err = mach.CryptoStore.GetGroupSession(ctx, roomID, sessionID) // Check if the session somehow appeared in the store without telling us // We accept withheld sessions as received, as then the decryption attempt will show the error. return sess != nil || errors.Is(err, ErrGroupSessionWithheld) @@ -574,14 +658,6 @@ func (mach *OlmMachine) WaitForSession(ctx context.Context, roomID id.RoomID, se } } -func stringifyArray[T ~string](arr []T) []string { - strs := make([]string, len(arr)) - for i, v := range arr { - strs[i] = string(v) - } - return strs -} - func (mach *OlmMachine) receiveRoomKey(ctx context.Context, evt *DecryptedOlmEvent, content *event.RoomKeyEventContent) { log := zerolog.Ctx(ctx).With(). Str("algorithm", string(content.Algorithm)). @@ -622,11 +698,14 @@ func (mach *OlmMachine) receiveRoomKey(ctx context.Context, evt *DecryptedOlmEve log.Err(err).Msg("Failed to redact previous megolm sessions") } else { log.Info(). - Strs("session_ids", stringifyArray(sessionIDs)). + Array("session_ids", exzerolog.ArrayOfStrs(sessionIDs)). Msg("Redacted previous megolm sessions") } } - mach.createGroupSession(ctx, evt.SenderKey, evt.Keys.Ed25519, content.RoomID, content.SessionID, content.SessionKey, maxAge, maxMessages, content.IsScheduled) + err = mach.createGroupSession(ctx, evt.SenderKey, evt.Keys.Ed25519, content.RoomID, content.SessionID, content.SessionKey, maxAge, maxMessages, content.IsScheduled) + if err != nil { + log.Err(err).Msg("Failed to create inbound group session") + } } func (mach *OlmMachine) HandleRoomKeyWithheld(ctx context.Context, content *event.RoomKeyWithheldEventContent) { @@ -634,6 +713,7 @@ func (mach *OlmMachine) HandleRoomKeyWithheld(ctx context.Context, content *even zerolog.Ctx(ctx).Debug().Interface("content", content).Msg("Non-megolm room key withheld event") return } + // TODO log if there's a conflict? (currently ignored) err := mach.CryptoStore.PutWithheldGroupSession(ctx, *content) if err != nil { zerolog.Ctx(ctx).Error().Err(err).Msg("Failed to save room key withheld event") @@ -650,7 +730,7 @@ func (mach *OlmMachine) ShareKeys(ctx context.Context, currentOTKCount int) erro start := time.Now() mach.otkUploadLock.Lock() defer mach.otkUploadLock.Unlock() - if mach.lastOTKUpload.Add(1*time.Minute).After(start) || currentOTKCount < 0 { + if mach.lastOTKUpload.Add(1*time.Minute).After(start) || (currentOTKCount < 0 && mach.account.Shared) { log.Debug().Msg("Checking OTK count from server due to suspiciously close share keys requests or negative OTK count") resp, err := mach.Client.UploadKeys(ctx, &mautrix.ReqUploadKeys{}) if err != nil { @@ -681,6 +761,11 @@ func (mach *OlmMachine) ShareKeys(ctx context.Context, currentOTKCount int) erro log.Debug().Msg("No one-time keys nor device keys got when trying to share keys") return nil } + // Save the keys before sending the upload request in case there is a + // network failure. + if err := mach.saveAccount(ctx); err != nil { + return err + } req := &mautrix.ReqUploadKeys{ DeviceKeys: deviceKeys, OneTimeKeys: oneTimeKeys, @@ -691,6 +776,7 @@ func (mach *OlmMachine) ShareKeys(ctx context.Context, currentOTKCount int) erro return err } mach.lastOTKUpload = time.Now() + mach.account.Internal.MarkKeysAsPublished() mach.account.Shared = true return mach.saveAccount(ctx) } @@ -702,7 +788,7 @@ func (mach *OlmMachine) ExpiredKeyDeleteLoop(ctx context.Context) { if err != nil { log.Err(err).Msg("Failed to redact expired megolm sessions") } else if len(sessionIDs) > 0 { - log.Info().Strs("session_ids", stringifyArray(sessionIDs)).Msg("Redacted expired megolm sessions") + log.Info().Array("session_ids", exzerolog.ArrayOfStrs(sessionIDs)).Msg("Redacted expired megolm sessions") } else { log.Debug().Msg("Didn't find any expired megolm sessions") } diff --git a/crypto/machine_bench_test.go b/crypto/machine_bench_test.go new file mode 100644 index 00000000..fd40d795 --- /dev/null +++ b/crypto/machine_bench_test.go @@ -0,0 +1,67 @@ +// Copyright (c) 2025 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package crypto_test + +import ( + "context" + "fmt" + "math/rand/v2" + "testing" + + "github.com/rs/zerolog" + globallog "github.com/rs/zerolog/log" // zerolog-allow-global-log + "github.com/stretchr/testify/require" + + "maunium.net/go/mautrix/crypto/cryptohelper" + "maunium.net/go/mautrix/id" + "maunium.net/go/mautrix/mockserver" +) + +func randomDeviceCount(r *rand.Rand) int { + k := 1 + for k < 10 && r.IntN(3) > 0 { + k++ + } + return k +} + +func BenchmarkOlmMachine_ShareGroupSession(b *testing.B) { + globallog.Logger = zerolog.Nop() + server := mockserver.Create(b) + server.PopOTKs = false + server.MemoryStore = false + var i int + var shareTargets []id.UserID + r := rand.New(rand.NewPCG(293, 0)) + var totalDeviceCount int + for i = 1; i < 1000; i++ { + userID := id.UserID(fmt.Sprintf("@user%d:localhost", i)) + deviceCount := randomDeviceCount(r) + for j := 0; j < deviceCount; j++ { + client, _ := server.Login(b, nil, userID, id.DeviceID(fmt.Sprintf("u%d_d%d", i, j))) + mach := client.Crypto.(*cryptohelper.CryptoHelper).Machine() + keysCache, err := mach.GenerateCrossSigningKeys() + require.NoError(b, err) + err = mach.PublishCrossSigningKeys(context.TODO(), keysCache, nil) + require.NoError(b, err) + } + totalDeviceCount += deviceCount + shareTargets = append(shareTargets, userID) + } + for b.Loop() { + client, _ := server.Login(b, nil, id.UserID(fmt.Sprintf("@benchuser%d:localhost", i)), id.DeviceID(fmt.Sprintf("u%d_d1", i))) + mach := client.Crypto.(*cryptohelper.CryptoHelper).Machine() + keysCache, err := mach.GenerateCrossSigningKeys() + require.NoError(b, err) + err = mach.PublishCrossSigningKeys(context.TODO(), keysCache, nil) + require.NoError(b, err) + err = mach.ShareGroupSession(context.TODO(), "!room:localhost", shareTargets) + require.NoError(b, err) + i++ + } + fmt.Println(totalDeviceCount, "devices total") +} diff --git a/crypto/machine_test.go b/crypto/machine_test.go index d3750d34..872c3ac4 100644 --- a/crypto/machine_test.go +++ b/crypto/machine_test.go @@ -36,28 +36,24 @@ func (mockStateStore) FindSharedRooms(context.Context, id.UserID) ([]id.RoomID, func newMachine(t *testing.T, userID id.UserID) *OlmMachine { client, err := mautrix.NewClient("http://localhost", userID, "token") - if err != nil { - t.Fatalf("Error creating client: %v", err) - } + require.NoError(t, err, "Error creating client") client.DeviceID = "device1" gobStore := NewMemoryStore(nil) - if err != nil { - t.Fatalf("Error creating Gob store: %v", err) - } + require.NoError(t, err, "Error creating Gob store") machine := NewOlmMachine(client, nil, gobStore, mockStateStore{}) - if err := machine.Load(context.TODO()); err != nil { - t.Fatalf("Error creating account: %v", err) - } + err = machine.Load(context.TODO()) + require.NoError(t, err, "Error creating account") return machine } func TestRatchetMegolmSession(t *testing.T) { mach := newMachine(t, "user1") - outSess := mach.newOutboundGroupSession(context.TODO(), "meow") - inSess, err := mach.CryptoStore.GetGroupSession(context.TODO(), "meow", mach.OwnIdentity().IdentityKey, outSess.ID()) + outSess, err := mach.newOutboundGroupSession(context.TODO(), "meow") + assert.NoError(t, err) + inSess, err := mach.CryptoStore.GetGroupSession(context.TODO(), "meow", outSess.ID()) require.NoError(t, err) assert.Equal(t, uint32(0), inSess.Internal.FirstKnownIndex()) err = inSess.RatchetTo(10) @@ -77,12 +73,11 @@ func TestOlmMachineOlmMegolmSessions(t *testing.T) { otk = otkTmp break } + machineIn.account.Internal.MarkKeysAsPublished() // create outbound olm session for sending machine using OTK olmSession, err := machineOut.account.Internal.NewOutboundSession(machineIn.account.IdentityKey(), otk.Key) - if err != nil { - t.Errorf("Failed to create outbound olm session: %v", err) - } + require.NoError(t, err, "Error creating outbound olm session") // store sender device identity in receiving machine store machineIn.CryptoStore.PutDevices(context.TODO(), "user1", map[id.DeviceID]*id.Device{ @@ -95,7 +90,8 @@ func TestOlmMachineOlmMegolmSessions(t *testing.T) { }) // create & store outbound megolm session for sending the event later - megolmOutSession := machineOut.newOutboundGroupSession(context.TODO(), "room1") + megolmOutSession, err := machineOut.newOutboundGroupSession(context.TODO(), "room1") + assert.NoError(t, err) megolmOutSession.Shared = true machineOut.CryptoStore.AddOutboundGroupSession(context.TODO(), megolmOutSession) @@ -118,29 +114,21 @@ func TestOlmMachineOlmMegolmSessions(t *testing.T) { Type: event.ToDeviceEncrypted, Sender: "user1", }, senderKey, content.Type, content.Body) - if err != nil { - t.Errorf("Error decrypting olm content: %v", err) - } + require.NoError(t, err, "Error decrypting olm ciphertext") + // store room key in new inbound group session roomKeyEvt := decrypted.Content.AsRoomKey() igs, err := NewInboundGroupSession(senderKey, signingKey, "room1", roomKeyEvt.SessionKey, 0, 0, false) - if err != nil { - t.Errorf("Error creating inbound megolm session: %v", err) - } - if err = machineIn.CryptoStore.PutGroupSession(context.TODO(), "room1", senderKey, igs.ID(), igs); err != nil { - t.Errorf("Error storing inbound megolm session: %v", err) - } + require.NoError(t, err, "Error creating inbound group session") + err = machineIn.CryptoStore.PutGroupSession(context.TODO(), igs) + require.NoError(t, err, "Error storing inbound group session") } // encrypt event with megolm session in sending machine eventContent := map[string]string{"hello": "world"} encryptedEvtContent, err := machineOut.EncryptMegolmEvent(context.TODO(), "room1", event.EventMessage, eventContent) - if err != nil { - t.Errorf("Error encrypting megolm event: %v", err) - } - if megolmOutSession.MessageCount != 1 { - t.Errorf("Megolm outbound session message count is not 1 but %d", megolmOutSession.MessageCount) - } + require.NoError(t, err, "Error encrypting megolm event") + assert.Equal(t, 1, megolmOutSession.MessageCount) encryptedEvt := &event.Event{ Content: event.Content{Parsed: encryptedEvtContent}, @@ -152,22 +140,12 @@ func TestOlmMachineOlmMegolmSessions(t *testing.T) { // decrypt event on receiving machine and confirm decryptedEvt, err := machineIn.DecryptMegolmEvent(context.TODO(), encryptedEvt) - if err != nil { - t.Errorf("Error decrypting megolm event: %v", err) - } - if decryptedEvt.Type != event.EventMessage { - t.Errorf("Expected event type %v, got %v", event.EventMessage, decryptedEvt.Type) - } - if decryptedEvt.Content.Raw["hello"] != "world" { - t.Errorf("Expected event content %v, got %v", eventContent, decryptedEvt.Content.Raw) - } + require.NoError(t, err, "Error decrypting megolm event") + assert.Equal(t, event.EventMessage, decryptedEvt.Type) + assert.Equal(t, "world", decryptedEvt.Content.Raw["hello"]) machineOut.EncryptMegolmEvent(context.TODO(), "room1", event.EventMessage, eventContent) - if megolmOutSession.Expired() { - t.Error("Megolm outbound session expired before 3rd message") - } + assert.False(t, megolmOutSession.Expired(), "Megolm outbound session expired before 3rd message") machineOut.EncryptMegolmEvent(context.TODO(), "room1", event.EventMessage, eventContent) - if !megolmOutSession.Expired() { - t.Error("Megolm outbound session not expired after 3rd message") - } + assert.True(t, megolmOutSession.Expired(), "Megolm outbound session not expired after 3rd message") } diff --git a/crypto/olm/account.go b/crypto/olm/account.go index 37458d1b..2ec5dd70 100644 --- a/crypto/olm/account.go +++ b/crypto/olm/account.go @@ -1,28 +1,105 @@ -//go:build !goolm +// Copyright (c) 2024 Sumner Evans +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. package olm -// #cgo LDFLAGS: -lolm -lstdc++ -// #include -import "C" - import ( - "crypto/rand" - "encoding/base64" - "encoding/json" - "unsafe" - - "github.com/tidwall/gjson" - "github.com/tidwall/sjson" - - "maunium.net/go/mautrix/crypto/canonicaljson" "maunium.net/go/mautrix/id" ) -// Account stores a device account for end to end encrypted messaging. -type Account struct { - int *C.OlmAccount - mem []byte +type Account interface { + // Pickle returns an Account as a base64 string. Encrypts the Account using the + // supplied key. + Pickle(key []byte) ([]byte, error) + + // Unpickle loads an Account from a pickled base64 string. Decrypts the + // Account using the supplied key. Returns error on failure. + Unpickle(pickled, key []byte) error + + // IdentityKeysJSON returns the public parts of the identity keys for the Account. + IdentityKeysJSON() ([]byte, error) + + // IdentityKeys returns the public parts of the Ed25519 and Curve25519 identity + // keys for the Account. + IdentityKeys() (id.Ed25519, id.Curve25519, error) + + // Sign returns the signature of a message using the ed25519 key for this + // Account. + Sign(message []byte) ([]byte, error) + + // OneTimeKeys returns the public parts of the unpublished one time keys for + // the Account. + // + // The returned data is a struct with the single value "Curve25519", which is + // itself an object mapping key id to base64-encoded Curve25519 key. For + // example: + // + // { + // Curve25519: { + // "AAAAAA": "wo76WcYtb0Vk/pBOdmduiGJ0wIEjW4IBMbbQn7aSnTo", + // "AAAAAB": "LRvjo46L1X2vx69sS9QNFD29HWulxrmW11Up5AfAjgU" + // } + // } + OneTimeKeys() (map[string]id.Curve25519, error) + + // MarkKeysAsPublished marks the current set of one time keys as being + // published. + MarkKeysAsPublished() + + // MaxNumberOfOneTimeKeys returns the largest number of one time keys this + // Account can store. + MaxNumberOfOneTimeKeys() uint + + // GenOneTimeKeys generates a number of new one time keys. If the total + // number of keys stored by this Account exceeds MaxNumberOfOneTimeKeys + // then the old keys are discarded. + GenOneTimeKeys(num uint) error + + // NewOutboundSession creates a new out-bound session for sending messages to a + // given curve25519 identityKey and oneTimeKey. Returns error on failure. If the + // keys couldn't be decoded as base64 then the error will be "INVALID_BASE64" + NewOutboundSession(theirIdentityKey, theirOneTimeKey id.Curve25519) (Session, error) + + // NewInboundSession creates a new in-bound session for sending/receiving + // messages from an incoming PRE_KEY message. Returns error on failure. If + // the base64 couldn't be decoded then the error will be "INVALID_BASE64". If + // the message was for an unsupported protocol version then the error will be + // "BAD_MESSAGE_VERSION". If the message couldn't be decoded then then the + // error will be "BAD_MESSAGE_FORMAT". If the message refers to an unknown one + // time key then the error will be "BAD_MESSAGE_KEY_ID". + NewInboundSession(oneTimeKeyMsg string) (Session, error) + + // NewInboundSessionFrom creates a new in-bound session for sending/receiving + // messages from an incoming PRE_KEY message. Returns error on failure. If + // the base64 couldn't be decoded then the error will be "INVALID_BASE64". If + // the message was for an unsupported protocol version then the error will be + // "BAD_MESSAGE_VERSION". If the message couldn't be decoded then then the + // error will be "BAD_MESSAGE_FORMAT". If the message refers to an unknown one + // time key then the error will be "BAD_MESSAGE_KEY_ID". + NewInboundSessionFrom(theirIdentityKey *id.Curve25519, oneTimeKeyMsg string) (Session, error) + + // RemoveOneTimeKeys removes the one time keys that the session used from the + // Account. Returns error on failure. If the Account doesn't have any + // matching one time keys then the error will be "BAD_MESSAGE_KEY_ID". + RemoveOneTimeKeys(s Session) error +} + +var Driver = "none" + +var InitBlankAccount func() Account +var InitNewAccount func() (Account, error) +var InitNewAccountFromPickled func(pickled, key []byte) (Account, error) + +// NewAccount creates a new Account. +func NewAccount() (Account, error) { + return InitNewAccount() +} + +func NewBlankAccount() Account { + return InitBlankAccount() } // AccountFromPickled loads an Account from a pickled base64 string. Decrypts @@ -30,375 +107,6 @@ type Account struct { // doesn't match the one used to encrypt the Account then the error will be // "BAD_ACCOUNT_KEY". If the base64 couldn't be decoded then the error will be // "INVALID_BASE64". -func AccountFromPickled(pickled, key []byte) (*Account, error) { - if len(pickled) == 0 { - return nil, EmptyInput - } - a := NewBlankAccount() - return a, a.Unpickle(pickled, key) -} - -func NewBlankAccount() *Account { - memory := make([]byte, accountSize()) - return &Account{ - int: C.olm_account(unsafe.Pointer(&memory[0])), - mem: memory, - } -} - -// NewAccount creates a new Account. -func NewAccount() *Account { - a := NewBlankAccount() - random := make([]byte, a.createRandomLen()+1) - _, err := rand.Read(random) - if err != nil { - panic(NotEnoughGoRandom) - } - r := C.olm_create_account( - (*C.OlmAccount)(a.int), - unsafe.Pointer(&random[0]), - C.size_t(len(random))) - if r == errorVal() { - panic(a.lastError()) - } else { - return a - } -} - -// accountSize returns the size of an account object in bytes. -func accountSize() uint { - return uint(C.olm_account_size()) -} - -// lastError returns an error describing the most recent error to happen to an -// account. -func (a *Account) lastError() error { - return convertError(C.GoString(C.olm_account_last_error((*C.OlmAccount)(a.int)))) -} - -// Clear clears the memory used to back this Account. -func (a *Account) Clear() error { - r := C.olm_clear_account((*C.OlmAccount)(a.int)) - if r == errorVal() { - return a.lastError() - } else { - return nil - } -} - -// pickleLen returns the number of bytes needed to store an Account. -func (a *Account) pickleLen() uint { - return uint(C.olm_pickle_account_length((*C.OlmAccount)(a.int))) -} - -// createRandomLen returns the number of random bytes needed to create an -// Account. -func (a *Account) createRandomLen() uint { - return uint(C.olm_create_account_random_length((*C.OlmAccount)(a.int))) -} - -// identityKeysLen returns the size of the output buffer needed to hold the -// identity keys. -func (a *Account) identityKeysLen() uint { - return uint(C.olm_account_identity_keys_length((*C.OlmAccount)(a.int))) -} - -// signatureLen returns the length of an ed25519 signature encoded as base64. -func (a *Account) signatureLen() uint { - return uint(C.olm_account_signature_length((*C.OlmAccount)(a.int))) -} - -// oneTimeKeysLen returns the size of the output buffer needed to hold the one -// time keys. -func (a *Account) oneTimeKeysLen() uint { - return uint(C.olm_account_one_time_keys_length((*C.OlmAccount)(a.int))) -} - -// genOneTimeKeysRandomLen returns the number of random bytes needed to -// generate a given number of new one time keys. -func (a *Account) genOneTimeKeysRandomLen(num uint) uint { - return uint(C.olm_account_generate_one_time_keys_random_length( - (*C.OlmAccount)(a.int), - C.size_t(num))) -} - -// Pickle returns an Account as a base64 string. Encrypts the Account using the -// supplied key. -func (a *Account) Pickle(key []byte) []byte { - if len(key) == 0 { - panic(NoKeyProvided) - } - pickled := make([]byte, a.pickleLen()) - r := C.olm_pickle_account( - (*C.OlmAccount)(a.int), - unsafe.Pointer(&key[0]), - C.size_t(len(key)), - unsafe.Pointer(&pickled[0]), - C.size_t(len(pickled))) - if r == errorVal() { - panic(a.lastError()) - } - return pickled[:r] -} - -func (a *Account) Unpickle(pickled, key []byte) error { - if len(key) == 0 { - return NoKeyProvided - } - r := C.olm_unpickle_account( - (*C.OlmAccount)(a.int), - unsafe.Pointer(&key[0]), - C.size_t(len(key)), - unsafe.Pointer(&pickled[0]), - C.size_t(len(pickled))) - if r == errorVal() { - return a.lastError() - } - return nil -} - -// Deprecated -func (a *Account) GobEncode() ([]byte, error) { - pickled := a.Pickle(pickleKey) - length := base64.RawStdEncoding.DecodedLen(len(pickled)) - rawPickled := make([]byte, length) - _, err := base64.RawStdEncoding.Decode(rawPickled, pickled) - return rawPickled, err -} - -// Deprecated -func (a *Account) GobDecode(rawPickled []byte) error { - if a.int == nil { - *a = *NewBlankAccount() - } - length := base64.RawStdEncoding.EncodedLen(len(rawPickled)) - pickled := make([]byte, length) - base64.RawStdEncoding.Encode(pickled, rawPickled) - return a.Unpickle(pickled, pickleKey) -} - -// Deprecated -func (a *Account) MarshalJSON() ([]byte, error) { - pickled := a.Pickle(pickleKey) - quotes := make([]byte, len(pickled)+2) - quotes[0] = '"' - quotes[len(quotes)-1] = '"' - copy(quotes[1:len(quotes)-1], pickled) - return quotes, nil -} - -// Deprecated -func (a *Account) UnmarshalJSON(data []byte) error { - if len(data) == 0 || data[0] != '"' || data[len(data)-1] != '"' { - return InputNotJSONString - } - if a.int == nil { - *a = *NewBlankAccount() - } - return a.Unpickle(data[1:len(data)-1], pickleKey) -} - -// IdentityKeysJSON returns the public parts of the identity keys for the Account. -func (a *Account) IdentityKeysJSON() []byte { - identityKeys := make([]byte, a.identityKeysLen()) - r := C.olm_account_identity_keys( - (*C.OlmAccount)(a.int), - unsafe.Pointer(&identityKeys[0]), - C.size_t(len(identityKeys))) - if r == errorVal() { - panic(a.lastError()) - } else { - return identityKeys - } -} - -// IdentityKeys returns the public parts of the Ed25519 and Curve25519 identity -// keys for the Account. -func (a *Account) IdentityKeys() (id.Ed25519, id.Curve25519) { - identityKeysJSON := a.IdentityKeysJSON() - results := gjson.GetManyBytes(identityKeysJSON, "ed25519", "curve25519") - return id.Ed25519(results[0].Str), id.Curve25519(results[1].Str) -} - -// Sign returns the signature of a message using the ed25519 key for this -// Account. -func (a *Account) Sign(message []byte) []byte { - if len(message) == 0 { - panic(EmptyInput) - } - signature := make([]byte, a.signatureLen()) - r := C.olm_account_sign( - (*C.OlmAccount)(a.int), - unsafe.Pointer(&message[0]), - C.size_t(len(message)), - unsafe.Pointer(&signature[0]), - C.size_t(len(signature))) - if r == errorVal() { - panic(a.lastError()) - } - return signature -} - -// SignJSON signs the given JSON object following the Matrix specification: -// https://matrix.org/docs/spec/appendices#signing-json -func (a *Account) SignJSON(obj interface{}) (string, error) { - objJSON, err := json.Marshal(obj) - if err != nil { - return "", err - } - objJSON, _ = sjson.DeleteBytes(objJSON, "unsigned") - objJSON, _ = sjson.DeleteBytes(objJSON, "signatures") - return string(a.Sign(canonicaljson.CanonicalJSONAssumeValid(objJSON))), nil -} - -// OneTimeKeys returns the public parts of the unpublished one time keys for -// the Account. -// -// The returned data is a struct with the single value "Curve25519", which is -// itself an object mapping key id to base64-encoded Curve25519 key. For -// example: -// -// { -// Curve25519: { -// "AAAAAA": "wo76WcYtb0Vk/pBOdmduiGJ0wIEjW4IBMbbQn7aSnTo", -// "AAAAAB": "LRvjo46L1X2vx69sS9QNFD29HWulxrmW11Up5AfAjgU" -// } -// } -func (a *Account) OneTimeKeys() map[string]id.Curve25519 { - oneTimeKeysJSON := make([]byte, a.oneTimeKeysLen()) - r := C.olm_account_one_time_keys( - (*C.OlmAccount)(a.int), - unsafe.Pointer(&oneTimeKeysJSON[0]), - C.size_t(len(oneTimeKeysJSON))) - if r == errorVal() { - panic(a.lastError()) - } - var oneTimeKeys struct { - Curve25519 map[string]id.Curve25519 `json:"curve25519"` - } - err := json.Unmarshal(oneTimeKeysJSON, &oneTimeKeys) - if err != nil { - panic(err) - } - return oneTimeKeys.Curve25519 -} - -// MarkKeysAsPublished marks the current set of one time keys as being -// published. -func (a *Account) MarkKeysAsPublished() { - C.olm_account_mark_keys_as_published((*C.OlmAccount)(a.int)) -} - -// MaxNumberOfOneTimeKeys returns the largest number of one time keys this -// Account can store. -func (a *Account) MaxNumberOfOneTimeKeys() uint { - return uint(C.olm_account_max_number_of_one_time_keys((*C.OlmAccount)(a.int))) -} - -// GenOneTimeKeys generates a number of new one time keys. If the total number -// of keys stored by this Account exceeds MaxNumberOfOneTimeKeys then the old -// keys are discarded. -func (a *Account) GenOneTimeKeys(num uint) { - random := make([]byte, a.genOneTimeKeysRandomLen(num)+1) - _, err := rand.Read(random) - if err != nil { - panic(NotEnoughGoRandom) - } - r := C.olm_account_generate_one_time_keys( - (*C.OlmAccount)(a.int), - C.size_t(num), - unsafe.Pointer(&random[0]), - C.size_t(len(random))) - if r == errorVal() { - panic(a.lastError()) - } -} - -// NewOutboundSession creates a new out-bound session for sending messages to a -// given curve25519 identityKey and oneTimeKey. Returns error on failure. If the -// keys couldn't be decoded as base64 then the error will be "INVALID_BASE64" -func (a *Account) NewOutboundSession(theirIdentityKey, theirOneTimeKey id.Curve25519) (*Session, error) { - if len(theirIdentityKey) == 0 || len(theirOneTimeKey) == 0 { - return nil, EmptyInput - } - s := NewBlankSession() - random := make([]byte, s.createOutboundRandomLen()+1) - _, err := rand.Read(random) - if err != nil { - panic(NotEnoughGoRandom) - } - r := C.olm_create_outbound_session( - (*C.OlmSession)(s.int), - (*C.OlmAccount)(a.int), - unsafe.Pointer(&([]byte(theirIdentityKey)[0])), - C.size_t(len(theirIdentityKey)), - unsafe.Pointer(&([]byte(theirOneTimeKey)[0])), - C.size_t(len(theirOneTimeKey)), - unsafe.Pointer(&random[0]), - C.size_t(len(random))) - if r == errorVal() { - return nil, s.lastError() - } - return s, nil -} - -// NewInboundSession creates a new in-bound session for sending/receiving -// messages from an incoming PRE_KEY message. Returns error on failure. If -// the base64 couldn't be decoded then the error will be "INVALID_BASE64". If -// the message was for an unsupported protocol version then the error will be -// "BAD_MESSAGE_VERSION". If the message couldn't be decoded then then the -// error will be "BAD_MESSAGE_FORMAT". If the message refers to an unknown one -// time key then the error will be "BAD_MESSAGE_KEY_ID". -func (a *Account) NewInboundSession(oneTimeKeyMsg string) (*Session, error) { - if len(oneTimeKeyMsg) == 0 { - return nil, EmptyInput - } - s := NewBlankSession() - r := C.olm_create_inbound_session( - (*C.OlmSession)(s.int), - (*C.OlmAccount)(a.int), - unsafe.Pointer(&([]byte(oneTimeKeyMsg)[0])), - C.size_t(len(oneTimeKeyMsg))) - if r == errorVal() { - return nil, s.lastError() - } - return s, nil -} - -// NewInboundSessionFrom creates a new in-bound session for sending/receiving -// messages from an incoming PRE_KEY message. Returns error on failure. If -// the base64 couldn't be decoded then the error will be "INVALID_BASE64". If -// the message was for an unsupported protocol version then the error will be -// "BAD_MESSAGE_VERSION". If the message couldn't be decoded then then the -// error will be "BAD_MESSAGE_FORMAT". If the message refers to an unknown one -// time key then the error will be "BAD_MESSAGE_KEY_ID". -func (a *Account) NewInboundSessionFrom(theirIdentityKey id.Curve25519, oneTimeKeyMsg string) (*Session, error) { - if len(theirIdentityKey) == 0 || len(oneTimeKeyMsg) == 0 { - return nil, EmptyInput - } - s := NewBlankSession() - r := C.olm_create_inbound_session_from( - (*C.OlmSession)(s.int), - (*C.OlmAccount)(a.int), - unsafe.Pointer(&([]byte(theirIdentityKey)[0])), - C.size_t(len(theirIdentityKey)), - unsafe.Pointer(&([]byte(oneTimeKeyMsg)[0])), - C.size_t(len(oneTimeKeyMsg))) - if r == errorVal() { - return nil, s.lastError() - } - return s, nil -} - -// RemoveOneTimeKeys removes the one time keys that the session used from the -// Account. Returns error on failure. If the Account doesn't have any -// matching one time keys then the error will be "BAD_MESSAGE_KEY_ID". -func (a *Account) RemoveOneTimeKeys(s *Session) error { - r := C.olm_remove_one_time_keys( - (*C.OlmAccount)(a.int), - (*C.OlmSession)(s.int)) - if r == errorVal() { - return a.lastError() - } - return nil +func AccountFromPickled(pickled, key []byte) (Account, error) { + return InitNewAccountFromPickled(pickled, key) } diff --git a/crypto/olm/account_goolm.go b/crypto/olm/account_goolm.go deleted file mode 100644 index eeff54f9..00000000 --- a/crypto/olm/account_goolm.go +++ /dev/null @@ -1,154 +0,0 @@ -//go:build goolm - -package olm - -import ( - "encoding/json" - - "github.com/tidwall/sjson" - - "maunium.net/go/mautrix/crypto/canonicaljson" - "maunium.net/go/mautrix/crypto/goolm/account" - "maunium.net/go/mautrix/id" -) - -// Account stores a device account for end to end encrypted messaging. -type Account struct { - account.Account -} - -// NewAccount creates a new Account. -func NewAccount() *Account { - a, err := account.NewAccount(nil) - if err != nil { - panic(err) - } - ac := &Account{} - ac.Account = *a - return ac -} - -func NewBlankAccount() *Account { - return &Account{} -} - -// Clear clears the memory used to back this Account. -func (a *Account) Clear() error { - a.Account = account.Account{} - return nil -} - -// Pickle returns an Account as a base64 string. Encrypts the Account using the -// supplied key. -func (a *Account) Pickle(key []byte) []byte { - if len(key) == 0 { - panic(NoKeyProvided) - } - pickled, err := a.Account.Pickle(key) - if err != nil { - panic(err) - } - return pickled -} - -// IdentityKeysJSON returns the public parts of the identity keys for the Account. -func (a *Account) IdentityKeysJSON() []byte { - identityKeys, err := a.Account.IdentityKeysJSON() - if err != nil { - panic(err) - } - return identityKeys -} - -// Sign returns the signature of a message using the ed25519 key for this -// Account. -func (a *Account) Sign(message []byte) []byte { - if len(message) == 0 { - panic(EmptyInput) - } - signature, err := a.Account.Sign(message) - if err != nil { - panic(err) - } - return signature -} - -// SignJSON signs the given JSON object following the Matrix specification: -// https://matrix.org/docs/spec/appendices#signing-json -func (a *Account) SignJSON(obj interface{}) (string, error) { - objJSON, err := json.Marshal(obj) - if err != nil { - return "", err - } - objJSON, _ = sjson.DeleteBytes(objJSON, "unsigned") - objJSON, _ = sjson.DeleteBytes(objJSON, "signatures") - return string(a.Sign(canonicaljson.CanonicalJSONAssumeValid(objJSON))), nil -} - -// MaxNumberOfOneTimeKeys returns the largest number of one time keys this -// Account can store. -func (a *Account) MaxNumberOfOneTimeKeys() uint { - return uint(account.MaxOneTimeKeys) -} - -// GenOneTimeKeys generates a number of new one time keys. If the total number -// of keys stored by this Account exceeds MaxNumberOfOneTimeKeys then the old -// keys are discarded. -func (a *Account) GenOneTimeKeys(num uint) { - err := a.Account.GenOneTimeKeys(nil, num) - if err != nil { - panic(err) - } -} - -// NewOutboundSession creates a new out-bound session for sending messages to a -// given curve25519 identityKey and oneTimeKey. Returns error on failure. -func (a *Account) NewOutboundSession(theirIdentityKey, theirOneTimeKey id.Curve25519) (*Session, error) { - if len(theirIdentityKey) == 0 || len(theirOneTimeKey) == 0 { - return nil, EmptyInput - } - s := &Session{} - newSession, err := a.Account.NewOutboundSession(theirIdentityKey, theirOneTimeKey) - if err != nil { - return nil, err - } - s.OlmSession = *newSession - return s, nil -} - -// NewInboundSession creates a new in-bound session for sending/receiving -// messages from an incoming PRE_KEY message. Returns error on failure. -func (a *Account) NewInboundSession(oneTimeKeyMsg string) (*Session, error) { - if len(oneTimeKeyMsg) == 0 { - return nil, EmptyInput - } - s := &Session{} - newSession, err := a.Account.NewInboundSession(nil, []byte(oneTimeKeyMsg)) - if err != nil { - return nil, err - } - s.OlmSession = *newSession - return s, nil -} - -// NewInboundSessionFrom creates a new in-bound session for sending/receiving -// messages from an incoming PRE_KEY message. Returns error on failure. -func (a *Account) NewInboundSessionFrom(theirIdentityKey id.Curve25519, oneTimeKeyMsg string) (*Session, error) { - if len(theirIdentityKey) == 0 || len(oneTimeKeyMsg) == 0 { - return nil, EmptyInput - } - s := &Session{} - newSession, err := a.Account.NewInboundSession(&theirIdentityKey, []byte(oneTimeKeyMsg)) - if err != nil { - return nil, err - } - s.OlmSession = *newSession - return s, nil -} - -// RemoveOneTimeKeys removes the one time keys that the session used from the -// Account. Returns error on failure. -func (a *Account) RemoveOneTimeKeys(s *Session) error { - a.Account.RemoveOneTimeKeys(&s.OlmSession) - return nil -} diff --git a/crypto/olm/account_test.go b/crypto/olm/account_test.go new file mode 100644 index 00000000..0e055881 --- /dev/null +++ b/crypto/olm/account_test.go @@ -0,0 +1,122 @@ +// Copyright (c) 2024 Sumner Evans +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package olm_test + +import ( + "bytes" + "encoding/base64" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "go.mau.fi/util/exerrors" + + "maunium.net/go/mautrix/crypto/ed25519" + "maunium.net/go/mautrix/crypto/goolm/account" + "maunium.net/go/mautrix/crypto/libolm" + "maunium.net/go/mautrix/crypto/olm" +) + +func ensureAccountsEqual(t *testing.T, a, b olm.Account) { + t.Helper() + + assert.Equal(t, a.MaxNumberOfOneTimeKeys(), b.MaxNumberOfOneTimeKeys()) + + aEd25519, aCurve25519, err := a.IdentityKeys() + require.NoError(t, err) + bEd25519, bCurve25519, err := b.IdentityKeys() + require.NoError(t, err) + assert.Equal(t, aEd25519, bEd25519) + assert.Equal(t, aCurve25519, bCurve25519) + + aIdentityKeysJSON, err := a.IdentityKeysJSON() + require.NoError(t, err) + bIdentityKeysJSON, err := b.IdentityKeysJSON() + require.NoError(t, err) + assert.JSONEq(t, string(aIdentityKeysJSON), string(bIdentityKeysJSON)) + + aOTKs, err := a.OneTimeKeys() + require.NoError(t, err) + bOTKs, err := b.OneTimeKeys() + require.NoError(t, err) + assert.Equal(t, aOTKs, bOTKs) +} + +// TestAccount_UnpickleLibolmToGoolm tests creating an account from libolm, +// pickling it, and importing it into goolm. +func TestAccount_UnpickleLibolmToGoolm(t *testing.T) { + libolmAccount, err := libolm.NewAccount() + require.NoError(t, err) + + require.NoError(t, libolmAccount.GenOneTimeKeys(50)) + + libolmPickled, err := libolmAccount.Pickle([]byte("test")) + require.NoError(t, err) + + goolmAccount, err := account.AccountFromPickled(libolmPickled, []byte("test")) + require.NoError(t, err) + + ensureAccountsEqual(t, libolmAccount, goolmAccount) + + goolmPickled, err := goolmAccount.Pickle([]byte("test")) + require.NoError(t, err) + assert.Equal(t, libolmPickled, goolmPickled) +} + +// TestAccount_UnpickleGoolmToLibolm tests creating an account from goolm, +// pickling it, and importing it into libolm. +func TestAccount_UnpickleGoolmToLibolm(t *testing.T) { + goolmAccount, err := account.NewAccount() + require.NoError(t, err) + + require.NoError(t, goolmAccount.GenOneTimeKeys(50)) + + goolmPickled, err := goolmAccount.Pickle([]byte("test")) + require.NoError(t, err) + + libolmAccount, err := libolm.AccountFromPickled(bytes.Clone(goolmPickled), []byte("test")) + require.NoError(t, err) + + ensureAccountsEqual(t, libolmAccount, goolmAccount) + + libolmPickled, err := libolmAccount.Pickle([]byte("test")) + require.NoError(t, err) + assert.Equal(t, goolmPickled, libolmPickled) +} + +func FuzzAccount_Sign(f *testing.F) { + f.Add([]byte("anything")) + + libolmAccount := exerrors.Must(libolm.NewAccount()) + goolmAccount := exerrors.Must(account.AccountFromPickled(exerrors.Must(libolmAccount.Pickle([]byte("test"))), []byte("test"))) + + f.Fuzz(func(t *testing.T, message []byte) { + if len(message) == 0 { + t.Skip("empty message is not supported") + } + + libolmSignature, err := libolmAccount.Sign(bytes.Clone(message)) + require.NoError(t, err) + goolmSignature, err := goolmAccount.Sign(bytes.Clone(message)) + require.NoError(t, err) + assert.Equal(t, goolmSignature, libolmSignature) + + goolmSignatureBytes, err := base64.RawStdEncoding.DecodeString(string(goolmSignature)) + require.NoError(t, err) + libolmSignatureBytes, err := base64.RawStdEncoding.DecodeString(string(libolmSignature)) + require.NoError(t, err) + + libolmEd25519, _, err := libolmAccount.IdentityKeys() + require.NoError(t, err) + + assert.True(t, ed25519.Verify(ed25519.PublicKey(libolmEd25519.Bytes()), message, libolmSignatureBytes)) + assert.True(t, ed25519.Verify(ed25519.PublicKey(libolmEd25519.Bytes()), message, goolmSignatureBytes)) + + assert.True(t, goolmAccount.IdKeys.Ed25519.Verify(bytes.Clone(message), libolmSignatureBytes)) + assert.True(t, goolmAccount.IdKeys.Ed25519.Verify(bytes.Clone(message), goolmSignatureBytes)) + }) +} diff --git a/crypto/olm/error.go b/crypto/olm/error.go deleted file mode 100644 index 63352e20..00000000 --- a/crypto/olm/error.go +++ /dev/null @@ -1,62 +0,0 @@ -//go:build !goolm - -package olm - -import ( - "errors" - "fmt" -) - -// Error codes from go-olm -var ( - EmptyInput = errors.New("empty input") - NoKeyProvided = errors.New("no pickle key provided") - NotEnoughGoRandom = errors.New("couldn't get enough randomness from crypto/rand") - SignatureNotFound = errors.New("input JSON doesn't contain signature from specified device") - InputNotJSONString = errors.New("input doesn't look like a JSON string") -) - -// Error codes from olm code -var ( - NotEnoughRandom = errors.New("not enough entropy was supplied") - OutputBufferTooSmall = errors.New("supplied output buffer is too small") - BadMessageVersion = errors.New("the message version is unsupported") - BadMessageFormat = errors.New("the message couldn't be decoded") - BadMessageMAC = errors.New("the message couldn't be decrypted") - BadMessageKeyID = errors.New("the message references an unknown key ID") - InvalidBase64 = errors.New("the input base64 was invalid") - BadAccountKey = errors.New("the supplied account key is invalid") - UnknownPickleVersion = errors.New("the pickled object is too new") - CorruptedPickle = errors.New("the pickled object couldn't be decoded") - BadSessionKey = errors.New("attempt to initialise an inbound group session from an invalid session key") - UnknownMessageIndex = errors.New("attempt to decode a message whose index is earlier than our earliest known session key") - BadLegacyAccountPickle = errors.New("attempt to unpickle an account which uses pickle version 1") - BadSignature = errors.New("received message had a bad signature") - InputBufferTooSmall = errors.New("the input data was too small to be valid") -) - -var errorMap = map[string]error{ - "NOT_ENOUGH_RANDOM": NotEnoughRandom, - "OUTPUT_BUFFER_TOO_SMALL": OutputBufferTooSmall, - "BAD_MESSAGE_VERSION": BadMessageVersion, - "BAD_MESSAGE_FORMAT": BadMessageFormat, - "BAD_MESSAGE_MAC": BadMessageMAC, - "BAD_MESSAGE_KEY_ID": BadMessageKeyID, - "INVALID_BASE64": InvalidBase64, - "BAD_ACCOUNT_KEY": BadAccountKey, - "UNKNOWN_PICKLE_VERSION": UnknownPickleVersion, - "CORRUPTED_PICKLE": CorruptedPickle, - "BAD_SESSION_KEY": BadSessionKey, - "UNKNOWN_MESSAGE_INDEX": UnknownMessageIndex, - "BAD_LEGACY_ACCOUNT_PICKLE": BadLegacyAccountPickle, - "BAD_SIGNATURE": BadSignature, - "INPUT_BUFFER_TOO_SMALL": InputBufferTooSmall, -} - -func convertError(errCode string) error { - err, ok := errorMap[errCode] - if ok { - return err - } - return fmt.Errorf("unknown error: %s", errCode) -} diff --git a/crypto/olm/error_goolm.go b/crypto/olm/error_goolm.go deleted file mode 100644 index 0e54e566..00000000 --- a/crypto/olm/error_goolm.go +++ /dev/null @@ -1,23 +0,0 @@ -//go:build goolm - -package olm - -import ( - "errors" - - "maunium.net/go/mautrix/crypto/goolm" -) - -// Error codes from go-olm -var ( - EmptyInput = goolm.ErrEmptyInput - NoKeyProvided = goolm.ErrNoKeyProvided - NotEnoughGoRandom = errors.New("couldn't get enough randomness from crypto/rand") - SignatureNotFound = errors.New("input JSON doesn't contain signature from specified device") - InputNotJSONString = errors.New("input doesn't look like a JSON string") -) - -// Error codes from olm code -var ( - UnknownMessageIndex = goolm.ErrRatchetNotAvailable -) diff --git a/crypto/olm/errors.go b/crypto/olm/errors.go new file mode 100644 index 00000000..9e522b2a --- /dev/null +++ b/crypto/olm/errors.go @@ -0,0 +1,76 @@ +// Copyright (c) 2024 Sumner Evans +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package olm + +import "errors" + +// Those are the most common used errors +var ( + ErrBadSignature = errors.New("bad signature") + ErrBadMAC = errors.New("the message couldn't be decrypted (bad mac)") + ErrBadMessageFormat = errors.New("the message couldn't be decoded") + ErrBadVerification = errors.New("bad verification") + ErrWrongProtocolVersion = errors.New("wrong protocol version") + ErrEmptyInput = errors.New("empty input") + ErrNoKeyProvided = errors.New("no key provided") + ErrBadMessageKeyID = errors.New("the message references an unknown key ID") + ErrUnknownMessageIndex = errors.New("attempt to decode a message whose index is earlier than our earliest known session key") + ErrMsgIndexTooHigh = errors.New("message index too high") + ErrProtocolViolation = errors.New("not protocol message order") + ErrMessageKeyNotFound = errors.New("message key not found") + ErrChainTooHigh = errors.New("chain index too high") + ErrBadInput = errors.New("bad input") + ErrUnknownOlmPickleVersion = errors.New("unknown olm pickle version") + ErrUnknownJSONPickleVersion = errors.New("unknown JSON pickle version") + ErrInputToSmall = errors.New("input too small (truncated?)") +) + +// Error codes from go-olm +var ( + ErrNotEnoughGoRandom = errors.New("couldn't get enough randomness from crypto/rand") + ErrInputNotJSONString = errors.New("input doesn't look like a JSON string") +) + +// Error codes from olm code +var ( + ErrLibolmInvalidBase64 = errors.New("the input base64 was invalid") + + ErrLibolmNotEnoughRandom = errors.New("not enough entropy was supplied") + ErrLibolmOutputBufferTooSmall = errors.New("supplied output buffer is too small") + ErrLibolmBadAccountKey = errors.New("the supplied account key is invalid") + ErrLibolmCorruptedPickle = errors.New("the pickled object couldn't be decoded") + ErrLibolmBadSessionKey = errors.New("attempt to initialise an inbound group session from an invalid session key") + ErrLibolmBadLegacyAccountPickle = errors.New("attempt to unpickle an account which uses pickle version 1") +) + +// Deprecated: use variables prefixed with Err +var ( + EmptyInput = ErrEmptyInput + BadSignature = ErrBadSignature + InvalidBase64 = ErrLibolmInvalidBase64 + BadMessageKeyID = ErrBadMessageKeyID + BadMessageFormat = ErrBadMessageFormat + BadMessageVersion = ErrWrongProtocolVersion + BadMessageMAC = ErrBadMAC + UnknownPickleVersion = ErrUnknownOlmPickleVersion + NotEnoughRandom = ErrLibolmNotEnoughRandom + OutputBufferTooSmall = ErrLibolmOutputBufferTooSmall + BadAccountKey = ErrLibolmBadAccountKey + CorruptedPickle = ErrLibolmCorruptedPickle + BadSessionKey = ErrLibolmBadSessionKey + UnknownMessageIndex = ErrUnknownMessageIndex + BadLegacyAccountPickle = ErrLibolmBadLegacyAccountPickle + InputBufferTooSmall = ErrInputToSmall + NoKeyProvided = ErrNoKeyProvided + + NotEnoughGoRandom = ErrNotEnoughGoRandom + InputNotJSONString = ErrInputNotJSONString + + ErrBadVersion = ErrUnknownJSONPickleVersion + ErrWrongPickleVersion = ErrUnknownJSONPickleVersion + ErrRatchetNotAvailable = ErrUnknownMessageIndex +) diff --git a/crypto/olm/groupsession_test.go b/crypto/olm/groupsession_test.go new file mode 100644 index 00000000..0f845e90 --- /dev/null +++ b/crypto/olm/groupsession_test.go @@ -0,0 +1,48 @@ +package olm_test + +import ( + "fmt" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "maunium.net/go/mautrix/crypto/goolm/session" + "maunium.net/go/mautrix/crypto/libolm" +) + +// TestEncryptDecrypt_GoolmToLibolm tests encryption where goolm encrypts and libolm decrypts +func TestEncryptDecrypt_GoolmToLibolm(t *testing.T) { + goolmOutbound, err := session.NewMegolmOutboundSession() + require.NoError(t, err) + + libolmInbound, err := libolm.NewInboundGroupSession([]byte(goolmOutbound.Key())) + require.NoError(t, err) + + for i := 0; i < 10; i++ { + ciphertext, err := goolmOutbound.Encrypt([]byte(fmt.Sprintf("message %d", i))) + require.NoError(t, err) + + plaintext, msgIdx, err := libolmInbound.Decrypt(ciphertext) + assert.NoError(t, err) + assert.Equal(t, []byte(fmt.Sprintf("message %d", i)), plaintext) + assert.Equal(t, goolmOutbound.MessageIndex()-1, msgIdx) + } +} + +func TestEncryptDecrypt_LibolmToGoolm(t *testing.T) { + libolmOutbound, err := libolm.NewOutboundGroupSession() + require.NoError(t, err) + goolmInbound, err := session.NewMegolmInboundSession([]byte(libolmOutbound.Key())) + require.NoError(t, err) + + for i := 0; i < 10; i++ { + ciphertext, err := libolmOutbound.Encrypt([]byte(fmt.Sprintf("message %d", i))) + require.NoError(t, err) + + plaintext, msgIdx, err := goolmInbound.Decrypt(ciphertext) + assert.NoError(t, err) + assert.Equal(t, []byte(fmt.Sprintf("message %d", i)), plaintext) + assert.Equal(t, libolmOutbound.MessageIndex()-1, msgIdx) + } +} diff --git a/crypto/olm/inboundgroupsession.go b/crypto/olm/inboundgroupsession.go index a3bd3b65..8839b48c 100644 --- a/crypto/olm/inboundgroupsession.go +++ b/crypto/olm/inboundgroupsession.go @@ -1,310 +1,80 @@ -//go:build !goolm +// Copyright (c) 2024 Sumner Evans +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. package olm -// #cgo LDFLAGS: -lolm -lstdc++ -// #include -import "C" +import "maunium.net/go/mautrix/id" -import ( - "encoding/base64" - "unsafe" +type InboundGroupSession interface { + // Pickle returns an InboundGroupSession as a base64 string. Encrypts the + // InboundGroupSession using the supplied key. + Pickle(key []byte) ([]byte, error) - "maunium.net/go/mautrix/id" -) + // Unpickle loads an [InboundGroupSession] from a pickled base64 string. + // Decrypts the [InboundGroupSession] using the supplied key. + Unpickle(pickled, key []byte) error -// InboundGroupSession stores an inbound encrypted messaging session for a -// group. -type InboundGroupSession struct { - int *C.OlmInboundGroupSession - mem []byte + // Decrypt decrypts a message using the [InboundGroupSession]. Returns the + // plain-text and message index on success. Returns error on failure. If + // the base64 couldn't be decoded then the error will be "INVALID_BASE64". + // If the message is for an unsupported version of the protocol then the + // error will be "BAD_MESSAGE_VERSION". If the message couldn't be decoded + // then the error will be BAD_MESSAGE_FORMAT". If the MAC on the message + // was invalid then the error will be "BAD_MESSAGE_MAC". If we do not have + // a session key corresponding to the message's index (ie, it was sent + // before the session key was shared with us) the error will be + // "OLM_UNKNOWN_MESSAGE_INDEX". + Decrypt(message []byte) ([]byte, uint, error) + + // ID returns a base64-encoded identifier for this session. + ID() id.SessionID + + // FirstKnownIndex returns the first message index we know how to decrypt. + FirstKnownIndex() uint32 + + // IsVerified check if the session has been verified as a valid session. + // (A session is verified either because the original session share was + // signed, or because we have subsequently successfully decrypted a + // message.) + IsVerified() bool + + // Export returns the base64-encoded ratchet key for this session, at the + // given index, in a format which can be used by + // InboundGroupSession.InboundGroupSessionImport(). Encrypts the + // InboundGroupSession using the supplied key. Returns error on failure. + // if we do not have a session key corresponding to the given index (ie, it + // was sent before the session key was shared with us) the error will be + // "OLM_UNKNOWN_MESSAGE_INDEX". + Export(messageIndex uint32) ([]byte, error) } +var InitInboundGroupSessionFromPickled func(pickled, key []byte) (InboundGroupSession, error) +var InitNewInboundGroupSession func(sessionKey []byte) (InboundGroupSession, error) +var InitInboundGroupSessionImport func(sessionKey []byte) (InboundGroupSession, error) +var InitBlankInboundGroupSession func() InboundGroupSession + // InboundGroupSessionFromPickled loads an InboundGroupSession from a pickled -// base64 string. Decrypts the InboundGroupSession using the supplied key. -// Returns error on failure. If the key doesn't match the one used to encrypt -// the InboundGroupSession then the error will be "BAD_SESSION_KEY". If the -// base64 couldn't be decoded then the error will be "INVALID_BASE64". -func InboundGroupSessionFromPickled(pickled, key []byte) (*InboundGroupSession, error) { - if len(pickled) == 0 { - return nil, EmptyInput - } - lenKey := len(key) - if lenKey == 0 { - key = []byte(" ") - } - s := NewBlankInboundGroupSession() - return s, s.Unpickle(pickled, key) +// base64 string. Decrypts the InboundGroupSession using the supplied key. +// Returns error on failure. +func InboundGroupSessionFromPickled(pickled, key []byte) (InboundGroupSession, error) { + return InitInboundGroupSessionFromPickled(pickled, key) } // NewInboundGroupSession creates a new inbound group session from a key -// exported from OutboundGroupSession.Key(). Returns error on failure. -// If the sessionKey is not valid base64 the error will be -// "OLM_INVALID_BASE64". If the session_key is invalid the error will be -// "OLM_BAD_SESSION_KEY". -func NewInboundGroupSession(sessionKey []byte) (*InboundGroupSession, error) { - if len(sessionKey) == 0 { - return nil, EmptyInput - } - s := NewBlankInboundGroupSession() - r := C.olm_init_inbound_group_session( - (*C.OlmInboundGroupSession)(s.int), - (*C.uint8_t)(&sessionKey[0]), - C.size_t(len(sessionKey))) - if r == errorVal() { - return nil, s.lastError() - } - return s, nil +// exported from OutboundGroupSession.Key(). Returns error on failure. +func NewInboundGroupSession(sessionKey []byte) (InboundGroupSession, error) { + return InitNewInboundGroupSession(sessionKey) } // InboundGroupSessionImport imports an inbound group session from a previous -// export. Returns error on failure. If the sessionKey is not valid base64 -// the error will be "OLM_INVALID_BASE64". If the session_key is invalid the -// error will be "OLM_BAD_SESSION_KEY". -func InboundGroupSessionImport(sessionKey []byte) (*InboundGroupSession, error) { - if len(sessionKey) == 0 { - return nil, EmptyInput - } - s := NewBlankInboundGroupSession() - r := C.olm_import_inbound_group_session( - (*C.OlmInboundGroupSession)(s.int), - (*C.uint8_t)(&sessionKey[0]), - C.size_t(len(sessionKey))) - if r == errorVal() { - return nil, s.lastError() - } - return s, nil +// export. Returns error on failure. +func InboundGroupSessionImport(sessionKey []byte) (InboundGroupSession, error) { + return InitInboundGroupSessionImport(sessionKey) } -// inboundGroupSessionSize is the size of an inbound group session object in -// bytes. -func inboundGroupSessionSize() uint { - return uint(C.olm_inbound_group_session_size()) -} - -// newInboundGroupSession initialises an empty InboundGroupSession. -func NewBlankInboundGroupSession() *InboundGroupSession { - memory := make([]byte, inboundGroupSessionSize()) - return &InboundGroupSession{ - int: C.olm_inbound_group_session(unsafe.Pointer(&memory[0])), - mem: memory, - } -} - -// lastError returns an error describing the most recent error to happen to an -// inbound group session. -func (s *InboundGroupSession) lastError() error { - return convertError(C.GoString(C.olm_inbound_group_session_last_error((*C.OlmInboundGroupSession)(s.int)))) -} - -// Clear clears the memory used to back this InboundGroupSession. -func (s *InboundGroupSession) Clear() error { - r := C.olm_clear_inbound_group_session((*C.OlmInboundGroupSession)(s.int)) - if r == errorVal() { - return s.lastError() - } - return nil -} - -// pickleLen returns the number of bytes needed to store an inbound group -// session. -func (s *InboundGroupSession) pickleLen() uint { - return uint(C.olm_pickle_inbound_group_session_length((*C.OlmInboundGroupSession)(s.int))) -} - -// Pickle returns an InboundGroupSession as a base64 string. Encrypts the -// InboundGroupSession using the supplied key. -func (s *InboundGroupSession) Pickle(key []byte) []byte { - if len(key) == 0 { - panic(NoKeyProvided) - } - pickled := make([]byte, s.pickleLen()) - r := C.olm_pickle_inbound_group_session( - (*C.OlmInboundGroupSession)(s.int), - unsafe.Pointer(&key[0]), - C.size_t(len(key)), - unsafe.Pointer(&pickled[0]), - C.size_t(len(pickled))) - if r == errorVal() { - panic(s.lastError()) - } - return pickled[:r] -} - -func (s *InboundGroupSession) Unpickle(pickled, key []byte) error { - if len(key) == 0 { - return NoKeyProvided - } else if len(pickled) == 0 { - return EmptyInput - } - r := C.olm_unpickle_inbound_group_session( - (*C.OlmInboundGroupSession)(s.int), - unsafe.Pointer(&key[0]), - C.size_t(len(key)), - unsafe.Pointer(&pickled[0]), - C.size_t(len(pickled))) - if r == errorVal() { - return s.lastError() - } - return nil -} - -// Deprecated -func (s *InboundGroupSession) GobEncode() ([]byte, error) { - pickled := s.Pickle(pickleKey) - length := base64.RawStdEncoding.DecodedLen(len(pickled)) - rawPickled := make([]byte, length) - _, err := base64.RawStdEncoding.Decode(rawPickled, pickled) - return rawPickled, err -} - -// Deprecated -func (s *InboundGroupSession) GobDecode(rawPickled []byte) error { - if s == nil || s.int == nil { - *s = *NewBlankInboundGroupSession() - } - length := base64.RawStdEncoding.EncodedLen(len(rawPickled)) - pickled := make([]byte, length) - base64.RawStdEncoding.Encode(pickled, rawPickled) - return s.Unpickle(pickled, pickleKey) -} - -// Deprecated -func (s *InboundGroupSession) MarshalJSON() ([]byte, error) { - pickled := s.Pickle(pickleKey) - quotes := make([]byte, len(pickled)+2) - quotes[0] = '"' - quotes[len(quotes)-1] = '"' - copy(quotes[1:len(quotes)-1], pickled) - return quotes, nil -} - -// Deprecated -func (s *InboundGroupSession) UnmarshalJSON(data []byte) error { - if len(data) == 0 || data[0] != '"' || data[len(data)-1] != '"' { - return InputNotJSONString - } - if s == nil || s.int == nil { - *s = *NewBlankInboundGroupSession() - } - return s.Unpickle(data[1:len(data)-1], pickleKey) -} - -func clone(original []byte) []byte { - clone := make([]byte, len(original)) - copy(clone, original) - return clone -} - -// decryptMaxPlaintextLen returns the maximum number of bytes of plain-text a -// given message could decode to. The actual size could be different due to -// padding. Returns error on failure. If the message base64 couldn't be -// decoded then the error will be "INVALID_BASE64". If the message is for an -// unsupported version of the protocol then the error will be -// "BAD_MESSAGE_VERSION". If the message couldn't be decoded then the error -// will be "BAD_MESSAGE_FORMAT". -func (s *InboundGroupSession) decryptMaxPlaintextLen(message []byte) (uint, error) { - if len(message) == 0 { - return 0, EmptyInput - } - // olm_group_decrypt_max_plaintext_length destroys the input, so we have to clone it - message = clone(message) - r := C.olm_group_decrypt_max_plaintext_length( - (*C.OlmInboundGroupSession)(s.int), - (*C.uint8_t)(&message[0]), - C.size_t(len(message))) - if r == errorVal() { - return 0, s.lastError() - } - return uint(r), nil -} - -// Decrypt decrypts a message using the InboundGroupSession. Returns the the -// plain-text and message index on success. Returns error on failure. If the -// base64 couldn't be decoded then the error will be "INVALID_BASE64". If the -// message is for an unsupported version of the protocol then the error will be -// "BAD_MESSAGE_VERSION". If the message couldn't be decoded then the error -// will be BAD_MESSAGE_FORMAT". If the MAC on the message was invalid then the -// error will be "BAD_MESSAGE_MAC". If we do not have a session key -// corresponding to the message's index (ie, it was sent before the session key -// was shared with us) the error will be "OLM_UNKNOWN_MESSAGE_INDEX". -func (s *InboundGroupSession) Decrypt(message []byte) ([]byte, uint, error) { - if len(message) == 0 { - return nil, 0, EmptyInput - } - decryptMaxPlaintextLen, err := s.decryptMaxPlaintextLen(message) - if err != nil { - return nil, 0, err - } - messageCopy := make([]byte, len(message)) - copy(messageCopy, message) - plaintext := make([]byte, decryptMaxPlaintextLen) - var messageIndex uint32 - r := C.olm_group_decrypt( - (*C.OlmInboundGroupSession)(s.int), - (*C.uint8_t)(&messageCopy[0]), - C.size_t(len(messageCopy)), - (*C.uint8_t)(&plaintext[0]), - C.size_t(len(plaintext)), - (*C.uint32_t)(&messageIndex)) - if r == errorVal() { - return nil, 0, s.lastError() - } - return plaintext[:r], uint(messageIndex), nil -} - -// sessionIdLen returns the number of bytes needed to store a session ID. -func (s *InboundGroupSession) sessionIdLen() uint { - return uint(C.olm_inbound_group_session_id_length((*C.OlmInboundGroupSession)(s.int))) -} - -// ID returns a base64-encoded identifier for this session. -func (s *InboundGroupSession) ID() id.SessionID { - sessionID := make([]byte, s.sessionIdLen()) - r := C.olm_inbound_group_session_id( - (*C.OlmInboundGroupSession)(s.int), - (*C.uint8_t)(&sessionID[0]), - C.size_t(len(sessionID))) - if r == errorVal() { - panic(s.lastError()) - } - return id.SessionID(sessionID[:r]) -} - -// FirstKnownIndex returns the first message index we know how to decrypt. -func (s *InboundGroupSession) FirstKnownIndex() uint32 { - return uint32(C.olm_inbound_group_session_first_known_index((*C.OlmInboundGroupSession)(s.int))) -} - -// IsVerified check if the session has been verified as a valid session. (A -// session is verified either because the original session share was signed, or -// because we have subsequently successfully decrypted a message.) -func (s *InboundGroupSession) IsVerified() uint { - return uint(C.olm_inbound_group_session_is_verified((*C.OlmInboundGroupSession)(s.int))) -} - -// exportLen returns the number of bytes needed to export an inbound group -// session. -func (s *InboundGroupSession) exportLen() uint { - return uint(C.olm_export_inbound_group_session_length((*C.OlmInboundGroupSession)(s.int))) -} - -// Export returns the base64-encoded ratchet key for this session, at the given -// index, in a format which can be used by -// InboundGroupSession.InboundGroupSessionImport(). Encrypts the -// InboundGroupSession using the supplied key. Returns error on failure. -// if we do not have a session key corresponding to the given index (ie, it was -// sent before the session key was shared with us) the error will be -// "OLM_UNKNOWN_MESSAGE_INDEX". -func (s *InboundGroupSession) Export(messageIndex uint32) ([]byte, error) { - key := make([]byte, s.exportLen()) - r := C.olm_export_inbound_group_session( - (*C.OlmInboundGroupSession)(s.int), - (*C.uint8_t)(&key[0]), - C.size_t(len(key)), - C.uint32_t(messageIndex)) - if r == errorVal() { - return nil, s.lastError() - } - return key[:r], nil +func NewBlankInboundGroupSession() InboundGroupSession { + return InitBlankInboundGroupSession() } diff --git a/crypto/olm/inboundgroupsession_goolm.go b/crypto/olm/inboundgroupsession_goolm.go deleted file mode 100644 index 4e561cf7..00000000 --- a/crypto/olm/inboundgroupsession_goolm.go +++ /dev/null @@ -1,149 +0,0 @@ -//go:build goolm - -package olm - -import ( - "maunium.net/go/mautrix/crypto/goolm/session" - "maunium.net/go/mautrix/id" -) - -// InboundGroupSession stores an inbound encrypted messaging session for a -// group. -type InboundGroupSession struct { - session.MegolmInboundSession -} - -// InboundGroupSessionFromPickled loads an InboundGroupSession from a pickled -// base64 string. Decrypts the InboundGroupSession using the supplied key. -// Returns error on failure. -func InboundGroupSessionFromPickled(pickled, key []byte) (*InboundGroupSession, error) { - if len(pickled) == 0 { - return nil, EmptyInput - } - lenKey := len(key) - if lenKey == 0 { - key = []byte(" ") - } - megolmSession, err := session.MegolmInboundSessionFromPickled(pickled, key) - if err != nil { - return nil, err - } - return &InboundGroupSession{ - MegolmInboundSession: *megolmSession, - }, nil -} - -// NewInboundGroupSession creates a new inbound group session from a key -// exported from OutboundGroupSession.Key(). Returns error on failure. -func NewInboundGroupSession(sessionKey []byte) (*InboundGroupSession, error) { - if len(sessionKey) == 0 { - return nil, EmptyInput - } - megolmSession, err := session.NewMegolmInboundSession(sessionKey) - if err != nil { - return nil, err - } - return &InboundGroupSession{ - MegolmInboundSession: *megolmSession, - }, nil -} - -// InboundGroupSessionImport imports an inbound group session from a previous -// export. Returns error on failure. -func InboundGroupSessionImport(sessionKey []byte) (*InboundGroupSession, error) { - if len(sessionKey) == 0 { - return nil, EmptyInput - } - megolmSession, err := session.NewMegolmInboundSessionFromExport(sessionKey) - if err != nil { - return nil, err - } - return &InboundGroupSession{ - MegolmInboundSession: *megolmSession, - }, nil -} - -func NewBlankInboundGroupSession() *InboundGroupSession { - return &InboundGroupSession{} -} - -// Clear clears the memory used to back this InboundGroupSession. -func (s *InboundGroupSession) Clear() error { - s.MegolmInboundSession = session.MegolmInboundSession{} - return nil -} - -// Pickle returns an InboundGroupSession as a base64 string. Encrypts the -// InboundGroupSession using the supplied key. -func (s *InboundGroupSession) Pickle(key []byte) []byte { - if len(key) == 0 { - panic(NoKeyProvided) - } - pickled, err := s.MegolmInboundSession.Pickle(key) - if err != nil { - panic(err) - } - return pickled -} - -func (s *InboundGroupSession) Unpickle(pickled, key []byte) error { - if len(key) == 0 { - return NoKeyProvided - } else if len(pickled) == 0 { - return EmptyInput - } - sOlm, err := session.MegolmInboundSessionFromPickled(pickled, key) - if err != nil { - return err - } - s.MegolmInboundSession = *sOlm - return nil -} - -// Decrypt decrypts a message using the InboundGroupSession. Returns the the -// plain-text and message index on success. Returns error on failure. -func (s *InboundGroupSession) Decrypt(message []byte) ([]byte, uint, error) { - if len(message) == 0 { - return nil, 0, EmptyInput - } - plaintext, messageIndex, err := s.MegolmInboundSession.Decrypt(message) - if err != nil { - return nil, 0, err - } - return plaintext, uint(messageIndex), nil -} - -// ID returns a base64-encoded identifier for this session. -func (s *InboundGroupSession) ID() id.SessionID { - return s.MegolmInboundSession.SessionID() -} - -// FirstKnownIndex returns the first message index we know how to decrypt. -func (s *InboundGroupSession) FirstKnownIndex() uint32 { - return s.MegolmInboundSession.InitialRatchet.Counter -} - -// IsVerified check if the session has been verified as a valid session. (A -// session is verified either because the original session share was signed, or -// because we have subsequently successfully decrypted a message.) -func (s *InboundGroupSession) IsVerified() uint { - if s.MegolmInboundSession.SigningKeyVerified { - return 1 - } - return 0 -} - -// Export returns the base64-encoded ratchet key for this session, at the given -// index, in a format which can be used by -// InboundGroupSession.InboundGroupSessionImport(). Encrypts the -// InboundGroupSession using the supplied key. Returns error on failure. -// if we do not have a session key corresponding to the given index (ie, it was -// sent before the session key was shared with us) the error will be -// returned. -func (s *InboundGroupSession) Export(messageIndex uint32) ([]byte, error) { - res, err := s.MegolmInboundSession.SessionExportMessage(messageIndex) - if err != nil { - return nil, err - } - return res, nil -} diff --git a/crypto/olm/olm.go b/crypto/olm/olm.go index 685e1b6b..fa2345e1 100644 --- a/crypto/olm/olm.go +++ b/crypto/olm/olm.go @@ -1,34 +1,20 @@ -//go:build !goolm +// Copyright (c) 2024 Sumner Evans +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. package olm -// #cgo LDFLAGS: -lolm -lstdc++ -// #include -import "C" -import ( - "maunium.net/go/mautrix/id" -) - -// Signatures is the data structure used to sign JSON objects. -type Signatures map[id.UserID]map[id.DeviceKeyID]string +var GetVersion func() (major, minor, patch uint8) +var SetPickleKeyImpl func(key []byte) // Version returns the version number of the olm library. func Version() (major, minor, patch uint8) { - C.olm_get_library_version( - (*C.uint8_t)(&major), - (*C.uint8_t)(&minor), - (*C.uint8_t)(&patch)) - return + return GetVersion() } -// errorVal returns the value that olm functions return if there was an error. -func errorVal() C.size_t { - return C.olm_error() -} - -var pickleKey = []byte("maunium.net/go/mautrix/crypto/olm") - // SetPickleKey sets the global pickle key used when encoding structs with Gob or JSON. func SetPickleKey(key []byte) { - pickleKey = key + SetPickleKeyImpl(key) } diff --git a/crypto/olm/olm_goolm.go b/crypto/olm/olm_goolm.go deleted file mode 100644 index dbe12a76..00000000 --- a/crypto/olm/olm_goolm.go +++ /dev/null @@ -1,20 +0,0 @@ -//go:build goolm - -package olm - -import ( - "maunium.net/go/mautrix/id" -) - -// Signatures is the data structure used to sign JSON objects. -type Signatures map[id.UserID]map[id.DeviceKeyID]string - -// Version returns the version number of the olm library. -func Version() (major, minor, patch uint8) { - return 3, 2, 15 -} - -// SetPickleKey sets the global pickle key used when encoding structs with Gob or JSON. -func SetPickleKey(key []byte) { - panic("gob and json encoding is deprecated and not supported with goolm") -} diff --git a/crypto/olm/outboundgroupsession.go b/crypto/olm/outboundgroupsession.go index b6a33d36..7e582b7e 100644 --- a/crypto/olm/outboundgroupsession.go +++ b/crypto/olm/outboundgroupsession.go @@ -1,239 +1,57 @@ -//go:build !goolm +// Copyright (c) 2024 Sumner Evans +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. package olm -// #cgo LDFLAGS: -lolm -lstdc++ -// #include -import "C" +import "maunium.net/go/mautrix/id" -import ( - "crypto/rand" - "encoding/base64" - "unsafe" +type OutboundGroupSession interface { + // Pickle returns a Session as a base64 string. Encrypts the Session using + // the supplied key. + Pickle(key []byte) ([]byte, error) - "maunium.net/go/mautrix/id" -) + // Unpickle loads an [OutboundGroupSession] from a pickled base64 string. + // Decrypts the [OutboundGroupSession] using the supplied key. + Unpickle(pickled, key []byte) error -// OutboundGroupSession stores an outbound encrypted messaging session for a -// group. -type OutboundGroupSession struct { - int *C.OlmOutboundGroupSession - mem []byte + // Encrypt encrypts a message using the [OutboundGroupSession]. Returns the + // encrypted message as base64. + Encrypt(plaintext []byte) ([]byte, error) + + // ID returns a base64-encoded identifier for this session. + ID() id.SessionID + + // MessageIndex returns the message index for this session. Each message + // is sent with an increasing index; this returns the index for the next + // message. + MessageIndex() uint + + // Key returns the base64-encoded current ratchet key for this session. + Key() string } +var InitNewOutboundGroupSessionFromPickled func(pickled, key []byte) (OutboundGroupSession, error) +var InitNewOutboundGroupSession func() (OutboundGroupSession, error) +var InitNewBlankOutboundGroupSession func() OutboundGroupSession + // OutboundGroupSessionFromPickled loads an OutboundGroupSession from a pickled // base64 string. Decrypts the OutboundGroupSession using the supplied key. // Returns error on failure. If the key doesn't match the one used to encrypt // the OutboundGroupSession then the error will be "BAD_SESSION_KEY". If the // base64 couldn't be decoded then the error will be "INVALID_BASE64". -func OutboundGroupSessionFromPickled(pickled, key []byte) (*OutboundGroupSession, error) { - if len(pickled) == 0 { - return nil, EmptyInput - } - s := NewBlankOutboundGroupSession() - return s, s.Unpickle(pickled, key) +func OutboundGroupSessionFromPickled(pickled, key []byte) (OutboundGroupSession, error) { + return InitNewOutboundGroupSessionFromPickled(pickled, key) } // NewOutboundGroupSession creates a new outbound group session. -func NewOutboundGroupSession() *OutboundGroupSession { - s := NewBlankOutboundGroupSession() - random := make([]byte, s.createRandomLen()+1) - _, err := rand.Read(random) - if err != nil { - panic(NotEnoughGoRandom) - } - r := C.olm_init_outbound_group_session( - (*C.OlmOutboundGroupSession)(s.int), - (*C.uint8_t)(&random[0]), - C.size_t(len(random))) - if r == errorVal() { - panic(s.lastError()) - } - return s +func NewOutboundGroupSession() (OutboundGroupSession, error) { + return InitNewOutboundGroupSession() } -// outboundGroupSessionSize is the size of an outbound group session object in -// bytes. -func outboundGroupSessionSize() uint { - return uint(C.olm_outbound_group_session_size()) -} - -// newOutboundGroupSession initialises an empty OutboundGroupSession. -func NewBlankOutboundGroupSession() *OutboundGroupSession { - memory := make([]byte, outboundGroupSessionSize()) - return &OutboundGroupSession{ - int: C.olm_outbound_group_session(unsafe.Pointer(&memory[0])), - mem: memory, - } -} - -// lastError returns an error describing the most recent error to happen to an -// outbound group session. -func (s *OutboundGroupSession) lastError() error { - return convertError(C.GoString(C.olm_outbound_group_session_last_error((*C.OlmOutboundGroupSession)(s.int)))) -} - -// Clear clears the memory used to back this OutboundGroupSession. -func (s *OutboundGroupSession) Clear() error { - r := C.olm_clear_outbound_group_session((*C.OlmOutboundGroupSession)(s.int)) - if r == errorVal() { - return s.lastError() - } else { - return nil - } -} - -// pickleLen returns the number of bytes needed to store an outbound group -// session. -func (s *OutboundGroupSession) pickleLen() uint { - return uint(C.olm_pickle_outbound_group_session_length((*C.OlmOutboundGroupSession)(s.int))) -} - -// Pickle returns an OutboundGroupSession as a base64 string. Encrypts the -// OutboundGroupSession using the supplied key. -func (s *OutboundGroupSession) Pickle(key []byte) []byte { - if len(key) == 0 { - panic(NoKeyProvided) - } - pickled := make([]byte, s.pickleLen()) - r := C.olm_pickle_outbound_group_session( - (*C.OlmOutboundGroupSession)(s.int), - unsafe.Pointer(&key[0]), - C.size_t(len(key)), - unsafe.Pointer(&pickled[0]), - C.size_t(len(pickled))) - if r == errorVal() { - panic(s.lastError()) - } - return pickled[:r] -} - -func (s *OutboundGroupSession) Unpickle(pickled, key []byte) error { - if len(key) == 0 { - return NoKeyProvided - } - r := C.olm_unpickle_outbound_group_session( - (*C.OlmOutboundGroupSession)(s.int), - unsafe.Pointer(&key[0]), - C.size_t(len(key)), - unsafe.Pointer(&pickled[0]), - C.size_t(len(pickled))) - if r == errorVal() { - return s.lastError() - } - return nil -} - -// Deprecated -func (s *OutboundGroupSession) GobEncode() ([]byte, error) { - pickled := s.Pickle(pickleKey) - length := base64.RawStdEncoding.DecodedLen(len(pickled)) - rawPickled := make([]byte, length) - _, err := base64.RawStdEncoding.Decode(rawPickled, pickled) - return rawPickled, err -} - -// Deprecated -func (s *OutboundGroupSession) GobDecode(rawPickled []byte) error { - if s == nil || s.int == nil { - *s = *NewBlankOutboundGroupSession() - } - length := base64.RawStdEncoding.EncodedLen(len(rawPickled)) - pickled := make([]byte, length) - base64.RawStdEncoding.Encode(pickled, rawPickled) - return s.Unpickle(pickled, pickleKey) -} - -// Deprecated -func (s *OutboundGroupSession) MarshalJSON() ([]byte, error) { - pickled := s.Pickle(pickleKey) - quotes := make([]byte, len(pickled)+2) - quotes[0] = '"' - quotes[len(quotes)-1] = '"' - copy(quotes[1:len(quotes)-1], pickled) - return quotes, nil -} - -// Deprecated -func (s *OutboundGroupSession) UnmarshalJSON(data []byte) error { - if len(data) == 0 || data[0] != '"' || data[len(data)-1] != '"' { - return InputNotJSONString - } - if s == nil || s.int == nil { - *s = *NewBlankOutboundGroupSession() - } - return s.Unpickle(data[1:len(data)-1], pickleKey) -} - -// createRandomLen returns the number of random bytes needed to create an -// Account. -func (s *OutboundGroupSession) createRandomLen() uint { - return uint(C.olm_init_outbound_group_session_random_length((*C.OlmOutboundGroupSession)(s.int))) -} - -// encryptMsgLen returns the size of the next message in bytes for the given -// number of plain-text bytes. -func (s *OutboundGroupSession) encryptMsgLen(plainTextLen int) uint { - return uint(C.olm_group_encrypt_message_length((*C.OlmOutboundGroupSession)(s.int), C.size_t(plainTextLen))) -} - -// Encrypt encrypts a message using the Session. Returns the encrypted message -// as base64. -func (s *OutboundGroupSession) Encrypt(plaintext []byte) []byte { - if len(plaintext) == 0 { - panic(EmptyInput) - } - message := make([]byte, s.encryptMsgLen(len(plaintext))) - r := C.olm_group_encrypt( - (*C.OlmOutboundGroupSession)(s.int), - (*C.uint8_t)(&plaintext[0]), - C.size_t(len(plaintext)), - (*C.uint8_t)(&message[0]), - C.size_t(len(message))) - if r == errorVal() { - panic(s.lastError()) - } - return message[:r] -} - -// sessionIdLen returns the number of bytes needed to store a session ID. -func (s *OutboundGroupSession) sessionIdLen() uint { - return uint(C.olm_outbound_group_session_id_length((*C.OlmOutboundGroupSession)(s.int))) -} - -// ID returns a base64-encoded identifier for this session. -func (s *OutboundGroupSession) ID() id.SessionID { - sessionID := make([]byte, s.sessionIdLen()) - r := C.olm_outbound_group_session_id( - (*C.OlmOutboundGroupSession)(s.int), - (*C.uint8_t)(&sessionID[0]), - C.size_t(len(sessionID))) - if r == errorVal() { - panic(s.lastError()) - } - return id.SessionID(sessionID[:r]) -} - -// MessageIndex returns the message index for this session. Each message is -// sent with an increasing index; this returns the index for the next message. -func (s *OutboundGroupSession) MessageIndex() uint { - return uint(C.olm_outbound_group_session_message_index((*C.OlmOutboundGroupSession)(s.int))) -} - -// sessionKeyLen returns the number of bytes needed to store a session key. -func (s *OutboundGroupSession) sessionKeyLen() uint { - return uint(C.olm_outbound_group_session_key_length((*C.OlmOutboundGroupSession)(s.int))) -} - -// Key returns the base64-encoded current ratchet key for this session. -func (s *OutboundGroupSession) Key() string { - sessionKey := make([]byte, s.sessionKeyLen()) - r := C.olm_outbound_group_session_key( - (*C.OlmOutboundGroupSession)(s.int), - (*C.uint8_t)(&sessionKey[0]), - C.size_t(len(sessionKey))) - if r == errorVal() { - panic(s.lastError()) - } - return string(sessionKey[:r]) +// NewBlankOutboundGroupSession initialises an empty [OutboundGroupSession]. +func NewBlankOutboundGroupSession() OutboundGroupSession { + return InitNewBlankOutboundGroupSession() } diff --git a/crypto/olm/outboundgroupsession_goolm.go b/crypto/olm/outboundgroupsession_goolm.go deleted file mode 100644 index 7c201213..00000000 --- a/crypto/olm/outboundgroupsession_goolm.go +++ /dev/null @@ -1,111 +0,0 @@ -//go:build goolm - -package olm - -import ( - "maunium.net/go/mautrix/crypto/goolm/session" - "maunium.net/go/mautrix/id" -) - -// OutboundGroupSession stores an outbound encrypted messaging session for a -// group. -type OutboundGroupSession struct { - session.MegolmOutboundSession -} - -// OutboundGroupSessionFromPickled loads an OutboundGroupSession from a pickled -// base64 string. Decrypts the OutboundGroupSession using the supplied key. -// Returns error on failure. If the key doesn't match the one used to encrypt -// the OutboundGroupSession then the error will be "BAD_SESSION_KEY". If the -// base64 couldn't be decoded then the error will be "INVALID_BASE64". -func OutboundGroupSessionFromPickled(pickled, key []byte) (*OutboundGroupSession, error) { - if len(pickled) == 0 { - return nil, EmptyInput - } - lenKey := len(key) - if lenKey == 0 { - key = []byte(" ") - } - megolmSession, err := session.MegolmOutboundSessionFromPickled(pickled, key) - if err != nil { - return nil, err - } - return &OutboundGroupSession{ - MegolmOutboundSession: *megolmSession, - }, nil -} - -// NewOutboundGroupSession creates a new outbound group session. -func NewOutboundGroupSession() *OutboundGroupSession { - megolmSession, err := session.NewMegolmOutboundSession() - if err != nil { - panic(err) - } - return &OutboundGroupSession{ - MegolmOutboundSession: *megolmSession, - } -} - -// newOutboundGroupSession initialises an empty OutboundGroupSession. -func NewBlankOutboundGroupSession() *OutboundGroupSession { - return &OutboundGroupSession{} -} - -// Clear clears the memory used to back this OutboundGroupSession. -func (s *OutboundGroupSession) Clear() error { - s.MegolmOutboundSession = session.MegolmOutboundSession{} - return nil -} - -// Pickle returns an OutboundGroupSession as a base64 string. Encrypts the -// OutboundGroupSession using the supplied key. -func (s *OutboundGroupSession) Pickle(key []byte) []byte { - if len(key) == 0 { - panic(NoKeyProvided) - } - pickled, err := s.MegolmOutboundSession.Pickle(key) - if err != nil { - panic(err) - } - return pickled -} - -func (s *OutboundGroupSession) Unpickle(pickled, key []byte) error { - if len(key) == 0 { - return NoKeyProvided - } - return s.MegolmOutboundSession.Unpickle(pickled, key) -} - -// Encrypt encrypts a message using the Session. Returns the encrypted message -// as base64. -func (s *OutboundGroupSession) Encrypt(plaintext []byte) []byte { - if len(plaintext) == 0 { - panic(EmptyInput) - } - message, err := s.MegolmOutboundSession.Encrypt(plaintext) - if err != nil { - panic(err) - } - return message -} - -// ID returns a base64-encoded identifier for this session. -func (s *OutboundGroupSession) ID() id.SessionID { - return s.MegolmOutboundSession.SessionID() -} - -// MessageIndex returns the message index for this session. Each message is -// sent with an increasing index; this returns the index for the next message. -func (s *OutboundGroupSession) MessageIndex() uint { - return uint(s.MegolmOutboundSession.Ratchet.Counter) -} - -// Key returns the base64-encoded current ratchet key for this session. -func (s *OutboundGroupSession) Key() string { - message, err := s.MegolmOutboundSession.SessionSharingMessage() - if err != nil { - panic(err) - } - return string(message) -} diff --git a/crypto/olm/outboundgroupsession_test.go b/crypto/olm/outboundgroupsession_test.go new file mode 100644 index 00000000..cbbc89f7 --- /dev/null +++ b/crypto/olm/outboundgroupsession_test.go @@ -0,0 +1,133 @@ +package olm_test + +import ( + "bytes" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "maunium.net/go/mautrix/crypto/goolm/session" + "maunium.net/go/mautrix/crypto/libolm" +) + +func TestMegolmOutboundSessionPickle_RoundtripThroughGoolm(t *testing.T) { + libolmSession, err := libolm.NewOutboundGroupSession() + require.NoError(t, err) + libolmPickled, err := libolmSession.Pickle([]byte("test")) + require.NoError(t, err) + + goolmSession, err := session.MegolmOutboundSessionFromPickled(libolmPickled, []byte("test")) + require.NoError(t, err) + + goolmPickled, err := goolmSession.Pickle([]byte("test")) + require.NoError(t, err) + + assert.Equal(t, libolmPickled, goolmPickled, "pickled versions are not the same") + + libolmSession2, err := libolm.NewOutboundGroupSession() + require.NoError(t, err) + err = libolmSession2.Unpickle(bytes.Clone(goolmPickled), []byte("test")) + require.NoError(t, err) + + assert.Equal(t, libolmSession.Key(), libolmSession2.Key()) +} + +func TestMegolmOutboundSessionPickle_RoundtripThroughLibolm(t *testing.T) { + goolmSession, err := session.NewMegolmOutboundSession() + require.NoError(t, err) + + goolmPickled, err := goolmSession.Pickle([]byte("test")) + require.NoError(t, err) + + libolmSession, err := libolm.NewOutboundGroupSession() + require.NoError(t, err) + err = libolmSession.Unpickle(bytes.Clone(goolmPickled), []byte("test")) + require.NoError(t, err) + + libolmPickled, err := libolmSession.Pickle([]byte("test")) + require.NoError(t, err) + + assert.Equal(t, goolmPickled, libolmPickled, "pickled versions are not the same") + + goolmSession2, err := session.MegolmOutboundSessionFromPickled(libolmPickled, []byte("test")) + require.NoError(t, err) + + assert.Equal(t, goolmSession.Key(), goolmSession2.Key()) + assert.Equal(t, goolmSession.SigningKey.PrivateKey, goolmSession2.SigningKey.PrivateKey) +} + +func TestMegolmOutboundSessionPickleLibolm(t *testing.T) { + libolmSession, err := libolm.NewOutboundGroupSession() + require.NoError(t, err) + libolmPickled, err := libolmSession.Pickle([]byte("test")) + require.NoError(t, err) + + goolmSession, err := session.MegolmOutboundSessionFromPickled(bytes.Clone(libolmPickled), []byte("test")) + require.NoError(t, err) + goolmPickled, err := goolmSession.Pickle([]byte("test")) + require.NoError(t, err) + + assert.Equal(t, libolmPickled, goolmPickled, "pickled versions are not the same") + assert.Equal(t, goolmSession.SigningKey.PrivateKey.PubKey(), goolmSession.SigningKey.PublicKey) + + // Ensure that the key export is the same and that the pickle is the same + assert.Equal(t, libolmSession.Key(), goolmSession.Key(), "keys are not the same") +} + +func TestMegolmOutboundSessionPickleGoolm(t *testing.T) { + goolmSession, err := session.NewMegolmOutboundSession() + require.NoError(t, err) + goolmPickled, err := goolmSession.Pickle([]byte("test")) + require.NoError(t, err) + + libolmSession, err := libolm.NewOutboundGroupSession() + require.NoError(t, err) + err = libolmSession.Unpickle(bytes.Clone(goolmPickled), []byte("test")) + require.NoError(t, err) + libolmPickled, err := libolmSession.Pickle([]byte("test")) + require.NoError(t, err) + + assert.Equal(t, libolmPickled, goolmPickled, "pickled versions are not the same") + assert.Equal(t, goolmSession.SigningKey.PrivateKey.PubKey(), goolmSession.SigningKey.PublicKey) + + // Ensure that the key export is the same and that the pickle is the same + assert.Equal(t, libolmSession.Key(), goolmSession.Key(), "keys are not the same") +} + +func FuzzMegolmOutboundSession_Encrypt(f *testing.F) { + f.Add([]byte("anything")) + + f.Fuzz(func(t *testing.T, plaintext []byte) { + if len(plaintext) == 0 { + t.Skip("empty plaintext is not supported") + } + + libolmSession, err := libolm.NewOutboundGroupSession() + require.NoError(t, err) + libolmPickled, err := libolmSession.Pickle([]byte("test")) + require.NoError(t, err) + + goolmSession, err := session.MegolmOutboundSessionFromPickled(bytes.Clone(libolmPickled), []byte("test")) + require.NoError(t, err) + + assert.Equal(t, libolmSession.Key(), goolmSession.Key()) + + // Encrypt the plaintext ten times because the ratchet increments. + for i := 0; i < 10; i++ { + assert.EqualValues(t, i, libolmSession.MessageIndex()) + assert.EqualValues(t, i, goolmSession.MessageIndex()) + + libolmEncrypted, err := libolmSession.Encrypt(plaintext) + require.NoError(t, err) + + goolmEncrypted, err := goolmSession.Encrypt(plaintext) + require.NoError(t, err) + + assert.Equal(t, libolmEncrypted, goolmEncrypted) + + assert.EqualValues(t, i+1, libolmSession.MessageIndex()) + assert.EqualValues(t, i+1, goolmSession.MessageIndex()) + } + }) +} diff --git a/crypto/olm/pk_interface.go b/crypto/olm/pk.go similarity index 52% rename from crypto/olm/pk_interface.go rename to crypto/olm/pk.go index 11c41431..70ee452d 100644 --- a/crypto/olm/pk_interface.go +++ b/crypto/olm/pk.go @@ -7,7 +7,6 @@ package olm import ( - "maunium.net/go/mautrix/crypto/goolm/pk" "maunium.net/go/mautrix/id" ) @@ -27,15 +26,32 @@ type PKSigning interface { SignJSON(obj any) (string, error) } -var _ PKSigning = (*pk.Signing)(nil) - // PKDecryption is an interface for decrypting messages. type PKDecryption interface { // PublicKey returns the public key. PublicKey() id.Curve25519 // Decrypt verifies and decrypts the given message. - Decrypt(ciphertext, mac []byte, key id.Curve25519) ([]byte, error) + Decrypt(ephemeralKey, mac, ciphertext []byte) ([]byte, error) } -var _ PKDecryption = (*pk.Decryption)(nil) +var InitNewPKSigning func() (PKSigning, error) +var InitNewPKSigningFromSeed func(seed []byte) (PKSigning, error) +var InitNewPKDecryptionFromPrivateKey func(privateKey []byte) (PKDecryption, error) + +// NewPKSigning creates a new [PKSigning] object, containing a key pair for +// signing messages. +func NewPKSigning() (PKSigning, error) { + return InitNewPKSigning() +} + +// NewPKSigningFromSeed creates a new PKSigning object using the given seed. +func NewPKSigningFromSeed(seed []byte) (PKSigning, error) { + return InitNewPKSigningFromSeed(seed) +} + +// NewPKDecryptionFromPrivateKey creates a new [PKDecryption] from a +// base64-encoded private key. +func NewPKDecryptionFromPrivateKey(privateKey []byte) (PKDecryption, error) { + return InitNewPKDecryptionFromPrivateKey(privateKey) +} diff --git a/crypto/olm/pk_goolm.go b/crypto/olm/pk_goolm.go deleted file mode 100644 index 372c94fa..00000000 --- a/crypto/olm/pk_goolm.go +++ /dev/null @@ -1,29 +0,0 @@ -// Copyright (c) 2024 Sumner Evans -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. - -// When the goolm build flag is enabled, this file will make [PKSigning] -// constructors use the goolm constuctors. - -//go:build goolm - -package olm - -import "maunium.net/go/mautrix/crypto/goolm/pk" - -// NewPKSigningFromSeed creates a new PKSigning object using the given seed. -func NewPKSigningFromSeed(seed []byte) (PKSigning, error) { - return pk.NewSigningFromSeed(seed) -} - -// NewPKSigning creates a new [PKSigning] object, containing a key pair for -// signing messages. -func NewPKSigning() (PKSigning, error) { - return pk.NewSigning() -} - -func NewPKDecryption(privateKey []byte) (PKDecryption, error) { - return pk.NewDecryption() -} diff --git a/crypto/olm/pk_test.go b/crypto/olm/pk_test.go index b57e6571..99ac1e6b 100644 --- a/crypto/olm/pk_test.go +++ b/crypto/olm/pk_test.go @@ -4,8 +4,7 @@ // License, v. 2.0. If a copy of the MPL was not distributed with this // file, You can obtain one at http://mozilla.org/MPL/2.0/. -// Only run this test if goo is disabled (that is, libolm is used). -//go:build !goolm +// Only run this test if goolm is disabled (that is, libolm is used). package olm_test @@ -16,7 +15,7 @@ import ( "github.com/stretchr/testify/require" "maunium.net/go/mautrix/crypto/goolm/pk" - "maunium.net/go/mautrix/crypto/olm" + "maunium.net/go/mautrix/crypto/libolm" ) func FuzzSign(f *testing.F) { @@ -24,7 +23,7 @@ func FuzzSign(f *testing.F) { goolmPkSigning, err := pk.NewSigningFromSeed(seed) require.NoError(f, err) - libolmPkSigning, err := olm.NewPKSigningFromSeed(seed) + libolmPkSigning, err := libolm.NewPKSigningFromSeed(seed) require.NoError(f, err) f.Add([]byte("message")) diff --git a/crypto/olm/session.go b/crypto/olm/session.go index 185e0b3d..c4b91ffc 100644 --- a/crypto/olm/session.go +++ b/crypto/olm/session.go @@ -1,362 +1,83 @@ -//go:build !goolm +// Copyright (c) 2024 Sumner Evans +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. package olm -// #cgo LDFLAGS: -lolm -lstdc++ -// #include -// #include -// #include -// void olm_session_describe(OlmSession * session, char *buf, size_t buflen) __attribute__((weak)); -// void meowlm_session_describe(OlmSession * session, char *buf, size_t buflen) { -// if (olm_session_describe) { -// olm_session_describe(session, buf, buflen); -// } else { -// sprintf(buf, "olm_session_describe not supported"); -// } -// } -import "C" +import "maunium.net/go/mautrix/id" -import ( - "crypto/rand" - "encoding/base64" - "unsafe" +type Session interface { + // Pickle returns a Session as a base64 string. Encrypts the Session using + // the supplied key. + Pickle(key []byte) ([]byte, error) - "maunium.net/go/mautrix/id" -) + // Unpickle loads a Session from a pickled base64 string. Decrypts the + // Session using the supplied key. + Unpickle(pickled, key []byte) error -// Session stores an end to end encrypted messaging session. -type Session struct { - int *C.OlmSession - mem []byte + // ID returns an identifier for this Session. Will be the same for both + // ends of the conversation. + ID() id.SessionID + + // HasReceivedMessage returns true if this session has received any + // message. + HasReceivedMessage() bool + + // MatchesInboundSession checks if the PRE_KEY message is for this in-bound + // Session. This can happen if multiple messages are sent to this Account + // before this Account sends a message in reply. Returns true if the + // session matches. Returns false if the session does not match. Returns + // error on failure. If the base64 couldn't be decoded then the error will + // be "INVALID_BASE64". If the message was for an unsupported protocol + // version then the error will be "BAD_MESSAGE_VERSION". If the message + // couldn't be decoded then then the error will be "BAD_MESSAGE_FORMAT". + MatchesInboundSession(oneTimeKeyMsg string) (bool, error) + + // MatchesInboundSessionFrom checks if the PRE_KEY message is for this + // in-bound Session. This can happen if multiple messages are sent to this + // Account before this Account sends a message in reply. Returns true if + // the session matches. Returns false if the session does not match. + // Returns error on failure. If the base64 couldn't be decoded then the + // error will be "INVALID_BASE64". If the message was for an unsupported + // protocol version then the error will be "BAD_MESSAGE_VERSION". If the + // message couldn't be decoded then then the error will be + // "BAD_MESSAGE_FORMAT". + MatchesInboundSessionFrom(theirIdentityKey, oneTimeKeyMsg string) (bool, error) + + // EncryptMsgType returns the type of the next message that Encrypt will + // return. Returns MsgTypePreKey if the message will be a PRE_KEY message. + // Returns MsgTypeMsg if the message will be a normal message. + EncryptMsgType() id.OlmMsgType + + // Encrypt encrypts a message using the Session. Returns the encrypted + // message as base64. + Encrypt(plaintext []byte) (id.OlmMsgType, []byte, error) + + // Decrypt decrypts a message using the Session. Returns the plain-text on + // success. Returns error on failure. If the base64 couldn't be decoded + // then the error will be "INVALID_BASE64". If the message is for an + // unsupported version of the protocol then the error will be + // "BAD_MESSAGE_VERSION". If the message couldn't be decoded then the error + // will be BAD_MESSAGE_FORMAT". If the MAC on the message was invalid then + // the error will be "BAD_MESSAGE_MAC". + Decrypt(message string, msgType id.OlmMsgType) ([]byte, error) + + // Describe generates a string describing the internal state of an olm + // session for debugging and logging purposes. + Describe() string } -// sessionSize is the size of a session object in bytes. -func sessionSize() uint { - return uint(C.olm_session_size()) -} +var InitSessionFromPickled func(pickled, key []byte) (Session, error) +var InitNewBlankSession func() Session // SessionFromPickled loads a Session from a pickled base64 string. Decrypts -// the Session using the supplied key. Returns error on failure. If the key -// doesn't match the one used to encrypt the Session then the error will be -// "BAD_SESSION_KEY". If the base64 couldn't be decoded then the error will be -// "INVALID_BASE64". -func SessionFromPickled(pickled, key []byte) (*Session, error) { - if len(pickled) == 0 { - return nil, EmptyInput - } - s := NewBlankSession() - return s, s.Unpickle(pickled, key) +// the Session using the supplied key. Returns error on failure. +func SessionFromPickled(pickled, key []byte) (Session, error) { + return InitSessionFromPickled(pickled, key) } -func NewBlankSession() *Session { - memory := make([]byte, sessionSize()) - return &Session{ - int: C.olm_session(unsafe.Pointer(&memory[0])), - mem: memory, - } -} - -// lastError returns an error describing the most recent error to happen to a -// session. -func (s *Session) lastError() error { - return convertError(C.GoString(C.olm_session_last_error((*C.OlmSession)(s.int)))) -} - -// Clear clears the memory used to back this Session. -func (s *Session) Clear() error { - r := C.olm_clear_session((*C.OlmSession)(s.int)) - if r == errorVal() { - return s.lastError() - } - return nil -} - -// pickleLen returns the number of bytes needed to store a session. -func (s *Session) pickleLen() uint { - return uint(C.olm_pickle_session_length((*C.OlmSession)(s.int))) -} - -// createOutboundRandomLen returns the number of random bytes needed to create -// an outbound session. -func (s *Session) createOutboundRandomLen() uint { - return uint(C.olm_create_outbound_session_random_length((*C.OlmSession)(s.int))) -} - -// idLen returns the length of the buffer needed to return the id for this -// session. -func (s *Session) idLen() uint { - return uint(C.olm_session_id_length((*C.OlmSession)(s.int))) -} - -// encryptRandomLen returns the number of random bytes needed to encrypt the -// next message. -func (s *Session) encryptRandomLen() uint { - return uint(C.olm_encrypt_random_length((*C.OlmSession)(s.int))) -} - -// encryptMsgLen returns the size of the next message in bytes for the given -// number of plain-text bytes. -func (s *Session) encryptMsgLen(plainTextLen int) uint { - return uint(C.olm_encrypt_message_length((*C.OlmSession)(s.int), C.size_t(plainTextLen))) -} - -// decryptMaxPlaintextLen returns the maximum number of bytes of plain-text a -// given message could decode to. The actual size could be different due to -// padding. Returns error on failure. If the message base64 couldn't be -// decoded then the error will be "INVALID_BASE64". If the message is for an -// unsupported version of the protocol then the error will be -// "BAD_MESSAGE_VERSION". If the message couldn't be decoded then the error -// will be "BAD_MESSAGE_FORMAT". -func (s *Session) decryptMaxPlaintextLen(message string, msgType id.OlmMsgType) (uint, error) { - if len(message) == 0 { - return 0, EmptyInput - } - r := C.olm_decrypt_max_plaintext_length( - (*C.OlmSession)(s.int), - C.size_t(msgType), - unsafe.Pointer(C.CString(message)), - C.size_t(len(message))) - if r == errorVal() { - return 0, s.lastError() - } - return uint(r), nil -} - -// Pickle returns a Session as a base64 string. Encrypts the Session using the -// supplied key. -func (s *Session) Pickle(key []byte) []byte { - if len(key) == 0 { - panic(NoKeyProvided) - } - pickled := make([]byte, s.pickleLen()) - r := C.olm_pickle_session( - (*C.OlmSession)(s.int), - unsafe.Pointer(&key[0]), - C.size_t(len(key)), - unsafe.Pointer(&pickled[0]), - C.size_t(len(pickled))) - if r == errorVal() { - panic(s.lastError()) - } - return pickled[:r] -} - -func (s *Session) Unpickle(pickled, key []byte) error { - if len(key) == 0 { - return NoKeyProvided - } - r := C.olm_unpickle_session( - (*C.OlmSession)(s.int), - unsafe.Pointer(&key[0]), - C.size_t(len(key)), - unsafe.Pointer(&pickled[0]), - C.size_t(len(pickled))) - if r == errorVal() { - return s.lastError() - } - return nil -} - -// Deprecated -func (s *Session) GobEncode() ([]byte, error) { - pickled := s.Pickle(pickleKey) - length := base64.RawStdEncoding.DecodedLen(len(pickled)) - rawPickled := make([]byte, length) - _, err := base64.RawStdEncoding.Decode(rawPickled, pickled) - return rawPickled, err -} - -// Deprecated -func (s *Session) GobDecode(rawPickled []byte) error { - if s == nil || s.int == nil { - *s = *NewBlankSession() - } - length := base64.RawStdEncoding.EncodedLen(len(rawPickled)) - pickled := make([]byte, length) - base64.RawStdEncoding.Encode(pickled, rawPickled) - return s.Unpickle(pickled, pickleKey) -} - -// Deprecated -func (s *Session) MarshalJSON() ([]byte, error) { - pickled := s.Pickle(pickleKey) - quotes := make([]byte, len(pickled)+2) - quotes[0] = '"' - quotes[len(quotes)-1] = '"' - copy(quotes[1:len(quotes)-1], pickled) - return quotes, nil -} - -// Deprecated -func (s *Session) UnmarshalJSON(data []byte) error { - if len(data) == 0 || len(data) == 0 || data[0] != '"' || data[len(data)-1] != '"' { - return InputNotJSONString - } - if s == nil || s.int == nil { - *s = *NewBlankSession() - } - return s.Unpickle(data[1:len(data)-1], pickleKey) -} - -// Id returns an identifier for this Session. Will be the same for both ends -// of the conversation. -func (s *Session) ID() id.SessionID { - sessionID := make([]byte, s.idLen()) - r := C.olm_session_id( - (*C.OlmSession)(s.int), - unsafe.Pointer(&sessionID[0]), - C.size_t(len(sessionID))) - if r == errorVal() { - panic(s.lastError()) - } - return id.SessionID(sessionID) -} - -// HasReceivedMessage returns true if this session has received any message. -func (s *Session) HasReceivedMessage() bool { - switch C.olm_session_has_received_message((*C.OlmSession)(s.int)) { - case 0: - return false - default: - return true - } -} - -// MatchesInboundSession checks if the PRE_KEY message is for this in-bound -// Session. This can happen if multiple messages are sent to this Account -// before this Account sends a message in reply. Returns true if the session -// matches. Returns false if the session does not match. Returns error on -// failure. If the base64 couldn't be decoded then the error will be -// "INVALID_BASE64". If the message was for an unsupported protocol version -// then the error will be "BAD_MESSAGE_VERSION". If the message couldn't be -// decoded then then the error will be "BAD_MESSAGE_FORMAT". -func (s *Session) MatchesInboundSession(oneTimeKeyMsg string) (bool, error) { - if len(oneTimeKeyMsg) == 0 { - return false, EmptyInput - } - r := C.olm_matches_inbound_session( - (*C.OlmSession)(s.int), - unsafe.Pointer(&([]byte(oneTimeKeyMsg))[0]), - C.size_t(len(oneTimeKeyMsg))) - if r == 1 { - return true, nil - } else if r == 0 { - return false, nil - } else { // if r == errorVal() - return false, s.lastError() - } -} - -// MatchesInboundSessionFrom checks if the PRE_KEY message is for this in-bound -// Session. This can happen if multiple messages are sent to this Account -// before this Account sends a message in reply. Returns true if the session -// matches. Returns false if the session does not match. Returns error on -// failure. If the base64 couldn't be decoded then the error will be -// "INVALID_BASE64". If the message was for an unsupported protocol version -// then the error will be "BAD_MESSAGE_VERSION". If the message couldn't be -// decoded then then the error will be "BAD_MESSAGE_FORMAT". -func (s *Session) MatchesInboundSessionFrom(theirIdentityKey, oneTimeKeyMsg string) (bool, error) { - if len(theirIdentityKey) == 0 || len(oneTimeKeyMsg) == 0 { - return false, EmptyInput - } - r := C.olm_matches_inbound_session_from( - (*C.OlmSession)(s.int), - unsafe.Pointer(&([]byte(theirIdentityKey))[0]), - C.size_t(len(theirIdentityKey)), - unsafe.Pointer(&([]byte(oneTimeKeyMsg))[0]), - C.size_t(len(oneTimeKeyMsg))) - if r == 1 { - return true, nil - } else if r == 0 { - return false, nil - } else { // if r == errorVal() - return false, s.lastError() - } -} - -// EncryptMsgType returns the type of the next message that Encrypt will -// return. Returns MsgTypePreKey if the message will be a PRE_KEY message. -// Returns MsgTypeMsg if the message will be a normal message. Returns error -// on failure. -func (s *Session) EncryptMsgType() id.OlmMsgType { - switch C.olm_encrypt_message_type((*C.OlmSession)(s.int)) { - case C.size_t(id.OlmMsgTypePreKey): - return id.OlmMsgTypePreKey - case C.size_t(id.OlmMsgTypeMsg): - return id.OlmMsgTypeMsg - default: - panic("olm_encrypt_message_type returned invalid result") - } -} - -// Encrypt encrypts a message using the Session. Returns the encrypted message -// as base64. -func (s *Session) Encrypt(plaintext []byte) (id.OlmMsgType, []byte) { - if len(plaintext) == 0 { - panic(EmptyInput) - } - // Make the slice be at least length 1 - random := make([]byte, s.encryptRandomLen()+1) - _, err := rand.Read(random) - if err != nil { - panic(NotEnoughGoRandom) - } - messageType := s.EncryptMsgType() - message := make([]byte, s.encryptMsgLen(len(plaintext))) - r := C.olm_encrypt( - (*C.OlmSession)(s.int), - unsafe.Pointer(&plaintext[0]), - C.size_t(len(plaintext)), - unsafe.Pointer(&random[0]), - C.size_t(len(random)), - unsafe.Pointer(&message[0]), - C.size_t(len(message))) - if r == errorVal() { - panic(s.lastError()) - } - return messageType, message[:r] -} - -// Decrypt decrypts a message using the Session. Returns the the plain-text on -// success. Returns error on failure. If the base64 couldn't be decoded then -// the error will be "INVALID_BASE64". If the message is for an unsupported -// version of the protocol then the error will be "BAD_MESSAGE_VERSION". If -// the message couldn't be decoded then the error will be BAD_MESSAGE_FORMAT". -// If the MAC on the message was invalid then the error will be -// "BAD_MESSAGE_MAC". -func (s *Session) Decrypt(message string, msgType id.OlmMsgType) ([]byte, error) { - if len(message) == 0 { - return nil, EmptyInput - } - decryptMaxPlaintextLen, err := s.decryptMaxPlaintextLen(message, msgType) - if err != nil { - return nil, err - } - messageCopy := []byte(message) - plaintext := make([]byte, decryptMaxPlaintextLen) - r := C.olm_decrypt( - (*C.OlmSession)(s.int), - C.size_t(msgType), - unsafe.Pointer(&(messageCopy)[0]), - C.size_t(len(messageCopy)), - unsafe.Pointer(&plaintext[0]), - C.size_t(len(plaintext))) - if r == errorVal() { - return nil, s.lastError() - } - return plaintext[:r], nil -} - -// https://gitlab.matrix.org/matrix-org/olm/-/blob/3.2.8/include/olm/olm.h#L392-393 -const maxDescribeSize = 600 - -// Describe generates a string describing the internal state of an olm session for debugging and logging purposes. -func (s *Session) Describe() string { - desc := (*C.char)(C.malloc(C.size_t(maxDescribeSize))) - defer C.free(unsafe.Pointer(desc)) - C.meowlm_session_describe( - (*C.OlmSession)(s.int), - desc, - C.size_t(maxDescribeSize)) - return C.GoString(desc) +func NewBlankSession() Session { + return InitNewBlankSession() } diff --git a/crypto/olm/session_goolm.go b/crypto/olm/session_goolm.go deleted file mode 100644 index c77efaa2..00000000 --- a/crypto/olm/session_goolm.go +++ /dev/null @@ -1,110 +0,0 @@ -//go:build goolm - -package olm - -import ( - "maunium.net/go/mautrix/crypto/goolm/session" - "maunium.net/go/mautrix/id" -) - -// Session stores an end to end encrypted messaging session. -type Session struct { - session.OlmSession -} - -// SessionFromPickled loads a Session from a pickled base64 string. Decrypts -// the Session using the supplied key. Returns error on failure. -func SessionFromPickled(pickled, key []byte) (*Session, error) { - if len(pickled) == 0 { - return nil, EmptyInput - } - s := NewBlankSession() - return s, s.Unpickle(pickled, key) -} - -func NewBlankSession() *Session { - return &Session{} -} - -// Clear clears the memory used to back this Session. -func (s *Session) Clear() error { - s.OlmSession = session.OlmSession{} - return nil -} - -// Pickle returns a Session as a base64 string. Encrypts the Session using the -// supplied key. -func (s *Session) Pickle(key []byte) []byte { - if len(key) == 0 { - panic(NoKeyProvided) - } - pickled, err := s.OlmSession.Pickle(key) - if err != nil { - panic(err) - } - return pickled -} - -func (s *Session) Unpickle(pickled, key []byte) error { - if len(key) == 0 { - return NoKeyProvided - } else if len(pickled) == 0 { - return EmptyInput - } - sOlm, err := session.OlmSessionFromPickled(pickled, key) - if err != nil { - return err - } - s.OlmSession = *sOlm - return nil -} - -// MatchesInboundSession checks if the PRE_KEY message is for this in-bound -// Session. This can happen if multiple messages are sent to this Account -// before this Account sends a message in reply. Returns true if the session -// matches. Returns false if the session does not match. Returns error on -// failure. -func (s *Session) MatchesInboundSession(oneTimeKeyMsg string) (bool, error) { - return s.MatchesInboundSessionFrom("", oneTimeKeyMsg) -} - -// MatchesInboundSessionFrom checks if the PRE_KEY message is for this in-bound -// Session. This can happen if multiple messages are sent to this Account -// before this Account sends a message in reply. Returns true if the session -// matches. Returns false if the session does not match. Returns error on -// failure. -func (s *Session) MatchesInboundSessionFrom(theirIdentityKey, oneTimeKeyMsg string) (bool, error) { - if theirIdentityKey != "" { - theirKey := id.Curve25519(theirIdentityKey) - return s.OlmSession.MatchesInboundSessionFrom(&theirKey, []byte(oneTimeKeyMsg)) - } - return s.OlmSession.MatchesInboundSessionFrom(nil, []byte(oneTimeKeyMsg)) - -} - -// Encrypt encrypts a message using the Session. Returns the encrypted message -// as base64. -func (s *Session) Encrypt(plaintext []byte) (id.OlmMsgType, []byte) { - if len(plaintext) == 0 { - panic(EmptyInput) - } - messageType, message, err := s.OlmSession.Encrypt(plaintext, nil) - if err != nil { - panic(err) - } - return messageType, message -} - -// Decrypt decrypts a message using the Session. Returns the the plain-text on -// success. Returns error on failure. -func (s *Session) Decrypt(message string, msgType id.OlmMsgType) ([]byte, error) { - if len(message) == 0 { - return nil, EmptyInput - } - return s.OlmSession.Decrypt([]byte(message), msgType) -} - -// Describe generates a string describing the internal state of an olm session for debugging and logging purposes. -func (s *Session) Describe() string { - return s.OlmSession.Describe() -} diff --git a/crypto/olm/session_test.go b/crypto/olm/session_test.go new file mode 100644 index 00000000..b0b9896f --- /dev/null +++ b/crypto/olm/session_test.go @@ -0,0 +1,119 @@ +// Copyright (c) 2024 Sumner Evans +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package olm_test + +import ( + "bytes" + "fmt" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "go.mau.fi/util/exerrors" + "golang.org/x/exp/maps" + + "maunium.net/go/mautrix/crypto/goolm/account" + "maunium.net/go/mautrix/crypto/goolm/session" + "maunium.net/go/mautrix/crypto/libolm" + "maunium.net/go/mautrix/crypto/olm" + "maunium.net/go/mautrix/id" +) + +func TestBlankSession(t *testing.T) { + libolmSession := libolm.NewBlankSession() + session := session.NewOlmSession() + + assert.Equal(t, libolmSession.ID(), session.ID()) + assert.Equal(t, libolmSession.HasReceivedMessage(), session.HasReceivedMessage()) + assert.Equal(t, libolmSession.EncryptMsgType(), session.EncryptMsgType()) + assert.Equal(t, libolmSession.Describe(), session.Describe()) + + libolmPickled, err := libolmSession.Pickle([]byte("test")) + assert.NoError(t, err) + goolmPickled, err := session.Pickle([]byte("test")) + assert.NoError(t, err) + assert.Equal(t, goolmPickled, libolmPickled) +} + +func TestSessionPickle(t *testing.T) { + pickledDataFromLibOlm := []byte("icDKYm0b4aO23WgUuOxdpPoxC0UlEOYPVeuduNH3IkpFsmnWx5KuEOpxGiZw5IuB/sSn2RZUCTiJ90IvgC7AClkYGHep9O8lpiqQX73XVKD9okZDCAkBc83eEq0DKYC7HBkGRAU/4T6QPIBBY3UK4QZwULLE/fLsi3j4YZBehMtnlsqgHK0q1bvX4cRznZItVKR4ro0O9EAk6LLxJtSnRu5elSUk7YXT") + pickleKey := []byte("secret_key") + + goolmSession, err := session.OlmSessionFromPickled(bytes.Clone(pickledDataFromLibOlm), pickleKey) + assert.NoError(t, err) + + libolmSession, err := libolm.SessionFromPickled(bytes.Clone(pickledDataFromLibOlm), pickleKey) + assert.NoError(t, err) + + goolmPickled, err := goolmSession.Pickle(pickleKey) + require.NoError(t, err) + assert.Equal(t, pickledDataFromLibOlm, goolmPickled) + + libolmPickled, err := libolmSession.Pickle(pickleKey) + require.NoError(t, err) + assert.Equal(t, pickledDataFromLibOlm, libolmPickled) +} + +func TestSession_EncryptDecrypt(t *testing.T) { + combos := [][2]olm.Account{ + {exerrors.Must(libolm.NewAccount()), exerrors.Must(libolm.NewAccount())}, + {exerrors.Must(account.NewAccount()), exerrors.Must(account.NewAccount())}, + {exerrors.Must(libolm.NewAccount()), exerrors.Must(account.NewAccount())}, + {exerrors.Must(account.NewAccount()), exerrors.Must(libolm.NewAccount())}, + } + + for _, combo := range combos { + receiver, sender := combo[0], combo[1] + require.NoError(t, receiver.GenOneTimeKeys(50)) + require.NoError(t, sender.GenOneTimeKeys(50)) + + _, receiverCurve25519, err := receiver.IdentityKeys() + require.NoError(t, err) + accountAOTKs, err := receiver.OneTimeKeys() + require.NoError(t, err) + + senderSession, err := sender.NewOutboundSession(receiverCurve25519, accountAOTKs[maps.Keys(accountAOTKs)[0]]) + require.NoError(t, err) + + // Send a couple pre-key messages from sender -> receiver. + var receiverSession olm.Session + for i := 0; i < 10; i++ { + msgType, ciphertext, err := senderSession.Encrypt([]byte(fmt.Sprintf("prekey %d", i))) + require.NoError(t, err) + assert.Equal(t, id.OlmMsgTypePreKey, msgType) + + receiverSession, err = receiver.NewInboundSession(string(ciphertext)) + require.NoError(t, err) + + decrypted, err := receiverSession.Decrypt(string(ciphertext), msgType) + require.NoError(t, err) + assert.Equal(t, []byte(fmt.Sprintf("prekey %d", i)), decrypted) + } + + // Send some messages from receiver -> sender. + for i := 0; i < 10; i++ { + msgType, ciphertext, err := receiverSession.Encrypt([]byte(fmt.Sprintf("response %d", i))) + require.NoError(t, err) + assert.Equal(t, id.OlmMsgTypeMsg, msgType) + + decrypted, err := senderSession.Decrypt(string(ciphertext), msgType) + require.NoError(t, err) + assert.Equal(t, []byte(fmt.Sprintf("response %d", i)), decrypted) + } + + // Send some more messages from sender -> receiver + for i := 0; i < 10; i++ { + msgType, ciphertext, err := senderSession.Encrypt([]byte(fmt.Sprintf("%d", i))) + require.NoError(t, err) + assert.Equal(t, id.OlmMsgTypeMsg, msgType) + + decrypted, err := receiverSession.Decrypt(string(ciphertext), msgType) + require.NoError(t, err) + assert.Equal(t, []byte(fmt.Sprintf("%d", i)), decrypted) + } + } +} diff --git a/crypto/pkcs7/pkcs7.go b/crypto/pkcs7/pkcs7.go index c83c5afd..dc28ed6a 100644 --- a/crypto/pkcs7/pkcs7.go +++ b/crypto/pkcs7/pkcs7.go @@ -8,23 +8,23 @@ package pkcs7 import "bytes" -// Pad implements PKCS#7 padding as defined in [RFC2315]. It pads the plaintext -// to the given blockSize in the range [1, 255]. This is normally used in -// AES-CBC encryption. +// Pad implements PKCS#7 padding as defined in [RFC2315]. It pads the data to +// the given blockSize in the range [1, 255]. This is normally used in AES-CBC +// encryption. // // [RFC2315]: https://www.ietf.org/rfc/rfc2315.txt -func Pad(plaintext []byte, blockSize int) []byte { - padding := blockSize - len(plaintext)%blockSize - return append(plaintext, bytes.Repeat([]byte{byte(padding)}, padding)...) +func Pad(data []byte, blockSize int) []byte { + padding := blockSize - len(data)%blockSize + return append(data, bytes.Repeat([]byte{byte(padding)}, padding)...) } // Unpad implements PKCS#7 unpadding as defined in [RFC2315]. It unpads the -// plaintext by reading the padding amount from the last byte of the plaintext. -// This is normally used in AES-CBC decryption. +// data by reading the padding amount from the last byte of the data. This is +// normally used in AES-CBC decryption. // // [RFC2315]: https://www.ietf.org/rfc/rfc2315.txt -func Unpad(plaintext []byte) []byte { - length := len(plaintext) - unpadding := int(plaintext[length-1]) - return plaintext[:length-unpadding] +func Unpad(data []byte) []byte { + length := len(data) + unpadding := int(data[length-1]) + return data[:length-unpadding] } diff --git a/crypto/registergoolm.go b/crypto/registergoolm.go new file mode 100644 index 00000000..6b5b65fd --- /dev/null +++ b/crypto/registergoolm.go @@ -0,0 +1,11 @@ +//go:build goolm + +package crypto + +import ( + "maunium.net/go/mautrix/crypto/goolm" +) + +func init() { + goolm.Register() +} diff --git a/crypto/registerlibolm.go b/crypto/registerlibolm.go new file mode 100644 index 00000000..ef78b6b5 --- /dev/null +++ b/crypto/registerlibolm.go @@ -0,0 +1,9 @@ +//go:build !goolm + +package crypto + +import "maunium.net/go/mautrix/crypto/libolm" + +func init() { + libolm.Register() +} diff --git a/crypto/sessions.go b/crypto/sessions.go index 045af933..ccc7b784 100644 --- a/crypto/sessions.go +++ b/crypto/sessions.go @@ -8,6 +8,7 @@ package crypto import ( "errors" + "fmt" "time" "maunium.net/go/mautrix/crypto/olm" @@ -17,8 +18,14 @@ import ( ) var ( - SessionNotShared = errors.New("session has not been shared") - SessionExpired = errors.New("session has expired") + ErrSessionNotShared = errors.New("session has not been shared") + ErrSessionExpired = errors.New("session has expired") +) + +// Deprecated: use variables prefixed with Err +var ( + SessionNotShared = ErrSessionNotShared + SessionExpired = ErrSessionExpired ) // OlmSessionList is a list of OlmSessions. @@ -54,9 +61,9 @@ func (session *OlmSession) Describe() string { return session.Internal.Describe() } -func wrapSession(session *olm.Session) *OlmSession { +func wrapSession(session olm.Session) *OlmSession { return &OlmSession{ - Internal: *session, + Internal: session, ExpirationMixin: ExpirationMixin{ TimeMixin: TimeMixin{ CreationTime: time.Now(), @@ -68,7 +75,7 @@ func wrapSession(session *olm.Session) *OlmSession { } func (account *OlmAccount) NewInboundSessionFrom(senderKey id.Curve25519, ciphertext string) (*OlmSession, error) { - session, err := account.Internal.NewInboundSessionFrom(senderKey, ciphertext) + session, err := account.Internal.NewInboundSessionFrom(&senderKey, ciphertext) if err != nil { return nil, err } @@ -76,7 +83,7 @@ func (account *OlmAccount) NewInboundSessionFrom(senderKey id.Curve25519, cipher return wrapSession(session), nil } -func (session *OlmSession) Encrypt(plaintext []byte) (id.OlmMsgType, []byte) { +func (session *OlmSession) Encrypt(plaintext []byte) (id.OlmMsgType, []byte, error) { session.LastEncryptedTime = time.Now() return session.Internal.Encrypt(plaintext) } @@ -110,6 +117,7 @@ type InboundGroupSession struct { MaxMessages int IsScheduled bool KeyBackupVersion id.KeyBackupVersion + KeySource id.KeySource id id.SessionID } @@ -120,15 +128,16 @@ func NewInboundGroupSession(senderKey id.SenderKey, signingKey id.Ed25519, roomI return nil, err } return &InboundGroupSession{ - Internal: *igs, + Internal: igs, SigningKey: signingKey, SenderKey: senderKey, RoomID: roomID, - ForwardingChains: nil, + ForwardingChains: []string{}, ReceivedAt: time.Now().UTC(), MaxAge: maxAge.Milliseconds(), MaxMessages: maxMessages, IsScheduled: isScheduled, + KeySource: id.KeySourceDirect, }, nil } @@ -148,10 +157,26 @@ func (igs *InboundGroupSession) RatchetTo(index uint32) error { if err != nil { return err } - igs.Internal = *imported + igs.Internal = imported return nil } +func (igs *InboundGroupSession) export() (*ExportedSession, error) { + key, err := igs.Internal.Export(igs.Internal.FirstKnownIndex()) + if err != nil { + return nil, fmt.Errorf("failed to export session: %w", err) + } + return &ExportedSession{ + Algorithm: id.AlgorithmMegolmV1, + ForwardingChains: igs.ForwardingChains, + RoomID: igs.RoomID, + SenderKey: igs.SenderKey, + SenderClaimedKeys: SenderClaimedKeys{Ed25519: igs.SigningKey}, + SessionID: igs.ID(), + SessionKey: string(key), + }, nil +} + type OGSState int const ( @@ -180,9 +205,13 @@ type OutboundGroupSession struct { content *event.RoomKeyEventContent } -func NewOutboundGroupSession(roomID id.RoomID, encryptionContent *event.EncryptionEventContent) *OutboundGroupSession { +func NewOutboundGroupSession(roomID id.RoomID, encryptionContent *event.EncryptionEventContent) (*OutboundGroupSession, error) { + internal, err := olm.NewOutboundGroupSession() + if err != nil { + return nil, err + } ogs := &OutboundGroupSession{ - Internal: *olm.NewOutboundGroupSession(), + Internal: internal, ExpirationMixin: ExpirationMixin{ TimeMixin: TimeMixin{ CreationTime: time.Now(), @@ -196,14 +225,17 @@ func NewOutboundGroupSession(roomID id.RoomID, encryptionContent *event.Encrypti RoomID: roomID, } if encryptionContent != nil { + // Clamp rotation period to prevent unreasonable values + // Similar to https://github.com/matrix-org/matrix-rust-sdk/blob/matrix-sdk-crypto-0.7.1/crates/matrix-sdk-crypto/src/olm/group_sessions/outbound.rs#L415-L441 if encryptionContent.RotationPeriodMillis != 0 { ogs.MaxAge = time.Duration(encryptionContent.RotationPeriodMillis) * time.Millisecond + ogs.MaxAge = min(max(ogs.MaxAge, 1*time.Hour), 365*24*time.Hour) } if encryptionContent.RotationPeriodMessages != 0 { - ogs.MaxMessages = encryptionContent.RotationPeriodMessages + ogs.MaxMessages = min(max(encryptionContent.RotationPeriodMessages, 1), 10000) } } - return ogs + return ogs, nil } func (ogs *OutboundGroupSession) ShareContent() event.Content { @@ -231,13 +263,13 @@ func (ogs *OutboundGroupSession) Expired() bool { func (ogs *OutboundGroupSession) Encrypt(plaintext []byte) ([]byte, error) { if !ogs.Shared { - return nil, SessionNotShared + return nil, ErrSessionNotShared } else if ogs.Expired() { - return nil, SessionExpired + return nil, ErrSessionExpired } ogs.MessageCount++ ogs.LastEncryptedTime = time.Now() - return ogs.Internal.Encrypt(plaintext), nil + return ogs.Internal.Encrypt(plaintext) } type TimeMixin struct { diff --git a/crypto/sharing.go b/crypto/sharing.go index c0f3e209..10e37ccc 100644 --- a/crypto/sharing.go +++ b/crypto/sharing.go @@ -173,6 +173,19 @@ func (mach *OlmMachine) receiveSecret(ctx context.Context, evt *DecryptedOlmEven return } + // https://spec.matrix.org/v1.10/client-server-api/#msecretsend + // "The recipient must ensure... that the device is a verified device owned by the recipient" + if senderDevice, err := mach.GetOrFetchDevice(ctx, evt.Sender, evt.SenderDevice); err != nil { + log.Err(err).Msg("Failed to get or fetch sender device, rejecting secret") + return + } else if senderDevice == nil { + log.Warn().Msg("Unknown sender device, rejecting secret") + return + } else if !mach.IsDeviceTrusted(ctx, senderDevice) { + log.Warn().Msg("Sender device is not verified, rejecting secret") + return + } + mach.secretLock.Lock() secretChan := mach.secretListeners[content.RequestID] mach.secretLock.Unlock() diff --git a/crypto/sql_store.go b/crypto/sql_store.go index a3b3b74a..138cc557 100644 --- a/crypto/sql_store.go +++ b/crypto/sql_store.go @@ -13,6 +13,7 @@ import ( "encoding/json" "errors" "fmt" + "slices" "strings" "sync" "time" @@ -21,7 +22,7 @@ import ( "go.mau.fi/util/dbutil" "maunium.net/go/mautrix" - "maunium.net/go/mautrix/crypto/goolm/cipher" + "maunium.net/go/mautrix/crypto/goolm/libolmpickle" "maunium.net/go/mautrix/crypto/olm" "maunium.net/go/mautrix/crypto/sql_store_upgrade" "maunium.net/go/mautrix/event" @@ -52,14 +53,18 @@ var _ Store = (*SQLCryptoStore)(nil) // NewSQLCryptoStore initializes a new crypto Store using the given database, for a device's crypto material. // The stored material will be encrypted with the given key. func NewSQLCryptoStore(db *dbutil.Database, log dbutil.DatabaseLogger, accountID string, deviceID id.DeviceID, pickleKey []byte) *SQLCryptoStore { - return &SQLCryptoStore{ + store := &SQLCryptoStore{ DB: db.Child(sql_store_upgrade.VersionTableName, sql_store_upgrade.Table, log), PickleKey: pickleKey, AccountID: accountID, DeviceID: deviceID, - - olmSessionCache: make(map[id.SenderKey]map[id.SessionID]*OlmSession), } + store.InitFields() + return store +} + +func (store *SQLCryptoStore) InitFields() { + store.olmSessionCache = make(map[id.SenderKey]map[id.SessionID]*OlmSession) } // Flush does nothing for this implementation as data is already persisted in the database. @@ -77,8 +82,8 @@ func (store *SQLCryptoStore) PutNextBatch(ctx context.Context, nextBatch string) // GetNextBatch retrieves the next sync batch token for the current account. func (store *SQLCryptoStore) GetNextBatch(ctx context.Context) (string, error) { if store.SyncToken == "" { - err := store.DB.Conn(ctx). - QueryRowContext(ctx, "SELECT sync_token FROM crypto_account WHERE account_id=$1", store.AccountID). + err := store.DB. + QueryRow(ctx, "SELECT sync_token FROM crypto_account WHERE account_id=$1", store.AccountID). Scan(&store.SyncToken) if !errors.Is(err, sql.ErrNoRows) { return "", err @@ -123,8 +128,11 @@ func (store *SQLCryptoStore) FindDeviceID(ctx context.Context) (deviceID id.Devi // PutAccount stores an OlmAccount in the database. func (store *SQLCryptoStore) PutAccount(ctx context.Context, account *OlmAccount) error { store.Account = account - bytes := account.Internal.Pickle(store.PickleKey) - _, err := store.DB.Exec(ctx, ` + bytes, err := account.Internal.Pickle(store.PickleKey) + if err != nil { + return err + } + _, err = store.DB.Exec(ctx, ` INSERT INTO crypto_account (device_id, shared, sync_token, account, account_id, key_backup_version) VALUES ($1, $2, $3, $4, $5, $6) ON CONFLICT (account_id) DO UPDATE SET shared=excluded.shared, sync_token=excluded.sync_token, account=excluded.account, account_id=excluded.account_id, @@ -137,7 +145,7 @@ func (store *SQLCryptoStore) PutAccount(ctx context.Context, account *OlmAccount func (store *SQLCryptoStore) GetAccount(ctx context.Context) (*OlmAccount, error) { if store.Account == nil { row := store.DB.QueryRow(ctx, "SELECT shared, sync_token, account, key_backup_version FROM crypto_account WHERE account_id=$1", store.AccountID) - acc := &OlmAccount{Internal: *olm.NewBlankAccount()} + acc := &OlmAccount{Internal: olm.NewBlankAccount()} var accountBytes []byte err := row.Scan(&acc.Shared, &store.SyncToken, &accountBytes, &acc.KeyBackupVersion) if err == sql.ErrNoRows { @@ -183,7 +191,7 @@ func (store *SQLCryptoStore) GetSessions(ctx context.Context, key id.SenderKey) defer store.olmSessionCacheLock.Unlock() cache := store.getOlmSessionCache(key) for rows.Next() { - sess := OlmSession{Internal: *olm.NewBlankSession()} + sess := OlmSession{Internal: olm.NewBlankSession()} var sessionBytes []byte var sessionID id.SessionID err = rows.Scan(&sessionID, &sessionBytes, &sess.CreationTime, &sess.LastEncryptedTime, &sess.LastDecryptedTime) @@ -212,7 +220,7 @@ func (store *SQLCryptoStore) getOlmSessionCache(key id.SenderKey) map[id.Session return data } -// GetLatestSession retrieves the Olm session for a given sender key from the database that has the largest ID. +// GetLatestSession retrieves the Olm session for a given sender key from the database that had the most recent successful decryption. func (store *SQLCryptoStore) GetLatestSession(ctx context.Context, key id.SenderKey) (*OlmSession, error) { store.olmSessionCacheLock.Lock() defer store.olmSessionCacheLock.Unlock() @@ -220,7 +228,7 @@ func (store *SQLCryptoStore) GetLatestSession(ctx context.Context, key id.Sender row := store.DB.QueryRow(ctx, "SELECT session_id, session, created_at, last_encrypted, last_decrypted FROM crypto_olm_session WHERE sender_key=$1 AND account_id=$2 ORDER BY last_decrypted DESC LIMIT 1", key, store.AccountID) - sess := OlmSession{Internal: *olm.NewBlankSession()} + sess := OlmSession{Internal: olm.NewBlankSession()} var sessionBytes []byte var sessionID id.SessionID @@ -242,12 +250,26 @@ func (store *SQLCryptoStore) GetLatestSession(ctx context.Context, key id.Sender } } +// GetNewestSessionCreationTS gets the creation timestamp of the most recently created session with the given sender key. +// This will exclude sessions that have never been used to encrypt or decrypt a message. +func (store *SQLCryptoStore) GetNewestSessionCreationTS(ctx context.Context, key id.SenderKey) (createdAt time.Time, err error) { + err = store.DB.QueryRow(ctx, "SELECT created_at FROM crypto_olm_session WHERE sender_key=$1 AND account_id=$2 AND (last_encrypted <> created_at OR last_decrypted <> created_at) ORDER BY created_at DESC LIMIT 1", + key, store.AccountID).Scan(&createdAt) + if errors.Is(err, sql.ErrNoRows) { + err = nil + } + return +} + // AddSession persists an Olm session for a sender in the database. func (store *SQLCryptoStore) AddSession(ctx context.Context, key id.SenderKey, session *OlmSession) error { store.olmSessionCacheLock.Lock() defer store.olmSessionCacheLock.Unlock() - sessionBytes := session.Internal.Pickle(store.PickleKey) - _, err := store.DB.Exec(ctx, "INSERT INTO crypto_olm_session (session_id, sender_key, session, created_at, last_encrypted, last_decrypted, account_id) VALUES ($1, $2, $3, $4, $5, $6, $7)", + sessionBytes, err := session.Internal.Pickle(store.PickleKey) + if err != nil { + return err + } + _, err = store.DB.Exec(ctx, "INSERT INTO crypto_olm_session (session_id, sender_key, session, created_at, last_encrypted, last_decrypted, account_id) VALUES ($1, $2, $3, $4, $5, $6, $7)", session.ID(), key, sessionBytes, session.CreationTime, session.LastEncryptedTime, session.LastDecryptedTime, store.AccountID) store.getOlmSessionCache(key)[session.ID()] = session return err @@ -255,17 +277,41 @@ func (store *SQLCryptoStore) AddSession(ctx context.Context, key id.SenderKey, s // UpdateSession replaces the Olm session for a sender in the database. func (store *SQLCryptoStore) UpdateSession(ctx context.Context, _ id.SenderKey, session *OlmSession) error { - sessionBytes := session.Internal.Pickle(store.PickleKey) - _, err := store.DB.Exec(ctx, "UPDATE crypto_olm_session SET session=$1, last_encrypted=$2, last_decrypted=$3 WHERE session_id=$4 AND account_id=$5", + sessionBytes, err := session.Internal.Pickle(store.PickleKey) + if err != nil { + return err + } + _, err = store.DB.Exec(ctx, "UPDATE crypto_olm_session SET session=$1, last_encrypted=$2, last_decrypted=$3 WHERE session_id=$4 AND account_id=$5", sessionBytes, session.LastEncryptedTime, session.LastDecryptedTime, session.ID(), store.AccountID) return err } -func intishPtr[T int | int64](i T) *T { - if i == 0 { - return nil +func (store *SQLCryptoStore) DeleteSession(ctx context.Context, _ id.SenderKey, session *OlmSession) error { + _, err := store.DB.Exec(ctx, "DELETE FROM crypto_olm_session WHERE session_id=$1 AND account_id=$2", session.ID(), store.AccountID) + return err +} + +func (store *SQLCryptoStore) PutOlmHash(ctx context.Context, messageHash [32]byte, receivedAt time.Time) error { + _, err := store.DB.Exec(ctx, "INSERT INTO crypto_olm_message_hash (account_id, received_at, message_hash) VALUES ($1, $2, $3) ON CONFLICT (message_hash) DO NOTHING", store.AccountID, receivedAt.UnixMilli(), messageHash[:]) + return err +} + +func (store *SQLCryptoStore) GetOlmHash(ctx context.Context, messageHash [32]byte) (receivedAt time.Time, err error) { + var receivedAtInt int64 + err = store.DB.QueryRow(ctx, "SELECT received_at FROM crypto_olm_message_hash WHERE message_hash=$1", messageHash[:]).Scan(&receivedAtInt) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + err = nil + } + return } - return &i + receivedAt = time.UnixMilli(receivedAtInt) + return +} + +func (store *SQLCryptoStore) DeleteOldOlmHashes(ctx context.Context, beforeTS time.Time) error { + _, err := store.DB.Exec(ctx, "DELETE FROM crypto_olm_message_hash WHERE account_id = $1 AND received_at < $2", store.AccountID, beforeTS.UnixMilli()) + return err } func datePtr(t time.Time) *time.Time { @@ -276,46 +322,66 @@ func datePtr(t time.Time) *time.Time { } // PutGroupSession stores an inbound Megolm group session for a room, sender and session. -func (store *SQLCryptoStore) PutGroupSession(ctx context.Context, roomID id.RoomID, senderKey id.SenderKey, sessionID id.SessionID, session *InboundGroupSession) error { - sessionBytes := session.Internal.Pickle(store.PickleKey) +func (store *SQLCryptoStore) PutGroupSession(ctx context.Context, session *InboundGroupSession) error { + sessionBytes, err := session.Internal.Pickle(store.PickleKey) + if err != nil { + return err + } + if session.ForwardingChains == nil { + session.ForwardingChains = []string{} + } forwardingChains := strings.Join(session.ForwardingChains, ",") ratchetSafety, err := json.Marshal(&session.RatchetSafety) if err != nil { return fmt.Errorf("failed to marshal ratchet safety info: %w", err) } + zerolog.Ctx(ctx).Debug(). + Stringer("session_id", session.ID()). + Str("account_id", store.AccountID). + Stringer("sender_key", session.SenderKey). + Stringer("signing_key", session.SigningKey). + Stringer("room_id", session.RoomID). + Time("received_at", session.ReceivedAt). + Int64("max_age", session.MaxAge). + Int("max_messages", session.MaxMessages). + Bool("is_scheduled", session.IsScheduled). + Stringer("key_backup_version", session.KeyBackupVersion). + Stringer("key_source", session.KeySource). + Msg("Upserting megolm inbound group session") _, err = store.DB.Exec(ctx, ` INSERT INTO crypto_megolm_inbound_session ( session_id, sender_key, signing_key, room_id, session, forwarding_chains, - ratchet_safety, received_at, max_age, max_messages, is_scheduled, key_backup_version, account_id - ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13) + ratchet_safety, received_at, max_age, max_messages, is_scheduled, key_backup_version, key_source, account_id + ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14) ON CONFLICT (session_id, account_id) DO UPDATE SET withheld_code=NULL, withheld_reason=NULL, sender_key=excluded.sender_key, signing_key=excluded.signing_key, room_id=excluded.room_id, session=excluded.session, forwarding_chains=excluded.forwarding_chains, ratchet_safety=excluded.ratchet_safety, received_at=excluded.received_at, max_age=excluded.max_age, max_messages=excluded.max_messages, is_scheduled=excluded.is_scheduled, - key_backup_version=excluded.key_backup_version + key_backup_version=excluded.key_backup_version, key_source=excluded.key_source `, - sessionID, senderKey, session.SigningKey, roomID, sessionBytes, forwardingChains, - ratchetSafety, datePtr(session.ReceivedAt), intishPtr(session.MaxAge), intishPtr(session.MaxMessages), - session.IsScheduled, session.KeyBackupVersion, store.AccountID, + session.ID(), session.SenderKey, session.SigningKey, session.RoomID, sessionBytes, forwardingChains, + ratchetSafety, datePtr(session.ReceivedAt), dbutil.NumPtr(session.MaxAge), dbutil.NumPtr(session.MaxMessages), + session.IsScheduled, session.KeyBackupVersion, session.KeySource, store.AccountID, ) return err } // GetGroupSession retrieves an inbound Megolm group session for a room, sender and session. -func (store *SQLCryptoStore) GetGroupSession(ctx context.Context, roomID id.RoomID, senderKey id.SenderKey, sessionID id.SessionID) (*InboundGroupSession, error) { - var senderKeyDB, signingKey, forwardingChains, withheldCode, withheldReason sql.NullString +func (store *SQLCryptoStore) GetGroupSession(ctx context.Context, roomID id.RoomID, sessionID id.SessionID) (*InboundGroupSession, error) { + var senderKey, signingKey, forwardingChains, withheldCode, withheldReason sql.NullString var sessionBytes, ratchetSafetyBytes []byte var receivedAt sql.NullTime var maxAge, maxMessages sql.NullInt64 var isScheduled bool var version id.KeyBackupVersion + var keySource id.KeySource err := store.DB.QueryRow(ctx, ` - SELECT sender_key, signing_key, session, forwarding_chains, withheld_code, withheld_reason, ratchet_safety, received_at, max_age, max_messages, is_scheduled, key_backup_version + SELECT sender_key, signing_key, session, forwarding_chains, withheld_code, withheld_reason, ratchet_safety, received_at, max_age, max_messages, is_scheduled, key_backup_version, key_source FROM crypto_megolm_inbound_session - WHERE room_id=$1 AND (sender_key=$2 OR $2 = '') AND session_id=$3 AND account_id=$4`, - roomID, senderKey, sessionID, store.AccountID, - ).Scan(&senderKeyDB, &signingKey, &sessionBytes, &forwardingChains, &withheldCode, &withheldReason, &ratchetSafetyBytes, &receivedAt, &maxAge, &maxMessages, &isScheduled, &version) + WHERE room_id=$1 AND session_id=$2 AND account_id=$3`, + roomID, sessionID, store.AccountID, + ).Scan(&senderKey, &signingKey, &sessionBytes, &forwardingChains, &withheldCode, &withheldReason, &ratchetSafetyBytes, &receivedAt, &maxAge, &maxMessages, &isScheduled, &version, &keySource) if errors.Is(err, sql.ErrNoRows) { return nil, nil } else if err != nil { @@ -325,19 +391,19 @@ func (store *SQLCryptoStore) GetGroupSession(ctx context.Context, roomID id.Room RoomID: roomID, Algorithm: id.AlgorithmMegolmV1, SessionID: sessionID, - SenderKey: senderKey, + SenderKey: id.Curve25519(senderKey.String), Code: event.RoomKeyWithheldCode(withheldCode.String), Reason: withheldReason.String, } } igs, chains, rs, err := store.postScanInboundGroupSession(sessionBytes, ratchetSafetyBytes, forwardingChains.String) - if senderKey == "" { - senderKey = id.Curve25519(senderKeyDB.String) + if err != nil { + return nil, err } return &InboundGroupSession{ - Internal: *igs, + Internal: igs, SigningKey: id.Ed25519(signingKey.String), - SenderKey: senderKey, + SenderKey: id.Curve25519(senderKey.String), RoomID: roomID, ForwardingChains: chains, RatchetSafety: rs, @@ -346,10 +412,11 @@ func (store *SQLCryptoStore) GetGroupSession(ctx context.Context, roomID id.Room MaxMessages: int(maxMessages.Int64), IsScheduled: isScheduled, KeyBackupVersion: version, + KeySource: keySource, }, nil } -func (store *SQLCryptoStore) RedactGroupSession(ctx context.Context, _ id.RoomID, _ id.SenderKey, sessionID id.SessionID, reason string) error { +func (store *SQLCryptoStore) RedactGroupSession(ctx context.Context, _ id.RoomID, sessionID id.SessionID, reason string) error { _, err := store.DB.Exec(ctx, ` UPDATE crypto_megolm_inbound_session SET withheld_code=$1, withheld_reason=$2, session=NULL, forwarding_chains=NULL @@ -369,10 +436,7 @@ func (store *SQLCryptoStore) RedactGroupSessions(ctx context.Context, roomID id. AND session IS NOT NULL AND is_scheduled=false AND received_at IS NOT NULL RETURNING session_id `, event.RoomKeyWithheldBeeperRedacted, "Session redacted: "+reason, roomID, senderKey, store.AccountID) - if err != nil { - return nil, err - } - return dbutil.NewRowIter(res, dbutil.ScanSingleColumn[id.SessionID]).AsList() + return dbutil.NewRowIterWithError(res, dbutil.ScanSingleColumn[id.SessionID], err).AsList() } func (store *SQLCryptoStore) RedactExpiredGroupSessions(ctx context.Context) ([]id.SessionID, error) { @@ -400,10 +464,7 @@ func (store *SQLCryptoStore) RedactExpiredGroupSessions(ctx context.Context) ([] return nil, fmt.Errorf("unsupported dialect") } res, err := store.DB.Query(ctx, query, event.RoomKeyWithheldBeeperRedacted, "Session redacted: expired", store.AccountID) - if err != nil { - return nil, err - } - return dbutil.NewRowIter(res, dbutil.ScanSingleColumn[id.SessionID]).AsList() + return dbutil.NewRowIterWithError(res, dbutil.ScanSingleColumn[id.SessionID], err).AsList() } func (store *SQLCryptoStore) RedactOutdatedGroupSessions(ctx context.Context) ([]id.SessionID, error) { @@ -413,25 +474,25 @@ func (store *SQLCryptoStore) RedactOutdatedGroupSessions(ctx context.Context) ([ WHERE account_id=$3 AND session IS NOT NULL AND received_at IS NULL RETURNING session_id `, event.RoomKeyWithheldBeeperRedacted, "Session redacted: outdated", store.AccountID) - if err != nil { - return nil, err - } - return dbutil.NewRowIter(res, dbutil.ScanSingleColumn[id.SessionID]).AsList() + return dbutil.NewRowIterWithError(res, dbutil.ScanSingleColumn[id.SessionID], err).AsList() } func (store *SQLCryptoStore) PutWithheldGroupSession(ctx context.Context, content event.RoomKeyWithheldEventContent) error { - _, err := store.DB.Exec(ctx, "INSERT INTO crypto_megolm_inbound_session (session_id, sender_key, room_id, withheld_code, withheld_reason, received_at, account_id) VALUES ($1, $2, $3, $4, $5, $6, $7)", - content.SessionID, content.SenderKey, content.RoomID, content.Code, content.Reason, time.Now().UTC(), store.AccountID) + _, err := store.DB.Exec(ctx, ` + INSERT INTO crypto_megolm_inbound_session (session_id, sender_key, room_id, withheld_code, withheld_reason, received_at, account_id) + VALUES ($1, $2, $3, $4, $5, $6, $7) + ON CONFLICT (session_id, account_id) DO NOTHING + `, content.SessionID, content.SenderKey, content.RoomID, content.Code, content.Reason, time.Now().UTC(), store.AccountID) return err } -func (store *SQLCryptoStore) GetWithheldGroupSession(ctx context.Context, roomID id.RoomID, senderKey id.SenderKey, sessionID id.SessionID) (*event.RoomKeyWithheldEventContent, error) { - var code, reason sql.NullString +func (store *SQLCryptoStore) GetWithheldGroupSession(ctx context.Context, roomID id.RoomID, sessionID id.SessionID) (*event.RoomKeyWithheldEventContent, error) { + var senderKey, code, reason sql.NullString err := store.DB.QueryRow(ctx, ` - SELECT withheld_code, withheld_reason FROM crypto_megolm_inbound_session - WHERE room_id=$1 AND sender_key=$2 AND session_id=$3 AND account_id=$4`, - roomID, senderKey, sessionID, store.AccountID, - ).Scan(&code, &reason) + SELECT withheld_code, withheld_reason, sender_key FROM crypto_megolm_inbound_session + WHERE room_id=$1 AND session_id=$2 AND account_id=$3`, + roomID, sessionID, store.AccountID, + ).Scan(&code, &reason, &senderKey) if errors.Is(err, sql.ErrNoRows) { return nil, nil } else if err != nil || !code.Valid { @@ -441,13 +502,13 @@ func (store *SQLCryptoStore) GetWithheldGroupSession(ctx context.Context, roomID RoomID: roomID, Algorithm: id.AlgorithmMegolmV1, SessionID: sessionID, - SenderKey: senderKey, + SenderKey: id.Curve25519(senderKey.String), Code: event.RoomKeyWithheldCode(code.String), Reason: reason.String, }, nil } -func (store *SQLCryptoStore) postScanInboundGroupSession(sessionBytes, ratchetSafetyBytes []byte, forwardingChains string) (igs *olm.InboundGroupSession, chains []string, safety RatchetSafety, err error) { +func (store *SQLCryptoStore) postScanInboundGroupSession(sessionBytes, ratchetSafetyBytes []byte, forwardingChains string) (igs olm.InboundGroupSession, chains []string, safety RatchetSafety, err error) { igs = olm.NewBlankInboundGroupSession() err = igs.Unpickle(sessionBytes, store.PickleKey) if err != nil { @@ -455,6 +516,8 @@ func (store *SQLCryptoStore) postScanInboundGroupSession(sessionBytes, ratchetSa } if forwardingChains != "" { chains = strings.Split(forwardingChains, ",") + } else { + chains = []string{} } var rs RatchetSafety if len(ratchetSafetyBytes) > 0 { @@ -474,13 +537,17 @@ func (store *SQLCryptoStore) scanInboundGroupSession(rows dbutil.Scannable) (*In var maxAge, maxMessages sql.NullInt64 var isScheduled bool var version id.KeyBackupVersion - err := rows.Scan(&roomID, &senderKey, &signingKey, &sessionBytes, &forwardingChains, &ratchetSafetyBytes, &receivedAt, &maxAge, &maxMessages, &isScheduled, &version) + var keySource id.KeySource + err := rows.Scan(&roomID, &senderKey, &signingKey, &sessionBytes, &forwardingChains, &ratchetSafetyBytes, &receivedAt, &maxAge, &maxMessages, &isScheduled, &version, &keySource) if err != nil { return nil, err } igs, chains, rs, err := store.postScanInboundGroupSession(sessionBytes, ratchetSafetyBytes, forwardingChains.String) + if err != nil { + return nil, err + } return &InboundGroupSession{ - Internal: *igs, + Internal: igs, SigningKey: id.Ed25519(signingKey.String), SenderKey: id.Curve25519(senderKey.String), RoomID: roomID, @@ -491,12 +558,13 @@ func (store *SQLCryptoStore) scanInboundGroupSession(rows dbutil.Scannable) (*In MaxMessages: int(maxMessages.Int64), IsScheduled: isScheduled, KeyBackupVersion: version, + KeySource: keySource, }, nil } func (store *SQLCryptoStore) GetGroupSessionsForRoom(ctx context.Context, roomID id.RoomID) dbutil.RowIter[*InboundGroupSession] { rows, err := store.DB.Query(ctx, ` - SELECT room_id, sender_key, signing_key, session, forwarding_chains, ratchet_safety, received_at, max_age, max_messages, is_scheduled, key_backup_version + SELECT room_id, sender_key, signing_key, session, forwarding_chains, ratchet_safety, received_at, max_age, max_messages, is_scheduled, key_backup_version, key_source FROM crypto_megolm_inbound_session WHERE room_id=$1 AND account_id=$2 AND session IS NOT NULL`, roomID, store.AccountID, ) @@ -505,7 +573,7 @@ func (store *SQLCryptoStore) GetGroupSessionsForRoom(ctx context.Context, roomID func (store *SQLCryptoStore) GetAllGroupSessions(ctx context.Context) dbutil.RowIter[*InboundGroupSession] { rows, err := store.DB.Query(ctx, ` - SELECT room_id, sender_key, signing_key, session, forwarding_chains, ratchet_safety, received_at, max_age, max_messages, is_scheduled, key_backup_version + SELECT room_id, sender_key, signing_key, session, forwarding_chains, ratchet_safety, received_at, max_age, max_messages, is_scheduled, key_backup_version, key_source FROM crypto_megolm_inbound_session WHERE account_id=$1 AND session IS NOT NULL`, store.AccountID, ) @@ -514,7 +582,7 @@ func (store *SQLCryptoStore) GetAllGroupSessions(ctx context.Context) dbutil.Row func (store *SQLCryptoStore) GetGroupSessionsWithoutKeyBackupVersion(ctx context.Context, version id.KeyBackupVersion) dbutil.RowIter[*InboundGroupSession] { rows, err := store.DB.Query(ctx, ` - SELECT room_id, sender_key, signing_key, session, forwarding_chains, ratchet_safety, received_at, max_age, max_messages, is_scheduled, key_backup_version + SELECT room_id, sender_key, signing_key, session, forwarding_chains, ratchet_safety, received_at, max_age, max_messages, is_scheduled, key_backup_version, key_source FROM crypto_megolm_inbound_session WHERE account_id=$1 AND session IS NOT NULL AND key_backup_version != $2`, store.AccountID, version, ) @@ -523,8 +591,11 @@ func (store *SQLCryptoStore) GetGroupSessionsWithoutKeyBackupVersion(ctx context // AddOutboundGroupSession stores an outbound Megolm session, along with the information about the room and involved devices. func (store *SQLCryptoStore) AddOutboundGroupSession(ctx context.Context, session *OutboundGroupSession) error { - sessionBytes := session.Internal.Pickle(store.PickleKey) - _, err := store.DB.Exec(ctx, ` + sessionBytes, err := session.Internal.Pickle(store.PickleKey) + if err != nil { + return err + } + _, err = store.DB.Exec(ctx, ` INSERT INTO crypto_megolm_outbound_session (room_id, session_id, session, shared, max_messages, message_count, max_age, created_at, last_used, account_id) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10) @@ -539,8 +610,11 @@ func (store *SQLCryptoStore) AddOutboundGroupSession(ctx context.Context, sessio // UpdateOutboundGroupSession replaces an outbound Megolm session with for same room and session ID. func (store *SQLCryptoStore) UpdateOutboundGroupSession(ctx context.Context, session *OutboundGroupSession) error { - sessionBytes := session.Internal.Pickle(store.PickleKey) - _, err := store.DB.Exec(ctx, "UPDATE crypto_megolm_outbound_session SET session=$1, message_count=$2, last_used=$3 WHERE room_id=$4 AND session_id=$5 AND account_id=$6", + sessionBytes, err := session.Internal.Pickle(store.PickleKey) + if err != nil { + return err + } + _, err = store.DB.Exec(ctx, "UPDATE crypto_megolm_outbound_session SET session=$1, message_count=$2, last_used=$3 WHERE room_id=$4 AND session_id=$5 AND account_id=$6", sessionBytes, session.MessageCount, session.LastEncryptedTime, session.RoomID, session.ID(), store.AccountID) return err } @@ -565,7 +639,7 @@ func (store *SQLCryptoStore) GetOutboundGroupSession(ctx context.Context, roomID if err != nil { return nil, err } - ogs.Internal = *intOGS + ogs.Internal = intOGS ogs.RoomID = roomID ogs.MaxAge = time.Duration(maxAgeMS) * time.Millisecond return &ogs, nil @@ -595,6 +669,20 @@ func (store *SQLCryptoStore) IsOutboundGroupSessionShared(ctx context.Context, u // ValidateMessageIndex returns whether the given event information match the ones stored in the database // for the given sender key, session ID and index. If the index hasn't been stored, this will store it. func (store *SQLCryptoStore) ValidateMessageIndex(ctx context.Context, senderKey id.SenderKey, sessionID id.SessionID, eventID id.EventID, index uint, timestamp int64) (bool, error) { + if eventID == "" && timestamp == 0 { + var notOK bool + const validateEmptyQuery = ` + SELECT EXISTS(SELECT 1 FROM crypto_message_index WHERE sender_key=$1 AND session_id=$2 AND "index"=$3) + ` + err := store.DB.QueryRow(ctx, validateEmptyQuery, senderKey, sessionID, index).Scan(¬OK) + if notOK { + zerolog.Ctx(ctx).Debug(). + Uint("message_index", index). + Msg("Rejecting event without event ID and timestamp due to already knowing them") + } + return !notOK, err + } + const validateQuery = ` INSERT INTO crypto_message_index (sender_key, session_id, "index", event_id, timestamp) VALUES ($1, $2, $3, $4, $5) @@ -613,7 +701,7 @@ func (store *SQLCryptoStore) ValidateMessageIndex(ctx context.Context, senderKey Str("expected_event_id", expectedEventID.String()). Int64("expected_timestamp", expectedTimestamp). Int64("actual_timestamp", timestamp). - Msg("Failed to validate that message index wasn't duplicated") + Msg("Rejecting different event with duplicate message index") return false, nil } return true, nil @@ -641,11 +729,8 @@ func (store *SQLCryptoStore) GetDevices(ctx context.Context, userID id.UserID) ( } rows, err := store.DB.Query(ctx, "SELECT user_id, device_id, identity_key, signing_key, trust, deleted, name FROM crypto_device WHERE user_id=$1 AND deleted=false", userID) - if err != nil { - return nil, err - } data := make(map[id.DeviceID]*id.Device) - err = dbutil.NewRowIter(rows, scanDevice).Iter(func(device *id.Device) (bool, error) { + err = dbutil.NewRowIterWithError(rows, scanDevice, err).Iter(func(device *id.Device) (bool, error) { data[device.DeviceID] = device return true, nil }) @@ -744,6 +829,17 @@ func (store *SQLCryptoStore) PutDevices(ctx context.Context, userID id.UserID, d }) } +func userIDsToParams(users []id.UserID) (placeholders string, params []any) { + queryString := make([]string, len(users)) + params = make([]any, len(users)) + for i, user := range users { + queryString[i] = fmt.Sprintf("$%d", i+1) + params[i] = user + } + placeholders = strings.Join(queryString, ",") + return +} + // FilterTrackedUsers finds all the user IDs out of the given ones for which the database contains identity information. func (store *SQLCryptoStore) FilterTrackedUsers(ctx context.Context, users []id.UserID) ([]id.UserID, error) { var rows dbutil.Rows @@ -751,42 +847,29 @@ func (store *SQLCryptoStore) FilterTrackedUsers(ctx context.Context, users []id. if store.DB.Dialect == dbutil.Postgres && PostgresArrayWrapper != nil { rows, err = store.DB.Query(ctx, "SELECT user_id FROM crypto_tracked_user WHERE user_id = ANY($1)", PostgresArrayWrapper(users)) } else { - queryString := make([]string, len(users)) - params := make([]interface{}, len(users)) - for i, user := range users { - queryString[i] = fmt.Sprintf("?%d", i+1) - params[i] = user - } - rows, err = store.DB.Query(ctx, "SELECT user_id FROM crypto_tracked_user WHERE user_id IN ("+strings.Join(queryString, ",")+")", params...) + placeholders, params := userIDsToParams(users) + rows, err = store.DB.Query(ctx, "SELECT user_id FROM crypto_tracked_user WHERE user_id IN ("+placeholders+")", params...) } - if err != nil { - return users, err - } - return dbutil.NewRowIter(rows, dbutil.ScanSingleColumn[id.UserID]).AsList() + return dbutil.NewRowIterWithError(rows, dbutil.ScanSingleColumn[id.UserID], err).AsList() } // MarkTrackedUsersOutdated flags that the device list for given users are outdated. -func (store *SQLCryptoStore) MarkTrackedUsersOutdated(ctx context.Context, users []id.UserID) error { - return store.DB.DoTxn(ctx, nil, func(ctx context.Context) error { - // TODO refactor to use a single query - for _, userID := range users { - _, err := store.DB.Exec(ctx, "UPDATE crypto_tracked_user SET devices_outdated = true WHERE user_id = $1", userID) - if err != nil { - return fmt.Errorf("failed to update user in the tracked users list: %w", err) - } +func (store *SQLCryptoStore) MarkTrackedUsersOutdated(ctx context.Context, users []id.UserID) (err error) { + for chunk := range slices.Chunk(users, 1000) { + if store.DB.Dialect == dbutil.Postgres && PostgresArrayWrapper != nil { + _, err = store.DB.Exec(ctx, "UPDATE crypto_tracked_user SET devices_outdated = true WHERE user_id = ANY($1)", PostgresArrayWrapper(chunk)) + } else { + placeholders, params := userIDsToParams(chunk) + _, err = store.DB.Exec(ctx, "UPDATE crypto_tracked_user SET devices_outdated = true WHERE user_id IN ("+placeholders+")", params...) } - - return nil - }) + } + return } // GetOutdatedTrackerUsers gets all tracked users whose devices need to be updated. func (store *SQLCryptoStore) GetOutdatedTrackedUsers(ctx context.Context) ([]id.UserID, error) { rows, err := store.DB.Query(ctx, "SELECT user_id FROM crypto_tracked_user WHERE devices_outdated = TRUE") - if err != nil { - return nil, err - } - return dbutil.NewRowIter(rows, dbutil.ScanSingleColumn[id.UserID]).AsList() + return dbutil.NewRowIterWithError(rows, dbutil.ScanSingleColumn[id.UserID], err).AsList() } // PutCrossSigningKey stores a cross-signing key of some user along with its usage. @@ -871,30 +954,30 @@ func (store *SQLCryptoStore) DropSignaturesByKey(ctx context.Context, userID id. } func (store *SQLCryptoStore) PutSecret(ctx context.Context, name id.Secret, value string) error { - bytes, err := cipher.Pickle(store.PickleKey, []byte(value)) + bytes, err := libolmpickle.Pickle(store.PickleKey, []byte(value)) if err != nil { return err } _, err = store.DB.Exec(ctx, ` - INSERT INTO crypto_secrets (name, secret) VALUES ($1, $2) - ON CONFLICT (name) DO UPDATE SET secret=excluded.secret - `, name, bytes) + INSERT INTO crypto_secrets (account_id, name, secret) VALUES ($1, $2, $3) + ON CONFLICT (account_id, name) DO UPDATE SET secret=excluded.secret + `, store.AccountID, name, bytes) return err } func (store *SQLCryptoStore) GetSecret(ctx context.Context, name id.Secret) (value string, err error) { var bytes []byte - err = store.DB.QueryRow(ctx, `SELECT secret FROM crypto_secrets WHERE name=$1`, name).Scan(&bytes) + err = store.DB.QueryRow(ctx, `SELECT secret FROM crypto_secrets WHERE account_id=$1 AND name=$2`, store.AccountID, name).Scan(&bytes) if errors.Is(err, sql.ErrNoRows) { return "", nil } else if err != nil { return "", err } - bytes, err = cipher.Unpickle(store.PickleKey, bytes) + bytes, err = libolmpickle.Unpickle(store.PickleKey, bytes) return string(bytes), err } func (store *SQLCryptoStore) DeleteSecret(ctx context.Context, name id.Secret) (err error) { - _, err = store.DB.Exec(ctx, "DELETE FROM crypto_secrets WHERE name=$1", name) + _, err = store.DB.Exec(ctx, "DELETE FROM crypto_secrets WHERE account_id=$1 AND name=$2", store.AccountID, name) return } diff --git a/crypto/sql_store_upgrade/00-latest-revision.sql b/crypto/sql_store_upgrade/00-latest-revision.sql index 06aea750..3709f1e5 100644 --- a/crypto/sql_store_upgrade/00-latest-revision.sql +++ b/crypto/sql_store_upgrade/00-latest-revision.sql @@ -1,4 +1,4 @@ --- v0 -> v14 (compatible with v9+): Latest revision +-- v0 -> v19 (compatible with v15+): Latest revision CREATE TABLE IF NOT EXISTS crypto_account ( account_id TEXT PRIMARY KEY, device_id TEXT NOT NULL, @@ -43,6 +43,17 @@ CREATE TABLE IF NOT EXISTS crypto_olm_session ( last_encrypted timestamp NOT NULL, PRIMARY KEY (account_id, session_id) ); +CREATE INDEX crypto_olm_session_sender_key_idx ON crypto_olm_session (account_id, sender_key); + +CREATE TABLE crypto_olm_message_hash ( + account_id TEXT NOT NULL, + received_at BIGINT NOT NULL, + message_hash bytea NOT NULL PRIMARY KEY, + + CONSTRAINT crypto_olm_message_hash_account_fkey FOREIGN KEY (account_id) + REFERENCES crypto_account (account_id) ON DELETE CASCADE ON UPDATE CASCADE +); +CREATE INDEX crypto_olm_message_hash_account_idx ON crypto_olm_message_hash (account_id); CREATE TABLE IF NOT EXISTS crypto_megolm_inbound_session ( account_id TEXT, @@ -60,8 +71,11 @@ CREATE TABLE IF NOT EXISTS crypto_megolm_inbound_session ( max_messages INTEGER, is_scheduled BOOLEAN NOT NULL DEFAULT false, key_backup_version TEXT NOT NULL DEFAULT '', + key_source TEXT NOT NULL DEFAULT '', PRIMARY KEY (account_id, session_id) ); +-- Useful index to find keys that need backing up +CREATE INDEX crypto_megolm_inbound_session_backup_idx ON crypto_megolm_inbound_session(account_id, key_backup_version) WHERE session IS NOT NULL; CREATE TABLE IF NOT EXISTS crypto_megolm_outbound_session ( account_id TEXT, @@ -105,6 +119,9 @@ CREATE TABLE IF NOT EXISTS crypto_cross_signing_signatures ( ); CREATE TABLE IF NOT EXISTS crypto_secrets ( - name TEXT PRIMARY KEY NOT NULL, - secret bytea NOT NULL + account_id TEXT NOT NULL, + name TEXT NOT NULL, + secret bytea NOT NULL, + + PRIMARY KEY (account_id, name) ); diff --git a/crypto/sql_store_upgrade/15-fix-secrets.sql b/crypto/sql_store_upgrade/15-fix-secrets.sql new file mode 100644 index 00000000..d49cffae --- /dev/null +++ b/crypto/sql_store_upgrade/15-fix-secrets.sql @@ -0,0 +1,21 @@ +-- v15: Fix crypto_secrets table +CREATE TABLE crypto_secrets_new ( + account_id TEXT NOT NULL, + name TEXT NOT NULL, + secret bytea NOT NULL, + + PRIMARY KEY (account_id, name) +); + +INSERT INTO crypto_secrets_new (account_id, name, secret) +SELECT '', name, secret +FROM crypto_secrets; + +DROP TABLE crypto_secrets; + +ALTER TABLE crypto_secrets_new RENAME TO crypto_secrets; + +-- only: sqlite +UPDATE crypto_secrets SET account_id=(SELECT account_id FROM crypto_account ORDER BY rowid DESC LIMIT 1); +-- only: postgres +UPDATE crypto_secrets SET account_id=(SELECT account_id FROM crypto_account LIMIT 1); diff --git a/crypto/sql_store_upgrade/16-crypto-olm-sessions-index.sql b/crypto/sql_store_upgrade/16-crypto-olm-sessions-index.sql new file mode 100644 index 00000000..f0c3a0c5 --- /dev/null +++ b/crypto/sql_store_upgrade/16-crypto-olm-sessions-index.sql @@ -0,0 +1,2 @@ +-- v16 (compatible with v15+): Add index to crypto_olm_sessions to speedup lookups by sender_key +CREATE INDEX crypto_olm_session_sender_key_idx ON crypto_olm_session (account_id, sender_key); diff --git a/crypto/sql_store_upgrade/17-decrypted-olm-messages.sql b/crypto/sql_store_upgrade/17-decrypted-olm-messages.sql new file mode 100644 index 00000000..525bbb52 --- /dev/null +++ b/crypto/sql_store_upgrade/17-decrypted-olm-messages.sql @@ -0,0 +1,11 @@ +-- v17 (compatible with v15+): Add table for decrypted Olm message hashes +CREATE TABLE crypto_olm_message_hash ( + account_id TEXT NOT NULL, + received_at BIGINT NOT NULL, + message_hash bytea NOT NULL PRIMARY KEY, + + CONSTRAINT crypto_olm_message_hash_account_fkey FOREIGN KEY (account_id) + REFERENCES crypto_account (account_id) ON DELETE CASCADE ON UPDATE CASCADE +); + +CREATE INDEX crypto_olm_message_hash_account_idx ON crypto_olm_message_hash (account_id); diff --git a/crypto/sql_store_upgrade/18-megolm-inbound-session-backup-index.sql b/crypto/sql_store_upgrade/18-megolm-inbound-session-backup-index.sql new file mode 100644 index 00000000..da26da0f --- /dev/null +++ b/crypto/sql_store_upgrade/18-megolm-inbound-session-backup-index.sql @@ -0,0 +1,2 @@ +-- v18 (compatible with v15+): Add an index to the megolm_inbound_session table to make finding sessions to backup faster +CREATE INDEX crypto_megolm_inbound_session_backup_idx ON crypto_megolm_inbound_session(account_id, key_backup_version) WHERE session IS NOT NULL; diff --git a/crypto/sql_store_upgrade/19-megolm-session-source.sql b/crypto/sql_store_upgrade/19-megolm-session-source.sql new file mode 100644 index 00000000..f624222f --- /dev/null +++ b/crypto/sql_store_upgrade/19-megolm-session-source.sql @@ -0,0 +1,2 @@ +-- v19 (compatible with v15+): Store megolm session source +ALTER TABLE crypto_megolm_inbound_session ADD COLUMN key_source TEXT NOT NULL DEFAULT ''; diff --git a/crypto/sql_store_upgrade/upgrade.go b/crypto/sql_store_upgrade/upgrade.go index 08c995da..10c0c0c0 100644 --- a/crypto/sql_store_upgrade/upgrade.go +++ b/crypto/sql_store_upgrade/upgrade.go @@ -22,7 +22,7 @@ const VersionTableName = "crypto_version" var fs embed.FS func init() { - Table.Register(-1, 3, 0, "Unsupported version", false, func(ctx context.Context, database *dbutil.Database) error { + Table.Register(-1, 3, 0, "Unsupported version", dbutil.TxnModeOff, func(ctx context.Context, database *dbutil.Database) error { return fmt.Errorf("upgrading from versions 1 and 2 of the crypto store is no longer supported in mautrix-go v0.12+") }) Table.RegisterFS(fs) diff --git a/crypto/ssss/client.go b/crypto/ssss/client.go index 0cfdd24f..8691d032 100644 --- a/crypto/ssss/client.go +++ b/crypto/ssss/client.go @@ -53,7 +53,7 @@ func (mach *Machine) SetDefaultKeyID(ctx context.Context, keyID string) error { // GetKeyData gets the details about the given key ID. func (mach *Machine) GetKeyData(ctx context.Context, keyID string) (keyData *KeyMetadata, err error) { - keyData = &KeyMetadata{id: keyID} + keyData = &KeyMetadata{} err = mach.Client.GetAccountData(ctx, fmt.Sprintf("%s.%s", event.AccountDataSecretStorageKey.Type, keyID), keyData) return } @@ -95,6 +95,22 @@ func (mach *Machine) SetEncryptedAccountData(ctx context.Context, eventType even return mach.Client.SetAccountData(ctx, eventType.Type, &EncryptedAccountDataEventContent{Encrypted: encrypted}) } +// SetEncryptedAccountDataWithMetadata encrypts the given data with the given keys and stores it, +// alongside the unencrypted metadata, on the server. +func (mach *Machine) SetEncryptedAccountDataWithMetadata(ctx context.Context, eventType event.Type, data []byte, metadata map[string]any, keys ...*Key) error { + if len(keys) == 0 { + return ErrNoKeyGiven + } + encrypted := make(map[string]EncryptedKeyData, len(keys)) + for _, key := range keys { + encrypted[key.ID] = key.Encrypt(eventType.Type, data) + } + return mach.Client.SetAccountData(ctx, eventType.Type, &EncryptedAccountDataEventContent{ + Encrypted: encrypted, + Metadata: metadata, + }) +} + // GenerateAndUploadKey generates a new SSSS key and stores the metadata on the server. func (mach *Machine) GenerateAndUploadKey(ctx context.Context, passphrase string) (key *Key, err error) { key, err = NewKey(passphrase) diff --git a/crypto/ssss/key.go b/crypto/ssss/key.go index 3c38d3cd..78ebd8f3 100644 --- a/crypto/ssss/key.go +++ b/crypto/ssss/key.go @@ -7,11 +7,14 @@ package ssss import ( - "crypto/rand" + "crypto/hmac" + "crypto/sha256" "encoding/base64" "fmt" "strings" + "go.mau.fi/util/random" + "maunium.net/go/mautrix/crypto/utils" ) @@ -33,10 +36,7 @@ func NewKey(passphrase string) (*Key, error) { if len(passphrase) > 0 { // There's a passphrase. We need to generate a salt for it, set the metadata // and then compute the key using the passphrase and the metadata. - saltBytes := make([]byte, 24) - if _, err := rand.Read(saltBytes); err != nil { - return nil, fmt.Errorf("failed to get random bytes for salt: %w", err) - } + saltBytes := random.Bytes(24) keyData.Passphrase = &PassphraseMetadata{ Algorithm: PassphraseAlgorithmPBKDF2, Iterations: 500000, @@ -50,25 +50,21 @@ func NewKey(passphrase string) (*Key, error) { } } else { // No passphrase, just generate a random key - ssssKey = make([]byte, 32) - if _, err := rand.Read(ssssKey); err != nil { - return nil, fmt.Errorf("failed to get random bytes for key: %w", err) - } + ssssKey = random.Bytes(32) } // Generate a random ID for the key. It's what identifies the key in account data. - keyIDBytes := make([]byte, 24) - if _, err := rand.Read(keyIDBytes); err != nil { - return nil, fmt.Errorf("failed to get random bytes for key ID: %w", err) - } + keyIDBytes := random.Bytes(24) // We store a certain hash in the key metadata so that clients can check if the user entered the correct key. - var ivBytes [utils.AESCTRIVLength]byte - if _, err := rand.Read(ivBytes[:]); err != nil { - return nil, fmt.Errorf("failed to get random bytes for IV: %w", err) + ivBytes := random.Bytes(utils.AESCTRIVLength) + keyData.IV = base64.RawStdEncoding.EncodeToString(ivBytes) + macBytes, err := keyData.calculateHash(ssssKey) + if err != nil { + // This should never happen because we just generated the IV and key. + return nil, fmt.Errorf("failed to calculate hash: %w", err) } - keyData.IV = base64.RawStdEncoding.EncodeToString(ivBytes[:]) - keyData.MAC = keyData.calculateHash(ssssKey) + keyData.MAC = base64.RawStdEncoding.EncodeToString(macBytes) return &Key{ Key: ssssKey, @@ -114,12 +110,18 @@ func (key *Key) Decrypt(eventType string, data EncryptedKeyData) ([]byte, error) return nil, err } + mac, err := base64.RawStdEncoding.DecodeString(strings.TrimRight(data.MAC, "=")) + if err != nil { + return nil, err + } + // derive the AES and HMAC keys for the requested event type using the SSSS key aesKey, hmacKey := utils.DeriveKeysSHA256(key.Key, eventType) // compare the stored MAC with the one we calculated from the ciphertext - calcMac := utils.HMACSHA256B64(payload, hmacKey) - if strings.TrimRight(data.MAC, "=") != calcMac { + h := hmac.New(sha256.New, hmacKey[:]) + h.Write(payload) + if !hmac.Equal(h.Sum(nil), mac) { return nil, ErrKeyDataMACMismatch } diff --git a/crypto/ssss/meta.go b/crypto/ssss/meta.go index e752cf0c..34775fa7 100644 --- a/crypto/ssss/meta.go +++ b/crypto/ssss/meta.go @@ -7,7 +7,10 @@ package ssss import ( + "crypto/hmac" + "crypto/sha256" "encoding/base64" + "errors" "fmt" "strings" @@ -17,8 +20,6 @@ import ( // KeyMetadata represents server-side metadata about a SSSS key. The metadata can be used to get // the actual SSSS key from a passphrase or recovery key. type KeyMetadata struct { - id string - Name string `json:"name"` Algorithm Algorithm `json:"algorithm"` @@ -31,53 +32,92 @@ type KeyMetadata struct { } // VerifyRecoveryKey verifies that the given passphrase is valid and returns the computed SSSS key. -func (kd *KeyMetadata) VerifyPassphrase(passphrase string) (*Key, error) { +func (kd *KeyMetadata) VerifyPassphrase(keyID, passphrase string) (*Key, error) { ssssKey, err := kd.Passphrase.GetKey(passphrase) if err != nil { return nil, err - } else if !kd.VerifyKey(ssssKey) { - return nil, ErrIncorrectSSSSKey + } + err = kd.verifyKey(ssssKey) + if err != nil && !errors.Is(err, ErrUnverifiableKey) { + return nil, err } return &Key{ - ID: kd.id, + ID: keyID, Key: ssssKey, Metadata: kd, }, nil } // VerifyRecoveryKey verifies that the given recovery key is valid and returns the decoded SSSS key. -func (kd *KeyMetadata) VerifyRecoveryKey(recoverKey string) (*Key, error) { - ssssKey := utils.DecodeBase58RecoveryKey(recoverKey) +func (kd *KeyMetadata) VerifyRecoveryKey(keyID, recoveryKey string) (*Key, error) { + ssssKey := utils.DecodeBase58RecoveryKey(recoveryKey) if ssssKey == nil { return nil, ErrInvalidRecoveryKey - } else if !kd.VerifyKey(ssssKey) { - return nil, ErrIncorrectSSSSKey + } + err := kd.verifyKey(ssssKey) + if err != nil && !errors.Is(err, ErrUnverifiableKey) { + return nil, err } return &Key{ - ID: kd.id, + ID: keyID, Key: ssssKey, Metadata: kd, - }, nil + }, err +} + +func (kd *KeyMetadata) verifyKey(key []byte) error { + if kd.MAC == "" || kd.IV == "" { + return ErrUnverifiableKey + } + unpaddedMAC := strings.TrimRight(kd.MAC, "=") + expectedMACLength := base64.RawStdEncoding.EncodedLen(utils.SHAHashLength) + if len(unpaddedMAC) != expectedMACLength { + return fmt.Errorf("%w: invalid mac length %d (expected %d)", ErrCorruptedKeyMetadata, len(unpaddedMAC), expectedMACLength) + } + expectedMAC, err := base64.RawStdEncoding.DecodeString(unpaddedMAC) + if err != nil { + return fmt.Errorf("%w: failed to decode mac: %w", ErrCorruptedKeyMetadata, err) + } + calculatedMAC, err := kd.calculateHash(key) + if err != nil { + return err + } + // This doesn't really need to be constant time since it's fully local, but might as well be. + if !hmac.Equal(expectedMAC, calculatedMAC) { + return ErrIncorrectSSSSKey + } + return nil } // VerifyKey verifies the SSSS key is valid by calculating and comparing its MAC. func (kd *KeyMetadata) VerifyKey(key []byte) bool { - return strings.TrimRight(kd.MAC, "=") == kd.calculateHash(key) + return kd.verifyKey(key) == nil } // calculateHash calculates the hash used for checking if the key is entered correctly as described // in the spec: https://matrix.org/docs/spec/client_server/unstable#m-secret-storage-v1-aes-hmac-sha2 -func (kd *KeyMetadata) calculateHash(key []byte) string { +func (kd *KeyMetadata) calculateHash(key []byte) ([]byte, error) { aesKey, hmacKey := utils.DeriveKeysSHA256(key, "") + unpaddedIV := strings.TrimRight(kd.IV, "=") + expectedIVLength := base64.RawStdEncoding.EncodedLen(utils.AESCTRIVLength) + if len(unpaddedIV) < expectedIVLength || len(unpaddedIV) > expectedIVLength*3 { + return nil, fmt.Errorf("%w: invalid iv length %d (expected %d)", ErrCorruptedKeyMetadata, len(unpaddedIV), expectedIVLength) + } + rawIVBytes, err := base64.RawStdEncoding.DecodeString(unpaddedIV) + if err != nil { + return nil, fmt.Errorf("%w: failed to decode iv: %w", ErrCorruptedKeyMetadata, err) + } + // TODO log a warning for non-16 byte IVs? + // Certain broken clients like nheko generated 32-byte IVs where only the first 16 bytes were used. + ivBytes := *(*[utils.AESCTRIVLength]byte)(rawIVBytes[:utils.AESCTRIVLength]) - var ivBytes [utils.AESCTRIVLength]byte - _, _ = base64.RawStdEncoding.Decode(ivBytes[:], []byte(strings.TrimRight(kd.IV, "="))) - - cipher := utils.XorA256CTR(make([]byte, utils.AESCTRKeyLength), aesKey, ivBytes) - - return utils.HMACSHA256B64(cipher, hmacKey) + zeroes := make([]byte, utils.AESCTRKeyLength) + encryptedZeroes := utils.XorA256CTR(zeroes, aesKey, ivBytes) + h := hmac.New(sha256.New, hmacKey[:]) + h.Write(encryptedZeroes) + return h.Sum(nil), nil } // PassphraseMetadata represents server-side metadata about a SSSS key passphrase. diff --git a/crypto/ssss/meta_test.go b/crypto/ssss/meta_test.go index 2ad8f62a..d59809c7 100644 --- a/crypto/ssss/meta_test.go +++ b/crypto/ssss/meta_test.go @@ -8,10 +8,10 @@ package ssss_test import ( "encoding/json" - "errors" "testing" "github.com/stretchr/testify/assert" + "go.mau.fi/util/exerrors" "maunium.net/go/mautrix/crypto/ssss" ) @@ -41,12 +41,42 @@ const key2Meta = ` } ` +const key2MetaUnverified = ` +{ + "algorithm": "m.secret_storage.v1.aes-hmac-sha2" +} +` + +const key2MetaLongIV = ` +{ + "algorithm": "m.secret_storage.v1.aes-hmac-sha2", + "iv": "O0BOvTqiIAYjC+RMcyHfW2f/gdxjceTxoYtNlpPduJ8=", + "mac": "7k6OruQlWg0UmQjxGZ0ad4Q6DdwkgnoI7G6X3IjBYtI=" +} +` + +const key2MetaBrokenIV = ` +{ + "algorithm": "m.secret_storage.v1.aes-hmac-sha2", + "iv": "MeowMeowMeow", + "mac": "7k6OruQlWg0UmQjxGZ0ad4Q6DdwkgnoI7G6X3IjBYtI=" +} +` + +const key2MetaBrokenMAC = ` +{ + "algorithm": "m.secret_storage.v1.aes-hmac-sha2", + "iv": "O0BOvTqiIAYjC+RMcyHfWw==", + "mac": "7k6OruQlWg0UmQjxGZ0ad4Q6DdwkgnoI7G6X3IjBYtIMeowMeowMeow" +} +` + const key2ID = "NVe5vK6lZS9gEMQLJw0yqkzmE5Mr7dLv" const key2RecoveryKey = "EsUC xSxt XJgQ dz19 8WBZ rHdE GZo7 ybsn EFmG Y5HY MDAG GNWe" -func getKey1Meta() *ssss.KeyMetadata { +func getKeyMeta(meta string) *ssss.KeyMetadata { var km ssss.KeyMetadata - err := json.Unmarshal([]byte(key1Meta), &km) + err := json.Unmarshal([]byte(meta), &km) if err != nil { panic(err) } @@ -54,82 +84,91 @@ func getKey1Meta() *ssss.KeyMetadata { } func getKey1() *ssss.Key { - km := getKey1Meta() - key, err := km.VerifyRecoveryKey(key1RecoveryKey) - if err != nil { - panic(err) - } - key.ID = key1ID - return key -} - -func getKey2Meta() *ssss.KeyMetadata { - var km ssss.KeyMetadata - err := json.Unmarshal([]byte(key2Meta), &km) - if err != nil { - panic(err) - } - return &km + return exerrors.Must(getKeyMeta(key1Meta).VerifyRecoveryKey(key1ID, key1RecoveryKey)) } func getKey2() *ssss.Key { - km := getKey2Meta() - key, err := km.VerifyRecoveryKey(key2RecoveryKey) - if err != nil { - panic(err) - } - key.ID = key2ID - return key + return exerrors.Must(getKeyMeta(key2Meta).VerifyRecoveryKey(key2ID, key2RecoveryKey)) } func TestKeyMetadata_VerifyRecoveryKey_Correct(t *testing.T) { - km := getKey1Meta() - key, err := km.VerifyRecoveryKey(key1RecoveryKey) + km := getKeyMeta(key1Meta) + key, err := km.VerifyRecoveryKey(key1ID, key1RecoveryKey) assert.NoError(t, err) assert.NotNil(t, key) assert.Equal(t, key1RecoveryKey, key.RecoveryKey()) } func TestKeyMetadata_VerifyRecoveryKey_Correct2(t *testing.T) { - km := getKey2Meta() - key, err := km.VerifyRecoveryKey(key2RecoveryKey) + km := getKeyMeta(key2Meta) + key, err := km.VerifyRecoveryKey(key2ID, key2RecoveryKey) assert.NoError(t, err) assert.NotNil(t, key) assert.Equal(t, key2RecoveryKey, key.RecoveryKey()) } +func TestKeyMetadata_VerifyRecoveryKey_NonCompliant_LongIV(t *testing.T) { + km := getKeyMeta(key2MetaLongIV) + key, err := km.VerifyRecoveryKey(key2ID, key2RecoveryKey) + assert.NoError(t, err) + assert.NotNil(t, key) + assert.Equal(t, key2RecoveryKey, key.RecoveryKey()) +} + +func TestKeyMetadata_VerifyRecoveryKey_Unverified(t *testing.T) { + km := getKeyMeta(key2MetaUnverified) + key, err := km.VerifyRecoveryKey(key2ID, key2RecoveryKey) + assert.ErrorIs(t, err, ssss.ErrUnverifiableKey) + assert.NotNil(t, key) + assert.Equal(t, key2RecoveryKey, key.RecoveryKey()) +} + func TestKeyMetadata_VerifyRecoveryKey_Invalid(t *testing.T) { - km := getKey1Meta() - key, err := km.VerifyRecoveryKey("foo") - assert.True(t, errors.Is(err, ssss.ErrInvalidRecoveryKey), "unexpected error: %v", err) + km := getKeyMeta(key1Meta) + key, err := km.VerifyRecoveryKey(key1ID, "foo") + assert.ErrorIs(t, err, ssss.ErrInvalidRecoveryKey) assert.Nil(t, key) } func TestKeyMetadata_VerifyRecoveryKey_Incorrect(t *testing.T) { - km := getKey1Meta() - key, err := km.VerifyRecoveryKey(key2RecoveryKey) - assert.True(t, errors.Is(err, ssss.ErrIncorrectSSSSKey), "unexpected error: %v", err) + km := getKeyMeta(key1Meta) + key, err := km.VerifyRecoveryKey(key2ID, key2RecoveryKey) + assert.ErrorIs(t, err, ssss.ErrIncorrectSSSSKey) assert.Nil(t, key) } func TestKeyMetadata_VerifyPassphrase_Correct(t *testing.T) { - km := getKey1Meta() - key, err := km.VerifyPassphrase(key1Passphrase) + km := getKeyMeta(key1Meta) + key, err := km.VerifyPassphrase(key1ID, key1Passphrase) assert.NoError(t, err) assert.NotNil(t, key) assert.Equal(t, key1RecoveryKey, key.RecoveryKey()) } func TestKeyMetadata_VerifyPassphrase_Incorrect(t *testing.T) { - km := getKey1Meta() - key, err := km.VerifyPassphrase("incorrect horse battery staple") - assert.True(t, errors.Is(err, ssss.ErrIncorrectSSSSKey), "unexpected error %v", err) + km := getKeyMeta(key1Meta) + key, err := km.VerifyPassphrase(key1ID, "incorrect horse battery staple") + assert.ErrorIs(t, err, ssss.ErrIncorrectSSSSKey) assert.Nil(t, key) } func TestKeyMetadata_VerifyPassphrase_NotSet(t *testing.T) { - km := getKey2Meta() - key, err := km.VerifyPassphrase("hmm") - assert.True(t, errors.Is(err, ssss.ErrNoPassphrase), "unexpected error %v", err) + km := getKeyMeta(key2Meta) + key, err := km.VerifyPassphrase(key2ID, "hmm") + assert.ErrorIs(t, err, ssss.ErrNoPassphrase) + assert.Nil(t, key) +} + +func TestKeyMetadata_VerifyRecoveryKey_CorruptedIV(t *testing.T) { + km := getKeyMeta(key2MetaBrokenIV) + key, err := km.VerifyRecoveryKey(key2ID, key2RecoveryKey) + assert.ErrorIs(t, err, ssss.ErrCorruptedKeyMetadata) + assert.Nil(t, key) +} + +func TestKeyMetadata_VerifyRecoveryKey_CorruptedMAC(t *testing.T) { + km := getKeyMeta(key2MetaBrokenMAC) + key, err := km.VerifyRecoveryKey(key2ID, key2RecoveryKey) + assert.ErrorIs(t, err, ssss.ErrCorruptedKeyMetadata) assert.Nil(t, key) } diff --git a/crypto/ssss/types.go b/crypto/ssss/types.go index 60852c55..b7465d3e 100644 --- a/crypto/ssss/types.go +++ b/crypto/ssss/types.go @@ -26,6 +26,8 @@ var ( ErrUnsupportedPassphraseAlgorithm = errors.New("unsupported passphrase KDF algorithm") ErrIncorrectSSSSKey = errors.New("incorrect SSSS key") ErrInvalidRecoveryKey = errors.New("invalid recovery key") + ErrCorruptedKeyMetadata = errors.New("corrupted recovery key metadata") + ErrUnverifiableKey = errors.New("cannot verify recovery key: missing MAC or IV in metadata") ) // Algorithm is the identifier for an SSSS encryption algorithm. @@ -56,6 +58,7 @@ type EncryptedKeyData struct { type EncryptedAccountDataEventContent struct { Encrypted map[string]EncryptedKeyData `json:"encrypted"` + Metadata map[string]any `json:"com.beeper.metadata,omitzero"` } func (ed *EncryptedAccountDataEventContent) Decrypt(eventType string, key *Key) ([]byte, error) { diff --git a/crypto/store.go b/crypto/store.go index 3b6e6564..7620cf35 100644 --- a/crypto/store.go +++ b/crypto/store.go @@ -9,10 +9,14 @@ package crypto import ( "context" "fmt" + "slices" "sort" "sync" + "time" "go.mau.fi/util/dbutil" + "go.mau.fi/util/exsync" + "golang.org/x/exp/maps" "maunium.net/go/mautrix/event" "maunium.net/go/mautrix/id" @@ -41,22 +45,33 @@ type Store interface { HasSession(context.Context, id.SenderKey) bool // GetSessions returns all Olm sessions in the store with the given sender key. GetSessions(context.Context, id.SenderKey) (OlmSessionList, error) - // GetLatestSession returns the session with the highest session ID (lexiographically sorting). - // It's usually safe to return the most recently added session if sorting by session ID is too difficult. + // GetLatestSession returns the most recent session that should be used for encrypting outbound messages. + // It's usually the one with the most recent successful decryption or the highest ID lexically. GetLatestSession(context.Context, id.SenderKey) (*OlmSession, error) + // GetNewestSessionCreationTS returns the creation timestamp of the most recently created session for the given sender key. + GetNewestSessionCreationTS(context.Context, id.SenderKey) (time.Time, error) // UpdateSession updates a session that has previously been inserted with AddSession. UpdateSession(context.Context, id.SenderKey, *OlmSession) error + // DeleteSession deletes the given session that has been previously inserted with AddSession. + DeleteSession(context.Context, id.SenderKey, *OlmSession) error + + // PutOlmHash marks a given olm message hash as handled. + PutOlmHash(context.Context, [32]byte, time.Time) error + // GetOlmHash gets the time that a given olm hash was handled. + GetOlmHash(context.Context, [32]byte) (time.Time, error) + // DeleteOldOlmHashes deletes all olm hashes that were handled before the given time. + DeleteOldOlmHashes(context.Context, time.Time) error // PutGroupSession inserts an inbound Megolm session into the store. If an earlier withhold event has been inserted // with PutWithheldGroupSession, this call should replace that. However, PutWithheldGroupSession must not replace // sessions inserted with this call. - PutGroupSession(context.Context, id.RoomID, id.SenderKey, id.SessionID, *InboundGroupSession) error + PutGroupSession(context.Context, *InboundGroupSession) error // GetGroupSession gets an inbound Megolm session from the store. If the group session has been withheld // (i.e. a room key withheld event has been saved with PutWithheldGroupSession), this should return the // ErrGroupSessionWithheld error. The caller may use GetWithheldGroupSession to find more details. - GetGroupSession(context.Context, id.RoomID, id.SenderKey, id.SessionID) (*InboundGroupSession, error) + GetGroupSession(context.Context, id.RoomID, id.SessionID) (*InboundGroupSession, error) // RedactGroupSession removes the session data for the given inbound Megolm session from the store. - RedactGroupSession(context.Context, id.RoomID, id.SenderKey, id.SessionID, string) error + RedactGroupSession(context.Context, id.RoomID, id.SessionID, string) error // RedactGroupSessions removes the session data for all inbound Megolm sessions from a specific device and/or in a specific room. RedactGroupSessions(context.Context, id.RoomID, id.SenderKey, string) ([]id.SessionID, error) // RedactExpiredGroupSessions removes the session data for all inbound Megolm sessions that have expired. @@ -66,7 +81,7 @@ type Store interface { // PutWithheldGroupSession tells the store that a specific Megolm session was withheld. PutWithheldGroupSession(context.Context, event.RoomKeyWithheldEventContent) error // GetWithheldGroupSession gets the event content that was previously inserted with PutWithheldGroupSession. - GetWithheldGroupSession(context.Context, id.RoomID, id.SenderKey, id.SessionID) (*event.RoomKeyWithheldEventContent, error) + GetWithheldGroupSession(context.Context, id.RoomID, id.SessionID) (*event.RoomKeyWithheldEventContent, error) // GetGroupSessionsForRoom gets all the inbound Megolm sessions for a specific room. This is used for creating key // export files. Unlike GetGroupSession, this should not return any errors about withheld keys. @@ -131,6 +146,8 @@ type Store interface { IsKeySignedBy(ctx context.Context, userID id.UserID, key id.Ed25519, signedByUser id.UserID, signedByKey id.Ed25519) (bool, error) // DropSignaturesByKey deletes the signatures made by the given user and key from the store. It returns the number of signatures deleted. DropSignaturesByKey(context.Context, id.UserID, id.Ed25519) (int64, error) + // GetSignaturesForKeyBy retrieves the stored signatures for a given cross-signing or device key, by the given signer. + GetSignaturesForKeyBy(context.Context, id.UserID, id.Ed25519, id.UserID) (map[id.Ed25519]string, error) // PutSecret stores a named secret, replacing it if it exists already. PutSecret(context.Context, id.Secret, string) error @@ -160,8 +177,8 @@ type MemoryStore struct { Account *OlmAccount Sessions map[id.SenderKey]OlmSessionList - GroupSessions map[id.RoomID]map[id.SenderKey]map[id.SessionID]*InboundGroupSession - WithheldGroupSessions map[id.RoomID]map[id.SenderKey]map[id.SessionID]*event.RoomKeyWithheldEventContent + GroupSessions map[id.RoomID]map[id.SessionID]*InboundGroupSession + WithheldGroupSessions map[id.RoomID]map[id.SessionID]*event.RoomKeyWithheldEventContent OutGroupSessions map[id.RoomID]*OutboundGroupSession SharedGroupSessions map[id.UserID]map[id.IdentityKey]map[id.SessionID]struct{} MessageIndices map[messageIndexKey]messageIndexValue @@ -170,6 +187,7 @@ type MemoryStore struct { KeySignatures map[id.UserID]map[id.Ed25519]map[id.UserID]map[id.Ed25519]string OutdatedUsers map[id.UserID]struct{} Secrets map[id.Secret]string + OlmHashes *exsync.Set[[32]byte] } var _ Store = (*MemoryStore)(nil) @@ -182,8 +200,8 @@ func NewMemoryStore(saveCallback func() error) *MemoryStore { save: saveCallback, Sessions: make(map[id.SenderKey]OlmSessionList), - GroupSessions: make(map[id.RoomID]map[id.SenderKey]map[id.SessionID]*InboundGroupSession), - WithheldGroupSessions: make(map[id.RoomID]map[id.SenderKey]map[id.SessionID]*event.RoomKeyWithheldEventContent), + GroupSessions: make(map[id.RoomID]map[id.SessionID]*InboundGroupSession), + WithheldGroupSessions: make(map[id.RoomID]map[id.SessionID]*event.RoomKeyWithheldEventContent), OutGroupSessions: make(map[id.RoomID]*OutboundGroupSession), SharedGroupSessions: make(map[id.UserID]map[id.IdentityKey]map[id.SessionID]struct{}), MessageIndices: make(map[messageIndexKey]messageIndexValue), @@ -192,14 +210,14 @@ func NewMemoryStore(saveCallback func() error) *MemoryStore { KeySignatures: make(map[id.UserID]map[id.Ed25519]map[id.UserID]map[id.Ed25519]string), OutdatedUsers: make(map[id.UserID]struct{}), Secrets: make(map[id.Secret]string), + OlmHashes: exsync.NewSet[[32]byte](), } } func (gs *MemoryStore) Flush(_ context.Context) error { gs.lock.Lock() - err := gs.save() - gs.lock.Unlock() - return err + defer gs.lock.Unlock() + return gs.save() } func (gs *MemoryStore) GetAccount(_ context.Context) (*OlmAccount, error) { @@ -208,31 +226,42 @@ func (gs *MemoryStore) GetAccount(_ context.Context) (*OlmAccount, error) { func (gs *MemoryStore) PutAccount(_ context.Context, account *OlmAccount) error { gs.lock.Lock() + defer gs.lock.Unlock() gs.Account = account - err := gs.save() - gs.lock.Unlock() - return err + return gs.save() } func (gs *MemoryStore) GetSessions(_ context.Context, senderKey id.SenderKey) (OlmSessionList, error) { gs.lock.Lock() + defer gs.lock.Unlock() sessions, ok := gs.Sessions[senderKey] if !ok { sessions = []*OlmSession{} gs.Sessions[senderKey] = sessions } - gs.lock.Unlock() return sessions, nil } func (gs *MemoryStore) AddSession(_ context.Context, senderKey id.SenderKey, session *OlmSession) error { gs.lock.Lock() - sessions, _ := gs.Sessions[senderKey] + defer gs.lock.Unlock() + sessions := gs.Sessions[senderKey] gs.Sessions[senderKey] = append(sessions, session) sort.Sort(gs.Sessions[senderKey]) - err := gs.save() - gs.lock.Unlock() - return err + return gs.save() +} + +func (gs *MemoryStore) DeleteSession(ctx context.Context, senderKey id.SenderKey, target *OlmSession) error { + gs.lock.Lock() + defer gs.lock.Unlock() + sessions, ok := gs.Sessions[senderKey] + if !ok { + return nil + } + gs.Sessions[senderKey] = slices.DeleteFunc(sessions, func(session *OlmSession) bool { + return session == target + }) + return gs.save() } func (gs *MemoryStore) UpdateSession(_ context.Context, _ id.SenderKey, _ *OlmSession) error { @@ -242,102 +271,112 @@ func (gs *MemoryStore) UpdateSession(_ context.Context, _ id.SenderKey, _ *OlmSe func (gs *MemoryStore) HasSession(_ context.Context, senderKey id.SenderKey) bool { gs.lock.RLock() + defer gs.lock.RUnlock() sessions, ok := gs.Sessions[senderKey] - gs.lock.RUnlock() return ok && len(sessions) > 0 && !sessions[0].Expired() } +func (gs *MemoryStore) PutOlmHash(_ context.Context, hash [32]byte, receivedAt time.Time) error { + gs.OlmHashes.Add(hash) + return nil +} + +func (gs *MemoryStore) GetOlmHash(_ context.Context, hash [32]byte) (time.Time, error) { + if gs.OlmHashes.Has(hash) { + // The time isn't that important, so we just return the current time + return time.Now(), nil + } + return time.Time{}, nil +} + +func (gs *MemoryStore) DeleteOldOlmHashes(_ context.Context, beforeTS time.Time) error { + return nil +} + func (gs *MemoryStore) GetLatestSession(_ context.Context, senderKey id.SenderKey) (*OlmSession, error) { gs.lock.RLock() + defer gs.lock.RUnlock() sessions, ok := gs.Sessions[senderKey] - gs.lock.RUnlock() if !ok || len(sessions) == 0 { return nil, nil } - return sessions[0], nil + return sessions[len(sessions)-1], nil } -func (gs *MemoryStore) getGroupSessions(roomID id.RoomID, senderKey id.SenderKey) map[id.SessionID]*InboundGroupSession { +func (gs *MemoryStore) GetNewestSessionCreationTS(ctx context.Context, senderKey id.SenderKey) (createdAt time.Time, err error) { + var sess *OlmSession + sess, err = gs.GetLatestSession(ctx, senderKey) + if sess != nil { + createdAt = sess.CreationTime + } + return +} + +func (gs *MemoryStore) getGroupSessions(roomID id.RoomID) map[id.SessionID]*InboundGroupSession { room, ok := gs.GroupSessions[roomID] if !ok { - room = make(map[id.SenderKey]map[id.SessionID]*InboundGroupSession) + room = make(map[id.SessionID]*InboundGroupSession) gs.GroupSessions[roomID] = room } - sender, ok := room[senderKey] - if !ok { - sender = make(map[id.SessionID]*InboundGroupSession) - room[senderKey] = sender - } - return sender + return room } -func (gs *MemoryStore) PutGroupSession(_ context.Context, roomID id.RoomID, senderKey id.SenderKey, sessionID id.SessionID, igs *InboundGroupSession) error { +func (gs *MemoryStore) PutGroupSession(_ context.Context, igs *InboundGroupSession) error { gs.lock.Lock() - gs.getGroupSessions(roomID, senderKey)[sessionID] = igs - err := gs.save() - gs.lock.Unlock() - return err + defer gs.lock.Unlock() + gs.getGroupSessions(igs.RoomID)[igs.ID()] = igs + return gs.save() } -func (gs *MemoryStore) GetGroupSession(_ context.Context, roomID id.RoomID, senderKey id.SenderKey, sessionID id.SessionID) (*InboundGroupSession, error) { +func (gs *MemoryStore) GetGroupSession(_ context.Context, roomID id.RoomID, sessionID id.SessionID) (*InboundGroupSession, error) { gs.lock.Lock() - session, ok := gs.getGroupSessions(roomID, senderKey)[sessionID] + defer gs.lock.Unlock() + session, ok := gs.getGroupSessions(roomID)[sessionID] if !ok { - withheld, ok := gs.getWithheldGroupSessions(roomID, senderKey)[sessionID] - gs.lock.Unlock() + withheld, ok := gs.getWithheldGroupSessions(roomID)[sessionID] if ok { return nil, fmt.Errorf("%w (%s)", ErrGroupSessionWithheld, withheld.Code) } return nil, nil } - gs.lock.Unlock() return session, nil } -func (gs *MemoryStore) RedactGroupSession(_ context.Context, roomID id.RoomID, senderKey id.SenderKey, sessionID id.SessionID, reason string) error { +func (gs *MemoryStore) RedactGroupSession(_ context.Context, roomID id.RoomID, sessionID id.SessionID, reason string) error { gs.lock.Lock() - delete(gs.getGroupSessions(roomID, senderKey), sessionID) - err := gs.save() - gs.lock.Unlock() - return err + defer gs.lock.Unlock() + delete(gs.getGroupSessions(roomID), sessionID) + return gs.save() } func (gs *MemoryStore) RedactGroupSessions(_ context.Context, roomID id.RoomID, senderKey id.SenderKey, reason string) ([]id.SessionID, error) { gs.lock.Lock() + defer gs.lock.Unlock() var sessionIDs []id.SessionID if roomID != "" && senderKey != "" { - sessions := gs.getGroupSessions(roomID, senderKey) - for sessionID := range sessions { - sessionIDs = append(sessionIDs, sessionID) - delete(sessions, sessionID) + sessions := gs.getGroupSessions(roomID) + for sessionID, session := range sessions { + if session.SenderKey == senderKey { + sessionIDs = append(sessionIDs, sessionID) + delete(sessions, sessionID) + } } } else if senderKey != "" { for _, room := range gs.GroupSessions { - sessions, ok := room[senderKey] - if ok { - for sessionID := range sessions { + for sessionID, session := range room { + if session.SenderKey == senderKey { sessionIDs = append(sessionIDs, sessionID) + delete(room, sessionID) } - delete(room, senderKey) } } } else if roomID != "" { - room, ok := gs.GroupSessions[roomID] - if ok { - for senderKey := range room { - sessions := room[senderKey] - for sessionID := range sessions { - sessionIDs = append(sessionIDs, sessionID) - } - } - delete(gs.GroupSessions, roomID) - } + sessionIDs = maps.Keys(gs.GroupSessions[roomID]) + delete(gs.GroupSessions, roomID) } else { return nil, fmt.Errorf("room ID or sender key must be provided for redacting sessions") } - err := gs.save() - gs.lock.Unlock() - return sessionIDs, err + return sessionIDs, gs.save() } func (gs *MemoryStore) RedactExpiredGroupSessions(_ context.Context) ([]id.SessionID, error) { @@ -348,32 +387,26 @@ func (gs *MemoryStore) RedactOutdatedGroupSessions(_ context.Context) ([]id.Sess return nil, fmt.Errorf("not implemented") } -func (gs *MemoryStore) getWithheldGroupSessions(roomID id.RoomID, senderKey id.SenderKey) map[id.SessionID]*event.RoomKeyWithheldEventContent { +func (gs *MemoryStore) getWithheldGroupSessions(roomID id.RoomID) map[id.SessionID]*event.RoomKeyWithheldEventContent { room, ok := gs.WithheldGroupSessions[roomID] if !ok { - room = make(map[id.SenderKey]map[id.SessionID]*event.RoomKeyWithheldEventContent) + room = make(map[id.SessionID]*event.RoomKeyWithheldEventContent) gs.WithheldGroupSessions[roomID] = room } - sender, ok := room[senderKey] - if !ok { - sender = make(map[id.SessionID]*event.RoomKeyWithheldEventContent) - room[senderKey] = sender - } - return sender + return room } func (gs *MemoryStore) PutWithheldGroupSession(_ context.Context, content event.RoomKeyWithheldEventContent) error { gs.lock.Lock() - gs.getWithheldGroupSessions(content.RoomID, content.SenderKey)[content.SessionID] = &content - err := gs.save() - gs.lock.Unlock() - return err + defer gs.lock.Unlock() + gs.getWithheldGroupSessions(content.RoomID)[content.SessionID] = &content + return gs.save() } -func (gs *MemoryStore) GetWithheldGroupSession(_ context.Context, roomID id.RoomID, senderKey id.SenderKey, sessionID id.SessionID) (*event.RoomKeyWithheldEventContent, error) { +func (gs *MemoryStore) GetWithheldGroupSession(_ context.Context, roomID id.RoomID, sessionID id.SessionID) (*event.RoomKeyWithheldEventContent, error) { gs.lock.Lock() - session, ok := gs.getWithheldGroupSessions(roomID, senderKey)[sessionID] - gs.lock.Unlock() + defer gs.lock.Unlock() + session, ok := gs.getWithheldGroupSessions(roomID)[sessionID] if !ok { return nil, nil } @@ -387,51 +420,38 @@ func (gs *MemoryStore) GetGroupSessionsForRoom(_ context.Context, roomID id.Room if !ok { return nil } - var result []*InboundGroupSession - for _, sessions := range room { - for _, session := range sessions { - result = append(result, session) - } - } - return dbutil.NewSliceIter[*InboundGroupSession](result) + return dbutil.NewSliceIter(maps.Values(room)) } func (gs *MemoryStore) GetAllGroupSessions(_ context.Context) dbutil.RowIter[*InboundGroupSession] { gs.lock.Lock() + defer gs.lock.Unlock() var result []*InboundGroupSession for _, room := range gs.GroupSessions { - for _, sessions := range room { - for _, session := range sessions { - result = append(result, session) - } - } + result = append(result, maps.Values(room)...) } - gs.lock.Unlock() - return dbutil.NewSliceIter[*InboundGroupSession](result) + return dbutil.NewSliceIter(result) } func (gs *MemoryStore) GetGroupSessionsWithoutKeyBackupVersion(_ context.Context, version id.KeyBackupVersion) dbutil.RowIter[*InboundGroupSession] { gs.lock.Lock() + defer gs.lock.Unlock() var result []*InboundGroupSession for _, room := range gs.GroupSessions { - for _, sessions := range room { - for _, session := range sessions { - if session.KeyBackupVersion != version { - result = append(result, session) - } + for _, session := range room { + if session.KeyBackupVersion != version { + result = append(result, session) } } } - gs.lock.Unlock() - return dbutil.NewSliceIter[*InboundGroupSession](result) + return dbutil.NewSliceIter(result) } func (gs *MemoryStore) AddOutboundGroupSession(_ context.Context, session *OutboundGroupSession) error { gs.lock.Lock() + defer gs.lock.Unlock() gs.OutGroupSessions[session.RoomID] = session - err := gs.save() - gs.lock.Unlock() - return err + return gs.save() } func (gs *MemoryStore) UpdateOutboundGroupSession(_ context.Context, _ *OutboundGroupSession) error { @@ -441,8 +461,8 @@ func (gs *MemoryStore) UpdateOutboundGroupSession(_ context.Context, _ *Outbound func (gs *MemoryStore) GetOutboundGroupSession(_ context.Context, roomID id.RoomID) (*OutboundGroupSession, error) { gs.lock.RLock() + defer gs.lock.RUnlock() session, ok := gs.OutGroupSessions[roomID] - gs.lock.RUnlock() if !ok { return nil, nil } @@ -451,18 +471,18 @@ func (gs *MemoryStore) GetOutboundGroupSession(_ context.Context, roomID id.Room func (gs *MemoryStore) RemoveOutboundGroupSession(_ context.Context, roomID id.RoomID) error { gs.lock.Lock() + defer gs.lock.Unlock() session, ok := gs.OutGroupSessions[roomID] if !ok || session == nil { - gs.lock.Unlock() return nil } delete(gs.OutGroupSessions, roomID) - gs.lock.Unlock() return nil } func (gs *MemoryStore) MarkOutboundGroupSessionShared(_ context.Context, userID id.UserID, identityKey id.IdentityKey, sessionID id.SessionID) error { gs.lock.Lock() + defer gs.lock.Unlock() if _, ok := gs.SharedGroupSessions[userID]; !ok { gs.SharedGroupSessions[userID] = make(map[id.IdentityKey]map[id.SessionID]struct{}) @@ -475,7 +495,6 @@ func (gs *MemoryStore) MarkOutboundGroupSessionShared(_ context.Context, userID identities[identityKey][sessionID] = struct{}{} - gs.lock.Unlock() return nil } @@ -506,6 +525,9 @@ func (gs *MemoryStore) ValidateMessageIndex(_ context.Context, senderKey id.Send } val, ok := gs.MessageIndices[key] if !ok { + if eventID == "" && timestamp == 0 { + return true, nil + } gs.MessageIndices[key] = messageIndexValue{ EventID: eventID, Timestamp: timestamp, @@ -521,11 +543,11 @@ func (gs *MemoryStore) ValidateMessageIndex(_ context.Context, senderKey id.Send func (gs *MemoryStore) GetDevices(_ context.Context, userID id.UserID) (map[id.DeviceID]*id.Device, error) { gs.lock.RLock() + defer gs.lock.RUnlock() devices, ok := gs.Devices[userID] if !ok { devices = nil } - gs.lock.RUnlock() return devices, nil } @@ -560,30 +582,30 @@ func (gs *MemoryStore) FindDeviceByKey(_ context.Context, userID id.UserID, iden func (gs *MemoryStore) PutDevice(_ context.Context, userID id.UserID, device *id.Device) error { gs.lock.Lock() + defer gs.lock.Unlock() devices, ok := gs.Devices[userID] if !ok { devices = make(map[id.DeviceID]*id.Device) gs.Devices[userID] = devices } devices[device.DeviceID] = device - err := gs.save() - gs.lock.Unlock() - return err + return gs.save() } func (gs *MemoryStore) PutDevices(_ context.Context, userID id.UserID, devices map[id.DeviceID]*id.Device) error { gs.lock.Lock() + defer gs.lock.Unlock() gs.Devices[userID] = devices err := gs.save() if err == nil { delete(gs.OutdatedUsers, userID) } - gs.lock.Unlock() return err } func (gs *MemoryStore) FilterTrackedUsers(_ context.Context, users []id.UserID) ([]id.UserID, error) { gs.lock.RLock() + defer gs.lock.RUnlock() var ptr int for _, userID := range users { _, ok := gs.Devices[userID] @@ -592,33 +614,33 @@ func (gs *MemoryStore) FilterTrackedUsers(_ context.Context, users []id.UserID) ptr++ } } - gs.lock.RUnlock() return users[:ptr], nil } func (gs *MemoryStore) MarkTrackedUsersOutdated(_ context.Context, users []id.UserID) error { gs.lock.Lock() + defer gs.lock.Unlock() for _, userID := range users { if _, ok := gs.Devices[userID]; ok { gs.OutdatedUsers[userID] = struct{}{} } } - gs.lock.Unlock() return nil } func (gs *MemoryStore) GetOutdatedTrackedUsers(_ context.Context) ([]id.UserID, error) { gs.lock.RLock() + defer gs.lock.RUnlock() users := make([]id.UserID, 0, len(gs.OutdatedUsers)) for userID := range gs.OutdatedUsers { users = append(users, userID) } - gs.lock.RUnlock() return users, nil } func (gs *MemoryStore) PutCrossSigningKey(_ context.Context, userID id.UserID, usage id.CrossSigningUsage, key id.Ed25519) error { gs.lock.RLock() + defer gs.lock.RUnlock() userKeys, ok := gs.CrossSigningKeys[userID] if !ok { userKeys = make(map[id.CrossSigningUsage]id.CrossSigningKey) @@ -635,7 +657,6 @@ func (gs *MemoryStore) PutCrossSigningKey(_ context.Context, userID id.UserID, u } } err := gs.save() - gs.lock.RUnlock() return err } @@ -651,6 +672,7 @@ func (gs *MemoryStore) GetCrossSigningKeys(_ context.Context, userID id.UserID) func (gs *MemoryStore) PutSignature(_ context.Context, signedUserID id.UserID, signedKey id.Ed25519, signerUserID id.UserID, signerKey id.Ed25519, signature string) error { gs.lock.RLock() + defer gs.lock.RUnlock() signedUserSigs, ok := gs.KeySignatures[signedUserID] if !ok { signedUserSigs = make(map[id.Ed25519]map[id.UserID]map[id.Ed25519]string) @@ -667,9 +689,7 @@ func (gs *MemoryStore) PutSignature(_ context.Context, signedUserID id.UserID, s signaturesForKey[signerUserID] = signedByUser } signedByUser[signerKey] = signature - err := gs.save() - gs.lock.RUnlock() - return err + return gs.save() } func (gs *MemoryStore) GetSignaturesForKeyBy(_ context.Context, userID id.UserID, key id.Ed25519, signerID id.UserID) (map[id.Ed25519]string, error) { @@ -700,8 +720,9 @@ func (gs *MemoryStore) IsKeySignedBy(ctx context.Context, userID id.UserID, key } func (gs *MemoryStore) DropSignaturesByKey(_ context.Context, userID id.UserID, key id.Ed25519) (int64, error) { - var count int64 gs.lock.RLock() + defer gs.lock.RUnlock() + var count int64 for _, userSigs := range gs.KeySignatures { for _, keySigs := range userSigs { if signedBySigner, ok := keySigs[userID]; ok { @@ -712,27 +733,25 @@ func (gs *MemoryStore) DropSignaturesByKey(_ context.Context, userID id.UserID, } } } - gs.lock.RUnlock() return count, nil } func (gs *MemoryStore) PutSecret(_ context.Context, name id.Secret, value string) error { gs.lock.Lock() + defer gs.lock.Unlock() gs.Secrets[name] = value - gs.lock.Unlock() return nil } -func (gs *MemoryStore) GetSecret(_ context.Context, name id.Secret) (value string, _ error) { +func (gs *MemoryStore) GetSecret(_ context.Context, name id.Secret) (string, error) { gs.lock.RLock() - value = gs.Secrets[name] - gs.lock.RUnlock() - return + defer gs.lock.RUnlock() + return gs.Secrets[name], nil } func (gs *MemoryStore) DeleteSecret(_ context.Context, name id.Secret) error { gs.lock.Lock() + defer gs.lock.Unlock() delete(gs.Secrets, name) - gs.lock.Unlock() return nil } diff --git a/crypto/store_test.go b/crypto/store_test.go index e6969e3e..7a47243e 100644 --- a/crypto/store_test.go +++ b/crypto/store_test.go @@ -13,6 +13,8 @@ import ( "testing" _ "github.com/mattn/go-sqlite3" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "go.mau.fi/util/dbutil" "maunium.net/go/mautrix/crypto/olm" @@ -28,22 +30,14 @@ const groupSession = "9ZbsRqJuETbjnxPpKv29n3dubP/m5PSLbr9I9CIWS2O86F/Og1JZXhqT+4 func getCryptoStores(t *testing.T) map[string]Store { rawDB, err := sql.Open("sqlite3", ":memory:?_busy_timeout=5000") - if err != nil { - t.Fatalf("Error opening db: %v", err) - } + require.NoError(t, err, "Error opening raw database") db, err := dbutil.NewWithDB(rawDB, "sqlite3") - if err != nil { - t.Fatalf("Error opening db: %v", err) - } + require.NoError(t, err, "Error creating database wrapper") sqlStore := NewSQLCryptoStore(db, nil, "accid", id.DeviceID("dev"), []byte("test")) - if err = sqlStore.DB.Upgrade(context.TODO()); err != nil { - t.Fatalf("Error creating tables: %v", err) - } + err = sqlStore.DB.Upgrade(context.TODO()) + require.NoError(t, err, "Error upgrading database") gobStore := NewMemoryStore(nil) - if err != nil { - t.Fatalf("Error creating Gob store: %v", err) - } return map[string]Store{ "sql": sqlStore, @@ -55,9 +49,10 @@ func TestPutNextBatch(t *testing.T) { stores := getCryptoStores(t) store := stores["sql"].(*SQLCryptoStore) store.PutNextBatch(context.Background(), "batch1") - if batch, _ := store.GetNextBatch(context.Background()); batch != "batch1" { - t.Errorf("Expected batch1, got %v", batch) - } + + batch, err := store.GetNextBatch(context.Background()) + require.NoError(t, err, "Error retrieving next batch") + assert.Equal(t, "batch1", batch) } func TestPutAccount(t *testing.T) { @@ -67,15 +62,9 @@ func TestPutAccount(t *testing.T) { acc := NewOlmAccount() store.PutAccount(context.TODO(), acc) retrieved, err := store.GetAccount(context.TODO()) - if err != nil { - t.Fatalf("Error retrieving account: %v", err) - } - if acc.IdentityKey() != retrieved.IdentityKey() { - t.Errorf("Stored identity key %v, got %v", acc.IdentityKey(), retrieved.IdentityKey()) - } - if acc.SigningKey() != retrieved.SigningKey() { - t.Errorf("Stored signing key %v, got %v", acc.SigningKey(), retrieved.SigningKey()) - } + require.NoError(t, err, "Error retrieving account") + assert.Equal(t, acc.IdentityKey(), retrieved.IdentityKey(), "Identity key does not match") + assert.Equal(t, acc.SigningKey(), retrieved.SigningKey(), "Signing key does not match") }) } } @@ -85,18 +74,36 @@ func TestValidateMessageIndex(t *testing.T) { for storeName, store := range stores { t.Run(storeName, func(t *testing.T) { acc := NewOlmAccount() - if ok, _ := store.ValidateMessageIndex(context.TODO(), acc.IdentityKey(), "sess1", "event1", 0, 1000); !ok { - t.Error("First message not validated successfully") - } - if ok, _ := store.ValidateMessageIndex(context.TODO(), acc.IdentityKey(), "sess1", "event1", 0, 1001); ok { - t.Error("First message validated successfully after changing timestamp") - } - if ok, _ := store.ValidateMessageIndex(context.TODO(), acc.IdentityKey(), "sess1", "event2", 0, 1000); ok { - t.Error("First message validated successfully after changing event ID") - } - if ok, _ := store.ValidateMessageIndex(context.TODO(), acc.IdentityKey(), "sess1", "event1", 0, 1000); !ok { - t.Error("First message not validated successfully for a second time") - } + + // Validating without event ID and timestamp before we have them should work + ok, err := store.ValidateMessageIndex(context.TODO(), acc.IdentityKey(), "sess1", "", 0, 0) + require.NoError(t, err, "Error validating message index") + assert.True(t, ok, "First message validation should be valid") + + // First message should validate successfully + ok, err = store.ValidateMessageIndex(context.TODO(), acc.IdentityKey(), "sess1", "event1", 0, 1000) + require.NoError(t, err, "Error validating message index") + assert.True(t, ok, "First message validation should be valid") + + // Edit the timestamp and ensure validate fails + ok, err = store.ValidateMessageIndex(context.TODO(), acc.IdentityKey(), "sess1", "event1", 0, 1001) + require.NoError(t, err, "Error validating message index after timestamp change") + assert.False(t, ok, "First message validation should fail after timestamp change") + + // Edit the event ID and ensure validate fails + ok, err = store.ValidateMessageIndex(context.TODO(), acc.IdentityKey(), "sess1", "event2", 0, 1000) + require.NoError(t, err, "Error validating message index after event ID change") + assert.False(t, ok, "First message validation should fail after event ID change") + + // Validate again with the original parameters and ensure that it still passes + ok, err = store.ValidateMessageIndex(context.TODO(), acc.IdentityKey(), "sess1", "event1", 0, 1000) + require.NoError(t, err, "Error validating message index") + assert.True(t, ok, "First message validation should be valid") + + // Validating without event ID and timestamp must fail if we already know them + ok, err = store.ValidateMessageIndex(context.TODO(), acc.IdentityKey(), "sess1", "", 0, 0) + require.NoError(t, err, "Error validating message index") + assert.False(t, ok, "First message validation should be invalid") }) } } @@ -105,37 +112,26 @@ func TestStoreOlmSession(t *testing.T) { stores := getCryptoStores(t) for storeName, store := range stores { t.Run(storeName, func(t *testing.T) { - if store.HasSession(context.TODO(), olmSessID) { - t.Error("Found Olm session before inserting it") - } + require.False(t, store.HasSession(context.TODO(), olmSessID), "Found Olm session before inserting it") + olmInternal, err := olm.SessionFromPickled([]byte(olmPickled), []byte("test")) - if err != nil { - t.Fatalf("Error creating internal Olm session: %v", err) - } + require.NoError(t, err, "Error creating internal Olm session") olmSess := OlmSession{ id: olmSessID, - Internal: *olmInternal, + Internal: olmInternal, } err = store.AddSession(context.TODO(), olmSessID, &olmSess) - if err != nil { - t.Errorf("Error storing Olm session: %v", err) - } - if !store.HasSession(context.TODO(), olmSessID) { - t.Error("Not found Olm session after inserting it") - } + require.NoError(t, err, "Error storing Olm session") + assert.True(t, store.HasSession(context.TODO(), olmSessID), "Olm session not found after inserting it") retrieved, err := store.GetLatestSession(context.TODO(), olmSessID) - if err != nil { - t.Errorf("Failed retrieving Olm session: %v", err) - } + require.NoError(t, err, "Error retrieving Olm session") + assert.EqualValues(t, olmSessID, retrieved.ID()) - if retrieved.ID() != olmSessID { - t.Errorf("Expected session ID to be %v, got %v", olmSessID, retrieved.ID()) - } - if pickled := string(retrieved.Internal.Pickle([]byte("test"))); pickled != olmPickled { - t.Error("Pickled Olm session does not match original") - } + pickled, err := retrieved.Internal.Pickle([]byte("test")) + require.NoError(t, err, "Error pickling Olm session") + assert.EqualValues(t, pickled, olmPickled, "Pickled Olm session does not match original") }) } } @@ -147,30 +143,24 @@ func TestStoreMegolmSession(t *testing.T) { acc := NewOlmAccount() internal, err := olm.InboundGroupSessionFromPickled([]byte(groupSession), []byte("test")) - if err != nil { - t.Fatalf("Error creating internal inbound group session: %v", err) - } + require.NoError(t, err, "Error creating internal inbound group session") igs := &InboundGroupSession{ - Internal: *internal, + Internal: internal, SigningKey: acc.SigningKey(), SenderKey: acc.IdentityKey(), RoomID: "room1", } - err = store.PutGroupSession(context.TODO(), "room1", acc.IdentityKey(), igs.ID(), igs) - if err != nil { - t.Errorf("Error storing inbound group session: %v", err) - } + err = store.PutGroupSession(context.TODO(), igs) + require.NoError(t, err, "Error storing inbound group session") - retrieved, err := store.GetGroupSession(context.TODO(), "room1", acc.IdentityKey(), igs.ID()) - if err != nil { - t.Errorf("Error retrieving inbound group session: %v", err) - } + retrieved, err := store.GetGroupSession(context.TODO(), "room1", igs.ID()) + require.NoError(t, err, "Error retrieving inbound group session") - if pickled := string(retrieved.Internal.Pickle([]byte("test"))); pickled != groupSession { - t.Error("Pickled inbound group session does not match original") - } + pickled, err := retrieved.Internal.Pickle([]byte("test")) + require.NoError(t, err, "Error pickling inbound group session") + assert.EqualValues(t, pickled, groupSession, "Pickled inbound group session does not match original") }) } } @@ -180,39 +170,24 @@ func TestStoreOutboundMegolmSession(t *testing.T) { for storeName, store := range stores { t.Run(storeName, func(t *testing.T) { sess, err := store.GetOutboundGroupSession(context.TODO(), "room1") - if sess != nil { - t.Error("Got outbound session before inserting") - } - if err != nil { - t.Errorf("Error retrieving outbound session: %v", err) - } + require.NoError(t, err, "Error retrieving outbound session") + require.Nil(t, sess, "Got outbound session before inserting") - outbound := NewOutboundGroupSession("room1", nil) + outbound, err := NewOutboundGroupSession("room1", nil) + require.NoError(t, err) err = store.AddOutboundGroupSession(context.TODO(), outbound) - if err != nil { - t.Errorf("Error inserting outbound session: %v", err) - } + require.NoError(t, err, "Error inserting outbound session") sess, err = store.GetOutboundGroupSession(context.TODO(), "room1") - if sess == nil { - t.Error("Did not get outbound session after inserting") - } - if err != nil { - t.Errorf("Error retrieving outbound session: %v", err) - } + require.NoError(t, err, "Error retrieving outbound session") + assert.NotNil(t, sess, "Did not get outbound session after inserting") err = store.RemoveOutboundGroupSession(context.TODO(), "room1") - if err != nil { - t.Errorf("Error deleting outbound session: %v", err) - } + require.NoError(t, err, "Error deleting outbound session") sess, err = store.GetOutboundGroupSession(context.TODO(), "room1") - if sess != nil { - t.Error("Got outbound session after deleting") - } - if err != nil { - t.Errorf("Error retrieving outbound session: %v", err) - } + require.NoError(t, err, "Error retrieving outbound session after deletion") + assert.Nil(t, sess, "Got outbound session after deleting") }) } } @@ -234,58 +209,41 @@ func TestStoreOutboundMegolmSessionSharing(t *testing.T) { t.Run(storeName, func(t *testing.T) { device := resetDevice() err := store.PutDevice(context.TODO(), "user1", device) - if err != nil { - t.Errorf("Error storing devices: %v", err) - } + require.NoError(t, err, "Error storing device") shared, err := store.IsOutboundGroupSessionShared(context.TODO(), device.UserID, device.IdentityKey, "session1") - if err != nil { - t.Errorf("Error checking if outbound group session is shared: %v", err) - } else if shared { - t.Errorf("Outbound group session shared when it shouldn't") - } + require.NoError(t, err, "Error checking if outbound group session is shared") + assert.False(t, shared, "Outbound group session should not be shared initially") err = store.MarkOutboundGroupSessionShared(context.TODO(), device.UserID, device.IdentityKey, "session1") - if err != nil { - t.Errorf("Error marking outbound group session as shared: %v", err) - } + require.NoError(t, err, "Error marking outbound group session as shared") shared, err = store.IsOutboundGroupSessionShared(context.TODO(), device.UserID, device.IdentityKey, "session1") - if err != nil { - t.Errorf("Error checking if outbound group session is shared: %v", err) - } else if !shared { - t.Errorf("Outbound group session not shared when it should") - } + require.NoError(t, err, "Error checking if outbound group session is shared") + assert.True(t, shared, "Outbound group session should be shared after marking it as such") device = resetDevice() err = store.PutDevice(context.TODO(), "user1", device) - if err != nil { - t.Errorf("Error storing devices: %v", err) - } + require.NoError(t, err, "Error storing device after resetting") shared, err = store.IsOutboundGroupSessionShared(context.TODO(), device.UserID, device.IdentityKey, "session1") - if err != nil { - t.Errorf("Error checking if outbound group session is shared: %v", err) - } else if shared { - t.Errorf("Outbound group session shared when it shouldn't") - } + require.NoError(t, err, "Error checking if outbound group session is shared") + assert.False(t, shared, "Outbound group session should not be shared after resetting device") }) } } func TestStoreDevices(t *testing.T) { + devicesToCreate := 17 stores := getCryptoStores(t) for storeName, store := range stores { t.Run(storeName, func(t *testing.T) { outdated, err := store.GetOutdatedTrackedUsers(context.TODO()) - if err != nil { - t.Errorf("Error filtering tracked users: %v", err) - } - if len(outdated) > 0 { - t.Errorf("Got %d outdated tracked users when expected none", len(outdated)) - } + require.NoError(t, err, "Error filtering tracked users") + assert.Empty(t, outdated, "Expected no outdated tracked users initially") + deviceMap := make(map[id.DeviceID]*id.Device) - for i := 0; i < 17; i++ { + for i := 0; i < devicesToCreate; i++ { iStr := strconv.Itoa(i) acc := NewOlmAccount() deviceMap[id.DeviceID("dev"+iStr)] = &id.Device{ @@ -296,59 +254,33 @@ func TestStoreDevices(t *testing.T) { } } err = store.PutDevices(context.TODO(), "user1", deviceMap) - if err != nil { - t.Errorf("Error storing devices: %v", err) - } + require.NoError(t, err, "Error storing devices") devs, err := store.GetDevices(context.TODO(), "user1") - if err != nil { - t.Errorf("Error getting devices: %v", err) - } - if len(devs) != 17 { - t.Errorf("Stored 17 devices, got back %v", len(devs)) - } - if devs["dev0"].IdentityKey != deviceMap["dev0"].IdentityKey { - t.Errorf("First device identity key does not match") - } - if devs["dev16"].IdentityKey != deviceMap["dev16"].IdentityKey { - t.Errorf("Last device identity key does not match") - } + require.NoError(t, err, "Error getting devices") + assert.Len(t, devs, devicesToCreate, "Expected to get %d devices back", devicesToCreate) + assert.Equal(t, deviceMap, devs, "Stored devices do not match retrieved devices") filtered, err := store.FilterTrackedUsers(context.TODO(), []id.UserID{"user0", "user1", "user2"}) - if err != nil { - t.Errorf("Error filtering tracked users: %v", err) - } else if len(filtered) != 1 || filtered[0] != "user1" { - t.Errorf("Expected to get 'user1' from filter, got %v", filtered) - } + require.NoError(t, err, "Error filtering tracked users") + assert.Equal(t, []id.UserID{"user1"}, filtered, "Expected to get 'user1' from filter") outdated, err = store.GetOutdatedTrackedUsers(context.TODO()) - if err != nil { - t.Errorf("Error filtering tracked users: %v", err) - } - if len(outdated) > 0 { - t.Errorf("Got %d outdated tracked users when expected none", len(outdated)) - } + require.NoError(t, err, "Error filtering tracked users") + assert.Empty(t, outdated, "Expected no outdated tracked users after initial storage") + err = store.MarkTrackedUsersOutdated(context.TODO(), []id.UserID{"user0", "user1"}) - if err != nil { - t.Errorf("Error marking tracked users outdated: %v", err) - } + require.NoError(t, err, "Error marking tracked users outdated") + outdated, err = store.GetOutdatedTrackedUsers(context.TODO()) - if err != nil { - t.Errorf("Error filtering tracked users: %v", err) - } - if len(outdated) != 1 || outdated[0] != id.UserID("user1") { - t.Errorf("Got outdated tracked users %v when expected 'user1'", outdated) - } + require.NoError(t, err, "Error filtering tracked users") + assert.Equal(t, []id.UserID{"user1"}, outdated, "Expected 'user1' to be marked as outdated") + err = store.PutDevices(context.TODO(), "user1", deviceMap) - if err != nil { - t.Errorf("Error storing devices: %v", err) - } + require.NoError(t, err, "Error storing devices again") + outdated, err = store.GetOutdatedTrackedUsers(context.TODO()) - if err != nil { - t.Errorf("Error filtering tracked users: %v", err) - } - if len(outdated) > 0 { - t.Errorf("Got outdated tracked users %v when expected none", outdated) - } + require.NoError(t, err, "Error filtering tracked users") + assert.Empty(t, outdated, "Expected no outdated tracked users after re-storing devices") }) } } @@ -359,16 +291,11 @@ func TestStoreSecrets(t *testing.T) { t.Run(storeName, func(t *testing.T) { storedSecret := "trustno1" err := store.PutSecret(context.TODO(), id.SecretMegolmBackupV1, storedSecret) - if err != nil { - t.Errorf("Error storing secret: %v", err) - } + require.NoError(t, err, "Error storing secret") secret, err := store.GetSecret(context.TODO(), id.SecretMegolmBackupV1) - if err != nil { - t.Errorf("Error storing secret: %v", err) - } else if secret != storedSecret { - t.Errorf("Stored secret did not match: '%s' != '%s'", secret, storedSecret) - } + require.NoError(t, err, "Error retrieving secret") + assert.Equal(t, storedSecret, secret, "Retrieved secret does not match stored secret") }) } } diff --git a/crypto/utils/utils_test.go b/crypto/utils/utils_test.go index c4f01a68..b12fd9e2 100644 --- a/crypto/utils/utils_test.go +++ b/crypto/utils/utils_test.go @@ -9,6 +9,9 @@ package utils import ( "encoding/base64" "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestAES256Ctr(t *testing.T) { @@ -16,9 +19,7 @@ func TestAES256Ctr(t *testing.T) { key, iv := GenAttachmentA256CTR() enc := XorA256CTR([]byte(expected), key, iv) dec := XorA256CTR(enc, key, iv) - if string(dec) != expected { - t.Errorf("Expected decrypted using generated key/iv to be `%v`, got %v", expected, string(dec)) - } + assert.EqualValues(t, expected, dec, "Decrypted text should match original") var key2 [AESCTRKeyLength]byte var iv2 [AESCTRIVLength]byte @@ -29,9 +30,7 @@ func TestAES256Ctr(t *testing.T) { iv2[i] = byte(i) + 32 } dec2 := XorA256CTR([]byte{0x29, 0xc3, 0xff, 0x02, 0x21, 0xaf, 0x67, 0x73, 0x6e, 0xad, 0x9d}, key2, iv2) - if string(dec2) != expected { - t.Errorf("Expected decrypted using constant key/iv to be `%v`, got %v", expected, string(dec2)) - } + assert.EqualValues(t, expected, dec2, "Decrypted text with constant key/iv should match original") } func TestPBKDF(t *testing.T) { @@ -42,9 +41,7 @@ func TestPBKDF(t *testing.T) { key := PBKDF2SHA512([]byte("Hello world"), salt, 1000, 256) expected := "ffk9YdbVE1cgqOWgDaec0lH+rJzO+MuCcxpIn3Z6D0E=" keyB64 := base64.StdEncoding.EncodeToString([]byte(key)) - if keyB64 != expected { - t.Errorf("Expected base64 of generated key to be `%v`, got `%v`", expected, keyB64) - } + assert.Equal(t, expected, keyB64) } func TestDecodeSSSSKey(t *testing.T) { @@ -53,13 +50,10 @@ func TestDecodeSSSSKey(t *testing.T) { expected := "QCFDrXZYLEFnwf4NikVm62rYGJS2mNBEmAWLC3CgNPw=" decodedB64 := base64.StdEncoding.EncodeToString(decoded[:]) - if expected != decodedB64 { - t.Errorf("Expected decoded recovery key b64 to be `%v`, got `%v`", expected, decodedB64) - } + assert.Equal(t, expected, decodedB64) - if encoded := EncodeBase58RecoveryKey(decoded); encoded != recoveryKey { - t.Errorf("Expected recovery key to be `%v`, got `%v`", recoveryKey, encoded) - } + encoded := EncodeBase58RecoveryKey(decoded) + assert.Equal(t, recoveryKey, encoded) } func TestKeyDerivationAndHMAC(t *testing.T) { @@ -69,15 +63,11 @@ func TestKeyDerivationAndHMAC(t *testing.T) { aesKey, hmacKey := DeriveKeysSHA256(decoded[:], "m.cross_signing.master") ciphertextBytes, err := base64.StdEncoding.DecodeString("Fx16KlJ9vkd3Dd6CafIq5spaH5QmK5BALMzbtFbQznG2j1VARKK+klc4/Qo=") - if err != nil { - t.Error(err) - } + require.NoError(t, err) calcMac := HMACSHA256B64(ciphertextBytes, hmacKey) expectedMac := "0DABPNIZsP9iTOh1o6EM0s7BfHHXb96dN7Eca88jq2E" - if calcMac != expectedMac { - t.Errorf("Expected MAC `%v`, got `%v`", expectedMac, calcMac) - } + assert.Equal(t, expectedMac, calcMac) var ivBytes [AESCTRIVLength]byte decodedIV, _ := base64.StdEncoding.DecodeString("zxT/W5LpZ0Q819pfju6hZw==") @@ -85,7 +75,5 @@ func TestKeyDerivationAndHMAC(t *testing.T) { decrypted := string(XorA256CTR(ciphertextBytes, aesKey, ivBytes)) expectedDec := "Ec8eZDyvVkO3EDsEG6ej5c0cCHnX7PINqFXZjnaTV2s=" - if expectedDec != decrypted { - t.Errorf("Expected decrypted text to be `%v`, got `%v`", expectedDec, decrypted) - } + assert.Equal(t, expectedDec, decrypted) } diff --git a/crypto/verificationhelper/callbacks_test.go b/crypto/verificationhelper/callbacks_test.go new file mode 100644 index 00000000..3b943f28 --- /dev/null +++ b/crypto/verificationhelper/callbacks_test.go @@ -0,0 +1,167 @@ +// Copyright (c) 2024 Sumner Evans +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package verificationhelper_test + +import ( + "context" + + "maunium.net/go/mautrix/crypto/verificationhelper" + "maunium.net/go/mautrix/event" + "maunium.net/go/mautrix/id" +) + +type MockVerificationCallbacks interface { + GetRequestedVerifications() map[id.UserID][]id.VerificationTransactionID + GetScanQRCodeTransactions() []id.VerificationTransactionID + GetVerificationsReadyTransactions() []id.VerificationTransactionID + GetQRCodeShown(id.VerificationTransactionID) *verificationhelper.QRCode +} + +type baseVerificationCallbacks struct { + scanQRCodeTransactions []id.VerificationTransactionID + verificationsRequested map[id.UserID][]id.VerificationTransactionID + verificationsReady []id.VerificationTransactionID + qrCodesShown map[id.VerificationTransactionID]*verificationhelper.QRCode + qrCodesScanned map[id.VerificationTransactionID]struct{} + doneTransactions map[id.VerificationTransactionID]struct{} + verificationCancellation map[id.VerificationTransactionID]*event.VerificationCancelEventContent + emojisShown map[id.VerificationTransactionID][]rune + emojiDescriptionsShown map[id.VerificationTransactionID][]string + decimalsShown map[id.VerificationTransactionID][]int +} + +var _ verificationhelper.RequiredCallbacks = (*baseVerificationCallbacks)(nil) +var _ MockVerificationCallbacks = (*baseVerificationCallbacks)(nil) + +func newBaseVerificationCallbacks() *baseVerificationCallbacks { + return &baseVerificationCallbacks{ + verificationsRequested: map[id.UserID][]id.VerificationTransactionID{}, + qrCodesShown: map[id.VerificationTransactionID]*verificationhelper.QRCode{}, + qrCodesScanned: map[id.VerificationTransactionID]struct{}{}, + doneTransactions: map[id.VerificationTransactionID]struct{}{}, + verificationCancellation: map[id.VerificationTransactionID]*event.VerificationCancelEventContent{}, + emojisShown: map[id.VerificationTransactionID][]rune{}, + emojiDescriptionsShown: map[id.VerificationTransactionID][]string{}, + decimalsShown: map[id.VerificationTransactionID][]int{}, + } +} + +func (c *baseVerificationCallbacks) GetRequestedVerifications() map[id.UserID][]id.VerificationTransactionID { + return c.verificationsRequested +} + +func (c *baseVerificationCallbacks) GetScanQRCodeTransactions() []id.VerificationTransactionID { + return c.scanQRCodeTransactions +} + +func (c *baseVerificationCallbacks) GetVerificationsReadyTransactions() []id.VerificationTransactionID { + return c.verificationsReady +} + +func (c *baseVerificationCallbacks) GetQRCodeShown(txnID id.VerificationTransactionID) *verificationhelper.QRCode { + return c.qrCodesShown[txnID] +} + +func (c *baseVerificationCallbacks) WasOurQRCodeScanned(txnID id.VerificationTransactionID) bool { + _, ok := c.qrCodesScanned[txnID] + return ok +} + +func (c *baseVerificationCallbacks) IsVerificationDone(txnID id.VerificationTransactionID) bool { + _, ok := c.doneTransactions[txnID] + return ok +} + +func (c *baseVerificationCallbacks) GetVerificationCancellation(txnID id.VerificationTransactionID) *event.VerificationCancelEventContent { + return c.verificationCancellation[txnID] +} + +func (c *baseVerificationCallbacks) GetEmojisAndDescriptionsShown(txnID id.VerificationTransactionID) ([]rune, []string) { + return c.emojisShown[txnID], c.emojiDescriptionsShown[txnID] +} + +func (c *baseVerificationCallbacks) GetDecimalsShown(txnID id.VerificationTransactionID) []int { + return c.decimalsShown[txnID] +} + +func (c *baseVerificationCallbacks) VerificationRequested(ctx context.Context, txnID id.VerificationTransactionID, from id.UserID, fromDevice id.DeviceID) { + c.verificationsRequested[from] = append(c.verificationsRequested[from], txnID) +} + +func (c *baseVerificationCallbacks) VerificationReady(ctx context.Context, txnID id.VerificationTransactionID, otherDeviceID id.DeviceID, supportsSAS, allowScanQRCode bool, qrCode *verificationhelper.QRCode) { + c.verificationsReady = append(c.verificationsReady, txnID) + if allowScanQRCode { + c.scanQRCodeTransactions = append(c.scanQRCodeTransactions, txnID) + } + if qrCode != nil { + c.qrCodesShown[txnID] = qrCode + } +} + +func (c *baseVerificationCallbacks) VerificationCancelled(ctx context.Context, txnID id.VerificationTransactionID, code event.VerificationCancelCode, reason string) { + c.verificationCancellation[txnID] = &event.VerificationCancelEventContent{ + Code: code, + Reason: reason, + } +} + +func (c *baseVerificationCallbacks) VerificationDone(ctx context.Context, txnID id.VerificationTransactionID, method event.VerificationMethod) { + c.doneTransactions[txnID] = struct{}{} +} + +type sasVerificationCallbacks struct { + *baseVerificationCallbacks +} + +var _ verificationhelper.ShowSASCallbacks = (*sasVerificationCallbacks)(nil) + +func newSASVerificationCallbacks() *sasVerificationCallbacks { + return &sasVerificationCallbacks{newBaseVerificationCallbacks()} +} + +func newSASVerificationCallbacksWithBase(base *baseVerificationCallbacks) *sasVerificationCallbacks { + return &sasVerificationCallbacks{base} +} + +func (c *sasVerificationCallbacks) ShowSAS(ctx context.Context, txnID id.VerificationTransactionID, emojis []rune, emojiDescriptions []string, decimals []int) { + c.emojisShown[txnID] = emojis + c.emojiDescriptionsShown[txnID] = emojiDescriptions + c.decimalsShown[txnID] = decimals +} + +type showQRCodeVerificationCallbacks struct { + *baseVerificationCallbacks +} + +var _ verificationhelper.ShowQRCodeCallbacks = (*showQRCodeVerificationCallbacks)(nil) + +func newShowQRCodeVerificationCallbacks() *showQRCodeVerificationCallbacks { + return &showQRCodeVerificationCallbacks{newBaseVerificationCallbacks()} +} + +func newShowQRCodeVerificationCallbacksWithBase(base *baseVerificationCallbacks) *showQRCodeVerificationCallbacks { + return &showQRCodeVerificationCallbacks{base} +} + +func (c *showQRCodeVerificationCallbacks) QRCodeScanned(ctx context.Context, txnID id.VerificationTransactionID) { + c.qrCodesScanned[txnID] = struct{}{} +} + +type allVerificationCallbacks struct { + *baseVerificationCallbacks + *sasVerificationCallbacks + *showQRCodeVerificationCallbacks +} + +func newAllVerificationCallbacks() *allVerificationCallbacks { + base := newBaseVerificationCallbacks() + return &allVerificationCallbacks{ + base, + newSASVerificationCallbacksWithBase(base), + newShowQRCodeVerificationCallbacksWithBase(base), + } +} diff --git a/crypto/verificationhelper/ecdhkeys.go b/crypto/verificationhelper/ecdhkeys.go new file mode 100644 index 00000000..754530ed --- /dev/null +++ b/crypto/verificationhelper/ecdhkeys.go @@ -0,0 +1,57 @@ +package verificationhelper + +import ( + "crypto/ecdh" + "encoding/json" +) + +type ECDHPrivateKey struct { + *ecdh.PrivateKey +} + +func (e *ECDHPrivateKey) UnmarshalJSON(data []byte) (err error) { + if len(data) == 0 { + return nil + } + var raw []byte + err = json.Unmarshal(data, &raw) + if err != nil { + return + } + if len(raw) == 0 { + return nil + } + e.PrivateKey, err = ecdh.X25519().NewPrivateKey(raw) + return err +} + +func (e ECDHPrivateKey) MarshalJSON() ([]byte, error) { + if e.PrivateKey == nil { + return json.Marshal(nil) + } + return json.Marshal(e.Bytes()) +} + +type ECDHPublicKey struct { + *ecdh.PublicKey +} + +func (e *ECDHPublicKey) UnmarshalJSON(data []byte) (err error) { + if len(data) == 0 { + return nil + } + var raw []byte + err = json.Unmarshal(data, &raw) + if err != nil { + return + } + if len(raw) == 0 { + return nil + } + e.PublicKey, err = ecdh.X25519().NewPublicKey(raw) + return +} + +func (e ECDHPublicKey) MarshalJSON() ([]byte, error) { + return json.Marshal(e.Bytes()) +} diff --git a/crypto/verificationhelper/ecdhkeys_test.go b/crypto/verificationhelper/ecdhkeys_test.go new file mode 100644 index 00000000..109fbf88 --- /dev/null +++ b/crypto/verificationhelper/ecdhkeys_test.go @@ -0,0 +1,48 @@ +package verificationhelper_test + +import ( + "crypto/ecdh" + "crypto/rand" + "encoding/json" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "maunium.net/go/mautrix/crypto/verificationhelper" +) + +func TestECDHPrivateKey(t *testing.T) { + pk, err := ecdh.X25519().GenerateKey(rand.Reader) + require.NoError(t, err) + private := verificationhelper.ECDHPrivateKey{pk} + marshalled, err := json.Marshal(private) + require.NoError(t, err) + + assert.Len(t, marshalled, 46) + + var unmarshalled verificationhelper.ECDHPrivateKey + err = json.Unmarshal(marshalled, &unmarshalled) + require.NoError(t, err) + + assert.True(t, private.Equal(unmarshalled.PrivateKey)) +} + +func TestECDHPublicKey(t *testing.T) { + private, err := ecdh.X25519().GenerateKey(rand.Reader) + require.NoError(t, err) + + public := private.PublicKey() + + pub := verificationhelper.ECDHPublicKey{public} + marshalled, err := json.Marshal(pub) + require.NoError(t, err) + + assert.Len(t, marshalled, 46) + + var unmarshalled verificationhelper.ECDHPublicKey + err = json.Unmarshal(marshalled, &unmarshalled) + require.NoError(t, err) + + assert.True(t, public.Equal(unmarshalled.PublicKey)) +} diff --git a/crypto/verificationhelper/qrcode.go b/crypto/verificationhelper/qrcode.go index a28d8fc3..11698152 100644 --- a/crypto/verificationhelper/qrcode.go +++ b/crypto/verificationhelper/qrcode.go @@ -82,6 +82,10 @@ func NewQRCodeFromBytes(data []byte) (*QRCode, error) { // // [Section 11.12.2.4.1]: https://spec.matrix.org/v1.9/client-server-api/#qr-code-format func (q *QRCode) Bytes() []byte { + if q == nil { + return nil + } + var buf bytes.Buffer buf.WriteString("MATRIX") // Header buf.WriteByte(0x02) // Version diff --git a/crypto/verificationhelper/qrcode_test.go b/crypto/verificationhelper/qrcode_test.go index d2767734..fda3de2c 100644 --- a/crypto/verificationhelper/qrcode_test.go +++ b/crypto/verificationhelper/qrcode_test.go @@ -8,51 +8,76 @@ package verificationhelper_test import ( "bytes" + "encoding/base64" + "strings" "testing" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "maunium.net/go/mautrix/crypto/verificationhelper" + "maunium.net/go/mautrix/id" ) func TestQRCode_Roundtrip(t *testing.T) { var key1, key2 [32]byte copy(key1[:], bytes.Repeat([]byte{0x01}, 32)) copy(key2[:], bytes.Repeat([]byte{0x02}, 32)) - qrCode := verificationhelper.NewQRCode(verificationhelper.QRCodeModeCrossSigning, "test", key1, key2) + txnID := id.VerificationTransactionID(strings.Repeat("a", 20)) + qrCode := verificationhelper.NewQRCode(verificationhelper.QRCodeModeCrossSigning, txnID, key1, key2) encoded := qrCode.Bytes() decoded, err := verificationhelper.NewQRCodeFromBytes(encoded) require.NoError(t, err) assert.Equal(t, verificationhelper.QRCodeModeCrossSigning, decoded.Mode) - assert.EqualValues(t, "test", decoded.TransactionID) + assert.EqualValues(t, txnID, decoded.TransactionID) assert.Equal(t, key1, decoded.Key1) assert.Equal(t, key2, decoded.Key2) } func TestQRCodeDecode(t *testing.T) { - qrcodeData := []byte{ - 0x4d, 0x41, 0x54, 0x52, 0x49, 0x58, 0x02, 0x01, 0x00, 0x20, 0x47, 0x6e, 0x41, 0x65, 0x43, 0x76, - 0x74, 0x57, 0x6a, 0x7a, 0x4d, 0x4f, 0x56, 0x57, 0x51, 0x54, 0x6b, 0x74, 0x33, 0x35, 0x59, 0x52, - 0x55, 0x72, 0x75, 0x6a, 0x6d, 0x52, 0x50, 0x63, 0x38, 0x61, 0x18, 0x32, 0x7c, 0xc3, 0x8c, 0xc2, - 0xa6, 0xc2, 0xb5, 0xc2, 0xa7, 0x50, 0x57, 0x67, 0x19, 0x5e, 0xc3, 0xaf, 0xc2, 0xa0, 0xc2, 0x98, - 0xc2, 0x9d, 0x36, 0xc3, 0xad, 0x7a, 0x10, 0x2e, 0x18, 0x3e, 0x4e, 0xc3, 0x84, 0xc3, 0x81, 0x45, - 0x0c, 0xc2, 0xae, 0x19, 0x78, 0xc2, 0x99, 0x06, 0xc2, 0x92, 0xc2, 0x94, 0xc2, 0x8e, 0xc2, 0xb7, - 0x59, 0xc2, 0x96, 0xc2, 0xad, 0xc3, 0xbd, 0x70, 0x6a, 0x11, 0xc2, 0xba, 0xc2, 0xa9, 0x29, 0xc3, - 0x8f, 0x0d, 0xc2, 0xb8, 0xc2, 0x88, 0x67, 0x5b, 0xc3, 0xb3, 0x01, 0xc2, 0xb0, 0x63, 0x2e, 0xc2, - 0xa5, 0xc3, 0xb3, 0x60, 0xc3, 0x82, 0x04, 0xc3, 0xa3, 0x72, 0x7d, 0x7c, 0x1d, 0xc2, 0xb6, 0xc2, - 0xba, 0xc2, 0x81, 0x1e, 0xc2, 0x99, 0xc2, 0xb8, 0x7f, 0x0a, + testCases := []struct { + b64 string + txnID string + key1 string + key2 string + sharedSecret string + }{ + { + "TUFUUklYAgEAIEduQWVDdnRXanpNT1ZXUVRrdDM1WVJVcnVqbVJQYzhhGDJ8w4zCpsK1wqdQV2cZXsOvwqDCmMKdNsOtehAuGD5Ow4TDgUUMwq4ZeMKZBsKSwpTCjsK3WcKWwq3DvXBqEcK6wqkpw48NwrjCiGdbw7MBwrBjLsKlw7Ngw4IEw6NyfXwdwrbCusKBHsKZwrh/Cg==", + "GnAeCvtWjzMOVWQTkt35YRUrujmRPc8a", + "GDJ8w4zCpsK1wqdQV2cZXsOvwqDCmMKdNsOtehAuGD4=", + "TsOEw4FFDMKuGXjCmQbCksKUwo7Ct1nClsKtw71wahE=", + "wrrCqSnDjw3CuMKIZ1vDswHCsGMuwqXDs2DDggTDo3J9fB3CtsK6woEewpnCuH8K", + }, + { + "TUFUUklYAgEAIGM1YjljNzE3ZWIzYjRmYzBiZDhhZjA0MDQ4NDY5MDdle4oLkpUdO1cTu5M3K3B4BlnpxtAbVgXCuQKOIqMmt+xAjVvaEXF39X0z5waRY9UE0b5PKiWvOBSJHEGkxX28Y2OEDLIWP/kCVUlyXXENlj0=", + "c5b9c717eb3b4fc0bd8af0404846907e", + "e4oLkpUdO1cTu5M3K3B4BlnpxtAbVgXCuQKOIqMmt+w=", + "QI1b2hFxd/V9M+cGkWPVBNG+TyolrzgUiRxBpMV9vGM=", + "Y4QMshY/+QJVSXJdcQ2WPQ==", + }, + } + + for _, tc := range testCases { + t.Run(tc.b64, func(t *testing.T) { + qrcodeData, err := base64.StdEncoding.DecodeString(tc.b64) + require.NoError(t, err) + expectedKey1, err := base64.StdEncoding.DecodeString(tc.key1) + require.NoError(t, err) + expectedKey2, err := base64.StdEncoding.DecodeString(tc.key2) + require.NoError(t, err) + expectedSharedSecret, err := base64.StdEncoding.DecodeString(tc.sharedSecret) + require.NoError(t, err) + + decoded, err := verificationhelper.NewQRCodeFromBytes(qrcodeData) + require.NoError(t, err) + assert.Equal(t, verificationhelper.QRCodeModeSelfVerifyingMasterKeyTrusted, decoded.Mode) + assert.EqualValues(t, tc.txnID, decoded.TransactionID) + assert.EqualValues(t, expectedKey1, decoded.Key1) + assert.EqualValues(t, expectedKey2, decoded.Key2) + assert.EqualValues(t, expectedSharedSecret, decoded.SharedSecret) + }) } - decoded, err := verificationhelper.NewQRCodeFromBytes(qrcodeData) - require.NoError(t, err) - assert.Equal(t, verificationhelper.QRCodeModeSelfVerifyingMasterKeyTrusted, decoded.Mode) - assert.EqualValues(t, "GnAeCvtWjzMOVWQTkt35YRUrujmRPc8a", decoded.TransactionID) - assert.Equal(t, - [32]byte{0x18, 0x32, 0x7c, 0xc3, 0x8c, 0xc2, 0xa6, 0xc2, 0xb5, 0xc2, 0xa7, 0x50, 0x57, 0x67, 0x19, 0x5e, 0xc3, 0xaf, 0xc2, 0xa0, 0xc2, 0x98, 0xc2, 0x9d, 0x36, 0xc3, 0xad, 0x7a, 0x10, 0x2e, 0x18, 0x3e}, - decoded.Key1) - assert.Equal(t, - [32]byte{0x4e, 0xc3, 0x84, 0xc3, 0x81, 0x45, 0xc, 0xc2, 0xae, 0x19, 0x78, 0xc2, 0x99, 0x6, 0xc2, 0x92, 0xc2, 0x94, 0xc2, 0x8e, 0xc2, 0xb7, 0x59, 0xc2, 0x96, 0xc2, 0xad, 0xc3, 0xbd, 0x70, 0x6a, 0x11}, - decoded.Key2) } diff --git a/crypto/verificationhelper/reciprocate.go b/crypto/verificationhelper/reciprocate.go index ab177eb9..d8827b8b 100644 --- a/crypto/verificationhelper/reciprocate.go +++ b/crypto/verificationhelper/reciprocate.go @@ -9,10 +9,12 @@ package verificationhelper import ( "bytes" "context" + "errors" "fmt" "golang.org/x/exp/slices" + "maunium.net/go/mautrix/crypto" "maunium.net/go/mautrix/event" "maunium.net/go/mautrix/id" ) @@ -30,41 +32,63 @@ func (vh *VerificationHelper) HandleScannedQRData(ctx context.Context, data []by Stringer("transaction_id", qrCode.TransactionID). Int("mode", int(qrCode.Mode)). Logger() + ctx = log.WithContext(ctx) + vh.activeTransactionsLock.Lock() defer vh.activeTransactionsLock.Unlock() - txn, ok := vh.activeTransactions[qrCode.TransactionID] - if !ok { - log.Warn().Msg("Ignoring QR code scan for an unknown transaction") - return nil - } else if txn.VerificationState != verificationStateReady { - log.Warn().Msg("Ignoring QR code scan for a transaction that is not in the ready state") - return nil + txn, err := vh.store.GetVerificationTransaction(ctx, qrCode.TransactionID) + if err != nil { + return fmt.Errorf("failed to get transaction %s: %w", qrCode.TransactionID, err) + } else if txn.VerificationState != VerificationStateReady { + return vh.cancelVerificationTxn(ctx, txn, event.VerificationCancelCodeUnexpectedMessage, "transaction found in the QR code is not in the ready state") } - txn.VerificationState = verificationStateTheirQRScanned + txn.VerificationState = VerificationStateTheirQRScanned // Verify the keys log.Info().Msg("Verifying keys from QR code") + ownCrossSigningPublicKeys := vh.mach.GetOwnCrossSigningPublicKeys(ctx) + if ownCrossSigningPublicKeys == nil { + return crypto.ErrCrossSigningPubkeysNotCached + } + switch qrCode.Mode { case QRCodeModeCrossSigning: - panic("unimplemented") - // TODO verify and sign their master key + theirSigningKeys, err := vh.mach.GetCrossSigningPublicKeys(ctx, txn.TheirUserID) + if err != nil { + return vh.cancelVerificationTxn(ctx, txn, event.VerificationCancelCodeKeyMismatch, "couldn't get %s's cross-signing keys: %w", txn.TheirUserID, err) + } + if bytes.Equal(theirSigningKeys.MasterKey.Bytes(), qrCode.Key1[:]) { + log.Info().Msg("Verified that the other device has the master key we expected") + } else { + return vh.cancelVerificationTxn(ctx, txn, event.VerificationCancelCodeKeyMismatch, "the other device does not have the master key we expected") + } + + // Verify the master key is correct + if bytes.Equal(ownCrossSigningPublicKeys.MasterKey.Bytes(), qrCode.Key2[:]) { + log.Info().Msg("Verified that the other device has the same master key") + } else { + return vh.cancelVerificationTxn(ctx, txn, event.VerificationCancelCodeKeyMismatch, "the master key does not match") + } + + if err := vh.mach.SignUser(ctx, txn.TheirUserID, theirSigningKeys.MasterKey); err != nil { + return vh.cancelVerificationTxn(ctx, txn, event.VerificationCancelCodeInternalError, "failed to sign their master key: %w", err) + } case QRCodeModeSelfVerifyingMasterKeyTrusted: // The QR was created by a device that trusts the master key, which // means that we don't trust the key. Key1 is the master key public // key, and Key2 is what the other device thinks our device key is. - if vh.client.UserID != txn.TheirUser { - return fmt.Errorf("mode %d is only allowed when the other user is the same as the current user", qrCode.Mode) + if vh.client.UserID != txn.TheirUserID { + return vh.cancelVerificationTxn(ctx, txn, event.VerificationCancelCodeUnexpectedMessage, "mode %d is only allowed when the other user is the same as the current user", qrCode.Mode) } // Verify the master key is correct - crossSigningPubkeys := vh.mach.GetOwnCrossSigningPublicKeys(ctx) - if bytes.Equal(crossSigningPubkeys.MasterKey.Bytes(), qrCode.Key1[:]) { + if bytes.Equal(ownCrossSigningPublicKeys.MasterKey.Bytes(), qrCode.Key1[:]) { log.Info().Msg("Verified that the other device has the same master key") } else { - return fmt.Errorf("the master key does not match") + return vh.cancelVerificationTxn(ctx, txn, event.VerificationCancelCodeKeyMismatch, "the master key does not match") } // Verify that the device key that the other device things we have is @@ -73,53 +97,63 @@ func (vh *VerificationHelper) HandleScannedQRData(ctx context.Context, data []by if bytes.Equal(myKeys.SigningKey.Bytes(), qrCode.Key2[:]) { log.Info().Msg("Verified that the other device has the correct key for this device") } else { - return fmt.Errorf("the other device has the wrong key for this device") + return vh.cancelVerificationTxn(ctx, txn, event.VerificationCancelCodeKeyMismatch, "the other device has the wrong key for this device") } + if err := vh.mach.SignOwnMasterKey(ctx); err != nil { + return vh.cancelVerificationTxn(ctx, txn, event.VerificationCancelCodeInternalError, "failed to sign own master key: %w", err) + } case QRCodeModeSelfVerifyingMasterKeyUntrusted: // The QR was created by a device that does not trust the master key, // which means that we do trust the master key. Key1 is the other // device's device key, and Key2 is what the other device thinks the // master key is. - if vh.client.UserID != txn.TheirUser { - return fmt.Errorf("mode %d is only allowed when the other user is the same as the current user", qrCode.Mode) + // Check that we actually trust the master key. + if trusted, err := vh.mach.CryptoStore.IsKeySignedBy(ctx, vh.client.UserID, ownCrossSigningPublicKeys.MasterKey, vh.client.UserID, vh.mach.OwnIdentity().SigningKey); err != nil { + return err + } else if !trusted { + return vh.cancelVerificationTxn(ctx, txn, event.VerificationCancelCodeMasterKeyNotTrusted, "the master key is not trusted by this device, cannot verify device that does not trust the master key") + } + + if vh.client.UserID != txn.TheirUserID { + return vh.cancelVerificationTxn(ctx, txn, event.VerificationCancelCodeUnexpectedMessage, "mode %d is only allowed when the other user is the same as the current user", qrCode.Mode) } // Get their device - theirDevice, err := vh.mach.GetOrFetchDevice(ctx, txn.TheirUser, txn.TheirDevice) + theirDevice, err := vh.mach.GetOrFetchDevice(ctx, txn.TheirUserID, txn.TheirDeviceID) if err != nil { - return err + return vh.cancelVerificationTxn(ctx, txn, event.VerificationCancelCodeInternalError, "failed to get their device: %w", err) } // Verify that the other device's key is what we expect. if bytes.Equal(theirDevice.SigningKey.Bytes(), qrCode.Key1[:]) { log.Info().Msg("Verified that the other device key is what we expected") } else { - return fmt.Errorf("the other device's key is not what we expected") + return vh.cancelVerificationTxn(ctx, txn, event.VerificationCancelCodeKeyMismatch, "the other device's key is not what we expected") } // Verify that what they think the master key is is correct. - if bytes.Equal(vh.mach.GetOwnCrossSigningPublicKeys(ctx).MasterKey.Bytes(), qrCode.Key2[:]) { + if bytes.Equal(ownCrossSigningPublicKeys.MasterKey.Bytes(), qrCode.Key2[:]) { log.Info().Msg("Verified that the other device has the correct master key") } else { - return fmt.Errorf("the master key does not match") + return vh.cancelVerificationTxn(ctx, txn, event.VerificationCancelCodeKeyMismatch, "the master key does not match") } // Trust their device theirDevice.Trust = id.TrustStateVerified - err = vh.mach.CryptoStore.PutDevice(ctx, txn.TheirUser, theirDevice) + err = vh.mach.CryptoStore.PutDevice(ctx, txn.TheirUserID, theirDevice) if err != nil { - return fmt.Errorf("failed to update device trust state after verifying: %w", err) + return vh.cancelVerificationTxn(ctx, txn, event.VerificationCancelCodeInternalError, "failed to update device trust state after verifying: %+v", err) } // Cross-sign their device with the self-signing key err = vh.mach.SignOwnDevice(ctx, theirDevice) if err != nil { - return fmt.Errorf("failed to sign their device: %w", err) + return vh.cancelVerificationTxn(ctx, txn, event.VerificationCancelCodeInternalError, "failed to sign their device: %+v", err) } default: - return fmt.Errorf("unknown QR code mode %d", qrCode.Mode) + return vh.cancelVerificationTxn(ctx, txn, event.VerificationCancelCodeUnexpectedMessage, "unknown QR code mode %d", qrCode.Mode) } // Send a m.key.verification.start event with the secret @@ -131,108 +165,150 @@ func (vh *VerificationHelper) HandleScannedQRData(ctx context.Context, data []by } err = vh.sendVerificationEvent(ctx, txn, event.InRoomVerificationStart, txn.StartEventContent) if err != nil { - return err + return fmt.Errorf("failed to send m.key.verification.start event: %w", err) } + log.Debug().Msg("Successfully sent the m.key.verification.start event") // Immediately send the m.key.verification.done event, as our side of the // transaction is done. + err = vh.sendVerificationEvent(ctx, txn, event.InRoomVerificationDone, &event.VerificationDoneEventContent{}) + if err != nil { + return fmt.Errorf("failed to send m.key.verification.done event: %w", err) + } + log.Debug().Msg("Successfully sent the m.key.verification.done event") + txn.SentOurDone = true + if txn.ReceivedTheirDone { + log.Debug().Msg("We already received their done event. Setting verification state to done.") + if err = vh.store.DeleteVerification(ctx, txn.TransactionID); err != nil { + return err + } + vh.verificationDone(ctx, txn.TransactionID, txn.StartEventContent.Method) + } else { + return vh.store.SaveVerificationTransaction(ctx, txn) + } + return nil +} + +// ConfirmQRCodeScanned confirms that our QR code has been scanned and sends +// the m.key.verification.done event to the other device for the given +// transaction ID. The transaction ID should be one received via the +// VerificationRequested callback in [RequiredCallbacks] or the +// [StartVerification] or [StartInRoomVerification] functions. +func (vh *VerificationHelper) ConfirmQRCodeScanned(ctx context.Context, txnID id.VerificationTransactionID) error { + log := vh.getLog(ctx).With(). + Str("verification_action", "confirm QR code scanned"). + Stringer("transaction_id", txnID). + Logger() + ctx = log.WithContext(ctx) + + vh.activeTransactionsLock.Lock() + defer vh.activeTransactionsLock.Unlock() + txn, err := vh.store.GetVerificationTransaction(ctx, txnID) + if err != nil { + return fmt.Errorf("failed to get transaction %s: %w", txnID, err) + } else if txn.VerificationState != VerificationStateOurQRScanned { + return fmt.Errorf("transaction is not in the scanned state") + } + + log.Info().Msg("Confirming QR code scanned") + + // Get their device + theirDevice, err := vh.mach.GetOrFetchDevice(ctx, txn.TheirUserID, txn.TheirDeviceID) + if err != nil { + return err + } + + // Trust their device + theirDevice.Trust = id.TrustStateVerified + err = vh.mach.CryptoStore.PutDevice(ctx, txn.TheirUserID, theirDevice) + if err != nil { + return fmt.Errorf("failed to update device trust state after verifying: %w", err) + } + + if txn.TheirUserID == vh.client.UserID { + // Self-signing situation. + // + // If we have the cross-signing keys, then we need to sign their device + // using the self-signing key. Otherwise, they have the master private + // key, so we need to trust the master public key. + if vh.mach.CrossSigningKeys != nil { + err = vh.mach.SignOwnDevice(ctx, theirDevice) + if err != nil { + return fmt.Errorf("failed to sign our own new device: %w", err) + } + } else { + err = vh.mach.SignOwnMasterKey(ctx) + if err != nil { + return fmt.Errorf("failed to sign our own master key: %w", err) + } + } + } else { + // Cross-signing situation. Sign their master key. + theirSigningKeys, err := vh.mach.GetCrossSigningPublicKeys(ctx, txn.TheirUserID) + if err != nil { + return vh.cancelVerificationTxn(ctx, txn, event.VerificationCancelCodeKeyMismatch, "couldn't get %s's cross-signing keys: %w", txn.TheirUserID, err) + } + + if err := vh.mach.SignUser(ctx, txn.TheirUserID, theirSigningKeys.MasterKey); err != nil { + return vh.cancelVerificationTxn(ctx, txn, event.VerificationCancelCodeInternalError, "failed to sign their master key: %w", err) + } + } + err = vh.sendVerificationEvent(ctx, txn, event.InRoomVerificationDone, &event.VerificationDoneEventContent{}) if err != nil { return err } txn.SentOurDone = true if txn.ReceivedTheirDone { - txn.VerificationState = verificationStateDone - vh.verificationDone(ctx, txn.TransactionID) - } - return nil -} - -// ConfirmQRCodeScanned confirms that our QR code has been scanned and sends the -// m.key.verification.done event to the other device. -func (vh *VerificationHelper) ConfirmQRCodeScanned(ctx context.Context, txnID id.VerificationTransactionID) error { - log := vh.getLog(ctx).With(). - Str("verification_action", "confirm QR code scanned"). - Stringer("transaction_id", txnID). - Logger() - - vh.activeTransactionsLock.Lock() - defer vh.activeTransactionsLock.Unlock() - txn, ok := vh.activeTransactions[txnID] - if !ok { - log.Warn().Msg("Ignoring QR code scan confirmation for an unknown transaction") - return nil - } else if txn.VerificationState != verificationStateOurQRScanned { - log.Warn().Msg("Ignoring QR code scan confirmation for a transaction that is not in the started state") - return nil - } - - log.Info().Msg("Confirming QR code scanned") - - if txn.TheirUser == vh.client.UserID { - // Self-signing situation. Trust their device. - - // Get their device - theirDevice, err := vh.mach.GetOrFetchDevice(ctx, txn.TheirUser, txn.TheirDevice) - if err != nil { + if err = vh.store.DeleteVerification(ctx, txn.TransactionID); err != nil { return err } - - // Trust their device - theirDevice.Trust = id.TrustStateVerified - err = vh.mach.CryptoStore.PutDevice(ctx, txn.TheirUser, theirDevice) - if err != nil { - return fmt.Errorf("failed to update device trust state after verifying: %w", err) - } - - // Cross-sign their device with the self-signing key - if vh.mach.CrossSigningKeys != nil { - err = vh.mach.SignOwnDevice(ctx, theirDevice) - if err != nil { - return fmt.Errorf("failed to sign their device: %w", err) - } - } - } - // TODO: handle QR codes that are not self-signing situations - - err := vh.sendVerificationEvent(ctx, txn, event.InRoomVerificationDone, &event.VerificationDoneEventContent{}) - if err != nil { - return err - } - txn.SentOurDone = true - if txn.ReceivedTheirDone { - txn.VerificationState = verificationStateDone - vh.verificationDone(ctx, txn.TransactionID) + vh.verificationDone(ctx, txn.TransactionID, txn.StartEventContent.Method) + } else { + return vh.store.SaveVerificationTransaction(ctx, txn) } return nil } -func (vh *VerificationHelper) generateAndShowQRCode(ctx context.Context, txn *verificationTransaction) error { +func (vh *VerificationHelper) generateQRCode(ctx context.Context, txn *VerificationTransaction) (*QRCode, error) { log := vh.getLog(ctx).With(). Str("verification_action", "generate and show QR code"). Stringer("transaction_id", txn.TransactionID). Logger() - if vh.showQRCode == nil { - log.Warn().Msg("Ignoring QR code generation request as showing a QR code is not enabled on this device") - return nil - } - if !slices.Contains(txn.TheirSupportedMethods, event.VerificationMethodQRCodeScan) { - log.Warn().Msg("Ignoring QR code generation request as other device cannot scan QR codes") - return nil + ctx = log.WithContext(ctx) + + if !slices.Contains(vh.supportedMethods, event.VerificationMethodReciprocate) || + !slices.Contains(txn.TheirSupportedMethods, event.VerificationMethodReciprocate) { + log.Info().Msg("Ignoring QR code generation request as reciprocating is not supported by both devices") + return nil, nil + } else if !slices.Contains(txn.TheirSupportedMethods, event.VerificationMethodQRCodeScan) { + log.Info().Msg("Ignoring QR code generation request as other device cannot scan QR codes") + return nil, nil } ownCrossSigningPublicKeys := vh.mach.GetOwnCrossSigningPublicKeys(ctx) + if ownCrossSigningPublicKeys == nil || len(ownCrossSigningPublicKeys.MasterKey) == 0 { + return nil, errors.New("failed to get own cross-signing master public key") + } + ownMasterKeyTrusted, err := vh.mach.CryptoStore.IsKeySignedBy(ctx, vh.client.UserID, ownCrossSigningPublicKeys.MasterKey, vh.client.UserID, vh.mach.OwnIdentity().SigningKey) + if err != nil { + return nil, err + } mode := QRCodeModeCrossSigning - if vh.client.UserID == txn.TheirUser { + if vh.client.UserID == txn.TheirUserID { // This is a self-signing situation. - if trusted, err := vh.mach.IsUserTrusted(ctx, vh.client.UserID); err != nil { - return err - } else if trusted { + if ownMasterKeyTrusted { mode = QRCodeModeSelfVerifyingMasterKeyTrusted } else { mode = QRCodeModeSelfVerifyingMasterKeyUntrusted } + } else { + // This is a cross-signing situation. + if !ownMasterKeyTrusted { + return nil, errors.New("cannot cross-sign other device when own master key is not trusted") + } + mode = QRCodeModeCrossSigning } var key1, key2 []byte @@ -242,9 +318,9 @@ func (vh *VerificationHelper) generateAndShowQRCode(ctx context.Context, txn *ve key1 = ownCrossSigningPublicKeys.MasterKey.Bytes() // Key 2 is the other user's master signing key. - theirSigningKeys, err := vh.mach.GetCrossSigningPublicKeys(ctx, txn.TheirUser) + theirSigningKeys, err := vh.mach.GetCrossSigningPublicKeys(ctx, txn.TheirUserID) if err != nil { - return err + return nil, err } key2 = theirSigningKeys.MasterKey.Bytes() case QRCodeModeSelfVerifyingMasterKeyTrusted: @@ -252,23 +328,22 @@ func (vh *VerificationHelper) generateAndShowQRCode(ctx context.Context, txn *ve key1 = ownCrossSigningPublicKeys.MasterKey.Bytes() // Key 2 is the other device's key. - theirDevice, err := vh.mach.GetOrFetchDevice(ctx, txn.TheirUser, txn.TheirDevice) + theirDevice, err := vh.mach.GetOrFetchDevice(ctx, txn.TheirUserID, txn.TheirDeviceID) if err != nil { - return err + return nil, err } - key2 = theirDevice.IdentityKey.Bytes() + key2 = theirDevice.SigningKey.Bytes() case QRCodeModeSelfVerifyingMasterKeyUntrusted: // Key 1 is the current device's key - key1 = vh.mach.OwnIdentity().IdentityKey.Bytes() + key1 = vh.mach.OwnIdentity().SigningKey.Bytes() // Key 2 is the master signing key. key2 = ownCrossSigningPublicKeys.MasterKey.Bytes() default: - log.Fatal().Str("mode", string(mode)).Msg("Unknown QR code mode") + log.Fatal().Int("mode", int(mode)).Msg("Unknown QR code mode") } qrCode := NewQRCode(mode, txn.TransactionID, [32]byte(key1), [32]byte(key2)) txn.QRCodeSharedSecret = qrCode.SharedSecret - vh.showQRCode(ctx, txn.TransactionID, qrCode) - return nil + return qrCode, nil } diff --git a/crypto/verificationhelper/sas.go b/crypto/verificationhelper/sas.go index 8160a4e1..e6392c79 100644 --- a/crypto/verificationhelper/sas.go +++ b/crypto/verificationhelper/sas.go @@ -13,6 +13,7 @@ import ( "crypto/hmac" "crypto/rand" "crypto/sha256" + "crypto/subtle" "encoding/base64" "encoding/json" "errors" @@ -28,41 +29,43 @@ import ( "maunium.net/go/mautrix/id" ) -// StartSAS starts a SAS verification flow. The transaction ID should be the -// transaction ID of a verification request that was received via the -// VerificationRequested callback in [RequiredCallbacks]. +// StartSAS starts a SAS verification flow for the given transaction ID. The +// transaction ID should be one received via the VerificationRequested callback +// in [RequiredCallbacks] or the [StartVerification] or +// [StartInRoomVerification] functions. func (vh *VerificationHelper) StartSAS(ctx context.Context, txnID id.VerificationTransactionID) error { log := vh.getLog(ctx).With(). - Str("verification_action", "accept verification"). + Str("verification_action", "start SAS"). Stringer("transaction_id", txnID). Logger() + ctx = log.WithContext(ctx) vh.activeTransactionsLock.Lock() defer vh.activeTransactionsLock.Unlock() - txn, ok := vh.activeTransactions[txnID] - if !ok { - return fmt.Errorf("unknown transaction ID") - } else if txn.VerificationState != verificationStateReady { - return errors.New("transaction is not in ready state") + txn, err := vh.store.GetVerificationTransaction(ctx, txnID) + if err != nil { + return fmt.Errorf("failed to get verification transaction %s: %w", txnID, err) + } else if txn.VerificationState != VerificationStateReady { + return fmt.Errorf("transaction is not in ready state: %s", txn.VerificationState.String()) } else if txn.StartEventContent != nil { return errors.New("start event already sent or received") } - txn.VerificationState = verificationStateSASStarted + txn.VerificationState = VerificationStateSASStarted txn.StartedByUs = true if !slices.Contains(txn.TheirSupportedMethods, event.VerificationMethodSAS) { return fmt.Errorf("the other device does not support SAS verification") } // Ensure that we have their device key. - _, err := vh.mach.GetOrFetchDevice(ctx, txn.TheirUser, txn.TheirDevice) + _, err = vh.mach.GetOrFetchDevice(ctx, txn.TheirUserID, txn.TheirDeviceID) if err != nil { log.Err(err).Msg("Failed to fetch device") return err } log.Info().Msg("Sending start event") - txn.StartEventContent = &event.VerificationStartEventContent{ + startEventContent := event.VerificationStartEventContent{ FromDevice: vh.client.DeviceID, Method: event.VerificationMethodSAS, @@ -77,35 +80,43 @@ func (vh *VerificationHelper) StartSAS(ctx context.Context, txnID id.Verificatio event.SASMethodEmoji, }, } - return vh.sendVerificationEvent(ctx, txn, event.InRoomVerificationStart, txn.StartEventContent) + if err := vh.sendVerificationEvent(ctx, txn, event.InRoomVerificationStart, &startEventContent); err != nil { + return err + } + txn.StartEventContent = &startEventContent + return vh.store.SaveVerificationTransaction(ctx, txn) } // ConfirmSAS indicates that the user has confirmed that the SAS matches SAS -// shown on the other user's device. +// shown on the other user's device for the given transaction ID. The +// transaction ID should be one received via the VerificationRequested callback +// in [RequiredCallbacks] or the [StartVerification] or +// [StartInRoomVerification] functions. func (vh *VerificationHelper) ConfirmSAS(ctx context.Context, txnID id.VerificationTransactionID) error { log := vh.getLog(ctx).With(). Str("verification_action", "confirm SAS"). Stringer("transaction_id", txnID). Logger() + ctx = log.WithContext(ctx) vh.activeTransactionsLock.Lock() defer vh.activeTransactionsLock.Unlock() - txn, ok := vh.activeTransactions[txnID] - if !ok { - return fmt.Errorf("unknown transaction ID") - } else if txn.VerificationState != verificationStateSASKeysExchanged { + txn, err := vh.store.GetVerificationTransaction(ctx, txnID) + if err != nil { + return fmt.Errorf("failed to get transaction %s: %w", txnID, err) + } else if txn.VerificationState != VerificationStateSASKeysExchanged { return errors.New("transaction is not in keys exchanged state") } - var err error keys := map[id.KeyID]jsonbytes.UnpaddedBytes{} log.Info().Msg("Signing keys") + var masterKey string // My device key myDevice := vh.mach.OwnIdentity() myDeviceKeyID := id.NewKeyID(id.KeyAlgorithmEd25519, myDevice.DeviceID.String()) - keys[myDeviceKeyID], err = vh.verificationMACHKDF(txn, vh.client.UserID, vh.client.DeviceID, txn.TheirUser, txn.TheirDevice, myDeviceKeyID.String(), myDevice.SigningKey.String()) + keys[myDeviceKeyID], err = vh.verificationMACHKDF(txn, vh.client.UserID, vh.client.DeviceID, txn.TheirUserID, txn.TheirDeviceID, myDeviceKeyID.String(), myDevice.SigningKey.String()) if err != nil { return err } @@ -113,8 +124,9 @@ func (vh *VerificationHelper) ConfirmSAS(ctx context.Context, txnID id.Verificat // Master signing key crossSigningKeys := vh.mach.GetOwnCrossSigningPublicKeys(ctx) if crossSigningKeys != nil { - crossSigningKeyID := id.NewKeyID(id.KeyAlgorithmEd25519, crossSigningKeys.MasterKey.String()) - keys[crossSigningKeyID], err = vh.verificationMACHKDF(txn, vh.client.UserID, vh.client.DeviceID, txn.TheirUser, txn.TheirDevice, crossSigningKeyID.String(), crossSigningKeys.MasterKey.String()) + masterKey = crossSigningKeys.MasterKey.String() + crossSigningKeyID := id.NewKeyID(id.KeyAlgorithmEd25519, masterKey) + keys[crossSigningKeyID], err = vh.verificationMACHKDF(txn, vh.client.UserID, vh.client.DeviceID, txn.TheirUserID, txn.TheirDeviceID, crossSigningKeyID.String(), masterKey) if err != nil { return err } @@ -125,7 +137,7 @@ func (vh *VerificationHelper) ConfirmSAS(ctx context.Context, txnID id.Verificat keyIDs = append(keyIDs, keyID.String()) } slices.Sort(keyIDs) - keysMAC, err := vh.verificationMACHKDF(txn, vh.client.UserID, vh.client.DeviceID, txn.TheirUser, txn.TheirDevice, "KEY_IDS", strings.Join(keyIDs, ",")) + keysMAC, err := vh.verificationMACHKDF(txn, vh.client.UserID, vh.client.DeviceID, txn.TheirUserID, txn.TheirDeviceID, "KEY_IDS", strings.Join(keyIDs, ",")) if err != nil { return err } @@ -138,17 +150,23 @@ func (vh *VerificationHelper) ConfirmSAS(ctx context.Context, txnID id.Verificat if err != nil { return err } + log.Info().Msg("Sent our MAC event") txn.SentOurMAC = true if txn.ReceivedTheirMAC { - txn.VerificationState = verificationStateSASMACExchanged + txn.VerificationState = VerificationStateSASMACExchanged + + if err := vh.trustKeysAfterMACCheck(ctx, txn, masterKey); err != nil { + return fmt.Errorf("failed to trust keys: %w", err) + } + err := vh.sendVerificationEvent(ctx, txn, event.InRoomVerificationDone, &event.VerificationDoneEventContent{}) if err != nil { return err } txn.SentOurDone = true } - return nil + return vh.store.SaveVerificationTransaction(ctx, txn) } // onVerificationStartSAS handles the m.key.verification.start events with @@ -156,12 +174,13 @@ func (vh *VerificationHelper) ConfirmSAS(ctx context.Context, txnID id.Verificat // Spec. // // [Section 11.12.2.2]: https://spec.matrix.org/v1.9/client-server-api/#short-authentication-string-sas-verification -func (vh *VerificationHelper) onVerificationStartSAS(ctx context.Context, txn *verificationTransaction, evt *event.Event) error { +func (vh *VerificationHelper) onVerificationStartSAS(ctx context.Context, txn VerificationTransaction, evt *event.Event) error { startEvt := evt.Content.AsVerificationStart() log := vh.getLog(ctx).With(). - Str("verification_action", "start_sas"). + Str("verification_action", "start SAS"). Stringer("transaction_id", txn.TransactionID). Logger() + ctx = log.WithContext(ctx) log.Info().Msg("Received SAS verification start event") _, err := vh.mach.GetOrFetchDevice(ctx, evt.Sender, startEvt.FromDevice) @@ -204,29 +223,30 @@ func (vh *VerificationHelper) onVerificationStartSAS(ctx context.Context, txn *v return fmt.Errorf("failed to generate ephemeral key: %w", err) } txn.MACMethod = macMethod - txn.EphemeralKey = ephemeralKey - txn.StartEventContent = startEvt + txn.EphemeralKey = &ECDHPrivateKey{ephemeralKey} - commitment, err := calculateCommitment(ephemeralKey.PublicKey(), startEvt) - if err != nil { - return fmt.Errorf("failed to calculate commitment: %w", err) - } + if !txn.StartedByUs { + commitment, err := calculateCommitment(ephemeralKey.PublicKey(), txn) + if err != nil { + return fmt.Errorf("failed to calculate commitment: %w", err) + } - err = vh.sendVerificationEvent(ctx, txn, event.InRoomVerificationAccept, &event.VerificationAcceptEventContent{ - Commitment: commitment, - Hash: hashAlgorithm, - KeyAgreementProtocol: keyAggreementProtocol, - MessageAuthenticationCode: macMethod, - ShortAuthenticationString: sasMethods, - }) - if err != nil { - return fmt.Errorf("failed to send accept event: %w", err) + err = vh.sendVerificationEvent(ctx, txn, event.InRoomVerificationAccept, &event.VerificationAcceptEventContent{ + Commitment: commitment, + Hash: hashAlgorithm, + KeyAgreementProtocol: keyAggreementProtocol, + MessageAuthenticationCode: macMethod, + ShortAuthenticationString: sasMethods, + }) + if err != nil { + return fmt.Errorf("failed to send accept event: %w", err) + } + txn.VerificationState = VerificationStateSASAccepted } - txn.VerificationState = verificationStateSASAccepted - return nil + return vh.store.SaveVerificationTransaction(ctx, txn) } -func calculateCommitment(ephemeralPubKey *ecdh.PublicKey, startEvt *event.VerificationStartEventContent) ([]byte, error) { +func calculateCommitment(ephemeralPubKey *ecdh.PublicKey, txn VerificationTransaction) ([]byte, error) { // The commitmentHashInput is the hash (encoded as unpadded base64) of the // concatenation of the device's ephemeral public key (encoded as // unpadded base64) and the canonical JSON representation of the @@ -236,7 +256,7 @@ func calculateCommitment(ephemeralPubKey *ecdh.PublicKey, startEvt *event.Verifi // hashing it, but we are just stuck on that. commitmentHashInput := sha256.New() commitmentHashInput.Write([]byte(base64.RawStdEncoding.EncodeToString(ephemeralPubKey.Bytes()))) - encodedStartEvt, err := json.Marshal(startEvt) + encodedStartEvt, err := json.Marshal(txn.StartEventContent) if err != nil { return nil, err } @@ -248,7 +268,7 @@ func calculateCommitment(ephemeralPubKey *ecdh.PublicKey, startEvt *event.Verifi // event. This follows Step 4 of [Section 11.12.2.2] of the Spec. // // [Section 11.12.2.2]: https://spec.matrix.org/v1.9/client-server-api/#short-authentication-string-sas-verification -func (vh *VerificationHelper) onVerificationAccept(ctx context.Context, txn *verificationTransaction, evt *event.Event) { +func (vh *VerificationHelper) onVerificationAccept(ctx context.Context, txn VerificationTransaction, evt *event.Event) { acceptEvt := evt.Content.AsVerificationAccept() log := vh.getLog(ctx).With(). Str("verification_action", "accept"). @@ -259,11 +279,12 @@ func (vh *VerificationHelper) onVerificationAccept(ctx context.Context, txn *ver Str("message_authentication_code", string(acceptEvt.MessageAuthenticationCode)). Any("short_authentication_string", acceptEvt.ShortAuthenticationString). Logger() + ctx = log.WithContext(ctx) log.Info().Msg("Received SAS verification accept event") vh.activeTransactionsLock.Lock() defer vh.activeTransactionsLock.Unlock() - if txn.VerificationState != verificationStateSASStarted { + if txn.VerificationState != VerificationStateSASStarted { vh.cancelVerificationTxn(ctx, txn, event.VerificationCancelCodeUnexpectedMessage, "received accept event for a transaction that is not in the started state") return @@ -283,49 +304,49 @@ func (vh *VerificationHelper) onVerificationAccept(ctx context.Context, txn *ver return } - txn.VerificationState = verificationStateSASAccepted + txn.VerificationState = VerificationStateSASAccepted txn.MACMethod = acceptEvt.MessageAuthenticationCode txn.Commitment = acceptEvt.Commitment - txn.EphemeralKey = ephemeralKey + txn.EphemeralKey = &ECDHPrivateKey{ephemeralKey} txn.EphemeralPublicKeyShared = true + + if err := vh.store.SaveVerificationTransaction(ctx, txn); err != nil { + log.Err(err).Msg("failed to save verification transaction") + } } -func (vh *VerificationHelper) onVerificationKey(ctx context.Context, txn *verificationTransaction, evt *event.Event) { +func (vh *VerificationHelper) onVerificationKey(ctx context.Context, txn VerificationTransaction, evt *event.Event) { log := vh.getLog(ctx).With(). Str("verification_action", "key"). Logger() + ctx = log.WithContext(ctx) keyEvt := evt.Content.AsVerificationKey() vh.activeTransactionsLock.Lock() defer vh.activeTransactionsLock.Unlock() - if txn.VerificationState != verificationStateSASAccepted { + if txn.VerificationState != VerificationStateSASAccepted { vh.cancelVerificationTxn(ctx, txn, event.VerificationCancelCodeUnexpectedMessage, "received key event for a transaction that is not in the accepted state") return } var err error - txn.OtherPublicKey, err = ecdh.X25519().NewPublicKey(keyEvt.Key) + publicKey, err := ecdh.X25519().NewPublicKey(keyEvt.Key) if err != nil { log.Err(err).Msg("Failed to generate other public key") return } + txn.OtherPublicKey = &ECDHPublicKey{publicKey} if txn.EphemeralPublicKeyShared { // Verify that the commitment hash is correct - commitment, err := calculateCommitment(txn.OtherPublicKey, txn.StartEventContent) + commitment, err := calculateCommitment(publicKey, txn) if err != nil { log.Err(err).Msg("Failed to calculate commitment") return } if !bytes.Equal(commitment, txn.Commitment) { - err := vh.sendVerificationEvent(ctx, txn, event.InRoomVerificationCancel, &event.VerificationCancelEventContent{ - Code: event.VerificationCancelCodeKeyMismatch, - Reason: "The key was not the one we expected.", - }) - if err != nil { - log.Err(err).Msg("Failed to send cancellation event") - } + vh.cancelVerificationTxn(ctx, txn, event.VerificationCancelCodeKeyMismatch, "The key was not the one we expected") return } } else { @@ -338,7 +359,7 @@ func (vh *VerificationHelper) onVerificationKey(ctx context.Context, txn *verifi } txn.EphemeralPublicKeyShared = true } - txn.VerificationState = verificationStateSASKeysExchanged + txn.VerificationState = VerificationStateSASKeysExchanged sasBytes, err := vh.verificationSASHKDF(txn) if err != nil { @@ -348,6 +369,7 @@ func (vh *VerificationHelper) onVerificationKey(ctx context.Context, txn *verifi var decimals []int var emojis []rune + var emojiDescriptions []string if slices.Contains(txn.StartEventContent.ShortAuthenticationString, event.SASMethodDecimal) { decimals = []int{ (int(sasBytes[0])<<5 | int(sasBytes[1])>>3) + 1000, @@ -363,13 +385,18 @@ func (vh *VerificationHelper) onVerificationKey(ctx context.Context, txn *verifi // Right shift the number and then mask the lowest 6 bits. emojiIdx := (sasNum >> uint(48-(i+1)*6)) & 0b111111 emojis = append(emojis, allEmojis[emojiIdx]) + emojiDescriptions = append(emojiDescriptions, allEmojiDescriptions[emojiIdx]) } } - vh.showSAS(ctx, txn.TransactionID, emojis, decimals) + vh.showSAS(ctx, txn.TransactionID, emojis, emojiDescriptions, decimals) + + if err := vh.store.SaveVerificationTransaction(ctx, txn); err != nil { + log.Err(err).Msg("failed to save verification transaction") + } } -func (vh *VerificationHelper) verificationSASHKDF(txn *verificationTransaction) ([]byte, error) { - sharedSecret, err := txn.EphemeralKey.ECDH(txn.OtherPublicKey) +func (vh *VerificationHelper) verificationSASHKDF(txn VerificationTransaction) ([]byte, error) { + sharedSecret, err := txn.EphemeralKey.ECDH(txn.OtherPublicKey.PublicKey) if err != nil { return nil, err } @@ -384,8 +411,8 @@ func (vh *VerificationHelper) verificationSASHKDF(txn *verificationTransaction) }, "|") theirInfo := strings.Join([]string{ - txn.TheirUser.String(), - txn.TheirDevice.String(), + txn.TheirUserID.String(), + txn.TheirDeviceID.String(), base64.RawStdEncoding.EncodeToString(txn.OtherPublicKey.Bytes()), }, "|") @@ -458,8 +485,8 @@ func BrokenB64Encode(input []byte) string { return string(output) } -func (vh *VerificationHelper) verificationMACHKDF(txn *verificationTransaction, senderUser id.UserID, senderDevice id.DeviceID, receivingUser id.UserID, receivingDevice id.DeviceID, keyID, key string) ([]byte, error) { - sharedSecret, err := txn.EphemeralKey.ECDH(txn.OtherPublicKey) +func (vh *VerificationHelper) verificationMACHKDF(txn VerificationTransaction, senderUser id.UserID, senderDevice id.DeviceID, receivingUser id.UserID, receivingDevice id.DeviceID, keyID, key string) ([]byte, error) { + sharedSecret, err := txn.EphemeralKey.ECDH(txn.OtherPublicKey.PublicKey) if err != nil { return nil, err } @@ -559,28 +586,101 @@ var allEmojis = []rune{ '📌', } -func (vh *VerificationHelper) onVerificationMAC(ctx context.Context, txn *verificationTransaction, evt *event.Event) { +var allEmojiDescriptions = []string{ + "Dog", + "Cat", + "Lion", + "Horse", + "Unicorn", + "Pig", + "Elephant", + "Rabbit", + "Panda", + "Rooster", + "Penguin", + "Turtle", + "Fish", + "Octopus", + "Butterfly", + "Flower", + "Tree", + "Cactus", + "Mushroom", + "Globe", + "Moon", + "Cloud", + "Fire", + "Banana", + "Apple", + "Strawberry", + "Corn", + "Pizza", + "Cake", + "Heart", + "Smiley", + "Robot", + "Hat", + "Glasses", + "Spanner", + "Santa", + "Thumbs Up", + "Umbrella", + "Hourglass", + "Clock", + "Gift", + "Light Bulb", + "Book", + "Pencil", + "Paperclip", + "Scissors", + "Lock", + "Key", + "Hammer", + "Telephone", + "Flag", + "Train", + "Bicycle", + "Aeroplane", + "Rocket", + "Trophy", + "Ball", + "Guitar", + "Trumpet", + "Bell", + "Anchor", + "Headphones", + "Folder", + "Pin", +} + +func (vh *VerificationHelper) onVerificationMAC(ctx context.Context, txn VerificationTransaction, evt *event.Event) { log := vh.getLog(ctx).With(). Str("verification_action", "mac"). Logger() + ctx = log.WithContext(ctx) log.Info().Msg("Received SAS verification MAC event") + vh.activeTransactionsLock.Lock() + defer vh.activeTransactionsLock.Unlock() macEvt := evt.Content.AsVerificationMAC() // Verifying Keys MAC log.Info().Msg("Verifying MAC for all sent keys") var hasTheirDeviceKey bool + var masterKey string var keyIDs []string for keyID := range macEvt.MAC { keyIDs = append(keyIDs, keyID.String()) _, kID := keyID.Parse() - if kID == txn.TheirDevice.String() { + if kID == txn.TheirDeviceID.String() { hasTheirDeviceKey = true + } else { + masterKey = kID } } slices.Sort(keyIDs) - expectedKeyMAC, err := vh.verificationMACHKDF(txn, txn.TheirUser, txn.TheirDevice, vh.client.UserID, vh.client.DeviceID, "KEY_IDS", strings.Join(keyIDs, ",")) + expectedKeyMAC, err := vh.verificationMACHKDF(txn, txn.TheirUserID, txn.TheirDeviceID, vh.client.UserID, vh.client.DeviceID, "KEY_IDS", strings.Join(keyIDs, ",")) if err != nil { - vh.cancelVerificationTxn(ctx, txn, event.VerificationCancelCodeSASMismatch, "failed to calculate key list MAC: %v", err) + vh.cancelVerificationTxn(ctx, txn, event.VerificationCancelCodeSASMismatch, "failed to calculate key list MAC: %w", err) return } if !bytes.Equal(expectedKeyMAC, macEvt.Keys) { @@ -593,8 +693,9 @@ func (vh *VerificationHelper) onVerificationMAC(ctx context.Context, txn *verifi } // Verify the MAC for each key + var theirDevice *id.Device for keyID, mac := range macEvt.MAC { - log.Info().Str("key_id", keyID.String()).Msg("Received MAC for key") + log.Info().Stringer("key_id", keyID).Msg("Received MAC for key") alg, kID := keyID.Parse() if alg != id.KeyAlgorithmEd25519 { @@ -603,11 +704,14 @@ func (vh *VerificationHelper) onVerificationMAC(ctx context.Context, txn *verifi } var key string - var theirDevice *id.Device - if kID == txn.TheirDevice.String() { - theirDevice, err = vh.mach.GetOrFetchDevice(ctx, txn.TheirUser, txn.TheirDevice) + if kID == txn.TheirDeviceID.String() { + if theirDevice != nil { + vh.cancelVerificationTxn(ctx, txn, event.VerificationCancelCodeInvalidMessage, "two keys found for their device ID") + return + } + theirDevice, err = vh.mach.GetOrFetchDevice(ctx, txn.TheirUserID, txn.TheirDeviceID) if err != nil { - vh.cancelVerificationTxn(ctx, txn, event.VerificationCancelCodeUser, "failed to fetch their device: %v", err) + vh.cancelVerificationTxn(ctx, txn, event.VerificationCancelCodeUser, "failed to fetch their device: %w", err) return } key = theirDevice.SigningKey.String() @@ -624,38 +728,85 @@ func (vh *VerificationHelper) onVerificationMAC(ctx context.Context, txn *verifi key = crossSigningKeys.MasterKey.String() } - expectedMAC, err := vh.verificationMACHKDF(txn, txn.TheirUser, txn.TheirDevice, vh.client.UserID, vh.client.DeviceID, keyID.String(), key) + expectedMAC, err := vh.verificationMACHKDF(txn, txn.TheirUserID, txn.TheirDeviceID, vh.client.UserID, vh.client.DeviceID, keyID.String(), key) if err != nil { - vh.cancelVerificationTxn(ctx, txn, event.VerificationCancelCodeUser, "failed to calculate key MAC: %v", err) + vh.cancelVerificationTxn(ctx, txn, event.VerificationCancelCodeUser, "failed to calculate key MAC: %w", err) return } - if !bytes.Equal(expectedMAC, mac) { + if subtle.ConstantTimeCompare(expectedMAC, mac) == 0 { vh.cancelVerificationTxn(ctx, txn, event.VerificationCancelCodeSASMismatch, "MAC mismatch for key %s", keyID) return } - - // Trust their device - if kID == txn.TheirDevice.String() { - theirDevice.Trust = id.TrustStateVerified - err = vh.mach.CryptoStore.PutDevice(ctx, txn.TheirUser, theirDevice) - if err != nil { - vh.cancelVerificationTxn(ctx, txn, event.VerificationCancelCodeUser, "failed to update device trust state after verifying: %v", err) - return - } - } } log.Info().Msg("All MACs verified") - vh.activeTransactionsLock.Lock() - defer vh.activeTransactionsLock.Unlock() txn.ReceivedTheirMAC = true if txn.SentOurMAC { - txn.VerificationState = verificationStateSASMACExchanged + txn.VerificationState = VerificationStateSASMACExchanged + + if err := vh.trustKeysAfterMACCheck(ctx, txn, masterKey); err != nil { + vh.cancelVerificationTxn(ctx, txn, event.VerificationCancelCodeUser, "failed to trust keys: %w", err) + return + } + err := vh.sendVerificationEvent(ctx, txn, event.InRoomVerificationDone, &event.VerificationDoneEventContent{}) if err != nil { - vh.cancelVerificationTxn(ctx, txn, event.VerificationCancelCodeUser, "failed to send verification done event: %v", err) + vh.cancelVerificationTxn(ctx, txn, event.VerificationCancelCodeUser, "failed to send verification done event: %w", err) return } txn.SentOurDone = true } + + if err := vh.store.SaveVerificationTransaction(ctx, txn); err != nil { + log.Err(err).Msg("failed to save verification transaction") + } +} + +func (vh *VerificationHelper) trustKeysAfterMACCheck(ctx context.Context, txn VerificationTransaction, masterKey string) error { + theirDevice, err := vh.mach.GetOrFetchDevice(ctx, txn.TheirUserID, txn.TheirDeviceID) + if err != nil { + return fmt.Errorf("failed to fetch their device: %w", err) + } + // Trust their device + theirDevice.Trust = id.TrustStateVerified + err = vh.mach.CryptoStore.PutDevice(ctx, txn.TheirUserID, theirDevice) + if err != nil { + return fmt.Errorf("failed to update device trust state after verifying: %w", err) + } + + if txn.TheirUserID == vh.client.UserID { + // Self-signing situation. + // + // If we have the cross-signing keys, then we need to sign their device + // using the self-signing key. Otherwise, they have the master private + // key, so we need to trust the master public key. + if vh.mach.CrossSigningKeys != nil { + err = vh.mach.SignOwnDevice(ctx, theirDevice) + if err != nil { + return fmt.Errorf("failed to sign our own new device: %w", err) + } + } else { + err = vh.mach.SignOwnMasterKey(ctx) + if err != nil { + return fmt.Errorf("failed to sign our own master key: %w", err) + } + } + } else if masterKey != "" { + // Cross-signing situation. + // + // The master key was included in the list of keys to verify, so verify + // that it matches what we expect and sign their master key using the + // user-signing key. + theirSigningKeys, err := vh.mach.GetCrossSigningPublicKeys(ctx, txn.TheirUserID) + if err != nil { + return fmt.Errorf("couldn't get %s's cross-signing keys: %w", txn.TheirUserID, err) + } else if theirSigningKeys.MasterKey.String() != masterKey { + return fmt.Errorf("master keys do not match") + } + + if err := vh.mach.SignUser(ctx, txn.TheirUserID, theirSigningKeys.MasterKey); err != nil { + return fmt.Errorf("failed to sign %s's master key: %w", txn.TheirUserID, err) + } + } + return nil } diff --git a/crypto/verificationhelper/verificationhelper.go b/crypto/verificationhelper/verificationhelper.go index 025af25e..0a781c16 100644 --- a/crypto/verificationhelper/verificationhelper.go +++ b/crypto/verificationhelper/verificationhelper.go @@ -9,11 +9,13 @@ package verificationhelper import ( "bytes" "context" - "crypto/ecdh" + "errors" "fmt" "sync" + "time" "github.com/rs/zerolog" + "go.mau.fi/util/exslices" "go.mau.fi/util/jsontime" "golang.org/x/exp/maps" "golang.org/x/exp/slices" @@ -24,121 +26,33 @@ import ( "maunium.net/go/mautrix/id" ) -type verificationState int - -const ( - verificationStateRequested verificationState = iota - verificationStateReady - verificationStateCancelled - verificationStateDone - - verificationStateTheirQRScanned // We scanned their QR code - verificationStateOurQRScanned // They scanned our QR code - - verificationStateSASStarted // An SAS verification has been started - verificationStateSASAccepted // An SAS verification has been accepted - verificationStateSASKeysExchanged // An SAS verification has exchanged keys - verificationStateSASMACExchanged // An SAS verification has exchanged MACs -) - -func (step verificationState) String() string { - switch step { - case verificationStateRequested: - return "requested" - case verificationStateReady: - return "ready" - case verificationStateCancelled: - return "cancelled" - case verificationStateTheirQRScanned: - return "their_qr_scanned" - case verificationStateOurQRScanned: - return "our_qr_scanned" - case verificationStateSASStarted: - return "sas_started" - case verificationStateSASAccepted: - return "sas_accepted" - case verificationStateSASKeysExchanged: - return "sas_keys_exchanged" - case verificationStateSASMACExchanged: - return "sas_mac" - default: - return fmt.Sprintf("verificationStep(%d)", step) - } -} - -type verificationTransaction struct { - // RoomID is the room ID if the verification is happening in a room or - // empty if it is a to-device verification. - RoomID id.RoomID - - // VerificationState is the current step of the verification flow. - VerificationState verificationState - // TransactionID is the ID of the verification transaction. - TransactionID id.VerificationTransactionID - - // TheirDevice is the device ID of the device that either made the initial - // request or accepted our request. - TheirDevice id.DeviceID - // TheirUser is the user ID of the other user. - TheirUser id.UserID - // TheirSupportedMethods is a list of verification methods that the other - // device supports. - TheirSupportedMethods []event.VerificationMethod - - // SentToDeviceIDs is a list of devices which the initial request was sent - // to. This is only used for to-device verification requests, and is meant - // to be used to send cancellation requests to all other devices when a - // verification request is accepted via a m.key.verification.ready event. - SentToDeviceIDs []id.DeviceID - - // QRCodeSharedSecret is the shared secret that was encoded in the QR code - // that we showed. - QRCodeSharedSecret []byte - - StartedByUs bool // Whether the verification was started by us - StartEventContent *event.VerificationStartEventContent // The m.key.verification.start event content - Commitment []byte // The commitment from the m.key.verification.accept event - MACMethod event.MACMethod // The method used to calculate the MAC - EphemeralKey *ecdh.PrivateKey // The ephemeral key - EphemeralPublicKeyShared bool // Whether this device's ephemeral public key has been shared - OtherPublicKey *ecdh.PublicKey // The other device's ephemeral public key - ReceivedTheirMAC bool // Whether we have received their MAC - SentOurMAC bool // Whether we have sent our MAC - ReceivedTheirDone bool // Whether we have received their done event - SentOurDone bool // Whether we have sent our done event -} - // RequiredCallbacks is an interface representing the callbacks required for // the [VerificationHelper]. type RequiredCallbacks interface { // VerificationRequested is called when a verification request is received // from another device. - VerificationRequested(ctx context.Context, txnID id.VerificationTransactionID, from id.UserID) + VerificationRequested(ctx context.Context, txnID id.VerificationTransactionID, from id.UserID, fromDevice id.DeviceID) + + // VerificationReady is called when a verification request has been + // accepted by both parties. + VerificationReady(ctx context.Context, txnID id.VerificationTransactionID, otherDeviceID id.DeviceID, supportsSAS, supportsScanQRCode bool, qrCode *QRCode) // VerificationCancelled is called when the verification is cancelled. VerificationCancelled(ctx context.Context, txnID id.VerificationTransactionID, code event.VerificationCancelCode, reason string) // VerificationDone is called when the verification is done. - VerificationDone(ctx context.Context, txnID id.VerificationTransactionID) + VerificationDone(ctx context.Context, txnID id.VerificationTransactionID, method event.VerificationMethod) } -type showSASCallbacks interface { +type ShowSASCallbacks interface { // ShowSAS is a callback that is called when the SAS verification has // generated a short authentication string to show. It is guaranteed that - // either the emojis list, or the decimals list, or both will be present. - ShowSAS(ctx context.Context, txnID id.VerificationTransactionID, emojis []rune, decimals []int) + // either the emojis and emoji descriptions lists, or the decimals list, or + // both will be present. + ShowSAS(ctx context.Context, txnID id.VerificationTransactionID, emojis []rune, emojiDescriptions []string, decimals []int) } -type showQRCodeCallbacks interface { - // ScanQRCode is called when another device has sent a - // m.key.verification.ready event and indicated that they are capable of - // showing a QR code. - ScanQRCode(ctx context.Context, txnID id.VerificationTransactionID) - - // ShowQRCode is called when the verification has been accepted and a QR - // code should be shown to the user. - ShowQRCode(ctx context.Context, txnID id.VerificationTransactionID, qrCode *QRCode) - +type ShowQRCodeCallbacks interface { // QRCodeScanned is called when the other user has scanned the QR code and // sent the m.key.verification.start event. QRCodeScanned(ctx context.Context, txnID id.VerificationTransactionID) @@ -148,67 +62,80 @@ type VerificationHelper struct { client *mautrix.Client mach *crypto.OlmMachine - activeTransactions map[id.VerificationTransactionID]*verificationTransaction + store VerificationStore activeTransactionsLock sync.Mutex // supportedMethods are the methods that *we* support supportedMethods []event.VerificationMethod - verificationRequested func(ctx context.Context, txnID id.VerificationTransactionID, from id.UserID) + verificationRequested func(ctx context.Context, txnID id.VerificationTransactionID, from id.UserID, fromDevice id.DeviceID) + verificationReady func(ctx context.Context, txnID id.VerificationTransactionID, otherDeviceID id.DeviceID, supportsSAS, supportsScanQRCode bool, qrCode *QRCode) verificationCancelledCallback func(ctx context.Context, txnID id.VerificationTransactionID, code event.VerificationCancelCode, reason string) - verificationDone func(ctx context.Context, txnID id.VerificationTransactionID) + verificationDone func(ctx context.Context, txnID id.VerificationTransactionID, method event.VerificationMethod) - showSAS func(ctx context.Context, txnID id.VerificationTransactionID, emojis []rune, decimals []int) - - scanQRCode func(ctx context.Context, txnID id.VerificationTransactionID) - showQRCode func(ctx context.Context, txnID id.VerificationTransactionID, qrCode *QRCode) - qrCodeScaned func(ctx context.Context, txnID id.VerificationTransactionID) + // showSAS is a callback that will be called after the SAS verification + // dance is complete and we want the client to show the emojis/decimals + showSAS func(ctx context.Context, txnID id.VerificationTransactionID, emojis []rune, emojiDescriptions []string, decimals []int) + // qrCodeScanned is a callback that will be called when the other device + // scanned the QR code we are showing + qrCodeScanned func(ctx context.Context, txnID id.VerificationTransactionID) } var _ mautrix.VerificationHelper = (*VerificationHelper)(nil) -func NewVerificationHelper(client *mautrix.Client, mach *crypto.OlmMachine, callbacks any, supportsScan bool) *VerificationHelper { +func NewVerificationHelper(client *mautrix.Client, mach *crypto.OlmMachine, store VerificationStore, callbacks any, supportsQRShow, supportsQRScan, supportsSAS bool) *VerificationHelper { if client.Crypto == nil { panic("client.Crypto is nil") } + if store == nil { + store = NewInMemoryVerificationStore() + } + helper := VerificationHelper{ - client: client, - mach: mach, - activeTransactions: map[id.VerificationTransactionID]*verificationTransaction{}, + client: client, + mach: mach, + store: store, } if c, ok := callbacks.(RequiredCallbacks); !ok { - panic("callbacks must implement VerificationRequested") + panic("callbacks must implement RequiredCallbacks") } else { helper.verificationRequested = c.VerificationRequested + helper.verificationReady = c.VerificationReady helper.verificationCancelledCallback = c.VerificationCancelled helper.verificationDone = c.VerificationDone } - if c, ok := callbacks.(showSASCallbacks); ok { - helper.supportedMethods = append(helper.supportedMethods, event.VerificationMethodSAS) - helper.showSAS = c.ShowSAS + if supportsSAS { + if c, ok := callbacks.(ShowSASCallbacks); !ok { + panic("callbacks must implement showSAS if supportsSAS is true") + } else { + helper.supportedMethods = append(helper.supportedMethods, event.VerificationMethodSAS) + helper.showSAS = c.ShowSAS + } } - if c, ok := callbacks.(showQRCodeCallbacks); ok { - helper.supportedMethods = append(helper.supportedMethods, - event.VerificationMethodQRCodeShow, event.VerificationMethodReciprocate) - helper.scanQRCode = c.ScanQRCode - helper.showQRCode = c.ShowQRCode - helper.qrCodeScaned = c.QRCodeScanned + if supportsQRShow { + if c, ok := callbacks.(ShowQRCodeCallbacks); !ok { + panic("callbacks must implement ShowQRCodeCallbacks if supportsQRShow is true") + } else { + helper.supportedMethods = append(helper.supportedMethods, event.VerificationMethodQRCodeShow) + helper.supportedMethods = append(helper.supportedMethods, event.VerificationMethodReciprocate) + helper.qrCodeScanned = c.QRCodeScanned + } } - if supportsScan { - helper.supportedMethods = append(helper.supportedMethods, - event.VerificationMethodQRCodeScan, event.VerificationMethodReciprocate) + if supportsQRScan { + helper.supportedMethods = append(helper.supportedMethods, event.VerificationMethodQRCodeScan) + helper.supportedMethods = append(helper.supportedMethods, event.VerificationMethodReciprocate) } - - slices.Sort(helper.supportedMethods) - helper.supportedMethods = slices.Compact(helper.supportedMethods) + helper.supportedMethods = exslices.DeduplicateUnsorted(helper.supportedMethods) return &helper } func (vh *VerificationHelper) getLog(ctx context.Context) *zerolog.Logger { logger := zerolog.Ctx(ctx).With(). Str("component", "verification"). + Stringer("device_id", vh.client.DeviceID). + Stringer("user_id", vh.client.UserID). Any("supported_methods", vh.supportedMethods). Logger() return &logger @@ -236,67 +163,72 @@ func (vh *VerificationHelper) Init(ctx context.Context) error { // Wrapper for the event handlers to check that the transaction ID is known // and ignore the event if it isn't. - wrapHandler := func(callback func(context.Context, *verificationTransaction, *event.Event)) func(context.Context, *event.Event) { + wrapHandler := func(callback func(context.Context, VerificationTransaction, *event.Event)) func(context.Context, *event.Event) { return func(ctx context.Context, evt *event.Event) { log := vh.getLog(ctx).With(). Str("verification_action", "check transaction ID"). Stringer("sender", evt.Sender). Stringer("room_id", evt.RoomID). Stringer("event_id", evt.ID). + Stringer("event_type", evt.Type). Logger() + ctx = log.WithContext(ctx) var transactionID id.VerificationTransactionID if evt.ID != "" { transactionID = id.VerificationTransactionID(evt.ID) } else { - txnID, ok := evt.Content.Raw["transaction_id"].(string) - if !ok { + if txnID, ok := evt.Content.Parsed.(event.VerificationTransactionable); !ok { log.Warn().Msg("Ignoring verification event without a transaction ID") return + } else { + transactionID = txnID.GetTransactionID() } - transactionID = id.VerificationTransactionID(txnID) } log = log.With().Stringer("transaction_id", transactionID).Logger() vh.activeTransactionsLock.Lock() - txn, ok := vh.activeTransactions[transactionID] - vh.activeTransactionsLock.Unlock() - if !ok || txn.VerificationState == verificationStateCancelled || txn.VerificationState == verificationStateDone { - var code event.VerificationCancelCode - var reason string - if !ok { - log.Warn().Msg("Ignoring verification event for an unknown transaction and sending cancellation") - - // We have to create a fake transaction so that the call to - // verificationCancelled works. - txn = &verificationTransaction{ - RoomID: evt.RoomID, - TheirUser: evt.Sender, - } - txn.TransactionID = evt.Content.Parsed.(event.VerificationTransactionable).GetTransactionID() - if txn.TransactionID == "" { - txn.TransactionID = id.VerificationTransactionID(evt.ID) - } - if fromDevice, ok := evt.Content.Raw["from_device"]; ok { - txn.TheirDevice = id.DeviceID(fromDevice.(string)) - } - code = event.VerificationCancelCodeUnknownTransaction - reason = "The transaction ID was not recognized." - } else if txn.VerificationState == verificationStateCancelled { - log.Warn().Msg("Ignoring verification event for a cancelled transaction") - code = event.VerificationCancelCodeUnexpectedMessage - reason = "The transaction is cancelled." - } else if txn.VerificationState == verificationStateDone { - code = event.VerificationCancelCodeUnexpectedMessage - reason = "The transaction is done." + txn, err := vh.store.GetVerificationTransaction(ctx, transactionID) + if err != nil && errors.Is(err, ErrUnknownVerificationTransaction) { + log.Err(err).Msg("failed to get verification transaction") + vh.activeTransactionsLock.Unlock() + return + } else if errors.Is(err, ErrUnknownVerificationTransaction) { + // If it's a cancellation event for an unknown transaction, we + // can just ignore it. + if evt.Type == event.ToDeviceVerificationCancel || evt.Type == event.InRoomVerificationCancel { + log.Info().Msg("Ignoring verification cancellation event for an unknown transaction") + vh.activeTransactionsLock.Unlock() + return } - // Send the actual cancellation event. - vh.cancelVerificationTxn(ctx, txn, code, reason) + log.Warn().Msg("Sending cancellation event for unknown transaction ID") + + // We have to create a fake transaction so that the call to + // cancelVerificationTxn works. + txn = VerificationTransaction{ + ExpirationTime: jsontime.UnixMilli{Time: time.Now().Add(time.Minute * 10)}, + RoomID: evt.RoomID, + TheirUserID: evt.Sender, + } + if transactionable, ok := evt.Content.Parsed.(event.VerificationTransactionable); ok { + txn.TransactionID = transactionable.GetTransactionID() + } else { + txn.TransactionID = id.VerificationTransactionID(evt.ID) + } + if fromDevice, ok := evt.Content.Raw["from_device"]; ok { + txn.TheirDeviceID = id.DeviceID(fromDevice.(string)) + } + + // Send a cancellation event. + vh.cancelVerificationTxn(ctx, txn, event.VerificationCancelCodeUnknownTransaction, "The transaction ID was not recognized.") + vh.activeTransactionsLock.Unlock() return + } else { + vh.activeTransactionsLock.Unlock() } - logCtx := vh.getLog(ctx).With(). + logCtx := log.With(). Stringer("transaction_step", txn.VerificationState). Stringer("sender", evt.Sender) if evt.RoomID != "" { @@ -326,34 +258,54 @@ func (vh *VerificationHelper) Init(ctx context.Context) error { syncer.OnEventType(event.InRoomVerificationKey, wrapHandler(vh.onVerificationKey)) // SAS syncer.OnEventType(event.InRoomVerificationMAC, wrapHandler(vh.onVerificationMAC)) // SAS - return nil + allTransactions, err := vh.store.GetAllVerificationTransactions(ctx) + for _, txn := range allTransactions { + vh.expireTransactionAt(txn.TransactionID, txn.ExpirationTime.Time) + } + return err } // StartVerification starts an interactive verification flow with the given // user via a to-device event. func (vh *VerificationHelper) StartVerification(ctx context.Context, to id.UserID) (id.VerificationTransactionID, error) { + if len(vh.supportedMethods) == 0 { + return "", fmt.Errorf("no supported verification methods") + } + txnID := id.NewVerificationTransactionID() devices, err := vh.mach.CryptoStore.GetDevices(ctx, to) if err != nil { return "", fmt.Errorf("failed to get devices for user: %w", err) + } else if len(devices) == 0 { + // HACK: we are doing this because the client doesn't wait until it has + // the devices before starting verification. + if keys, err := vh.mach.FetchKeys(ctx, []id.UserID{to}, true); err != nil { + return "", err + } else { + devices = keys[to] + } } - vh.getLog(ctx).Info(). + log := vh.getLog(ctx).With(). Str("verification_action", "start verification"). Stringer("transaction_id", txnID). Stringer("to", to). Any("device_ids", maps.Keys(devices)). - Msg("Sending verification request") + Logger() + ctx = log.WithContext(ctx) + log.Info().Msg("Sending verification request") + now := time.Now() content := &event.Content{ Parsed: &event.VerificationRequestEventContent{ ToDeviceVerificationEvent: event.ToDeviceVerificationEvent{TransactionID: txnID}, FromDevice: vh.client.DeviceID, Methods: vh.supportedMethods, - Timestamp: jsontime.UnixMilliNow(), + Timestamp: jsontime.UM(now), }, } + vh.expireTransactionAt(txnID, now.Add(time.Minute*10)) req := mautrix.ReqSendToDevice{Messages: map[id.UserID]map[id.DeviceID]*event.Content{to: {}}} for deviceID := range devices { @@ -372,28 +324,29 @@ func (vh *VerificationHelper) StartVerification(ctx context.Context, to id.UserI vh.activeTransactionsLock.Lock() defer vh.activeTransactionsLock.Unlock() - vh.activeTransactions[txnID] = &verificationTransaction{ - VerificationState: verificationStateRequested, + return txnID, vh.store.SaveVerificationTransaction(ctx, VerificationTransaction{ + ExpirationTime: jsontime.UnixMilli{Time: time.Now().Add(time.Minute * 10)}, + VerificationState: VerificationStateRequested, TransactionID: txnID, - TheirUser: to, + TheirUserID: to, SentToDeviceIDs: maps.Keys(devices), - } - return txnID, nil + }) } -// StartVerification starts an interactive verification flow with the given -// user in the given room. +// StartInRoomVerification starts an interactive verification flow with the +// given user in the given room. func (vh *VerificationHelper) StartInRoomVerification(ctx context.Context, roomID id.RoomID, to id.UserID) (id.VerificationTransactionID, error) { log := vh.getLog(ctx).With(). Str("verification_action", "start in-room verification"). Stringer("room_id", roomID). Stringer("to", to). Logger() + ctx = log.WithContext(ctx) log.Info().Msg("Sending verification request") content := event.MessageEventContent{ MsgType: event.MsgVerificationRequest, - Body: "Alice is requesting to verify your device, but your client does not support verification, so you may need to use a different verification method.", + Body: fmt.Sprintf("%s is requesting to verify your device, but your client does not support verification, so you may need to use a different verification method.", vh.client.UserID), FromDevice: vh.client.DeviceID, Methods: vh.supportedMethods, To: to, @@ -412,65 +365,145 @@ func (vh *VerificationHelper) StartInRoomVerification(ctx context.Context, roomI vh.activeTransactionsLock.Lock() defer vh.activeTransactionsLock.Unlock() - vh.activeTransactions[txnID] = &verificationTransaction{ + return txnID, vh.store.SaveVerificationTransaction(ctx, VerificationTransaction{ + ExpirationTime: jsontime.UnixMilli{Time: time.Now().Add(time.Minute * 10)}, RoomID: roomID, - VerificationState: verificationStateRequested, + VerificationState: VerificationStateRequested, TransactionID: txnID, - TheirUser: to, - } - return txnID, nil + TheirUserID: to, + }) } // AcceptVerification accepts a verification request. The transaction ID should // be the transaction ID of a verification request that was received via the // VerificationRequested callback in [RequiredCallbacks]. func (vh *VerificationHelper) AcceptVerification(ctx context.Context, txnID id.VerificationTransactionID) error { + vh.activeTransactionsLock.Lock() + defer vh.activeTransactionsLock.Unlock() + log := vh.getLog(ctx).With(). Str("verification_action", "accept verification"). Stringer("transaction_id", txnID). Logger() + ctx = log.WithContext(ctx) - txn, ok := vh.activeTransactions[txnID] - if !ok { - return fmt.Errorf("unknown transaction ID") - } - if txn.VerificationState != verificationStateRequested { + txn, err := vh.store.GetVerificationTransaction(ctx, txnID) + if err != nil { + return err + } else if txn.VerificationState != VerificationStateRequested { return fmt.Errorf("transaction is not in the requested state") } - log.Info().Msg("Sending ready event") + supportedMethods := map[event.VerificationMethod]struct{}{} + for _, method := range txn.TheirSupportedMethods { + switch method { + case event.VerificationMethodSAS: + if slices.Contains(vh.supportedMethods, event.VerificationMethodSAS) { + supportedMethods[event.VerificationMethodSAS] = struct{}{} + } + case event.VerificationMethodQRCodeShow: + if slices.Contains(vh.supportedMethods, event.VerificationMethodQRCodeScan) { + supportedMethods[event.VerificationMethodQRCodeScan] = struct{}{} + supportedMethods[event.VerificationMethodReciprocate] = struct{}{} + } + case event.VerificationMethodQRCodeScan: + if slices.Contains(vh.supportedMethods, event.VerificationMethodQRCodeShow) { + supportedMethods[event.VerificationMethodQRCodeShow] = struct{}{} + supportedMethods[event.VerificationMethodReciprocate] = struct{}{} + } + } + } + + log.Info().Any("methods", maps.Keys(supportedMethods)).Msg("Sending ready event") readyEvt := &event.VerificationReadyEventContent{ FromDevice: vh.client.DeviceID, - Methods: vh.supportedMethods, + Methods: maps.Keys(supportedMethods), } - err := vh.sendVerificationEvent(ctx, txn, event.InRoomVerificationReady, readyEvt) + err = vh.sendVerificationEvent(ctx, txn, event.InRoomVerificationReady, readyEvt) if err != nil { return err } - txn.VerificationState = verificationStateReady + txn.VerificationState = VerificationStateReady - if slices.Contains(txn.TheirSupportedMethods, event.VerificationMethodQRCodeShow) { - vh.scanQRCode(ctx, txn.TransactionID) + supportsSAS := slices.Contains(vh.supportedMethods, event.VerificationMethodSAS) && + slices.Contains(txn.TheirSupportedMethods, event.VerificationMethodSAS) + supportsReciprocate := slices.Contains(vh.supportedMethods, event.VerificationMethodReciprocate) && + slices.Contains(txn.TheirSupportedMethods, event.VerificationMethodReciprocate) + supportsScanQRCode := supportsReciprocate && + slices.Contains(vh.supportedMethods, event.VerificationMethodQRCodeScan) && + slices.Contains(txn.TheirSupportedMethods, event.VerificationMethodQRCodeShow) + + qrCode, err := vh.generateQRCode(ctx, &txn) + if err != nil { + return err } - - return vh.generateAndShowQRCode(ctx, txn) + vh.verificationReady(ctx, txn.TransactionID, txn.TheirDeviceID, supportsSAS, supportsScanQRCode, qrCode) + return vh.store.SaveVerificationTransaction(ctx, txn) } -// CancelVerification cancels a verification request. The transaction ID should -// be the transaction ID of a verification request that was received via the -// VerificationRequested callback in [RequiredCallbacks]. +// DismissVerification dismisses the verification request with the given +// transaction ID. The transaction ID should be one received via the +// VerificationRequested callback in [RequiredCallbacks] or the +// [StartVerification] or [StartInRoomVerification] functions. +func (vh *VerificationHelper) DismissVerification(ctx context.Context, txnID id.VerificationTransactionID) error { + vh.activeTransactionsLock.Lock() + defer vh.activeTransactionsLock.Unlock() + return vh.store.DeleteVerification(ctx, txnID) +} + +// DismissVerification cancels the verification request with the given +// transaction ID. The transaction ID should be one received via the +// VerificationRequested callback in [RequiredCallbacks] or the +// [StartVerification] or [StartInRoomVerification] functions. func (vh *VerificationHelper) CancelVerification(ctx context.Context, txnID id.VerificationTransactionID, code event.VerificationCancelCode, reason string) error { + vh.activeTransactionsLock.Lock() + defer vh.activeTransactionsLock.Unlock() + + txn, err := vh.store.GetVerificationTransaction(ctx, txnID) + if err != nil { + return err + } log := vh.getLog(ctx).With(). Str("verification_action", "cancel verification"). Stringer("transaction_id", txnID). + Str("code", string(code)). + Str("reason", reason). Logger() ctx = log.WithContext(ctx) - txn, ok := vh.activeTransactions[txnID] - if !ok { - return fmt.Errorf("unknown transaction ID") + log.Info().Msg("Sending cancellation event") + cancelEvt := &event.VerificationCancelEventContent{Code: code, Reason: reason} + if len(txn.RoomID) > 0 { + // Sending the cancellation event to the room. + err := vh.sendVerificationEvent(ctx, txn, event.InRoomVerificationCancel, cancelEvt) + if err != nil { + return fmt.Errorf("failed to send cancel verification event (code: %s, reason: %s): %w", code, reason, err) + } + } else { + cancelEvt.SetTransactionID(txn.TransactionID) + req := mautrix.ReqSendToDevice{Messages: map[id.UserID]map[id.DeviceID]*event.Content{ + txn.TheirUserID: {}, + }} + if len(txn.TheirDeviceID) > 0 { + // Send the cancellation event to only the device that accepted the + // verification request. All of the other devices already received a + // cancellation event with code "m.acceped". + req.Messages[txn.TheirUserID][txn.TheirDeviceID] = &event.Content{Parsed: cancelEvt} + } else { + // Send the cancellation event to all of the devices that we sent the + // request to. + for _, deviceID := range txn.SentToDeviceIDs { + if deviceID != vh.client.DeviceID { + req.Messages[txn.TheirUserID][deviceID] = &event.Content{Parsed: cancelEvt} + } + } + } + _, err := vh.client.SendToDevice(ctx, event.ToDeviceVerificationCancel, &req) + if err != nil { + return fmt.Errorf("failed to send m.key.verification.cancel event to %v: %w", maps.Keys(req.Messages[txn.TheirUserID]), err) + } } - return vh.cancelVerificationTxn(ctx, txn, code, reason) + return vh.store.DeleteVerification(ctx, txn.TransactionID) } // sendVerificationEvent sends a verification event to the other user's device @@ -482,49 +515,56 @@ func (vh *VerificationHelper) CancelVerification(ctx context.Context, txnID id.V // [event.VerificationTransactionable]. // - evtType can be either the to-device or in-room version of the event type // as it is always stringified. -func (vh *VerificationHelper) sendVerificationEvent(ctx context.Context, txn *verificationTransaction, evtType event.Type, content any) error { +func (vh *VerificationHelper) sendVerificationEvent(ctx context.Context, txn VerificationTransaction, evtType event.Type, content any) error { if txn.RoomID != "" { content.(event.Relatable).SetRelatesTo(&event.RelatesTo{Type: event.RelReference, EventID: id.EventID(txn.TransactionID)}) _, err := vh.client.SendMessageEvent(ctx, txn.RoomID, evtType, &event.Content{ Parsed: content, }) if err != nil { - return fmt.Errorf("failed to send start event: %w", err) + return fmt.Errorf("failed to send %s event to %s: %w", evtType.String(), txn.RoomID, err) } } else { content.(event.VerificationTransactionable).SetTransactionID(txn.TransactionID) req := mautrix.ReqSendToDevice{Messages: map[id.UserID]map[id.DeviceID]*event.Content{ - txn.TheirUser: { - txn.TheirDevice: &event.Content{Parsed: content}, + txn.TheirUserID: { + txn.TheirDeviceID: &event.Content{Parsed: content}, }, }} _, err := vh.client.SendToDevice(ctx, evtType, &req) if err != nil { - return fmt.Errorf("failed to send start event: %w", err) + return fmt.Errorf("failed to send %s event to %s: %w", evtType.String(), txn.TheirDeviceID, err) } } return nil } -func (vh *VerificationHelper) cancelVerificationTxn(ctx context.Context, txn *verificationTransaction, code event.VerificationCancelCode, reasonFmtStr string, fmtArgs ...any) error { - log := vh.getLog(ctx) - reason := fmt.Sprintf(reasonFmtStr, fmtArgs...) - log.Info(). +// cancelVerificationTxn cancels a verification transaction with the given code +// and reason. It always returns an error, which is the formatted error message +// (this is allows the caller to return the result of this function call +// directly to expose the error to its caller). +// +// Must always be called with the activeTransactionsLock held. +func (vh *VerificationHelper) cancelVerificationTxn(ctx context.Context, txn VerificationTransaction, code event.VerificationCancelCode, reasonFmtStr string, fmtArgs ...any) error { + reason := fmt.Errorf(reasonFmtStr, fmtArgs...).Error() + log := vh.getLog(ctx).With(). Stringer("transaction_id", txn.TransactionID). Str("code", string(code)). Str("reason", reason). - Msg("Sending cancellation event") - cancelEvt := &event.VerificationCancelEventContent{ - Code: code, - Reason: reason, - } + Logger() + ctx = log.WithContext(ctx) + log.Info().Msg("Sending cancellation event") + cancelEvt := &event.VerificationCancelEventContent{Code: code, Reason: reason} err := vh.sendVerificationEvent(ctx, txn, event.InRoomVerificationCancel, cancelEvt) if err != nil { - return err + log.Err(err).Msg("failed to send cancellation event") + return fmt.Errorf("failed to send cancel verification event (code: %s, reason: %s): %w", code, reason, err) + } + if err = vh.store.DeleteVerification(ctx, txn.TransactionID); err != nil { + log.Err(err).Msg("deleting verification failed") } - txn.VerificationState = verificationStateCancelled vh.verificationCancelledCallback(ctx, txn.TransactionID, code, reason) - return nil + return fmt.Errorf("verification cancelled (code: %s): %s", code, reason) } func (vh *VerificationHelper) onVerificationRequest(ctx context.Context, evt *event.Event) { @@ -560,55 +600,124 @@ func (vh *VerificationHelper) onVerificationRequest(ctx context.Context, evt *ev return } - if verificationRequest.TransactionID == "" { + if verificationRequest.Timestamp.Add(10 * time.Minute).Before(time.Now()) { + log.Warn().Msg("Ignoring verification request that is over ten minutes old") + return + } + + if len(verificationRequest.TransactionID) == 0 { log.Warn().Msg("Ignoring verification request without a transaction ID") return } - log = log.With().Any("requested_methods", verificationRequest.Methods).Logger() + log = log.With(). + Any("requested_methods", verificationRequest.Methods). + Stringer("transaction_id", verificationRequest.TransactionID). + Stringer("from_device", verificationRequest.FromDevice). + Logger() ctx = log.WithContext(ctx) log.Info().Msg("Received verification request") - vh.activeTransactionsLock.Lock() - _, ok := vh.activeTransactions[verificationRequest.TransactionID] - if ok { - vh.activeTransactionsLock.Unlock() - log.Info().Msg("Ignoring verification request for an already active transaction") + // Check if we support any of the methods listed + var supportsAnyMethod bool + for _, method := range verificationRequest.Methods { + switch method { + case event.VerificationMethodSAS: + supportsAnyMethod = slices.Contains(vh.supportedMethods, event.VerificationMethodSAS) + case event.VerificationMethodQRCodeScan: + supportsAnyMethod = slices.Contains(vh.supportedMethods, event.VerificationMethodQRCodeShow) && + slices.Contains(verificationRequest.Methods, event.VerificationMethodReciprocate) + case event.VerificationMethodQRCodeShow: + supportsAnyMethod = slices.Contains(vh.supportedMethods, event.VerificationMethodQRCodeScan) && + slices.Contains(verificationRequest.Methods, event.VerificationMethodReciprocate) + } + if supportsAnyMethod { + break + } + } + if !supportsAnyMethod { + log.Warn().Msg("Ignoring verification request that doesn't have any methods we support") return } - vh.activeTransactions[verificationRequest.TransactionID] = &verificationTransaction{ + + vh.activeTransactionsLock.Lock() + newTxn := VerificationTransaction{ + ExpirationTime: jsontime.UnixMilli{Time: verificationRequest.Timestamp.Add(time.Minute * 10)}, RoomID: evt.RoomID, - VerificationState: verificationStateRequested, + VerificationState: VerificationStateRequested, TransactionID: verificationRequest.TransactionID, - TheirDevice: verificationRequest.FromDevice, - TheirUser: evt.Sender, + TheirDeviceID: verificationRequest.FromDevice, + TheirUserID: evt.Sender, TheirSupportedMethods: verificationRequest.Methods, } + if txn, err := vh.store.FindVerificationTransactionForUserDevice(ctx, evt.Sender, verificationRequest.FromDevice); err != nil && !errors.Is(err, ErrUnknownVerificationTransaction) { + log.Err(err).Stringer("sender", evt.Sender).Stringer("device_id", verificationRequest.FromDevice).Msg("failed to find verification transaction") + vh.activeTransactionsLock.Unlock() + return + } else if !errors.Is(err, ErrUnknownVerificationTransaction) { + if txn.TransactionID == verificationRequest.TransactionID { + vh.cancelVerificationTxn(ctx, txn, event.VerificationCancelCodeUnexpectedMessage, "received a new verification request for the same transaction ID") + } else { + vh.cancelVerificationTxn(ctx, txn, event.VerificationCancelCodeUnexpectedMessage, "received multiple verification requests from the same device") + vh.cancelVerificationTxn(ctx, newTxn, event.VerificationCancelCodeUnexpectedMessage, "received multiple verification requests from the same device") + } + vh.activeTransactionsLock.Unlock() + return + } + if err := vh.store.SaveVerificationTransaction(ctx, newTxn); err != nil { + log.Err(err).Msg("failed to save verification transaction") + } vh.activeTransactionsLock.Unlock() - vh.verificationRequested(ctx, verificationRequest.TransactionID, evt.Sender) + vh.expireTransactionAt(verificationRequest.TransactionID, verificationRequest.Timestamp.Add(time.Minute*10)) + vh.verificationRequested(ctx, verificationRequest.TransactionID, evt.Sender, verificationRequest.FromDevice) } -func (vh *VerificationHelper) onVerificationReady(ctx context.Context, txn *verificationTransaction, evt *event.Event) { +func (vh *VerificationHelper) expireTransactionAt(txnID id.VerificationTransactionID, expiresAt time.Time) { + go func() { + time.Sleep(time.Until(expiresAt)) + + vh.activeTransactionsLock.Lock() + defer vh.activeTransactionsLock.Unlock() + + txn, err := vh.store.GetVerificationTransaction(context.Background(), txnID) + if err == ErrUnknownVerificationTransaction { + // Already deleted, nothing to expire + return + } else if err != nil { + vh.getLog(context.Background()).Err(err).Msg("failed to get verification transaction to expire") + } else { + vh.cancelVerificationTxn(context.Background(), txn, event.VerificationCancelCodeTimeout, "verification timed out") + } + }() +} + +func (vh *VerificationHelper) onVerificationReady(ctx context.Context, txn VerificationTransaction, evt *event.Event) { log := vh.getLog(ctx).With(). Str("verification_action", "verification ready"). Logger() - - if txn.VerificationState != verificationStateRequested { - log.Warn().Msg("Ignoring verification ready event for a transaction that is not in the requested state") - return - } + ctx = log.WithContext(ctx) vh.activeTransactionsLock.Lock() defer vh.activeTransactionsLock.Unlock() + if txn.VerificationState != VerificationStateRequested { + vh.cancelVerificationTxn(ctx, txn, event.VerificationCancelCodeUnexpectedMessage, "verification ready event received for a transaction that is not in the requested state") + return + } + readyEvt := evt.Content.AsVerificationReady() // Update the transaction state. - txn.VerificationState = verificationStateReady - txn.TheirDevice = readyEvt.FromDevice + txn.VerificationState = VerificationStateReady + txn.TheirDeviceID = readyEvt.FromDevice txn.TheirSupportedMethods = readyEvt.Methods + log.Info(). + Stringer("their_device_id", txn.TheirDeviceID). + Any("their_supported_methods", txn.TheirSupportedMethods). + Msg("Received verification ready event") + // If we sent this verification request, send cancellations to all of the // other devices. if len(txn.SentToDeviceIDs) > 0 { @@ -619,56 +728,66 @@ func (vh *VerificationHelper) onVerificationReady(ctx context.Context, txn *veri Reason: "The verification was accepted on another device.", }, } - devices, err := vh.mach.CryptoStore.GetDevices(ctx, txn.TheirUser) - if err != nil { - vh.cancelVerificationTxn(ctx, txn, event.VerificationCancelCodeUser, "failed to get devices for %s: %v", txn.TheirUser, err) - return - } - req := mautrix.ReqSendToDevice{Messages: map[id.UserID]map[id.DeviceID]*event.Content{txn.TheirUser: {}}} - for deviceID := range devices { - if deviceID == txn.TheirDevice { + req := mautrix.ReqSendToDevice{Messages: map[id.UserID]map[id.DeviceID]*event.Content{txn.TheirUserID: {}}} + for _, deviceID := range txn.SentToDeviceIDs { + if deviceID == txn.TheirDeviceID || deviceID == vh.client.DeviceID { // Don't ever send a cancellation to the device that accepted - // the request. + // the request or to our own device (which can happen if this + // is a self-verification). continue } - req.Messages[txn.TheirUser][deviceID] = content + req.Messages[txn.TheirUserID][deviceID] = content } - _, err = vh.client.SendToDevice(ctx, event.ToDeviceVerificationRequest, &req) + _, err := vh.client.SendToDevice(ctx, event.ToDeviceVerificationCancel, &req) if err != nil { log.Warn().Err(err).Msg("Failed to send cancellation requests") } } - if slices.Contains(txn.TheirSupportedMethods, event.VerificationMethodQRCodeShow) { - vh.scanQRCode(ctx, txn.TransactionID) + supportsSAS := slices.Contains(vh.supportedMethods, event.VerificationMethodSAS) && + slices.Contains(txn.TheirSupportedMethods, event.VerificationMethodSAS) + supportsReciprocate := slices.Contains(vh.supportedMethods, event.VerificationMethodReciprocate) && + slices.Contains(txn.TheirSupportedMethods, event.VerificationMethodReciprocate) + supportsScanQRCode := supportsReciprocate && + slices.Contains(vh.supportedMethods, event.VerificationMethodQRCodeScan) && + slices.Contains(txn.TheirSupportedMethods, event.VerificationMethodQRCodeShow) + + qrCode, err := vh.generateQRCode(ctx, &txn) + if err != nil { + vh.cancelVerificationTxn(ctx, txn, event.VerificationCancelCodeInternalError, "failed to generate QR code: %w", err) + return } - err := vh.generateAndShowQRCode(ctx, txn) - if err != nil { - vh.cancelVerificationTxn(ctx, txn, event.VerificationCancelCodeUser, "failed to generate and show QR code: %v", err) + vh.verificationReady(ctx, txn.TransactionID, txn.TheirDeviceID, supportsSAS, supportsScanQRCode, qrCode) + + if err := vh.store.SaveVerificationTransaction(ctx, txn); err != nil { + vh.cancelVerificationTxn(ctx, txn, event.VerificationCancelCodeInternalError, "failed to save verification transaction: %w", err) } } -func (vh *VerificationHelper) onVerificationStart(ctx context.Context, txn *verificationTransaction, evt *event.Event) { +func (vh *VerificationHelper) onVerificationStart(ctx context.Context, txn VerificationTransaction, evt *event.Event) { startEvt := evt.Content.AsVerificationStart() log := vh.getLog(ctx).With(). Str("verification_action", "verification start"). Str("method", string(startEvt.Method)). + Stringer("their_device_id", txn.TheirDeviceID). + Any("their_supported_methods", txn.TheirSupportedMethods). + Bool("started_by_us", txn.StartedByUs). Logger() ctx = log.WithContext(ctx) + log.Info().Msg("Received verification start event") vh.activeTransactionsLock.Lock() defer vh.activeTransactionsLock.Unlock() - if txn.VerificationState == verificationStateSASStarted || txn.VerificationState == verificationStateOurQRScanned || txn.VerificationState == verificationStateTheirQRScanned { + if txn.VerificationState == VerificationStateSASStarted || txn.VerificationState == VerificationStateOurQRScanned || txn.VerificationState == VerificationStateTheirQRScanned { // We might have sent the event, and they also sent an event. if txn.StartEventContent == nil || !txn.StartedByUs { // We didn't sent a start event yet, so we have gotten ourselves // into a bad state. They've either sent two start events, or we // have gone on to a new state. - vh.cancelVerificationTxn(ctx, txn, event.VerificationCancelCodeUnexpectedMessage, - "got repeat start event from other user") + vh.cancelVerificationTxn(ctx, txn, event.VerificationCancelCodeUnexpectedMessage, "got repeat start event from other user") return } @@ -694,71 +813,120 @@ func (vh *VerificationHelper) onVerificationStart(ctx context.Context, txn *veri return } - if txn.TheirUser < vh.client.UserID || (txn.TheirUser == vh.client.UserID && txn.TheirDevice < vh.client.DeviceID) { - // Use their start event instead of ours + if txn.TheirUserID < vh.client.UserID || (txn.TheirUserID == vh.client.UserID && txn.TheirDeviceID < vh.client.DeviceID) { + log.Debug().Msg("Using their start event instead of ours because they are alphabetically before us") txn.StartedByUs = false txn.StartEventContent = startEvt } - } else if txn.VerificationState != verificationStateReady { - vh.cancelVerificationTxn(ctx, txn, event.VerificationCancelCodeUnexpectedMessage, - "got start event for transaction that is not in ready state") + } else if txn.VerificationState != VerificationStateReady { + vh.cancelVerificationTxn(ctx, txn, event.VerificationCancelCodeUnexpectedMessage, "got start event for transaction that is not in ready state") return + } else { + txn.StartEventContent = startEvt } switch startEvt.Method { case event.VerificationMethodSAS: - txn.VerificationState = verificationStateSASStarted + log.Info().Msg("Received SAS start event") + txn.VerificationState = VerificationStateSASStarted if err := vh.onVerificationStartSAS(ctx, txn, evt); err != nil { - vh.cancelVerificationTxn(ctx, txn, event.VerificationCancelCodeUser, "failed to handle SAS verification start: %v", err) + vh.cancelVerificationTxn(ctx, txn, event.VerificationCancelCodeUser, "failed to handle SAS verification start: %w", err) } case event.VerificationMethodReciprocate: log.Info().Msg("Received reciprocate start event") - if !bytes.Equal(txn.QRCodeSharedSecret, startEvt.Secret) { + if !bytes.Equal(txn.QRCodeSharedSecret, txn.StartEventContent.Secret) { vh.cancelVerificationTxn(ctx, txn, event.VerificationCancelCodeKeyMismatch, "reciprocated shared secret does not match") return } - txn.VerificationState = verificationStateOurQRScanned - vh.qrCodeScaned(ctx, txn.TransactionID) + txn.VerificationState = VerificationStateOurQRScanned + vh.qrCodeScanned(ctx, txn.TransactionID) + if err := vh.store.SaveVerificationTransaction(ctx, txn); err != nil { + log.Err(err).Msg("failed to save verification transaction") + } default: // Note that we should never get m.qr_code.show.v1 or m.qr_code.scan.v1 // here, since the start command for scanning and showing QR codes // should be of type m.reciprocate.v1. - log.Error().Str("method", string(startEvt.Method)).Msg("Unsupported verification method in start event") - vh.cancelVerificationTxn(ctx, txn, event.VerificationCancelCodeUnknownMethod, fmt.Sprintf("unknown method %s", startEvt.Method)) + log.Error().Str("method", string(txn.StartEventContent.Method)).Msg("Unsupported verification method in start event") + vh.cancelVerificationTxn(ctx, txn, event.VerificationCancelCodeUnknownMethod, "unknown method %s", txn.StartEventContent.Method) } } -func (vh *VerificationHelper) onVerificationDone(ctx context.Context, txn *verificationTransaction, evt *event.Event) { - vh.getLog(ctx).Info(). +func (vh *VerificationHelper) onVerificationDone(ctx context.Context, txn VerificationTransaction, evt *event.Event) { + log := vh.getLog(ctx).With(). Str("verification_action", "done"). Stringer("transaction_id", txn.TransactionID). - Msg("Verification done") + Bool("sent_our_done", txn.SentOurDone). + Logger() + ctx = log.WithContext(ctx) + log.Info().Msg("Verification done") vh.activeTransactionsLock.Lock() defer vh.activeTransactionsLock.Unlock() - if txn.VerificationState != verificationStateTheirQRScanned && txn.VerificationState != verificationStateSASMACExchanged { - vh.cancelVerificationTxn(ctx, txn, event.VerificationCancelCodeUnexpectedMessage, - "got done event for transaction that is not in QR-scanned or MAC-exchanged state") + if !slices.Contains([]VerificationState{ + VerificationStateTheirQRScanned, VerificationStateOurQRScanned, VerificationStateSASMACExchanged, + }, txn.VerificationState) { + vh.cancelVerificationTxn(ctx, txn, event.VerificationCancelCodeUnexpectedMessage, "got done event for transaction that is not in QR-scanned or MAC-exchanged state") return } - txn.VerificationState = verificationStateDone txn.ReceivedTheirDone = true if txn.SentOurDone { - vh.verificationDone(ctx, txn.TransactionID) + if err := vh.store.DeleteVerification(ctx, txn.TransactionID); err != nil { + log.Err(err).Msg("Delete verification failed") + } + vh.verificationDone(ctx, txn.TransactionID, txn.StartEventContent.Method) + } else if err := vh.store.SaveVerificationTransaction(ctx, txn); err != nil { + log.Err(err).Msg("failed to save verification transaction") } } -func (vh *VerificationHelper) onVerificationCancel(ctx context.Context, txn *verificationTransaction, evt *event.Event) { +func (vh *VerificationHelper) onVerificationCancel(ctx context.Context, txn VerificationTransaction, evt *event.Event) { cancelEvt := evt.Content.AsVerificationCancel() - vh.getLog(ctx).Info(). + log := vh.getLog(ctx).With(). Str("verification_action", "cancel"). Stringer("transaction_id", txn.TransactionID). Str("cancel_code", string(cancelEvt.Code)). Str("reason", cancelEvt.Reason). - Msg("Verification was cancelled") + Logger() + ctx = log.WithContext(ctx) + log.Info().Msg("Verification was cancelled") vh.activeTransactionsLock.Lock() defer vh.activeTransactionsLock.Unlock() - txn.VerificationState = verificationStateCancelled + + // Element (and at least the old desktop client) send cancellation events + // when the user rejects the verification request. This is really dumb, + // because they should just instead ignore the request and not send a + // cancellation. + // + // The above behavior causes a problem with the other devices that we sent + // the verification request to because they don't know that the request was + // cancelled. + // + // As a workaround, if we receive a cancellation event to a transaction + // that is currently in the REQUESTED state, then we will send + // cancellations to all of the devices that we sent the request to. This + // will ensure that all of the clients know that the request was cancelled. + if txn.VerificationState == VerificationStateRequested && len(txn.SentToDeviceIDs) > 0 { + content := &event.Content{ + Parsed: &event.VerificationCancelEventContent{ + ToDeviceVerificationEvent: event.ToDeviceVerificationEvent{TransactionID: txn.TransactionID}, + Code: event.VerificationCancelCodeUser, + Reason: "The verification was rejected from another device.", + }, + } + req := mautrix.ReqSendToDevice{Messages: map[id.UserID]map[id.DeviceID]*event.Content{txn.TheirUserID: {}}} + for _, deviceID := range txn.SentToDeviceIDs { + req.Messages[txn.TheirUserID][deviceID] = content + } + _, err := vh.client.SendToDevice(ctx, event.ToDeviceVerificationCancel, &req) + if err != nil { + log.Warn().Err(err).Msg("Failed to send cancellation requests") + } + } + + if err := vh.store.DeleteVerification(ctx, txn.TransactionID); err != nil { + log.Err(err).Msg("Delete verification failed") + } vh.verificationCancelledCallback(ctx, txn.TransactionID, cancelEvt.Code, cancelEvt.Reason) } diff --git a/crypto/verificationhelper/verificationhelper_qr_crosssign_test.go b/crypto/verificationhelper/verificationhelper_qr_crosssign_test.go new file mode 100644 index 00000000..5e3f146b --- /dev/null +++ b/crypto/verificationhelper/verificationhelper_qr_crosssign_test.go @@ -0,0 +1,153 @@ +// Copyright (c) 2024 Sumner Evans +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package verificationhelper_test + +import ( + "context" + "fmt" + "testing" + + "github.com/rs/zerolog/log" // zerolog-allow-global-log + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "maunium.net/go/mautrix/event" + "maunium.net/go/mautrix/id" +) + +func TestCrossSignVerification_ScanQRAndConfirmScan(t *testing.T) { + ctx := log.Logger.WithContext(context.TODO()) + + testCases := []struct { + sendingScansQR bool // false indicates that receiving device should emulate a scan + }{ + {false}, + {true}, + } + + for _, tc := range testCases { + t.Run(fmt.Sprintf("sendingScansQR=%t", tc.sendingScansQR), func(t *testing.T) { + ts, sendingClient, receivingClient, _, _, sendingMachine, receivingMachine := initServerAndLoginAliceBob(t, ctx) + sendingCallbacks, receivingCallbacks, sendingHelper, receivingHelper := initDefaultCallbacks(t, ctx, sendingClient, receivingClient, sendingMachine, receivingMachine) + var err error + + // Generate cross-signing keys for both users + _, _, err = sendingMachine.GenerateAndUploadCrossSigningKeys(ctx, nil, "") + require.NoError(t, err) + _, _, err = receivingMachine.GenerateAndUploadCrossSigningKeys(ctx, nil, "") + require.NoError(t, err) + + // Fetch each other's keys + sendingMachine.FetchKeys(ctx, []id.UserID{bobUserID}, true) + receivingMachine.FetchKeys(ctx, []id.UserID{aliceUserID}, true) + + // Send the verification request from the sender device and accept + // it on the receiving device and receive the verification ready + // event on the sending device. + txnID, err := sendingHelper.StartVerification(ctx, bobUserID) + require.NoError(t, err) + ts.DispatchToDevice(t, ctx, receivingClient) + err = receivingHelper.AcceptVerification(ctx, txnID) + require.NoError(t, err) + ts.DispatchToDevice(t, ctx, sendingClient) + + receivingShownQRCode := receivingCallbacks.GetQRCodeShown(txnID) + require.NotNil(t, receivingShownQRCode) + sendingShownQRCode := sendingCallbacks.GetQRCodeShown(txnID) + require.NotNil(t, sendingShownQRCode) + + if tc.sendingScansQR { + // Emulate scanning the QR code shown by the receiving device + // on the sending device. + err := sendingHelper.HandleScannedQRData(ctx, receivingShownQRCode.Bytes()) + require.NoError(t, err) + + // Ensure that the receiving device received a verification + // start event and a verification done event. + receivingInbox := ts.DeviceInbox[bobUserID][receivingDeviceID] + assert.Len(t, receivingInbox, 2) + + startEvt := receivingInbox[0].Content.AsVerificationStart() + assert.Equal(t, txnID, startEvt.TransactionID) + assert.Equal(t, sendingDeviceID, startEvt.FromDevice) + assert.Equal(t, event.VerificationMethodReciprocate, startEvt.Method) + assert.EqualValues(t, receivingShownQRCode.SharedSecret, startEvt.Secret) + + doneEvt := receivingInbox[1].Content.AsVerificationDone() + assert.Equal(t, txnID, doneEvt.TransactionID) + + // Handle the start and done events on the receiving client and + // confirm the scan. + ts.DispatchToDevice(t, ctx, receivingClient) + + // Ensure that the receiving device detected that its QR code + // was scanned. + assert.True(t, receivingCallbacks.WasOurQRCodeScanned(txnID)) + err = receivingHelper.ConfirmQRCodeScanned(ctx, txnID) + require.NoError(t, err) + + // Ensure that the sending device received a verification done + // event. + sendingInbox := ts.DeviceInbox[aliceUserID][sendingDeviceID] + require.Len(t, sendingInbox, 1) + doneEvt = sendingInbox[0].Content.AsVerificationDone() + assert.Equal(t, txnID, doneEvt.TransactionID) + + ts.DispatchToDevice(t, ctx, sendingClient) + } else { // receiving scans QR + // Emulate scanning the QR code shown by the sending device on + // the receiving device. + err := receivingHelper.HandleScannedQRData(ctx, sendingShownQRCode.Bytes()) + require.NoError(t, err) + + // Ensure that the sending device received a verification + // start event and a verification done event. + sendingInbox := ts.DeviceInbox[aliceUserID][sendingDeviceID] + assert.Len(t, sendingInbox, 2) + + startEvt := sendingInbox[0].Content.AsVerificationStart() + assert.Equal(t, txnID, startEvt.TransactionID) + assert.Equal(t, receivingDeviceID, startEvt.FromDevice) + assert.Equal(t, event.VerificationMethodReciprocate, startEvt.Method) + assert.EqualValues(t, sendingShownQRCode.SharedSecret, startEvt.Secret) + + doneEvt := sendingInbox[1].Content.AsVerificationDone() + assert.Equal(t, txnID, doneEvt.TransactionID) + + // Handle the start and done events on the receiving client and + // confirm the scan. + ts.DispatchToDevice(t, ctx, sendingClient) + + // Ensure that the sending device detected that its QR code was + // scanned. + assert.True(t, sendingCallbacks.WasOurQRCodeScanned(txnID)) + err = sendingHelper.ConfirmQRCodeScanned(ctx, txnID) + require.NoError(t, err) + + // Ensure that the receiving device received a verification + // done event. + receivingInbox := ts.DeviceInbox[bobUserID][receivingDeviceID] + require.Len(t, receivingInbox, 1) + doneEvt = receivingInbox[0].Content.AsVerificationDone() + assert.Equal(t, txnID, doneEvt.TransactionID) + + ts.DispatchToDevice(t, ctx, receivingClient) + } + + // Ensure that both devices have marked the verification as done. + assert.True(t, sendingCallbacks.IsVerificationDone(txnID)) + assert.True(t, receivingCallbacks.IsVerificationDone(txnID)) + + bobTrustsAlice, err := receivingMachine.IsUserTrusted(ctx, aliceUserID) + assert.NoError(t, err) + assert.True(t, bobTrustsAlice) + aliceTrustsBob, err := sendingMachine.IsUserTrusted(ctx, bobUserID) + assert.NoError(t, err) + assert.True(t, aliceTrustsBob) + }) + } +} diff --git a/crypto/verificationhelper/verificationhelper_qr_self_test.go b/crypto/verificationhelper/verificationhelper_qr_self_test.go new file mode 100644 index 00000000..ea918cd4 --- /dev/null +++ b/crypto/verificationhelper/verificationhelper_qr_self_test.go @@ -0,0 +1,369 @@ +// Copyright (c) 2024 Sumner Evans +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package verificationhelper_test + +import ( + "context" + "fmt" + "testing" + + "github.com/rs/zerolog/log" // zerolog-allow-global-log + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "maunium.net/go/mautrix/crypto" + "maunium.net/go/mautrix/crypto/verificationhelper" + "maunium.net/go/mautrix/event" +) + +func TestSelfVerification_Accept_QRContents(t *testing.T) { + ctx := log.Logger.WithContext(context.TODO()) + + testCases := []struct { + sendingGeneratedCrossSigningKeys bool + receivingGeneratedCrossSigningKeys bool + expectedAcceptError string + }{ + {true, false, ""}, + {false, true, ""}, + {false, false, "failed to get own cross-signing master public key"}, + } + + for _, tc := range testCases { + t.Run(fmt.Sprintf("sendingGenerated=%t receivingGenerated=%t err=%s", tc.sendingGeneratedCrossSigningKeys, tc.receivingGeneratedCrossSigningKeys, tc.expectedAcceptError), func(t *testing.T) { + ts, sendingClient, receivingClient, _, _, sendingMachine, receivingMachine := initServerAndLoginTwoAlice(t, ctx) + sendingCallbacks, receivingCallbacks, sendingHelper, receivingHelper := initDefaultCallbacks(t, ctx, sendingClient, receivingClient, sendingMachine, receivingMachine) + var err error + + var sendingRecoveryKey, receivingRecoveryKey string + var sendingCrossSigningKeysCache, receivingCrossSigningKeysCache *crypto.CrossSigningKeysCache + + if tc.sendingGeneratedCrossSigningKeys { + sendingRecoveryKey, sendingCrossSigningKeysCache, err = sendingMachine.GenerateAndUploadCrossSigningKeys(ctx, nil, "") + require.NoError(t, err) + assert.NotEmpty(t, sendingRecoveryKey) + assert.NotNil(t, sendingCrossSigningKeysCache) + } + + if tc.receivingGeneratedCrossSigningKeys { + receivingRecoveryKey, receivingCrossSigningKeysCache, err = receivingMachine.GenerateAndUploadCrossSigningKeys(ctx, nil, "") + require.NoError(t, err) + assert.NotEmpty(t, receivingRecoveryKey) + assert.NotNil(t, receivingCrossSigningKeysCache) + } + + // Send the verification request from the sender device and accept + // it on the receiving device and receive the verification ready + // event on the sending device. + txnID, err := sendingHelper.StartVerification(ctx, aliceUserID) + require.NoError(t, err) + ts.DispatchToDevice(t, ctx, receivingClient) + + err = receivingHelper.AcceptVerification(ctx, txnID) + if tc.expectedAcceptError != "" { + assert.ErrorContains(t, err, tc.expectedAcceptError) + return + } else { + require.NoError(t, err) + } + + ts.DispatchToDevice(t, ctx, sendingClient) + + receivingShownQRCode := receivingCallbacks.GetQRCodeShown(txnID) + require.NotNil(t, receivingShownQRCode) + assert.NotEmpty(t, receivingShownQRCode.SharedSecret) + assert.Equal(t, txnID, receivingShownQRCode.TransactionID) + + sendingShownQRCode := sendingCallbacks.GetQRCodeShown(txnID) + require.NotNil(t, sendingShownQRCode) + assert.NotEmpty(t, sendingShownQRCode.SharedSecret) + assert.Equal(t, txnID, sendingShownQRCode.TransactionID) + + // See the spec for the QR Code format: + // https://spec.matrix.org/v1.10/client-server-api/#qr-code-format + if tc.receivingGeneratedCrossSigningKeys { + masterKeyBytes := receivingMachine.GetOwnCrossSigningPublicKeys(ctx).MasterKey.Bytes() + + // The receiving device should have shown a QR Code with + // trusted mode + assert.Equal(t, verificationhelper.QRCodeModeSelfVerifyingMasterKeyTrusted, receivingShownQRCode.Mode) + assert.EqualValues(t, masterKeyBytes, receivingShownQRCode.Key1) // master key + assert.EqualValues(t, sendingMachine.OwnIdentity().SigningKey.Bytes(), receivingShownQRCode.Key2) // other device key + + // The sending device should have shown a QR code with + // untrusted mode. + assert.Equal(t, verificationhelper.QRCodeModeSelfVerifyingMasterKeyUntrusted, sendingShownQRCode.Mode) + assert.EqualValues(t, sendingMachine.OwnIdentity().SigningKey.Bytes(), sendingShownQRCode.Key1) // own device key + assert.EqualValues(t, masterKeyBytes, sendingShownQRCode.Key2) // master key + } else if tc.sendingGeneratedCrossSigningKeys { + masterKeyBytes := sendingMachine.GetOwnCrossSigningPublicKeys(ctx).MasterKey.Bytes() + + // The receiving device should have shown a QR code with + // untrusted mode + assert.Equal(t, verificationhelper.QRCodeModeSelfVerifyingMasterKeyUntrusted, receivingShownQRCode.Mode) + assert.EqualValues(t, receivingMachine.OwnIdentity().SigningKey.Bytes(), receivingShownQRCode.Key1) // own device key + assert.EqualValues(t, masterKeyBytes, receivingShownQRCode.Key2) // master key + + // The sending device should have shown a QR code with trusted + // mode. + assert.Equal(t, verificationhelper.QRCodeModeSelfVerifyingMasterKeyTrusted, sendingShownQRCode.Mode) + assert.EqualValues(t, masterKeyBytes, sendingShownQRCode.Key1) // master key + assert.EqualValues(t, receivingMachine.OwnIdentity().SigningKey.Bytes(), sendingShownQRCode.Key2) // other device key + } + }) + } +} + +func TestSelfVerification_ScanQRAndConfirmScan(t *testing.T) { + ctx := log.Logger.WithContext(context.TODO()) + + testCases := []struct { + sendingGeneratedCrossSigningKeys bool + sendingScansQR bool // false indicates that receiving device should emulate a scan + }{ + {false, false}, + {false, true}, + {true, false}, + {true, true}, + } + + for _, tc := range testCases { + t.Run(fmt.Sprintf("sendingGeneratedCrossSigningKeys=%t sendingScansQR=%t", tc.sendingGeneratedCrossSigningKeys, tc.sendingScansQR), func(t *testing.T) { + ts, sendingClient, receivingClient, _, _, sendingMachine, receivingMachine := initServerAndLoginTwoAlice(t, ctx) + sendingCallbacks, receivingCallbacks, sendingHelper, receivingHelper := initDefaultCallbacks(t, ctx, sendingClient, receivingClient, sendingMachine, receivingMachine) + var err error + + if tc.sendingGeneratedCrossSigningKeys { + _, _, err = sendingMachine.GenerateAndUploadCrossSigningKeys(ctx, nil, "") + require.NoError(t, err) + } else { + _, _, err = receivingMachine.GenerateAndUploadCrossSigningKeys(ctx, nil, "") + require.NoError(t, err) + } + + // Send the verification request from the sender device and accept + // it on the receiving device and receive the verification ready + // event on the sending device. + txnID, err := sendingHelper.StartVerification(ctx, aliceUserID) + require.NoError(t, err) + ts.DispatchToDevice(t, ctx, receivingClient) + err = receivingHelper.AcceptVerification(ctx, txnID) + require.NoError(t, err) + ts.DispatchToDevice(t, ctx, sendingClient) + + receivingShownQRCode := receivingCallbacks.GetQRCodeShown(txnID) + require.NotNil(t, receivingShownQRCode) + sendingShownQRCode := sendingCallbacks.GetQRCodeShown(txnID) + require.NotNil(t, sendingShownQRCode) + + if tc.sendingScansQR { + // Emulate scanning the QR code shown by the receiving device + // on the sending device. + err := sendingHelper.HandleScannedQRData(ctx, receivingShownQRCode.Bytes()) + require.NoError(t, err) + + // Ensure that the receiving device received a verification + // start event and a verification done event. + receivingInbox := ts.DeviceInbox[aliceUserID][receivingDeviceID] + assert.Len(t, receivingInbox, 2) + + startEvt := receivingInbox[0].Content.AsVerificationStart() + assert.Equal(t, txnID, startEvt.TransactionID) + assert.Equal(t, sendingDeviceID, startEvt.FromDevice) + assert.Equal(t, event.VerificationMethodReciprocate, startEvt.Method) + assert.EqualValues(t, receivingShownQRCode.SharedSecret, startEvt.Secret) + + doneEvt := receivingInbox[1].Content.AsVerificationDone() + assert.Equal(t, txnID, doneEvt.TransactionID) + + // Handle the start and done events on the receiving client and + // confirm the scan. + ts.DispatchToDevice(t, ctx, receivingClient) + + // Ensure that the receiving device detected that its QR code + // was scanned. + assert.True(t, receivingCallbacks.WasOurQRCodeScanned(txnID)) + err = receivingHelper.ConfirmQRCodeScanned(ctx, txnID) + require.NoError(t, err) + + // Ensure that the sending device received a verification done + // event. + sendingInbox := ts.DeviceInbox[aliceUserID][sendingDeviceID] + require.Len(t, sendingInbox, 1) + doneEvt = sendingInbox[0].Content.AsVerificationDone() + assert.Equal(t, txnID, doneEvt.TransactionID) + + ts.DispatchToDevice(t, ctx, sendingClient) + } else { // receiving scans QR + // Emulate scanning the QR code shown by the sending device on + // the receiving device. + err := receivingHelper.HandleScannedQRData(ctx, sendingShownQRCode.Bytes()) + require.NoError(t, err) + + // Ensure that the sending device received a verification + // start event and a verification done event. + sendingInbox := ts.DeviceInbox[aliceUserID][sendingDeviceID] + assert.Len(t, sendingInbox, 2) + + startEvt := sendingInbox[0].Content.AsVerificationStart() + assert.Equal(t, txnID, startEvt.TransactionID) + assert.Equal(t, receivingDeviceID, startEvt.FromDevice) + assert.Equal(t, event.VerificationMethodReciprocate, startEvt.Method) + assert.EqualValues(t, sendingShownQRCode.SharedSecret, startEvt.Secret) + + doneEvt := sendingInbox[1].Content.AsVerificationDone() + assert.Equal(t, txnID, doneEvt.TransactionID) + + // Handle the start and done events on the receiving client and + // confirm the scan. + ts.DispatchToDevice(t, ctx, sendingClient) + + // Ensure that the sending device detected that its QR code was + // scanned. + assert.True(t, sendingCallbacks.WasOurQRCodeScanned(txnID)) + err = sendingHelper.ConfirmQRCodeScanned(ctx, txnID) + require.NoError(t, err) + + // Ensure that the receiving device received a verification + // done event. + receivingInbox := ts.DeviceInbox[aliceUserID][receivingDeviceID] + require.Len(t, receivingInbox, 1) + doneEvt = receivingInbox[0].Content.AsVerificationDone() + assert.Equal(t, txnID, doneEvt.TransactionID) + + ts.DispatchToDevice(t, ctx, receivingClient) + } + + // Ensure that both devices have marked the verification as done. + assert.True(t, sendingCallbacks.IsVerificationDone(txnID)) + assert.True(t, receivingCallbacks.IsVerificationDone(txnID)) + }) + } +} + +func TestSelfVerification_ScanQRTransactionIDCorrupted(t *testing.T) { + ctx := log.Logger.WithContext(context.TODO()) + + ts, sendingClient, receivingClient, _, _, sendingMachine, receivingMachine := initServerAndLoginTwoAlice(t, ctx) + sendingCallbacks, receivingCallbacks, sendingHelper, receivingHelper := initDefaultCallbacks(t, ctx, sendingClient, receivingClient, sendingMachine, receivingMachine) + var err error + + _, _, err = sendingMachine.GenerateAndUploadCrossSigningKeys(ctx, nil, "") + require.NoError(t, err) + + // Send the verification request from the sender device and accept + // it on the receiving device and receive the verification ready + // event on the sending device. + txnID, err := sendingHelper.StartVerification(ctx, aliceUserID) + require.NoError(t, err) + ts.DispatchToDevice(t, ctx, receivingClient) + err = receivingHelper.AcceptVerification(ctx, txnID) + require.NoError(t, err) + ts.DispatchToDevice(t, ctx, sendingClient) + + receivingShownQRCodeBytes := receivingCallbacks.GetQRCodeShown(txnID).Bytes() + sendingShownQRCodeBytes := sendingCallbacks.GetQRCodeShown(txnID).Bytes() + + // Corrupt the QR codes (the 20th byte should be in the transaction ID) + receivingShownQRCodeBytes[20]++ + sendingShownQRCodeBytes[20]++ + + // Emulate scanning the QR code shown by the receiving device + // on the sending device. + err = sendingHelper.HandleScannedQRData(ctx, receivingShownQRCodeBytes) + assert.ErrorContains(t, err, "unknown transaction ID") + + // Emulate scanning the QR code shown by the sending device on + // the receiving device. + err = receivingHelper.HandleScannedQRData(ctx, sendingShownQRCodeBytes) + assert.ErrorContains(t, err, "unknown transaction ID") +} + +func TestSelfVerification_ScanQRKeyCorrupted(t *testing.T) { + ctx := log.Logger.WithContext(context.TODO()) + + testCases := []struct { + sendingGeneratedCrossSigningKeys bool + sendingScansQR bool // false indicates that receiving device should emulate a scan + corruptByte int + expectedError string + }{ + // The 50th byte should be in the first key + {false, false, 50, "the other device's key is not what we expected"}, // receiver scans sender QR code, sender doesn't trust the master key => mode 0x02 => key1 == sender device key + {false, true, 50, "the master key does not match"}, // sender scans receiver QR code, receiver trusts the master key => mode 0x01 => key1 == master key + {true, false, 50, "the master key does not match"}, // receiver scans sender QR code, sender trusts the master key => mode 0x01 => key1 == master key + {true, true, 50, "the other device's key is not what we expected"}, // sender scans receiver QR Code, receiver doesn't trust the master key => mode 0x02 => key1 == receiver device key + // The 100th byte should be in the second key + {false, false, 100, "the master key does not match"}, // receiver scans sender QR code, sender doesn't trust the master key => mode 0x02 => key2 == master key + {false, true, 100, "the other device has the wrong key for this device"}, // sender scans receiver QR code, receiver trusts the master key => mode 0x01 => key2 == sender device key + {true, false, 100, "the other device has the wrong key for this device"}, // receiver scans sender QR code, sender trusts the master key => mode 0x01 => key2 == receiver device key + {true, true, 100, "the master key does not match"}, // sender scans receiver QR Code, receiver doesn't trust the master key => mode 0x02 => key2 == master key + } + + for _, tc := range testCases { + t.Run(fmt.Sprintf("sendingGeneratedCrossSigningKeys=%t sendingScansQR=%t corrupt=%d", tc.sendingGeneratedCrossSigningKeys, tc.sendingScansQR, tc.corruptByte), func(t *testing.T) { + ts, sendingClient, receivingClient, _, _, sendingMachine, receivingMachine := initServerAndLoginTwoAlice(t, ctx) + sendingCallbacks, receivingCallbacks, sendingHelper, receivingHelper := initDefaultCallbacks(t, ctx, sendingClient, receivingClient, sendingMachine, receivingMachine) + var err error + + if tc.sendingGeneratedCrossSigningKeys { + _, _, err = sendingMachine.GenerateAndUploadCrossSigningKeys(ctx, nil, "") + require.NoError(t, err) + } else { + _, _, err = receivingMachine.GenerateAndUploadCrossSigningKeys(ctx, nil, "") + require.NoError(t, err) + } + + // Send the verification request from the sender device and accept + // it on the receiving device and receive the verification ready + // event on the sending device. + txnID, err := sendingHelper.StartVerification(ctx, aliceUserID) + require.NoError(t, err) + ts.DispatchToDevice(t, ctx, receivingClient) + err = receivingHelper.AcceptVerification(ctx, txnID) + require.NoError(t, err) + ts.DispatchToDevice(t, ctx, sendingClient) + + receivingShownQRCodeBytes := receivingCallbacks.GetQRCodeShown(txnID).Bytes() + sendingShownQRCodeBytes := sendingCallbacks.GetQRCodeShown(txnID).Bytes() + + // Corrupt the QR codes + receivingShownQRCodeBytes[tc.corruptByte]++ + sendingShownQRCodeBytes[tc.corruptByte]++ + + if tc.sendingScansQR { + // Emulate scanning the QR code shown by the receiving device + // on the sending device. + err := sendingHelper.HandleScannedQRData(ctx, receivingShownQRCodeBytes) + assert.ErrorContains(t, err, tc.expectedError) + + // Ensure that the receiving device received a cancellation. + receivingInbox := ts.DeviceInbox[aliceUserID][receivingDeviceID] + assert.Len(t, receivingInbox, 1) + ts.DispatchToDevice(t, ctx, receivingClient) + cancellation := receivingCallbacks.GetVerificationCancellation(txnID) + require.NotNil(t, cancellation) + assert.Equal(t, event.VerificationCancelCodeKeyMismatch, cancellation.Code) + assert.Equal(t, tc.expectedError, cancellation.Reason) + } else { // receiving scans QR + // Emulate scanning the QR code shown by the sending device on + // the receiving device. + err := receivingHelper.HandleScannedQRData(ctx, sendingShownQRCodeBytes) + assert.ErrorContains(t, err, tc.expectedError) + + // Ensure that the sending device received a cancellation. + sendingInbox := ts.DeviceInbox[aliceUserID][sendingDeviceID] + assert.Len(t, sendingInbox, 1) + ts.DispatchToDevice(t, ctx, sendingClient) + cancellation := sendingCallbacks.GetVerificationCancellation(txnID) + require.NotNil(t, cancellation) + assert.Equal(t, event.VerificationCancelCodeKeyMismatch, cancellation.Code) + assert.Equal(t, tc.expectedError, cancellation.Reason) + } + }) + } +} diff --git a/crypto/verificationhelper/verificationhelper_sas_test.go b/crypto/verificationhelper/verificationhelper_sas_test.go new file mode 100644 index 00000000..283eca84 --- /dev/null +++ b/crypto/verificationhelper/verificationhelper_sas_test.go @@ -0,0 +1,360 @@ +package verificationhelper_test + +import ( + "context" + "fmt" + "testing" + + "github.com/rs/zerolog/log" // zerolog-allow-global-log + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "golang.org/x/exp/maps" + + "maunium.net/go/mautrix/crypto" + "maunium.net/go/mautrix/event" + "maunium.net/go/mautrix/id" +) + +func TestVerification_SAS(t *testing.T) { + ctx := log.Logger.WithContext(context.TODO()) + + testCases := []struct { + sendingGeneratedCrossSigningKeys bool + sendingStartsSAS bool + sendingConfirmsFirst bool + }{ + {true, true, true}, + {true, true, false}, + {true, false, true}, + {true, false, false}, + {false, true, true}, + {false, true, false}, + {false, false, true}, + {false, false, false}, + } + + for _, tc := range testCases { + t.Run(fmt.Sprintf("sendingGenerated=%t sendingStartsSAS=%t sendingConfirmsFirst=%t", tc.sendingGeneratedCrossSigningKeys, tc.sendingStartsSAS, tc.sendingConfirmsFirst), func(t *testing.T) { + ts, sendingClient, receivingClient, _, _, sendingMachine, receivingMachine := initServerAndLoginTwoAlice(t, ctx) + sendingCallbacks, receivingCallbacks, sendingHelper, receivingHelper := initDefaultCallbacks(t, ctx, sendingClient, receivingClient, sendingMachine, receivingMachine) + var err error + + var sendingRecoveryKey, receivingRecoveryKey string + var sendingCrossSigningKeysCache, receivingCrossSigningKeysCache *crypto.CrossSigningKeysCache + + if tc.sendingGeneratedCrossSigningKeys { + sendingRecoveryKey, sendingCrossSigningKeysCache, err = sendingMachine.GenerateAndUploadCrossSigningKeys(ctx, nil, "") + require.NoError(t, err) + assert.NotEmpty(t, sendingRecoveryKey) + assert.NotNil(t, sendingCrossSigningKeysCache) + } else { + receivingRecoveryKey, receivingCrossSigningKeysCache, err = receivingMachine.GenerateAndUploadCrossSigningKeys(ctx, nil, "") + require.NoError(t, err) + assert.NotEmpty(t, receivingRecoveryKey) + assert.NotNil(t, receivingCrossSigningKeysCache) + } + + // Send the verification request from the sender device and accept + // it on the receiving device and receive the verification ready + // event on the sending device. + txnID, err := sendingHelper.StartVerification(ctx, aliceUserID) + require.NoError(t, err) + ts.DispatchToDevice(t, ctx, receivingClient) + err = receivingHelper.AcceptVerification(ctx, txnID) + require.NoError(t, err) + ts.DispatchToDevice(t, ctx, sendingClient) + + // Test that the start event is correct + var startEvt *event.VerificationStartEventContent + if tc.sendingStartsSAS { + err = sendingHelper.StartSAS(ctx, txnID) + require.NoError(t, err) + + // Ensure that the receiving device received a verification + // start event. + receivingInbox := ts.DeviceInbox[aliceUserID][receivingDeviceID] + assert.Len(t, receivingInbox, 1) + startEvt = receivingInbox[0].Content.AsVerificationStart() + assert.Equal(t, sendingDeviceID, startEvt.FromDevice) + } else { + err = receivingHelper.StartSAS(ctx, txnID) + require.NoError(t, err) + + // Ensure that the receiving device received a verification + // start event. + sendingInbox := ts.DeviceInbox[aliceUserID][sendingDeviceID] + assert.Len(t, sendingInbox, 1) + startEvt = sendingInbox[0].Content.AsVerificationStart() + assert.Equal(t, receivingDeviceID, startEvt.FromDevice) + } + assert.Equal(t, txnID, startEvt.TransactionID) + assert.Equal(t, event.VerificationMethodSAS, startEvt.Method) + assert.Contains(t, startEvt.Hashes, event.VerificationHashMethodSHA256) + assert.Contains(t, startEvt.KeyAgreementProtocols, event.KeyAgreementProtocolCurve25519HKDFSHA256) + assert.Contains(t, startEvt.MessageAuthenticationCodes, event.MACMethodHKDFHMACSHA256) + assert.Contains(t, startEvt.MessageAuthenticationCodes, event.MACMethodHKDFHMACSHA256V2) + assert.Contains(t, startEvt.ShortAuthenticationString, event.SASMethodDecimal) + assert.Contains(t, startEvt.ShortAuthenticationString, event.SASMethodEmoji) + + // Test that the accept event is correct + var acceptEvt *event.VerificationAcceptEventContent + if tc.sendingStartsSAS { + // Process the verification start event on the receiving + // device. + ts.DispatchToDevice(t, ctx, receivingClient) + + // Receiving device sent the accept event to the sending device + sendingInbox := ts.DeviceInbox[aliceUserID][sendingDeviceID] + assert.Len(t, sendingInbox, 1) + acceptEvt = sendingInbox[0].Content.AsVerificationAccept() + } else { + // Process the verification start event on the sending device. + ts.DispatchToDevice(t, ctx, sendingClient) + + // Sending device sent the accept event to the receiving device + receivingInbox := ts.DeviceInbox[aliceUserID][receivingDeviceID] + assert.Len(t, receivingInbox, 1) + acceptEvt = receivingInbox[0].Content.AsVerificationAccept() + } + assert.Equal(t, txnID, acceptEvt.TransactionID) + assert.Equal(t, acceptEvt.Hash, event.VerificationHashMethodSHA256) + assert.Equal(t, acceptEvt.KeyAgreementProtocol, event.KeyAgreementProtocolCurve25519HKDFSHA256) + assert.Equal(t, acceptEvt.MessageAuthenticationCode, event.MACMethodHKDFHMACSHA256V2) + assert.Contains(t, acceptEvt.ShortAuthenticationString, event.SASMethodDecimal) + assert.Contains(t, acceptEvt.ShortAuthenticationString, event.SASMethodEmoji) + assert.NotEmpty(t, acceptEvt.Commitment) + + // Test that the first key event is correct + var firstKeyEvt *event.VerificationKeyEventContent + if tc.sendingStartsSAS { + // Process the verification accept event on the sending device. + ts.DispatchToDevice(t, ctx, sendingClient) + + // Sending device sends first key event to the receiving + // device. + receivingInbox := ts.DeviceInbox[aliceUserID][receivingDeviceID] + assert.Len(t, receivingInbox, 1) + firstKeyEvt = receivingInbox[0].Content.AsVerificationKey() + } else { + // Process the verification accept event on the receiving + // device. + ts.DispatchToDevice(t, ctx, receivingClient) + + // Receiving device sends first key event to the sending + // device. + sendingInbox := ts.DeviceInbox[aliceUserID][sendingDeviceID] + assert.Len(t, sendingInbox, 1) + firstKeyEvt = sendingInbox[0].Content.AsVerificationKey() + } + assert.Equal(t, txnID, firstKeyEvt.TransactionID) + assert.NotEmpty(t, firstKeyEvt.Key) + assert.Len(t, firstKeyEvt.Key, 32) + + // Test that the second key event is correct + var secondKeyEvt *event.VerificationKeyEventContent + if tc.sendingStartsSAS { + // Process the first key event on the receiving device. + ts.DispatchToDevice(t, ctx, receivingClient) + + // Receiving device sends second key event to the sending + // device. + sendingInbox := ts.DeviceInbox[aliceUserID][sendingDeviceID] + assert.Len(t, sendingInbox, 1) + secondKeyEvt = sendingInbox[0].Content.AsVerificationKey() + + // Ensure that the receiving device showed emojis and SAS numbers. + assert.Len(t, receivingCallbacks.GetDecimalsShown(txnID), 3) + emojis, descriptions := receivingCallbacks.GetEmojisAndDescriptionsShown(txnID) + assert.Len(t, emojis, 7) + assert.Len(t, descriptions, 7) + } else { + // Process the first key event on the sending device. + ts.DispatchToDevice(t, ctx, sendingClient) + + // Sending device sends second key event to the receiving + // device. + receivingInbox := ts.DeviceInbox[aliceUserID][receivingDeviceID] + assert.Len(t, receivingInbox, 1) + secondKeyEvt = receivingInbox[0].Content.AsVerificationKey() + + // Ensure that the sending device showed emojis and SAS numbers. + assert.Len(t, sendingCallbacks.GetDecimalsShown(txnID), 3) + emojis, descriptions := sendingCallbacks.GetEmojisAndDescriptionsShown(txnID) + assert.Len(t, emojis, 7) + assert.Len(t, descriptions, 7) + } + assert.Equal(t, txnID, secondKeyEvt.TransactionID) + assert.NotEmpty(t, secondKeyEvt.Key) + assert.Len(t, secondKeyEvt.Key, 32) + + // Ensure that the SAS codes are the same. + if tc.sendingStartsSAS { + // Process the second key event on the sending device. + ts.DispatchToDevice(t, ctx, sendingClient) + } else { + // Process the second key event on the receiving device. + ts.DispatchToDevice(t, ctx, receivingClient) + } + assert.Equal(t, sendingCallbacks.GetDecimalsShown(txnID), receivingCallbacks.GetDecimalsShown(txnID)) + sendingEmojis, sendingDescriptions := sendingCallbacks.GetEmojisAndDescriptionsShown(txnID) + receivingEmojis, receivingDescriptions := receivingCallbacks.GetEmojisAndDescriptionsShown(txnID) + assert.Equal(t, sendingEmojis, receivingEmojis) + assert.Equal(t, sendingDescriptions, receivingDescriptions) + + // Test that the first MAC event is correct + var firstMACEvt *event.VerificationMACEventContent + if tc.sendingConfirmsFirst { + err = sendingHelper.ConfirmSAS(ctx, txnID) + require.NoError(t, err) + + // The receiving device should have received the MAC event. + receivingInbox := ts.DeviceInbox[aliceUserID][receivingDeviceID] + assert.Len(t, receivingInbox, 1) + firstMACEvt = receivingInbox[0].Content.AsVerificationMAC() + + // The MAC event should have a MAC for the sending device ID. + assert.Contains(t, maps.Keys(firstMACEvt.MAC), id.NewKeyID(id.KeyAlgorithmEd25519, sendingDeviceID.String())) + } else { + err = receivingHelper.ConfirmSAS(ctx, txnID) + require.NoError(t, err) + + // The sending device should have received the MAC event. + sendingInbox := ts.DeviceInbox[aliceUserID][sendingDeviceID] + assert.Len(t, sendingInbox, 1) + firstMACEvt = sendingInbox[0].Content.AsVerificationMAC() + + // The MAC event should have a MAC for the receiving device ID. + assert.Contains(t, maps.Keys(firstMACEvt.MAC), id.NewKeyID(id.KeyAlgorithmEd25519, receivingDeviceID.String())) + } + assert.Equal(t, txnID, firstMACEvt.TransactionID) + + // The master key and the sending device ID should be in the + // MAC event's mac keys. + if tc.sendingGeneratedCrossSigningKeys { + assert.Contains(t, maps.Keys(firstMACEvt.MAC), id.NewKeyID(id.KeyAlgorithmEd25519, sendingCrossSigningKeysCache.MasterKey.PublicKey().String())) + } else { + assert.Contains(t, maps.Keys(firstMACEvt.MAC), id.NewKeyID(id.KeyAlgorithmEd25519, receivingCrossSigningKeysCache.MasterKey.PublicKey().String())) + } + + // Test that the second MAC event is correct + var secondMACEvt *event.VerificationMACEventContent + if tc.sendingConfirmsFirst { + err = receivingHelper.ConfirmSAS(ctx, txnID) + require.NoError(t, err) + + // The sending device should have received the MAC event. + sendingInbox := ts.DeviceInbox[aliceUserID][sendingDeviceID] + assert.Len(t, sendingInbox, 1) + secondMACEvt = sendingInbox[0].Content.AsVerificationMAC() + + // The MAC event should have a MAC for the receiving device ID. + assert.Contains(t, maps.Keys(secondMACEvt.MAC), id.NewKeyID(id.KeyAlgorithmEd25519, receivingDeviceID.String())) + } else { + err = sendingHelper.ConfirmSAS(ctx, txnID) + require.NoError(t, err) + + // The receiving device should have received the MAC event. + receivingInbox := ts.DeviceInbox[aliceUserID][receivingDeviceID] + assert.Len(t, receivingInbox, 1) + secondMACEvt = receivingInbox[0].Content.AsVerificationMAC() + + // The MAC event should have a MAC for the sending device ID. + assert.Contains(t, maps.Keys(secondMACEvt.MAC), id.NewKeyID(id.KeyAlgorithmEd25519, sendingDeviceID.String())) + } + assert.Equal(t, txnID, secondMACEvt.TransactionID) + + // The master key and the sending device ID should be in the + // MAC event's mac keys. + if tc.sendingGeneratedCrossSigningKeys { + assert.Contains(t, maps.Keys(firstMACEvt.MAC), id.NewKeyID(id.KeyAlgorithmEd25519, sendingCrossSigningKeysCache.MasterKey.PublicKey().String())) + } else { + assert.Contains(t, maps.Keys(firstMACEvt.MAC), id.NewKeyID(id.KeyAlgorithmEd25519, receivingCrossSigningKeysCache.MasterKey.PublicKey().String())) + } + + // Test the transaction is done on both sides. We have to dispatch + // twice to process and drain all of the events. + ts.DispatchToDevice(t, ctx, sendingClient) + ts.DispatchToDevice(t, ctx, receivingClient) + ts.DispatchToDevice(t, ctx, sendingClient) + ts.DispatchToDevice(t, ctx, receivingClient) + assert.True(t, sendingCallbacks.IsVerificationDone(txnID)) + assert.True(t, receivingCallbacks.IsVerificationDone(txnID)) + }) + } +} + +func TestVerification_SAS_BothCallStart(t *testing.T) { + ctx := log.Logger.WithContext(context.TODO()) + + ts, sendingClient, receivingClient, _, _, sendingMachine, receivingMachine := initServerAndLoginTwoAlice(t, ctx) + sendingCallbacks, receivingCallbacks, sendingHelper, receivingHelper := initDefaultCallbacks(t, ctx, sendingClient, receivingClient, sendingMachine, receivingMachine) + var err error + + var sendingRecoveryKey string + var sendingCrossSigningKeysCache *crypto.CrossSigningKeysCache + + sendingRecoveryKey, sendingCrossSigningKeysCache, err = sendingMachine.GenerateAndUploadCrossSigningKeys(ctx, nil, "") + require.NoError(t, err) + assert.NotEmpty(t, sendingRecoveryKey) + assert.NotNil(t, sendingCrossSigningKeysCache) + + // Send the verification request from the sender device and accept + // it on the receiving device and receive the verification ready + // event on the sending device. + txnID, err := sendingHelper.StartVerification(ctx, aliceUserID) + require.NoError(t, err) + ts.DispatchToDevice(t, ctx, receivingClient) + err = receivingHelper.AcceptVerification(ctx, txnID) + require.NoError(t, err) + ts.DispatchToDevice(t, ctx, sendingClient) + + err = sendingHelper.StartSAS(ctx, txnID) + require.NoError(t, err) + + err = receivingHelper.StartSAS(ctx, txnID) + require.NoError(t, err) + + // Ensure that both devices have received the verification start event. + receivingInbox := ts.DeviceInbox[aliceUserID][receivingDeviceID] + assert.Len(t, receivingInbox, 1) + assert.Equal(t, txnID, receivingInbox[0].Content.AsVerificationStart().TransactionID) + sendingInbox := ts.DeviceInbox[aliceUserID][sendingDeviceID] + assert.Len(t, sendingInbox, 1) + assert.Equal(t, txnID, sendingInbox[0].Content.AsVerificationStart().TransactionID) + + // Process the start event from the receiving client to the sending client. + ts.DispatchToDevice(t, ctx, sendingClient) + receivingInbox = ts.DeviceInbox[aliceUserID][receivingDeviceID] + assert.Len(t, receivingInbox, 2) + assert.Equal(t, txnID, receivingInbox[0].Content.AsVerificationStart().TransactionID) + assert.Equal(t, txnID, receivingInbox[1].Content.AsVerificationAccept().TransactionID) + + // Process the rest of the events until we need to confirm the SAS. + for len(ts.DeviceInbox[aliceUserID][sendingDeviceID]) > 0 || len(ts.DeviceInbox[aliceUserID][receivingDeviceID]) > 0 { + ts.DispatchToDevice(t, ctx, receivingClient) + ts.DispatchToDevice(t, ctx, sendingClient) + } + + // Confirm the SAS only the receiving device. + receivingHelper.ConfirmSAS(ctx, txnID) + ts.DispatchToDevice(t, ctx, sendingClient) + + // Verification is not done until both devices confirm the SAS. + assert.False(t, sendingCallbacks.IsVerificationDone(txnID)) + assert.False(t, receivingCallbacks.IsVerificationDone(txnID)) + + // Now, confirm it on the sending device. + sendingHelper.ConfirmSAS(ctx, txnID) + + // Dispatching the events to the receiving device should get us to the done + // state on the receiving device. + ts.DispatchToDevice(t, ctx, receivingClient) + assert.False(t, sendingCallbacks.IsVerificationDone(txnID)) + assert.True(t, receivingCallbacks.IsVerificationDone(txnID)) + + // Dispatching the events to the sending client should get us to the done + // state on the sending device. + ts.DispatchToDevice(t, ctx, sendingClient) + assert.True(t, sendingCallbacks.IsVerificationDone(txnID)) + assert.True(t, receivingCallbacks.IsVerificationDone(txnID)) +} diff --git a/crypto/verificationhelper/verificationhelper_test.go b/crypto/verificationhelper/verificationhelper_test.go new file mode 100644 index 00000000..ce5ec5b4 --- /dev/null +++ b/crypto/verificationhelper/verificationhelper_test.go @@ -0,0 +1,517 @@ +package verificationhelper_test + +import ( + "context" + "database/sql" + "fmt" + "os" + "testing" + "time" + + "github.com/rs/zerolog" + "github.com/rs/zerolog/log" // zerolog-allow-global-log + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "maunium.net/go/mautrix" + "maunium.net/go/mautrix/crypto" + "maunium.net/go/mautrix/crypto/cryptohelper" + "maunium.net/go/mautrix/crypto/verificationhelper" + "maunium.net/go/mautrix/event" + "maunium.net/go/mautrix/id" + "maunium.net/go/mautrix/mockserver" +) + +var aliceUserID = id.UserID("@alice:example.org") +var bobUserID = id.UserID("@bob:example.org") +var sendingDeviceID = id.DeviceID("sending") +var receivingDeviceID = id.DeviceID("receiving") + +func init() { + log.Logger = zerolog.New(zerolog.ConsoleWriter{Out: os.Stdout, TimeFormat: time.RFC3339}).With().Timestamp().Logger().Level(zerolog.TraceLevel) + zerolog.DefaultContextLogger = &log.Logger +} + +func addDeviceID(ctx context.Context, cryptoStore crypto.Store, userID id.UserID, deviceID id.DeviceID) { + err := cryptoStore.PutDevice(ctx, userID, &id.Device{ + UserID: userID, + DeviceID: deviceID, + }) + if err != nil { + panic(err) + } +} + +func initServerAndLoginTwoAlice(t *testing.T, ctx context.Context) (ts *mockserver.MockServer, sendingClient, receivingClient *mautrix.Client, sendingCryptoStore, receivingCryptoStore crypto.Store, sendingMachine, receivingMachine *crypto.OlmMachine) { + t.Helper() + ts = mockserver.Create(t) + + sendingClient, sendingCryptoStore = ts.Login(t, ctx, aliceUserID, sendingDeviceID) + sendingMachine = sendingClient.Crypto.(*cryptohelper.CryptoHelper).Machine() + receivingClient, receivingCryptoStore = ts.Login(t, ctx, aliceUserID, receivingDeviceID) + receivingMachine = receivingClient.Crypto.(*cryptohelper.CryptoHelper).Machine() + + require.NoError(t, sendingCryptoStore.PutDevice(ctx, aliceUserID, sendingMachine.OwnIdentity())) + require.NoError(t, sendingCryptoStore.PutDevice(ctx, aliceUserID, receivingMachine.OwnIdentity())) + require.NoError(t, receivingCryptoStore.PutDevice(ctx, aliceUserID, sendingMachine.OwnIdentity())) + require.NoError(t, receivingCryptoStore.PutDevice(ctx, aliceUserID, receivingMachine.OwnIdentity())) + return +} + +func initServerAndLoginAliceBob(t *testing.T, ctx context.Context) (ts *mockserver.MockServer, sendingClient, receivingClient *mautrix.Client, sendingCryptoStore, receivingCryptoStore crypto.Store, sendingMachine, receivingMachine *crypto.OlmMachine) { + t.Helper() + ts = mockserver.Create(t) + + sendingClient, sendingCryptoStore = ts.Login(t, ctx, aliceUserID, sendingDeviceID) + sendingMachine = sendingClient.Crypto.(*cryptohelper.CryptoHelper).Machine() + receivingClient, receivingCryptoStore = ts.Login(t, ctx, bobUserID, receivingDeviceID) + receivingMachine = receivingClient.Crypto.(*cryptohelper.CryptoHelper).Machine() + + require.NoError(t, sendingCryptoStore.PutDevice(ctx, aliceUserID, sendingMachine.OwnIdentity())) + require.NoError(t, sendingCryptoStore.PutDevice(ctx, bobUserID, receivingMachine.OwnIdentity())) + require.NoError(t, receivingCryptoStore.PutDevice(ctx, aliceUserID, sendingMachine.OwnIdentity())) + require.NoError(t, receivingCryptoStore.PutDevice(ctx, bobUserID, receivingMachine.OwnIdentity())) + return +} + +func initDefaultCallbacks(t *testing.T, ctx context.Context, sendingClient, receivingClient *mautrix.Client, sendingMachine, receivingMachine *crypto.OlmMachine) (sendingCallbacks, receivingCallbacks *allVerificationCallbacks, sendingHelper, receivingHelper *verificationhelper.VerificationHelper) { + t.Helper() + sendingCallbacks = newAllVerificationCallbacks() + senderVerificationDB, err := sql.Open("sqlite3", ":memory:") + require.NoError(t, err) + senderVerificationStore, err := NewSQLiteVerificationStore(ctx, senderVerificationDB) + require.NoError(t, err) + + sendingHelper = verificationhelper.NewVerificationHelper(sendingClient, sendingMachine, senderVerificationStore, sendingCallbacks, true, true, true) + require.NoError(t, sendingHelper.Init(ctx)) + + receivingCallbacks = newAllVerificationCallbacks() + receiverVerificationDB, err := sql.Open("sqlite3", ":memory:") + require.NoError(t, err) + receiverVerificationStore, err := NewSQLiteVerificationStore(ctx, receiverVerificationDB) + require.NoError(t, err) + receivingHelper = verificationhelper.NewVerificationHelper(receivingClient, receivingMachine, receiverVerificationStore, receivingCallbacks, true, true, true) + require.NoError(t, receivingHelper.Init(ctx)) + return +} + +func TestVerification_Start(t *testing.T) { + ctx := log.Logger.WithContext(context.TODO()) + receivingDeviceID2 := id.DeviceID("receiving2") + + testCases := []struct { + supportsShow bool + supportsScan bool + supportsSAS bool + callbacks MockVerificationCallbacks + startVerificationErrMsg string + expectedVerificationMethods []event.VerificationMethod + }{ + {false, false, false, newBaseVerificationCallbacks(), "no supported verification methods", nil}, + {false, true, false, newBaseVerificationCallbacks(), "", []event.VerificationMethod{event.VerificationMethodQRCodeScan, event.VerificationMethodReciprocate}}, + + {false, false, false, newShowQRCodeVerificationCallbacks(), "no supported verification methods", nil}, + {true, false, false, newShowQRCodeVerificationCallbacks(), "", []event.VerificationMethod{event.VerificationMethodQRCodeShow, event.VerificationMethodReciprocate}}, + {false, true, false, newShowQRCodeVerificationCallbacks(), "", []event.VerificationMethod{event.VerificationMethodQRCodeScan, event.VerificationMethodReciprocate}}, + {true, true, false, newShowQRCodeVerificationCallbacks(), "", []event.VerificationMethod{event.VerificationMethodQRCodeShow, event.VerificationMethodQRCodeScan, event.VerificationMethodReciprocate}}, + + {false, false, true, newSASVerificationCallbacks(), "", []event.VerificationMethod{event.VerificationMethodSAS}}, + {false, true, true, newSASVerificationCallbacks(), "", []event.VerificationMethod{event.VerificationMethodSAS, event.VerificationMethodQRCodeScan, event.VerificationMethodReciprocate}}, + + {false, false, false, newAllVerificationCallbacks(), "no supported verification methods", nil}, + {false, false, true, newAllVerificationCallbacks(), "", []event.VerificationMethod{event.VerificationMethodSAS}}, + {false, true, true, newAllVerificationCallbacks(), "", []event.VerificationMethod{event.VerificationMethodSAS, event.VerificationMethodQRCodeScan, event.VerificationMethodReciprocate}}, + {true, false, true, newAllVerificationCallbacks(), "", []event.VerificationMethod{event.VerificationMethodSAS, event.VerificationMethodQRCodeShow, event.VerificationMethodReciprocate}}, + {true, true, true, newAllVerificationCallbacks(), "", []event.VerificationMethod{event.VerificationMethodSAS, event.VerificationMethodQRCodeShow, event.VerificationMethodQRCodeScan, event.VerificationMethodReciprocate}}, + } + + for i, tc := range testCases { + t.Run(fmt.Sprintf("%d", i), func(t *testing.T) { + ts := mockserver.Create(t) + + client, cryptoStore := ts.Login(t, ctx, aliceUserID, sendingDeviceID) + addDeviceID(ctx, cryptoStore, aliceUserID, sendingDeviceID) + addDeviceID(ctx, cryptoStore, aliceUserID, receivingDeviceID) + addDeviceID(ctx, cryptoStore, aliceUserID, receivingDeviceID2) + + senderHelper := verificationhelper.NewVerificationHelper(client, client.Crypto.(*cryptohelper.CryptoHelper).Machine(), nil, tc.callbacks, tc.supportsShow, tc.supportsScan, tc.supportsSAS) + err := senderHelper.Init(ctx) + require.NoError(t, err) + + txnID, err := senderHelper.StartVerification(ctx, aliceUserID) + if tc.startVerificationErrMsg != "" { + assert.ErrorContains(t, err, tc.startVerificationErrMsg) + return + } + + require.NoError(t, err) + assert.NotEmpty(t, txnID) + + toDeviceInbox := ts.DeviceInbox[aliceUserID] + + // Ensure that we didn't send a verification request to the + // sending device. + assert.Empty(t, toDeviceInbox[sendingDeviceID]) + + // Ensure that the verification request was sent to both of + // the other devices. + assert.NotEmpty(t, toDeviceInbox[receivingDeviceID]) + assert.NotEmpty(t, toDeviceInbox[receivingDeviceID2]) + assert.Equal(t, toDeviceInbox[receivingDeviceID], toDeviceInbox[receivingDeviceID2]) + require.Len(t, toDeviceInbox[receivingDeviceID], 1) + + // Ensure that the verification request is correct. + verificationRequest := toDeviceInbox[receivingDeviceID][0].Content.AsVerificationRequest() + assert.Equal(t, sendingDeviceID, verificationRequest.FromDevice) + assert.Equal(t, txnID, verificationRequest.TransactionID) + assert.ElementsMatch(t, tc.expectedVerificationMethods, verificationRequest.Methods) + }) + } +} + +func TestVerification_StartThenCancel(t *testing.T) { + ctx := log.Logger.WithContext(context.TODO()) + bystanderDeviceID := id.DeviceID("bystander") + + for _, sendingCancels := range []bool{true, false} { + t.Run(fmt.Sprintf("sendingCancels=%t", sendingCancels), func(t *testing.T) { + ts, sendingClient, receivingClient, sendingCryptoStore, receivingCryptoStore, sendingMachine, receivingMachine := initServerAndLoginTwoAlice(t, ctx) + _, _, sendingHelper, receivingHelper := initDefaultCallbacks(t, ctx, sendingClient, receivingClient, sendingMachine, receivingMachine) + + bystanderClient, _ := ts.Login(t, ctx, aliceUserID, bystanderDeviceID) + bystanderMachine := bystanderClient.Crypto.(*cryptohelper.CryptoHelper).Machine() + bystanderHelper := verificationhelper.NewVerificationHelper(bystanderClient, bystanderMachine, nil, newAllVerificationCallbacks(), true, true, true) + require.NoError(t, bystanderHelper.Init(ctx)) + + require.NoError(t, sendingCryptoStore.PutDevice(ctx, aliceUserID, bystanderMachine.OwnIdentity())) + require.NoError(t, receivingCryptoStore.PutDevice(ctx, aliceUserID, bystanderMachine.OwnIdentity())) + + txnID, err := sendingHelper.StartVerification(ctx, aliceUserID) + require.NoError(t, err) + + assert.Empty(t, ts.DeviceInbox[aliceUserID][sendingDeviceID]) + + // Process the request event on the receiving device. + receivingInbox := ts.DeviceInbox[aliceUserID][receivingDeviceID] + assert.Len(t, receivingInbox, 1) + assert.Equal(t, txnID, receivingInbox[0].Content.AsVerificationRequest().TransactionID) + ts.DispatchToDevice(t, ctx, receivingClient) + + // Process the request event on the bystander device. + bystanderInbox := ts.DeviceInbox[aliceUserID][bystanderDeviceID] + assert.Len(t, bystanderInbox, 1) + assert.Equal(t, txnID, bystanderInbox[0].Content.AsVerificationRequest().TransactionID) + ts.DispatchToDevice(t, ctx, bystanderClient) + + // Cancel the verification request. + var cancelEvt *event.VerificationCancelEventContent + if sendingCancels { + err = sendingHelper.CancelVerification(ctx, txnID, event.VerificationCancelCodeUser, "Recovery code preferred") + assert.NoError(t, err) + + // The sending device should not have a cancellation event. + assert.Empty(t, ts.DeviceInbox[aliceUserID][sendingDeviceID]) + + // Ensure that the cancellation event was sent to the receiving device. + assert.Len(t, ts.DeviceInbox[aliceUserID][receivingDeviceID], 1) + cancelEvt = ts.DeviceInbox[aliceUserID][receivingDeviceID][0].Content.AsVerificationCancel() + + // Ensure that the cancellation event was sent to the bystander device. + assert.Len(t, ts.DeviceInbox[aliceUserID][bystanderDeviceID], 1) + bystanderCancelEvt := ts.DeviceInbox[aliceUserID][bystanderDeviceID][0].Content.AsVerificationCancel() + assert.Equal(t, cancelEvt, bystanderCancelEvt) + } else { + err = receivingHelper.CancelVerification(ctx, txnID, event.VerificationCancelCodeUser, "Recovery code preferred") + assert.NoError(t, err) + + // The receiving device should not have a cancellation event. + assert.Empty(t, ts.DeviceInbox[aliceUserID][receivingDeviceID]) + + // Ensure that the cancellation event was sent to the sending device. + assert.Len(t, ts.DeviceInbox[aliceUserID][sendingDeviceID], 1) + cancelEvt = ts.DeviceInbox[aliceUserID][sendingDeviceID][0].Content.AsVerificationCancel() + + // The bystander device should not have a cancellation event. + assert.Empty(t, ts.DeviceInbox[aliceUserID][bystanderDeviceID]) + } + assert.Equal(t, txnID, cancelEvt.TransactionID) + assert.Equal(t, event.VerificationCancelCodeUser, cancelEvt.Code) + assert.Equal(t, "Recovery code preferred", cancelEvt.Reason) + + if !sendingCancels { + // Process the cancellation event on the sending device. + ts.DispatchToDevice(t, ctx, sendingClient) + + // Ensure that the cancellation event was sent to the bystander device. + assert.Len(t, ts.DeviceInbox[aliceUserID][bystanderDeviceID], 1) + bystanderCancelEvt := ts.DeviceInbox[aliceUserID][bystanderDeviceID][0].Content.AsVerificationCancel() + assert.Equal(t, txnID, bystanderCancelEvt.TransactionID) + assert.Equal(t, event.VerificationCancelCodeUser, bystanderCancelEvt.Code) + assert.Equal(t, "The verification was rejected from another device.", bystanderCancelEvt.Reason) + } + }) + } +} + +func TestVerification_Accept_NoSupportedMethods(t *testing.T) { + ctx := log.Logger.WithContext(context.TODO()) + + ts := mockserver.Create(t) + + sendingClient, sendingCryptoStore := ts.Login(t, ctx, aliceUserID, sendingDeviceID) + receivingClient, _ := ts.Login(t, ctx, aliceUserID, receivingDeviceID) + addDeviceID(ctx, sendingCryptoStore, aliceUserID, sendingDeviceID) + addDeviceID(ctx, sendingCryptoStore, aliceUserID, receivingDeviceID) + + sendingMachine := sendingClient.Crypto.(*cryptohelper.CryptoHelper).Machine() + recoveryKey, cache, err := sendingMachine.GenerateAndUploadCrossSigningKeys(ctx, nil, "") + assert.NoError(t, err) + assert.NotEmpty(t, recoveryKey) + assert.NotNil(t, cache) + + sendingHelper := verificationhelper.NewVerificationHelper(sendingClient, sendingMachine, nil, newAllVerificationCallbacks(), true, true, true) + err = sendingHelper.Init(ctx) + require.NoError(t, err) + + receivingCallbacks := newBaseVerificationCallbacks() + receivingHelper := verificationhelper.NewVerificationHelper(receivingClient, receivingClient.Crypto.(*cryptohelper.CryptoHelper).Machine(), nil, receivingCallbacks, false, false, false) + err = receivingHelper.Init(ctx) + require.NoError(t, err) + + txnID, err := sendingHelper.StartVerification(ctx, aliceUserID) + require.NoError(t, err) + require.NotEmpty(t, txnID) + + ts.DispatchToDevice(t, ctx, receivingClient) + + // Ensure that the receiver ignored the request because it + // doesn't support any of the verification methods in the + // request. + assert.Empty(t, receivingCallbacks.GetRequestedVerifications()) +} + +func TestVerification_Accept_CorrectMethodsPresented(t *testing.T) { + ctx := log.Logger.WithContext(context.TODO()) + + testCases := []struct { + sendingSupportsScan bool + sendingSupportsShow bool + receivingSupportsScan bool + receivingSupportsShow bool + sendingSupportsSAS bool + receivingSupportsSAS bool + sendingCallbacks MockVerificationCallbacks + receivingCallbacks MockVerificationCallbacks + expectedVerificationMethods []event.VerificationMethod + }{ + // TODO + {false, false, false, false, true, true, newSASVerificationCallbacks(), newSASVerificationCallbacks(), []event.VerificationMethod{event.VerificationMethodSAS}}, + {true, false, true, false, true, true, newSASVerificationCallbacks(), newSASVerificationCallbacks(), []event.VerificationMethod{event.VerificationMethodSAS}}, + + {true, false, false, true, false, false, newShowQRCodeVerificationCallbacks(), newShowQRCodeVerificationCallbacks(), []event.VerificationMethod{event.VerificationMethodReciprocate, event.VerificationMethodQRCodeShow}}, + {false, true, true, false, false, false, newShowQRCodeVerificationCallbacks(), newShowQRCodeVerificationCallbacks(), []event.VerificationMethod{event.VerificationMethodReciprocate, event.VerificationMethodQRCodeScan}}, + {true, false, true, true, false, false, newShowQRCodeVerificationCallbacks(), newShowQRCodeVerificationCallbacks(), []event.VerificationMethod{event.VerificationMethodReciprocate, event.VerificationMethodQRCodeShow}}, + {false, true, true, true, false, false, newShowQRCodeVerificationCallbacks(), newShowQRCodeVerificationCallbacks(), []event.VerificationMethod{event.VerificationMethodReciprocate, event.VerificationMethodQRCodeScan}}, + {true, true, true, false, false, false, newShowQRCodeVerificationCallbacks(), newShowQRCodeVerificationCallbacks(), []event.VerificationMethod{event.VerificationMethodReciprocate, event.VerificationMethodQRCodeScan}}, + {true, true, false, true, false, false, newShowQRCodeVerificationCallbacks(), newShowQRCodeVerificationCallbacks(), []event.VerificationMethod{event.VerificationMethodReciprocate, event.VerificationMethodQRCodeShow}}, + {true, true, true, true, false, false, newShowQRCodeVerificationCallbacks(), newShowQRCodeVerificationCallbacks(), []event.VerificationMethod{event.VerificationMethodReciprocate, event.VerificationMethodQRCodeScan, event.VerificationMethodQRCodeShow}}, + + {true, true, true, true, true, true, newAllVerificationCallbacks(), newAllVerificationCallbacks(), []event.VerificationMethod{event.VerificationMethodSAS, event.VerificationMethodReciprocate, event.VerificationMethodQRCodeScan, event.VerificationMethodQRCodeShow}}, + } + + for i, tc := range testCases { + t.Run(fmt.Sprintf("%d", i), func(t *testing.T) { + ts, sendingClient, receivingClient, _, _, sendingMachine, receivingMachine := initServerAndLoginTwoAlice(t, ctx) + + recoveryKey, sendingCrossSigningKeysCache, err := sendingMachine.GenerateAndUploadCrossSigningKeys(ctx, nil, "") + assert.NoError(t, err) + assert.NotEmpty(t, recoveryKey) + assert.NotNil(t, sendingCrossSigningKeysCache) + + sendingHelper := verificationhelper.NewVerificationHelper(sendingClient, sendingMachine, nil, tc.sendingCallbacks, tc.sendingSupportsShow, tc.sendingSupportsScan, tc.sendingSupportsSAS) + err = sendingHelper.Init(ctx) + require.NoError(t, err) + + receivingHelper := verificationhelper.NewVerificationHelper(receivingClient, receivingMachine, nil, tc.receivingCallbacks, tc.receivingSupportsShow, tc.receivingSupportsScan, tc.receivingSupportsSAS) + err = receivingHelper.Init(ctx) + require.NoError(t, err) + + txnID, err := sendingHelper.StartVerification(ctx, aliceUserID) + require.NoError(t, err) + + // Process the verification request on the receiving device. + ts.DispatchToDevice(t, ctx, receivingClient) + + // Ensure that the receiving device received a verification + // request with the correct transaction ID. + assert.ElementsMatch(t, []id.VerificationTransactionID{txnID}, tc.receivingCallbacks.GetRequestedVerifications()[aliceUserID]) + + // Have the receiving device accept the verification request. + err = receivingHelper.AcceptVerification(ctx, txnID) + require.NoError(t, err) + + // Ensure that the receiving device get a notification about the + // transaction being ready. + assert.Contains(t, tc.receivingCallbacks.GetVerificationsReadyTransactions(), txnID) + + // Ensure that if the receiving device should show a QR code that + // it has the correct content. + if tc.sendingSupportsScan && tc.receivingSupportsShow { + receivingShownQRCode := tc.receivingCallbacks.GetQRCodeShown(txnID) + require.NotNil(t, receivingShownQRCode) + assert.Equal(t, txnID, receivingShownQRCode.TransactionID) + assert.NotEmpty(t, receivingShownQRCode.SharedSecret) + } + + // Check for whether the receiving device should be scanning a QR + // code. + if tc.receivingSupportsScan && tc.sendingSupportsShow { + assert.Contains(t, tc.receivingCallbacks.GetScanQRCodeTransactions(), txnID) + } + + // Check that the m.key.verification.ready event has the correct + // content. + sendingInbox := ts.DeviceInbox[aliceUserID][sendingDeviceID] + assert.Len(t, sendingInbox, 1) + readyEvt := sendingInbox[0].Content.AsVerificationReady() + assert.Equal(t, txnID, readyEvt.TransactionID) + assert.Equal(t, receivingDeviceID, readyEvt.FromDevice) + assert.ElementsMatch(t, tc.expectedVerificationMethods, readyEvt.Methods) + + // Receive the m.key.verification.ready event on the sending + // device. + ts.DispatchToDevice(t, ctx, sendingClient) + + // Ensure that the sending device got a notification about the + // transaction being ready. + assert.Contains(t, tc.sendingCallbacks.GetVerificationsReadyTransactions(), txnID) + + // Ensure that if the sending device should show a QR code that it + // has the correct content. + if tc.receivingSupportsScan && tc.sendingSupportsShow { + sendingShownQRCode := tc.sendingCallbacks.GetQRCodeShown(txnID) + require.NotNil(t, sendingShownQRCode) + assert.Equal(t, txnID, sendingShownQRCode.TransactionID) + assert.NotEmpty(t, sendingShownQRCode.SharedSecret) + } + + // Check for whether the sending device should be scanning a QR + // code. + if tc.sendingSupportsScan && tc.receivingSupportsShow { + assert.Contains(t, tc.sendingCallbacks.GetScanQRCodeTransactions(), txnID) + } + }) + } +} + +// TestAcceptSelfVerificationCancelOnNonParticipatingDevices ensures that we do +// not regress https://github.com/mautrix/go/pull/230. +func TestVerification_Accept_CancelOnNonParticipatingDevices(t *testing.T) { + ctx := log.Logger.WithContext(context.TODO()) + ts, sendingClient, receivingClient, sendingCryptoStore, receivingCryptoStore, sendingMachine, receivingMachine := initServerAndLoginTwoAlice(t, ctx) + _, _, sendingHelper, receivingHelper := initDefaultCallbacks(t, ctx, sendingClient, receivingClient, sendingMachine, receivingMachine) + + nonParticipatingDeviceID1 := id.DeviceID("non-participating1") + nonParticipatingDeviceID2 := id.DeviceID("non-participating2") + addDeviceID(ctx, sendingCryptoStore, aliceUserID, nonParticipatingDeviceID1) + addDeviceID(ctx, sendingCryptoStore, aliceUserID, nonParticipatingDeviceID2) + addDeviceID(ctx, receivingCryptoStore, aliceUserID, nonParticipatingDeviceID1) + addDeviceID(ctx, receivingCryptoStore, aliceUserID, nonParticipatingDeviceID2) + + _, _, err := sendingMachine.GenerateAndUploadCrossSigningKeys(ctx, nil, "") + assert.NoError(t, err) + + // Send the verification request from the sender device and accept it on + // the receiving device. + txnID, err := sendingHelper.StartVerification(ctx, aliceUserID) + require.NoError(t, err) + ts.DispatchToDevice(t, ctx, receivingClient) + err = receivingHelper.AcceptVerification(ctx, txnID) + require.NoError(t, err) + + // Receive the m.key.verification.ready event on the sending device. + ts.DispatchToDevice(t, ctx, sendingClient) + + // The sending and receiving devices should not have any cancellation + // events in their inboxes. + assert.Empty(t, ts.DeviceInbox[aliceUserID][sendingDeviceID]) + assert.Empty(t, ts.DeviceInbox[aliceUserID][receivingDeviceID]) + + // There should now be cancellation events in the non-participating devices + // inboxes (in addition to the request event). + assert.Len(t, ts.DeviceInbox[aliceUserID][nonParticipatingDeviceID1], 2) + assert.Len(t, ts.DeviceInbox[aliceUserID][nonParticipatingDeviceID2], 2) + assert.Equal(t, ts.DeviceInbox[aliceUserID][nonParticipatingDeviceID1][1], ts.DeviceInbox[aliceUserID][nonParticipatingDeviceID2][1]) + cancellationEvent := ts.DeviceInbox[aliceUserID][nonParticipatingDeviceID1][1].Content.AsVerificationCancel() + assert.Equal(t, txnID, cancellationEvent.TransactionID) + assert.Equal(t, event.VerificationCancelCodeAccepted, cancellationEvent.Code) +} + +func TestVerification_ErrorOnDoubleAccept(t *testing.T) { + ctx := log.Logger.WithContext(context.TODO()) + ts, sendingClient, receivingClient, _, _, sendingMachine, receivingMachine := initServerAndLoginTwoAlice(t, ctx) + _, _, sendingHelper, receivingHelper := initDefaultCallbacks(t, ctx, sendingClient, receivingClient, sendingMachine, receivingMachine) + + _, _, err := sendingMachine.GenerateAndUploadCrossSigningKeys(ctx, nil, "") + require.NoError(t, err) + + txnID, err := sendingHelper.StartVerification(ctx, aliceUserID) + require.NoError(t, err) + ts.DispatchToDevice(t, ctx, receivingClient) + err = receivingHelper.AcceptVerification(ctx, txnID) + require.NoError(t, err) + err = receivingHelper.AcceptVerification(ctx, txnID) + assert.ErrorContains(t, err, "transaction is not in the requested state") +} + +// TestVerification_CancelOnDoubleStart ensures that the receiving device +// cancels both transactions if the sending device starts two verifications. +// +// This test ensures that the following bullet point from [Section 10.12.2.2.1 +// of the Spec] is followed: +// +// - When the same device attempts to initiate multiple verification attempts, +// the recipient should cancel all attempts with that device. +// +// [Section 10.12.2.2.1 of the Spec]: https://spec.matrix.org/v1.10/client-server-api/#error-and-exception-handling +func TestVerification_CancelOnDoubleStart(t *testing.T) { + ctx := log.Logger.WithContext(context.TODO()) + ts, sendingClient, receivingClient, _, _, sendingMachine, receivingMachine := initServerAndLoginTwoAlice(t, ctx) + sendingCallbacks, receivingCallbacks, sendingHelper, receivingHelper := initDefaultCallbacks(t, ctx, sendingClient, receivingClient, sendingMachine, receivingMachine) + + _, _, err := sendingMachine.GenerateAndUploadCrossSigningKeys(ctx, nil, "") + require.NoError(t, err) + + // Send and accept the first verification request. + txnID1, err := sendingHelper.StartVerification(ctx, aliceUserID) + require.NoError(t, err) + ts.DispatchToDevice(t, ctx, receivingClient) + err = receivingHelper.AcceptVerification(ctx, txnID1) + require.NoError(t, err) + ts.DispatchToDevice(t, ctx, sendingClient) // Process the m.key.verification.ready event + + // Send a second verification request + txnID2, err := sendingHelper.StartVerification(ctx, aliceUserID) + require.NoError(t, err) + ts.DispatchToDevice(t, ctx, receivingClient) + + // Ensure that the sending device received a cancellation event for both of + // the ongoing transactions. + sendingInbox := ts.DeviceInbox[aliceUserID][sendingDeviceID] + require.Len(t, sendingInbox, 2) + cancelEvt1 := sendingInbox[0].Content.AsVerificationCancel() + cancelEvt2 := sendingInbox[1].Content.AsVerificationCancel() + cancelledTxnIDs := []id.VerificationTransactionID{cancelEvt1.TransactionID, cancelEvt2.TransactionID} + assert.Contains(t, cancelledTxnIDs, txnID1) + assert.Contains(t, cancelledTxnIDs, txnID2) + assert.Equal(t, event.VerificationCancelCodeUnexpectedMessage, cancelEvt1.Code) + assert.Equal(t, event.VerificationCancelCodeUnexpectedMessage, cancelEvt2.Code) + assert.Equal(t, "received multiple verification requests from the same device", cancelEvt1.Reason) + assert.Equal(t, "received multiple verification requests from the same device", cancelEvt2.Reason) + + assert.NotNil(t, receivingCallbacks.GetVerificationCancellation(txnID1)) + assert.NotNil(t, receivingCallbacks.GetVerificationCancellation(txnID2)) + ts.DispatchToDevice(t, ctx, sendingClient) // Process the m.key.verification.cancel events + assert.NotNil(t, sendingCallbacks.GetVerificationCancellation(txnID1)) + assert.NotNil(t, sendingCallbacks.GetVerificationCancellation(txnID2)) +} diff --git a/crypto/verificationhelper/verificationstore.go b/crypto/verificationhelper/verificationstore.go new file mode 100644 index 00000000..1eb8f752 --- /dev/null +++ b/crypto/verificationhelper/verificationstore.go @@ -0,0 +1,159 @@ +package verificationhelper + +import ( + "context" + "errors" + "fmt" + + "go.mau.fi/util/jsontime" + + "maunium.net/go/mautrix/event" + "maunium.net/go/mautrix/id" +) + +var ErrUnknownVerificationTransaction = errors.New("unknown transaction ID") + +type VerificationState int + +const ( + VerificationStateRequested VerificationState = iota + VerificationStateReady + + VerificationStateTheirQRScanned // We scanned their QR code + VerificationStateOurQRScanned // They scanned our QR code + + VerificationStateSASStarted // An SAS verification has been started + VerificationStateSASAccepted // An SAS verification has been accepted + VerificationStateSASKeysExchanged // An SAS verification has exchanged keys + VerificationStateSASMACExchanged // An SAS verification has exchanged MACs +) + +func (step VerificationState) String() string { + switch step { + case VerificationStateRequested: + return "requested" + case VerificationStateReady: + return "ready" + case VerificationStateTheirQRScanned: + return "their_qr_scanned" + case VerificationStateOurQRScanned: + return "our_qr_scanned" + case VerificationStateSASStarted: + return "sas_started" + case VerificationStateSASAccepted: + return "sas_accepted" + case VerificationStateSASKeysExchanged: + return "sas_keys_exchanged" + case VerificationStateSASMACExchanged: + return "sas_mac" + default: + return fmt.Sprintf("VerificationState(%d)", step) + } +} + +type VerificationTransaction struct { + ExpirationTime jsontime.UnixMilli `json:"expiration_time,omitempty"` + + // RoomID is the room ID if the verification is happening in a room or + // empty if it is a to-device verification. + RoomID id.RoomID `json:"room_id,omitempty"` + + // VerificationState is the current step of the verification flow. + VerificationState VerificationState `json:"verification_state"` + // TransactionID is the ID of the verification transaction. + TransactionID id.VerificationTransactionID `json:"transaction_id"` + + // TheirDeviceID is the device ID of the device that either made the + // initial request or accepted our request. + TheirDeviceID id.DeviceID `json:"their_device_id,omitempty"` + // TheirUserID is the user ID of the other user. + TheirUserID id.UserID `json:"their_user_id,omitempty"` + // TheirSupportedMethods is a list of verification methods that the other + // device supports. + TheirSupportedMethods []event.VerificationMethod `json:"their_supported_methods,omitempty"` + + // SentToDeviceIDs is a list of devices which the initial request was sent + // to. This is only used for to-device verification requests, and is meant + // to be used to send cancellation requests to all other devices when a + // verification request is accepted via a m.key.verification.ready event. + SentToDeviceIDs []id.DeviceID `json:"sent_to_device_ids,omitempty"` + + // QRCodeSharedSecret is the shared secret that was encoded in the QR code + // that we showed. + QRCodeSharedSecret []byte `json:"qr_code_shared_secret,omitempty"` + + StartedByUs bool `json:"started_by_us,omitempty"` // Whether the verification was started by us + StartEventContent *event.VerificationStartEventContent `json:"start_event_content,omitempty"` // The m.key.verification.start event content + Commitment []byte `json:"committment,omitempty"` // The commitment from the m.key.verification.accept event + MACMethod event.MACMethod `json:"mac_method,omitempty"` // The method used to calculate the MAC + EphemeralKey *ECDHPrivateKey `json:"ephemeral_key,omitempty"` // The ephemeral key + EphemeralPublicKeyShared bool `json:"ephemeral_public_key_shared,omitempty"` // Whether this device's ephemeral public key has been shared + OtherPublicKey *ECDHPublicKey `json:"other_public_key,omitempty"` // The other device's ephemeral public key + ReceivedTheirMAC bool `json:"received_their_mac,omitempty"` // Whether we have received their MAC + SentOurMAC bool `json:"sent_our_mac,omitempty"` // Whether we have sent our MAC + ReceivedTheirDone bool `json:"received_their_done,omitempty"` // Whether we have received their done event + SentOurDone bool `json:"sent_our_done,omitempty"` // Whether we have sent our done event +} + +type VerificationStore interface { + // DeleteVerification deletes a verification transaction by ID + DeleteVerification(ctx context.Context, txnID id.VerificationTransactionID) error + // GetVerificationTransaction gets a verification transaction by ID + GetVerificationTransaction(ctx context.Context, txnID id.VerificationTransactionID) (VerificationTransaction, error) + // SaveVerificationTransaction saves a verification transaction by ID + SaveVerificationTransaction(ctx context.Context, txn VerificationTransaction) error + // FindVerificationTransactionForUserDevice finds a verification + // transaction by user and device ID + FindVerificationTransactionForUserDevice(ctx context.Context, userID id.UserID, deviceID id.DeviceID) (VerificationTransaction, error) + // GetAllVerificationTransactions returns all of the verification + // transactions. This is used to reset the cancellation timeouts. + GetAllVerificationTransactions(ctx context.Context) ([]VerificationTransaction, error) +} + +type InMemoryVerificationStore struct { + txns map[id.VerificationTransactionID]VerificationTransaction +} + +var _ VerificationStore = (*InMemoryVerificationStore)(nil) + +func NewInMemoryVerificationStore() *InMemoryVerificationStore { + return &InMemoryVerificationStore{ + txns: map[id.VerificationTransactionID]VerificationTransaction{}, + } +} + +func (i *InMemoryVerificationStore) DeleteVerification(ctx context.Context, txnID id.VerificationTransactionID) error { + if _, ok := i.txns[txnID]; !ok { + return ErrUnknownVerificationTransaction + } + delete(i.txns, txnID) + return nil +} + +func (i *InMemoryVerificationStore) GetVerificationTransaction(ctx context.Context, txnID id.VerificationTransactionID) (VerificationTransaction, error) { + if _, ok := i.txns[txnID]; !ok { + return VerificationTransaction{}, ErrUnknownVerificationTransaction + } + return i.txns[txnID], nil +} + +func (i *InMemoryVerificationStore) SaveVerificationTransaction(ctx context.Context, txn VerificationTransaction) error { + i.txns[txn.TransactionID] = txn + return nil +} + +func (i *InMemoryVerificationStore) FindVerificationTransactionForUserDevice(ctx context.Context, userID id.UserID, deviceID id.DeviceID) (VerificationTransaction, error) { + for _, existingTxn := range i.txns { + if existingTxn.TheirUserID == userID && existingTxn.TheirDeviceID == deviceID { + return existingTxn, nil + } + } + return VerificationTransaction{}, ErrUnknownVerificationTransaction +} + +func (i *InMemoryVerificationStore) GetAllVerificationTransactions(ctx context.Context) (txns []VerificationTransaction, err error) { + for _, txn := range i.txns { + txns = append(txns, txn) + } + return +} diff --git a/crypto/verificationhelper/verificationstore_test.go b/crypto/verificationhelper/verificationstore_test.go new file mode 100644 index 00000000..e64153b1 --- /dev/null +++ b/crypto/verificationhelper/verificationstore_test.go @@ -0,0 +1,85 @@ +package verificationhelper_test + +import ( + "context" + "database/sql" + "errors" + + _ "github.com/mattn/go-sqlite3" + "github.com/rs/zerolog" + "go.mau.fi/util/dbutil" + + "maunium.net/go/mautrix/crypto/verificationhelper" + "maunium.net/go/mautrix/id" +) + +type SQLiteVerificationStore struct { + db *sql.DB +} + +const ( + selectVerifications = `SELECT transaction_data FROM verifications` + getVerificationByTransactionID = selectVerifications + ` WHERE transaction_id = ?1` + getVerificationByUserDeviceID = selectVerifications + ` + WHERE transaction_data->>'their_user_id' = ?1 + AND transaction_data->>'their_device_id' = ?2 + ` + deleteVerificationsQuery = `DELETE FROM verifications WHERE transaction_id = ?1` +) + +var _ verificationhelper.VerificationStore = (*SQLiteVerificationStore)(nil) + +func NewSQLiteVerificationStore(ctx context.Context, db *sql.DB) (*SQLiteVerificationStore, error) { + _, err := db.ExecContext(ctx, ` + CREATE TABLE verifications ( + transaction_id TEXT PRIMARY KEY NOT NULL, + transaction_data JSONB NOT NULL + ); + CREATE INDEX verifications_user_device_id ON + verifications(transaction_data->>'their_user_id', transaction_data->>'their_device_id'); + `) + return &SQLiteVerificationStore{db}, err +} + +func (s *SQLiteVerificationStore) GetAllVerificationTransactions(ctx context.Context) ([]verificationhelper.VerificationTransaction, error) { + rows, err := s.db.QueryContext(ctx, selectVerifications) + return dbutil.NewRowIterWithError(rows, func(dbutil.Scannable) (txn verificationhelper.VerificationTransaction, err error) { + err = rows.Scan(&dbutil.JSON{Data: &txn}) + return + }, err).AsList() +} + +func (vq *SQLiteVerificationStore) GetVerificationTransaction(ctx context.Context, txnID id.VerificationTransactionID) (txn verificationhelper.VerificationTransaction, err error) { + zerolog.Ctx(ctx).Warn().Stringer("transaction_id", txnID).Msg("Getting verification transaction") + row := vq.db.QueryRowContext(ctx, getVerificationByTransactionID, txnID) + err = row.Scan(&dbutil.JSON{Data: &txn}) + if errors.Is(err, sql.ErrNoRows) { + err = verificationhelper.ErrUnknownVerificationTransaction + } + return +} + +func (vq *SQLiteVerificationStore) FindVerificationTransactionForUserDevice(ctx context.Context, userID id.UserID, deviceID id.DeviceID) (txn verificationhelper.VerificationTransaction, err error) { + row := vq.db.QueryRowContext(ctx, getVerificationByUserDeviceID, userID, deviceID) + err = row.Scan(&dbutil.JSON{Data: &txn}) + if errors.Is(err, sql.ErrNoRows) { + err = verificationhelper.ErrUnknownVerificationTransaction + } + return +} + +func (vq *SQLiteVerificationStore) SaveVerificationTransaction(ctx context.Context, txn verificationhelper.VerificationTransaction) (err error) { + zerolog.Ctx(ctx).Debug().Any("transaction", &txn).Msg("Saving verification transaction") + _, err = vq.db.ExecContext(ctx, ` + INSERT INTO verifications (transaction_id, transaction_data) + VALUES (?1, ?2) + ON CONFLICT (transaction_id) DO UPDATE + SET transaction_data=excluded.transaction_data + `, txn.TransactionID, &dbutil.JSON{Data: &txn}) + return +} + +func (vq *SQLiteVerificationStore) DeleteVerification(ctx context.Context, txnID id.VerificationTransactionID) (err error) { + _, err = vq.db.ExecContext(ctx, deleteVerificationsQuery, txnID) + return +} diff --git a/error.go b/error.go index bcd568d8..4711b3dc 100644 --- a/error.go +++ b/error.go @@ -12,6 +12,8 @@ import ( "fmt" "net/http" + "go.mau.fi/util/exhttp" + "go.mau.fi/util/exmaps" "golang.org/x/exp/maps" ) @@ -24,49 +26,69 @@ import ( // // logout // } var ( + // Generic error for when the server encounters an error and it does not have a more specific error code. + // Note that `errors.Is` will check the error message rather than code for M_UNKNOWNs. + MUnknown = RespError{ErrCode: "M_UNKNOWN", StatusCode: http.StatusInternalServerError} // Forbidden access, e.g. joining a room without permission, failed login. - MForbidden = RespError{ErrCode: "M_FORBIDDEN"} + MForbidden = RespError{ErrCode: "M_FORBIDDEN", StatusCode: http.StatusForbidden} // Unrecognized request, e.g. the endpoint does not exist or is not implemented. - MUnrecognized = RespError{ErrCode: "M_UNRECOGNIZED"} + MUnrecognized = RespError{ErrCode: "M_UNRECOGNIZED", StatusCode: http.StatusNotFound} // The access token specified was not recognised. - MUnknownToken = RespError{ErrCode: "M_UNKNOWN_TOKEN"} + MUnknownToken = RespError{ErrCode: "M_UNKNOWN_TOKEN", StatusCode: http.StatusUnauthorized} // No access token was specified for the request. - MMissingToken = RespError{ErrCode: "M_MISSING_TOKEN"} + MMissingToken = RespError{ErrCode: "M_MISSING_TOKEN", StatusCode: http.StatusUnauthorized} // Request contained valid JSON, but it was malformed in some way, e.g. missing required keys, invalid values for keys. - MBadJSON = RespError{ErrCode: "M_BAD_JSON"} + MBadJSON = RespError{ErrCode: "M_BAD_JSON", StatusCode: http.StatusBadRequest} // Request did not contain valid JSON. - MNotJSON = RespError{ErrCode: "M_NOT_JSON"} + MNotJSON = RespError{ErrCode: "M_NOT_JSON", StatusCode: http.StatusBadRequest} // No resource was found for this request. - MNotFound = RespError{ErrCode: "M_NOT_FOUND"} + MNotFound = RespError{ErrCode: "M_NOT_FOUND", StatusCode: http.StatusNotFound} // Too many requests have been sent in a short period of time. Wait a while then try again. - MLimitExceeded = RespError{ErrCode: "M_LIMIT_EXCEEDED"} + MLimitExceeded = RespError{ErrCode: "M_LIMIT_EXCEEDED", StatusCode: http.StatusTooManyRequests} // The user ID associated with the request has been deactivated. // Typically for endpoints that prove authentication, such as /login. MUserDeactivated = RespError{ErrCode: "M_USER_DEACTIVATED"} // Encountered when trying to register a user ID which has been taken. - MUserInUse = RespError{ErrCode: "M_USER_IN_USE"} + MUserInUse = RespError{ErrCode: "M_USER_IN_USE", StatusCode: http.StatusBadRequest} // Encountered when trying to register a user ID which is not valid. - MInvalidUsername = RespError{ErrCode: "M_INVALID_USERNAME"} + MInvalidUsername = RespError{ErrCode: "M_INVALID_USERNAME", StatusCode: http.StatusBadRequest} // Sent when the room alias given to the createRoom API is already in use. - MRoomInUse = RespError{ErrCode: "M_ROOM_IN_USE"} + MRoomInUse = RespError{ErrCode: "M_ROOM_IN_USE", StatusCode: http.StatusBadRequest} // The state change requested cannot be performed, such as attempting to unban a user who is not banned. MBadState = RespError{ErrCode: "M_BAD_STATE"} // The request or entity was too large. - MTooLarge = RespError{ErrCode: "M_TOO_LARGE"} + MTooLarge = RespError{ErrCode: "M_TOO_LARGE", StatusCode: http.StatusRequestEntityTooLarge} // The resource being requested is reserved by an application service, or the application service making the request has not created the resource. - MExclusive = RespError{ErrCode: "M_EXCLUSIVE"} + MExclusive = RespError{ErrCode: "M_EXCLUSIVE", StatusCode: http.StatusBadRequest} // The client's request to create a room used a room version that the server does not support. MUnsupportedRoomVersion = RespError{ErrCode: "M_UNSUPPORTED_ROOM_VERSION"} // The client attempted to join a room that has a version the server does not support. // Inspect the room_version property of the error response for the room's version. MIncompatibleRoomVersion = RespError{ErrCode: "M_INCOMPATIBLE_ROOM_VERSION"} // The client specified a parameter that has the wrong value. - MInvalidParam = RespError{ErrCode: "M_INVALID_PARAM"} + MInvalidParam = RespError{ErrCode: "M_INVALID_PARAM", StatusCode: http.StatusBadRequest} + // The client specified a room key backup version that is not the current room key backup version for the user. + MWrongRoomKeysVersion = RespError{ErrCode: "M_WRONG_ROOM_KEYS_VERSION", StatusCode: http.StatusForbidden} MURLNotSet = RespError{ErrCode: "M_URL_NOT_SET"} MBadStatus = RespError{ErrCode: "M_BAD_STATUS"} MConnectionTimeout = RespError{ErrCode: "M_CONNECTION_TIMEOUT"} MConnectionFailed = RespError{ErrCode: "M_CONNECTION_FAILED"} + + MUnredactedContentDeleted = RespError{ErrCode: "FI.MAU.MSC2815_UNREDACTED_CONTENT_DELETED"} + MUnredactedContentNotReceived = RespError{ErrCode: "FI.MAU.MSC2815_UNREDACTED_CONTENT_NOT_RECEIVED"} +) + +var ( + ErrClientIsNil = errors.New("client is nil") + ErrClientHasNoHomeserver = errors.New("client has no homeserver set") + + ErrResponseTooLong = errors.New("response content length too long") + ErrBodyReadReachedLimit = errors.New("reached response size limit while reading body") + + // Special error that indicates we should retry canceled contexts. Note that on it's own this + // is useless, the context itself must also be replaced. + ErrContextCancelRetry = errors.New("retry canceled context") ) // HTTPError An HTTP Error response, which may wrap an underlying native Go Error. @@ -92,10 +114,9 @@ func (e HTTPError) Error() string { if e.WrappedError != nil { return fmt.Sprintf("%s: %v", e.Message, e.WrappedError) } else if e.RespError != nil { - return fmt.Sprintf("failed to %s %s: %s (HTTP %d): %s", e.Request.Method, e.Request.URL.Path, - e.RespError.ErrCode, e.Response.StatusCode, e.RespError.Err) + return fmt.Sprintf("%s (HTTP %d): %s", e.RespError.ErrCode, e.Response.StatusCode, e.RespError.Err) } else { - msg := fmt.Sprintf("failed to %s %s: HTTP %d", e.Request.Method, e.Request.URL.Path, e.Response.StatusCode) + msg := fmt.Sprintf("HTTP %d", e.Response.StatusCode) if len(e.ResponseBody) > 0 { msg = fmt.Sprintf("%s: %s", msg, e.ResponseBody) } @@ -117,7 +138,12 @@ func (e HTTPError) Unwrap() error { type RespError struct { ErrCode string Err string - ExtraData map[string]interface{} + ExtraData map[string]any + + StatusCode int + ExtraHeader map[string]string + + CanRetry bool } func (e *RespError) UnmarshalJSON(data []byte) error { @@ -127,19 +153,70 @@ func (e *RespError) UnmarshalJSON(data []byte) error { } e.ErrCode, _ = e.ExtraData["errcode"].(string) e.Err, _ = e.ExtraData["error"].(string) + e.CanRetry, _ = e.ExtraData["com.beeper.can_retry"].(bool) return nil } func (e *RespError) MarshalJSON() ([]byte, error) { - data := maps.Clone(e.ExtraData) - if data == nil { - data = make(map[string]any) - } + data := exmaps.NonNilClone(e.ExtraData) data["errcode"] = e.ErrCode data["error"] = e.Err + if e.CanRetry { + data["com.beeper.can_retry"] = e.CanRetry + } return json.Marshal(data) } +func (e RespError) Write(w http.ResponseWriter) { + if w == nil { + return + } + statusCode := e.StatusCode + if statusCode == 0 { + statusCode = http.StatusInternalServerError + } + for key, value := range e.ExtraHeader { + w.Header().Set(key, value) + } + exhttp.WriteJSONResponse(w, statusCode, &e) +} + +func (e RespError) WithMessage(msg string, args ...any) RespError { + if len(args) > 0 { + msg = fmt.Sprintf(msg, args...) + } + e.Err = msg + return e +} + +func (e RespError) WithStatus(status int) RespError { + e.StatusCode = status + return e +} + +func (e RespError) WithCanRetry(canRetry bool) RespError { + e.CanRetry = canRetry + return e +} + +func (e RespError) WithExtraData(extraData map[string]any) RespError { + e.ExtraData = exmaps.NonNilClone(e.ExtraData) + maps.Copy(e.ExtraData, extraData) + return e +} + +func (e RespError) WithExtraHeader(key, value string) RespError { + e.ExtraHeader = exmaps.NonNilClone(e.ExtraHeader) + e.ExtraHeader[key] = value + return e +} + +func (e RespError) WithExtraHeaders(headers map[string]string) RespError { + e.ExtraHeader = exmaps.NonNilClone(e.ExtraHeader) + maps.Copy(e.ExtraHeader, headers) + return e +} + // Error returns the errcode and error message. func (e RespError) Error() string { return e.ErrCode + ": " + e.Err diff --git a/event/accountdata.go b/event/accountdata.go index 6637fcfe..223919a1 100644 --- a/event/accountdata.go +++ b/event/accountdata.go @@ -8,6 +8,8 @@ package event import ( "encoding/json" + "strings" + "time" "maunium.net/go/mautrix/id" ) @@ -18,10 +20,47 @@ type TagEventContent struct { Tags Tags `json:"tags"` } -type Tags map[string]Tag +type Tags map[RoomTag]TagMetadata -type Tag struct { +type RoomTag string + +const ( + RoomTagFavourite RoomTag = "m.favourite" + RoomTagLowPriority RoomTag = "m.lowpriority" + RoomTagServerNotice RoomTag = "m.server_notice" +) + +func (rt RoomTag) IsUserDefined() bool { + return strings.HasPrefix(string(rt), "u.") +} + +func (rt RoomTag) String() string { + return string(rt) +} + +func (rt RoomTag) Name() string { + if rt.IsUserDefined() { + return string(rt[2:]) + } + switch rt { + case RoomTagFavourite: + return "Favourite" + case RoomTagLowPriority: + return "Low priority" + case RoomTagServerNotice: + return "Server notice" + default: + return "" + } +} + +// Deprecated: type alias +type Tag = TagMetadata + +type TagMetadata struct { Order json.Number `json:"order,omitempty"` + + MauDoublePuppetSource string `json:"fi.mau.double_puppet_source,omitempty"` } // DirectChatsEventContent represents the content of a m.direct account data event. @@ -43,3 +82,38 @@ type IgnoredUserListEventContent struct { type IgnoredUser struct { // This is an empty object } + +type MarkedUnreadEventContent struct { + Unread bool `json:"unread"` +} + +type BeeperMuteEventContent struct { + MutedUntil int64 `json:"muted_until,omitempty"` +} + +func (bmec *BeeperMuteEventContent) IsMuted() bool { + return bmec.MutedUntil < 0 || (bmec.MutedUntil > 0 && bmec.GetMutedUntilTime().After(time.Now())) +} + +var MutedForever = time.Date(9999, 12, 31, 23, 59, 59, 999999999, time.UTC) + +func (bmec *BeeperMuteEventContent) GetMutedUntilTime() time.Time { + if bmec.MutedUntil < 0 { + return MutedForever + } else if bmec.MutedUntil > 0 { + return time.UnixMilli(bmec.MutedUntil) + } + return time.Time{} +} + +func (bmec *BeeperMuteEventContent) GetMuteDuration() time.Duration { + ts := bmec.GetMutedUntilTime() + now := time.Now() + if ts.Before(now) { + return 0 + } else if ts == MutedForever { + return -1 + } else { + return ts.Sub(now) + } +} diff --git a/event/audio.go b/event/audio.go new file mode 100644 index 00000000..9eeb8edb --- /dev/null +++ b/event/audio.go @@ -0,0 +1,21 @@ +package event + +import ( + "encoding/json" +) + +type MSC1767Audio struct { + Duration int `json:"duration"` + Waveform []int `json:"waveform"` +} + +type serializableMSC1767Audio MSC1767Audio + +func (ma *MSC1767Audio) MarshalJSON() ([]byte, error) { + if ma.Waveform == nil { + ma.Waveform = []int{} + } + return json.Marshal((*serializableMSC1767Audio)(ma)) +} + +type MSC3245Voice struct{} diff --git a/event/beeper.go b/event/beeper.go index 51ddd77f..a1a60b35 100644 --- a/event/beeper.go +++ b/event/beeper.go @@ -7,6 +7,15 @@ package event import ( + "encoding/base32" + "encoding/binary" + "encoding/json" + "fmt" + "html" + "regexp" + "strconv" + "strings" + "maunium.net/go/mautrix/id" ) @@ -32,15 +41,20 @@ const ( ) type BeeperMessageStatusEventContent struct { - Network string `json:"network"` + Network string `json:"network,omitempty"` RelatesTo RelatesTo `json:"m.relates_to"` Status MessageStatus `json:"status"` Reason MessageStatusReason `json:"reason,omitempty"` - Error string `json:"error,omitempty"` - Message string `json:"message,omitempty"` + // Deprecated: clients were showing this to users even though they aren't supposed to. + // Use InternalError for error messages that should be included in bug reports, but not shown in the UI. + Error string `json:"error,omitempty"` + InternalError string `json:"internal_error,omitempty"` + Message string `json:"message,omitempty"` LastRetry id.EventID `json:"last_retry,omitempty"` + TargetTxnID string `json:"relates_to_txn_id,omitempty"` + MutateEventKey string `json:"mutate_event_key,omitempty"` // Indicates the set of users to whom the event was delivered. If nil, then @@ -50,6 +64,18 @@ type BeeperMessageStatusEventContent struct { DeliveredToUsers *[]id.UserID `json:"delivered_to_users,omitempty"` } +type BeeperRelatesTo struct { + EventID id.EventID `json:"event_id,omitempty"` + RoomID id.RoomID `json:"room_id,omitempty"` + Type RelationType `json:"rel_type,omitempty"` +} + +type BeeperTranscriptionEventContent struct { + Text []ExtensibleText `json:"m.text,omitempty"` + Model string `json:"com.beeper.transcription.model,omitempty"` + RelatesTo BeeperRelatesTo `json:"com.beeper.relates_to,omitempty"` +} + type BeeperRetryMetadata struct { OriginalEventID id.EventID `json:"original_event_id"` RetryCount int `json:"retry_count"` @@ -62,18 +88,54 @@ type BeeperRoomKeyAckEventContent struct { FirstMessageIndex int `json:"first_message_index"` } +type BeeperChatDeleteEventContent struct { + DeleteForEveryone bool `json:"delete_for_everyone,omitempty"` + FromMessageRequest bool `json:"from_message_request,omitempty"` +} + +type BeeperAcceptMessageRequestEventContent struct { + // Whether this was triggered by a message rather than an explicit event + IsImplicit bool `json:"-"` +} + +type BeeperSendStateEventContent struct { + Type string `json:"type"` + StateKey string `json:"state_key"` + Content Content `json:"content"` +} + +type IntOrString int + +func (ios *IntOrString) UnmarshalJSON(data []byte) error { + if len(data) > 0 && data[0] == '"' { + var str string + err := json.Unmarshal(data, &str) + if err != nil { + return err + } + intVal, err := strconv.Atoi(str) + if err != nil { + return err + } + *ios = IntOrString(intVal) + return nil + } + return json.Unmarshal(data, (*int)(ios)) +} + type LinkPreview struct { CanonicalURL string `json:"og:url,omitempty"` Title string `json:"og:title,omitempty"` Type string `json:"og:type,omitempty"` Description string `json:"og:description,omitempty"` + SiteName string `json:"og:site_name,omitempty"` ImageURL id.ContentURIString `json:"og:image,omitempty"` - ImageSize int `json:"matrix:image:size,omitempty"` - ImageWidth int `json:"og:image:width,omitempty"` - ImageHeight int `json:"og:image:height,omitempty"` - ImageType string `json:"og:image:type,omitempty"` + ImageSize IntOrString `json:"matrix:image:size,omitempty"` + ImageWidth IntOrString `json:"og:image:width,omitempty"` + ImageHeight IntOrString `json:"og:image:height,omitempty"` + ImageType string `json:"og:image:type,omitempty"` } // BeeperLinkPreview contains the data for a bundled URL preview as specified in MSC4095 @@ -84,4 +146,186 @@ type BeeperLinkPreview struct { MatchedURL string `json:"matched_url,omitempty"` ImageEncryption *EncryptedFileInfo `json:"beeper:image:encryption,omitempty"` + ImageBlurhash string `json:"matrix:image:blurhash,omitempty"` +} + +type BeeperProfileExtra struct { + RemoteID string `json:"com.beeper.bridge.remote_id,omitempty"` + Identifiers []string `json:"com.beeper.bridge.identifiers,omitempty"` + Service string `json:"com.beeper.bridge.service,omitempty"` + Network string `json:"com.beeper.bridge.network,omitempty"` + IsBridgeBot bool `json:"com.beeper.bridge.is_bridge_bot,omitempty"` + IsNetworkBot bool `json:"com.beeper.bridge.is_network_bot,omitempty"` +} + +type BeeperPerMessageProfile struct { + ID string `json:"id"` + Displayname string `json:"displayname,omitempty"` + AvatarURL *id.ContentURIString `json:"avatar_url,omitempty"` + AvatarFile *EncryptedFileInfo `json:"avatar_file,omitempty"` + HasFallback bool `json:"has_fallback,omitempty"` +} + +type BeeperActionMessageType string + +const ( + BeeperActionMessageCall BeeperActionMessageType = "call" +) + +type BeeperActionMessageCallType string + +const ( + BeeperActionMessageCallTypeVoice BeeperActionMessageCallType = "voice" + BeeperActionMessageCallTypeVideo BeeperActionMessageCallType = "video" +) + +type BeeperActionMessage struct { + Type BeeperActionMessageType `json:"type"` + CallType BeeperActionMessageCallType `json:"call_type,omitempty"` +} + +func (content *MessageEventContent) AddPerMessageProfileFallback() { + if content.BeeperPerMessageProfile == nil || content.BeeperPerMessageProfile.HasFallback || content.BeeperPerMessageProfile.Displayname == "" { + return + } + content.BeeperPerMessageProfile.HasFallback = true + content.EnsureHasHTML() + content.Body = fmt.Sprintf("%s: %s", content.BeeperPerMessageProfile.Displayname, content.Body) + content.FormattedBody = fmt.Sprintf( + "%s: %s", + html.EscapeString(content.BeeperPerMessageProfile.Displayname), + content.FormattedBody, + ) +} + +var HTMLProfileFallbackRegex = regexp.MustCompile(`([^<]+): `) + +func (content *MessageEventContent) RemovePerMessageProfileFallback() { + if content.NewContent != nil && content.NewContent != content { + content.NewContent.RemovePerMessageProfileFallback() + } + if content == nil || content.BeeperPerMessageProfile == nil || !content.BeeperPerMessageProfile.HasFallback || content.BeeperPerMessageProfile.Displayname == "" { + return + } + content.BeeperPerMessageProfile.HasFallback = false + content.Body = strings.TrimPrefix(content.Body, content.BeeperPerMessageProfile.Displayname+": ") + if content.Format == FormatHTML { + content.FormattedBody = HTMLProfileFallbackRegex.ReplaceAllLiteralString(content.FormattedBody, "") + } +} + +type BeeperAIStreamEventContent struct { + TurnID string `json:"turn_id"` + Seq int `json:"seq"` + Part map[string]any `json:"part"` + TargetEvent id.EventID `json:"target_event,omitempty"` + AgentID string `json:"agent_id,omitempty"` + RelatesTo *RelatesTo `json:"m.relates_to,omitempty"` +} + +type BeeperEncodedOrder struct { + order int64 + suborder int16 +} + +func NewBeeperEncodedOrder(order int64, suborder int16) *BeeperEncodedOrder { + return &BeeperEncodedOrder{order: order, suborder: suborder} +} + +func BeeperEncodedOrderFromString(str string) (*BeeperEncodedOrder, error) { + order, suborder, err := decodeIntPair(str) + if err != nil { + return nil, err + } + return &BeeperEncodedOrder{order: order, suborder: suborder}, nil +} + +func (b *BeeperEncodedOrder) String() string { + if b == nil { + return "" + } + return encodeIntPair(b.order, b.suborder) +} + +func (b *BeeperEncodedOrder) OrderPair() (int64, int16) { + if b == nil { + return 0, 0 + } + return b.order, b.suborder +} + +func (b *BeeperEncodedOrder) IsZero() bool { + return b == nil || (b.order == 0 && b.suborder == 0) +} + +func (b *BeeperEncodedOrder) MarshalJSON() ([]byte, error) { + return []byte(`"` + b.String() + `"`), nil +} + +func (b *BeeperEncodedOrder) UnmarshalJSON(data []byte) error { + if b == nil { + return fmt.Errorf("BeeperEncodedOrder: receiver is nil") + } + str := string(data) + if len(str) < 2 { + return fmt.Errorf("invalid encoded order string: %s", str) + } + decoded, err := BeeperEncodedOrderFromString(str[1 : len(str)-1]) + if err != nil { + return err + } + b.order, b.suborder = decoded.order, decoded.suborder + return nil +} + +// encodeIntPair encodes an int64 and an int16 into a lexicographically sortable string +func encodeIntPair(a int64, b int16) string { + // Create a buffer to hold the binary representation of the integers. + // Will need 8 bytes for the int64 and 2 bytes for the int16. + var buf [10]byte + + // Flip the sign bit of each integer to map the entire int range to uint + // in a way that preserves the order of the original integers. + // + // Explanation: + // - By XORing with (1 << 63), we flip the most significant bit (sign bit) of the int64 value. + // - Negative numbers (which have a sign bit of 1) become smaller uint64 values. + // - Non-negative numbers (with a sign bit of 0) become larger uint64 values. + // - This mapping preserves the original ordering when the uint64 values are compared. + binary.BigEndian.PutUint64(buf[0:8], uint64(a)^(1<<63)) + binary.BigEndian.PutUint16(buf[8:10], uint16(b)^(1<<15)) + + // Encode the buffer into a Base32 string without padding using the Hex encoding. + // + // Explanation: + // - Base32 encoding converts binary data into a text representation using 32 ASCII characters. + // - Using Base32HexEncoding ensures that the characters are in lexicographical order. + // - Disabling padding results in a consistent string length, which is important for sorting. + encoded := base32.HexEncoding.WithPadding(base32.NoPadding).EncodeToString(buf[:]) + + return encoded +} + +// decodeIntPair decodes a string produced by encodeIntPair back into the original int64 and int16 values +func decodeIntPair(encoded string) (int64, int16, error) { + // Decode the Base32 string back into the original byte buffer. + buf, err := base32.HexEncoding.WithPadding(base32.NoPadding).DecodeString(encoded) + if err != nil { + return 0, 0, fmt.Errorf("failed to decode string: %w", err) + } + + // Check that the decoded buffer has the expected length. + if len(buf) != 10 { + return 0, 0, fmt.Errorf("invalid encoded string length: expected 10 bytes, got %d", len(buf)) + } + + // Read the uint values from the buffer using big-endian byte order. + aPos := binary.BigEndian.Uint64(buf[0:8]) + bPos := binary.BigEndian.Uint16(buf[8:10]) + + // Reverse the sign bit flip to retrieve the original values. + a := int64(aPos ^ (1 << 63)) + b := int16(bPos ^ (1 << 15)) + + return a, b, nil } diff --git a/event/capabilities.d.ts b/event/capabilities.d.ts new file mode 100644 index 00000000..26aeb347 --- /dev/null +++ b/event/capabilities.d.ts @@ -0,0 +1,225 @@ +/** + * The content of the `com.beeper.room_features` state event. + */ +export interface RoomFeatures { + /** + * Supported formatting features. If omitted, no formatting is supported. + * + * Capability level 0 means the corresponding HTML tags/attributes are ignored + * and will be treated as if they don't exist, which means that children will + * be rendered, but attributes will be dropped. + */ + formatting?: Record + /** + * Supported file message types and their features. + * + * If a message type isn't listed here, it should be treated as support level -2 (will be rejected). + */ + file?: Record + /** + * Supported state event types and their parameters. Currently, there are no parameters, + * but it is likely there will be some in the future (like max name/topic length, avatar mime types, etc.). + * + * Events that are not listed or have a support level of zero or below should be treated as unsupported. + * + * Clients should at least check `m.room.name`, `m.room.topic`, and `m.room.avatar` here. + * `m.room.member` will not be listed here, as it's controlled by the member_actions field. + * `com.beeper.disappearing_timer` should be listed here, but the parameters are in the disappearing_timer field for now. + */ + state?: Record + /** + * Supported member actions and their support levels. + * + * Actions that are not listed or have a support level of zero or below should be treated as unsupported. + */ + member_actions?: Record + + /** Maximum length of normal text messages. */ + max_text_length?: integer + + /** Whether location messages (`m.location`) are supported. */ + location_message?: CapabilitySupportLevel + /** Whether polls are supported. */ + poll?: CapabilitySupportLevel + /** Whether replying in a thread is supported. */ + thread?: CapabilitySupportLevel + /** Whether replying to a specific message is supported. */ + reply?: CapabilitySupportLevel + + /** Whether edits are supported. */ + edit?: CapabilitySupportLevel + /** How many times can an individual message be edited. */ + edit_max_count?: integer + /** How old messages can be edited, in seconds. */ + edit_max_age?: seconds + /** Whether deleting messages for everyone is supported */ + delete?: CapabilitySupportLevel + /** How old messages can be deleted for everyone, in seconds. */ + delete_max_age?: seconds + /** Whether deleting messages just for yourself is supported. No message age limit. */ + delete_for_me?: boolean + /** Allowed configuration options for disappearing timers. */ + disappearing_timer?: DisappearingTimerCapability + + /** Whether reactions are supported. */ + reaction?: CapabilitySupportLevel + /** How many reactions can be added to a single message. */ + reaction_count?: integer + /** + * The Unicode emojis allowed for reactions. If omitted, all emojis are allowed. + * Emojis in this list must include variation selector 16 if allowed in the Unicode spec. + */ + allowed_reactions?: string[] + /** Whether custom emoji reactions are allowed. */ + custom_emoji_reactions?: boolean + + /** Whether deleting the chat for yourself is supported. */ + delete_chat?: boolean + /** Whether deleting the chat for all participants is supported. */ + delete_chat_for_everyone?: boolean + /** What can be done with message requests? */ + message_request?: { + accept_with_message?: CapabilitySupportLevel + accept_with_button?: CapabilitySupportLevel + } +} + +declare type integer = number +declare type seconds = integer +declare type milliseconds = integer +declare type MIMEClass = "image" | "audio" | "video" | "text" | "font" | "model" | "application" +declare type MIMETypeOrPattern = + "*/*" + | `${MIMEClass}/*` + | `${MIMEClass}/${string}` + | `${MIMEClass}/${string}; ${string}` + +export enum MemberAction { + Ban = "ban", + Kick = "kick", + Leave = "leave", + RevokeInvite = "revoke_invite", + Invite = "invite", +} + +declare type EventType = string + +// This is an object for future extensibility (e.g. max name/topic length) +export interface StateFeatures { + level: CapabilitySupportLevel +} + +export enum CapabilityMsgType { + // Real message types used in the `msgtype` field + Image = "m.image", + File = "m.file", + Audio = "m.audio", + Video = "m.video", + + // Pseudo types only used in capabilities + /** An `m.audio` message that has `"org.matrix.msc3245.voice": {}` */ + Voice = "org.matrix.msc3245.voice", + /** An `m.video` message that has `"info": {"fi.mau.gif": true}`, or an `m.image` message of type `image/gif` */ + GIF = "fi.mau.gif", + /** An `m.sticker` event, no `msgtype` field */ + Sticker = "m.sticker", +} + +export interface FileFeatures { + /** + * The supported MIME types or type patterns and their support levels. + * + * If a mime type doesn't match any pattern provided, + * it should be treated as support level -2 (will be rejected). + */ + mime_types: Record + + /** The support level for captions within this file message type */ + caption?: CapabilitySupportLevel + /** The maximum length for captions (only applicable if captions are supported). */ + max_caption_length?: integer + /** The maximum file size as bytes. */ + max_size?: integer + /** For images and videos, the maximum width as pixels. */ + max_width?: integer + /** For images and videos, the maximum height as pixels. */ + max_height?: integer + /** For videos and audio files, the maximum duration as seconds. */ + max_duration?: seconds + + /** Can this type of file be sent as view-once media? */ + view_once?: boolean +} + +export enum DisappearingType { + None = "", + AfterRead = "after_read", + AfterSend = "after_send", +} + +export interface DisappearingTimerCapability { + types: DisappearingType[] + /** Allowed timer values. If omitted, any timer is allowed. */ + timers?: milliseconds[] + /** + * Whether clients should omit the empty disappearing_timer object in messages that they don't want to disappear + * + * Generally, bridged rooms will want the object to be always present, while native Matrix rooms don't, + * so the hardcoded features for Matrix rooms should set this to true, while bridges will not. + */ + omit_empty_timer?: true +} + +/** + * The support level for a feature. These are integers rather than booleans + * to accurately represent what the bridge is doing and hopefully make the + * state event more generally useful. Our clients should check for > 0 to + * determine if the feature should be allowed. + */ +export enum CapabilitySupportLevel { + /** The feature is unsupported and messages using it will be rejected. */ + Rejected = -2, + /** The feature is unsupported and has no fallback. The message will go through, but data may be lost. */ + Dropped = -1, + /** The feature is unsupported, but may have a fallback. The nature of the fallback depends on the context. */ + Unsupported = 0, + /** The feature is partially supported (e.g. it may be converted to a different format). */ + PartialSupport = 1, + /** The feature is fully supported and can be safely used. */ + FullySupported = 2, +} + +/** + * A formatting feature that consists of specific HTML tags and/or attributes. + */ +export enum FormattingFeature { + Bold = "bold", // strong, b + Italic = "italic", // em, i + Underline = "underline", // u + Strikethrough = "strikethrough", // del, s + InlineCode = "inline_code", // code + CodeBlock = "code_block", // pre + code + SyntaxHighlighting = "code_block.syntax_highlighting", //

+	Blockquote = "blockquote", // blockquote
+	InlineLink = "inline_link", // a
+	UserLink = "user_link", // 
+	RoomLink = "room_link", // 
+	EventLink = "event_link", // 
+	AtRoomMention = "at_room_mention", // @room (no html tag)
+	UnorderedList = "unordered_list", // ul + li
+	OrderedList = "ordered_list", // ol + li
+	ListStart = "ordered_list.start", // 
    + ListJumpValue = "ordered_list.jump_value", //
  1. + CustomEmoji = "custom_emoji", // + Spoiler = "spoiler", // + SpoilerReason = "spoiler.reason", // + TextForegroundColor = "color.foreground", // + TextBackgroundColor = "color.background", // + HorizontalLine = "horizontal_line", // hr + Headers = "headers", // h1, h2, h3, h4, h5, h6 + Superscript = "superscript", // sup + Subscript = "subscript", // sub + Math = "math", // + DetailsSummary = "details_summary", //
    ......
    + Table = "table", // table, thead, tbody, tr, th, td +} diff --git a/event/capabilities.go b/event/capabilities.go new file mode 100644 index 00000000..a86c726b --- /dev/null +++ b/event/capabilities.go @@ -0,0 +1,414 @@ +// Copyright (c) 2024 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package event + +import ( + "crypto/sha256" + "encoding/base64" + "encoding/binary" + "fmt" + "io" + "mime" + "slices" + "strings" + + "go.mau.fi/util/exerrors" + "go.mau.fi/util/jsontime" + "go.mau.fi/util/ptr" + "golang.org/x/exp/constraints" + "golang.org/x/exp/maps" +) + +type RoomFeatures struct { + ID string `json:"id,omitempty"` + + // N.B. New fields need to be added to the Hash function to be included in the deduplication hash. + + Formatting FormattingFeatureMap `json:"formatting,omitempty"` + File FileFeatureMap `json:"file,omitempty"` + State StateFeatureMap `json:"state,omitempty"` + MemberActions MemberFeatureMap `json:"member_actions,omitempty"` + + MaxTextLength int `json:"max_text_length,omitempty"` + + LocationMessage CapabilitySupportLevel `json:"location_message,omitempty"` + Poll CapabilitySupportLevel `json:"poll,omitempty"` + Thread CapabilitySupportLevel `json:"thread,omitempty"` + Reply CapabilitySupportLevel `json:"reply,omitempty"` + + Edit CapabilitySupportLevel `json:"edit,omitempty"` + EditMaxCount int `json:"edit_max_count,omitempty"` + EditMaxAge *jsontime.Seconds `json:"edit_max_age,omitempty"` + Delete CapabilitySupportLevel `json:"delete,omitempty"` + DeleteForMe bool `json:"delete_for_me,omitempty"` + DeleteMaxAge *jsontime.Seconds `json:"delete_max_age,omitempty"` + + DisappearingTimer *DisappearingTimerCapability `json:"disappearing_timer,omitempty"` + + Reaction CapabilitySupportLevel `json:"reaction,omitempty"` + ReactionCount int `json:"reaction_count,omitempty"` + AllowedReactions []string `json:"allowed_reactions,omitempty"` + CustomEmojiReactions bool `json:"custom_emoji_reactions,omitempty"` + + ReadReceipts bool `json:"read_receipts,omitempty"` + TypingNotifications bool `json:"typing_notifications,omitempty"` + Archive bool `json:"archive,omitempty"` + MarkAsUnread bool `json:"mark_as_unread,omitempty"` + DeleteChat bool `json:"delete_chat,omitempty"` + DeleteChatForEveryone bool `json:"delete_chat_for_everyone,omitempty"` + + MessageRequest *MessageRequestFeatures `json:"message_request,omitempty"` + + PerMessageProfileRelay bool `json:"-"` +} + +func (rf *RoomFeatures) GetID() string { + if rf.ID != "" { + return rf.ID + } + return base64.RawURLEncoding.EncodeToString(rf.Hash()) +} + +func (rf *RoomFeatures) Clone() *RoomFeatures { + if rf == nil { + return nil + } + clone := *rf + clone.File = clone.File.Clone() + clone.Formatting = maps.Clone(clone.Formatting) + clone.State = clone.State.Clone() + clone.MemberActions = clone.MemberActions.Clone() + clone.EditMaxAge = ptr.Clone(clone.EditMaxAge) + clone.DeleteMaxAge = ptr.Clone(clone.DeleteMaxAge) + clone.DisappearingTimer = clone.DisappearingTimer.Clone() + clone.AllowedReactions = slices.Clone(clone.AllowedReactions) + clone.MessageRequest = clone.MessageRequest.Clone() + return &clone +} + +type MemberFeatureMap map[MemberAction]CapabilitySupportLevel + +func (mfm MemberFeatureMap) Clone() MemberFeatureMap { + return maps.Clone(mfm) +} + +type MemberAction string + +const ( + MemberActionBan MemberAction = "ban" + MemberActionKick MemberAction = "kick" + MemberActionLeave MemberAction = "leave" + MemberActionRevokeInvite MemberAction = "revoke_invite" + MemberActionInvite MemberAction = "invite" +) + +type StateFeatureMap map[string]*StateFeatures + +func (sfm StateFeatureMap) Clone() StateFeatureMap { + dup := maps.Clone(sfm) + for key, value := range dup { + dup[key] = value.Clone() + } + return dup +} + +type StateFeatures struct { + Level CapabilitySupportLevel `json:"level"` +} + +func (sf *StateFeatures) Clone() *StateFeatures { + if sf == nil { + return nil + } + clone := *sf + return &clone +} + +func (sf *StateFeatures) Hash() []byte { + return sf.Level.Hash() +} + +type FormattingFeatureMap map[FormattingFeature]CapabilitySupportLevel + +type FileFeatureMap map[CapabilityMsgType]*FileFeatures + +func (ffm FileFeatureMap) Clone() FileFeatureMap { + dup := maps.Clone(ffm) + for key, value := range dup { + dup[key] = value.Clone() + } + return dup +} + +type DisappearingTimerCapability struct { + Types []DisappearingType `json:"types"` + Timers []jsontime.Milliseconds `json:"timers,omitempty"` + + OmitEmptyTimer bool `json:"omit_empty_timer,omitempty"` +} + +func (dtc *DisappearingTimerCapability) Clone() *DisappearingTimerCapability { + if dtc == nil { + return nil + } + clone := *dtc + clone.Types = slices.Clone(clone.Types) + clone.Timers = slices.Clone(clone.Timers) + return &clone +} + +func (dtc *DisappearingTimerCapability) Supports(content *BeeperDisappearingTimer) bool { + if dtc == nil || content == nil || content.Type == DisappearingTypeNone { + return true + } + return slices.Contains(dtc.Types, content.Type) && (dtc.Timers == nil || slices.Contains(dtc.Timers, content.Timer)) +} + +type MessageRequestFeatures struct { + AcceptWithMessage CapabilitySupportLevel `json:"accept_with_message,omitempty"` + AcceptWithButton CapabilitySupportLevel `json:"accept_with_button,omitempty"` +} + +func (mrf *MessageRequestFeatures) Clone() *MessageRequestFeatures { + return ptr.Clone(mrf) +} + +func (mrf *MessageRequestFeatures) Hash() []byte { + if mrf == nil { + return nil + } + hasher := sha256.New() + hashValue(hasher, "accept_with_message", mrf.AcceptWithMessage) + hashValue(hasher, "accept_with_button", mrf.AcceptWithButton) + return hasher.Sum(nil) +} + +type CapabilityMsgType = MessageType + +// Message types which are used for event capability signaling, but aren't real values for the msgtype field. +const ( + CapMsgVoice CapabilityMsgType = "org.matrix.msc3245.voice" + CapMsgGIF CapabilityMsgType = "fi.mau.gif" + CapMsgSticker CapabilityMsgType = "m.sticker" +) + +type CapabilitySupportLevel int + +func (csl CapabilitySupportLevel) Partial() bool { + return csl >= CapLevelPartialSupport +} + +func (csl CapabilitySupportLevel) Full() bool { + return csl >= CapLevelFullySupported +} + +func (csl CapabilitySupportLevel) Reject() bool { + return csl <= CapLevelRejected +} + +const ( + CapLevelRejected CapabilitySupportLevel = -2 // The feature is unsupported and messages using it will be rejected. + CapLevelDropped CapabilitySupportLevel = -1 // The feature is unsupported and has no fallback. The message will go through, but data may be lost. + CapLevelUnsupported CapabilitySupportLevel = 0 // The feature is unsupported, but may have a fallback. + CapLevelPartialSupport CapabilitySupportLevel = 1 // The feature is partially supported (e.g. it may be converted to a different format). + CapLevelFullySupported CapabilitySupportLevel = 2 // The feature is fully supported and can be safely used. +) + +type FormattingFeature string + +const ( + FmtBold FormattingFeature = "bold" // strong, b + FmtItalic FormattingFeature = "italic" // em, i + FmtUnderline FormattingFeature = "underline" // u + FmtStrikethrough FormattingFeature = "strikethrough" // del, s + FmtInlineCode FormattingFeature = "inline_code" // code + FmtCodeBlock FormattingFeature = "code_block" // pre + code + FmtSyntaxHighlighting FormattingFeature = "code_block.syntax_highlighting" //
    
    +	FmtBlockquote          FormattingFeature = "blockquote"                     // blockquote
    +	FmtInlineLink          FormattingFeature = "inline_link"                    // a
    +	FmtUserLink            FormattingFeature = "user_link"                      // 
    +	FmtRoomLink            FormattingFeature = "room_link"                      // 
    +	FmtEventLink           FormattingFeature = "event_link"                     // 
    +	FmtAtRoomMention       FormattingFeature = "at_room_mention"                // @room (no html tag)
    +	FmtUnorderedList       FormattingFeature = "unordered_list"                 // ul + li
    +	FmtOrderedList         FormattingFeature = "ordered_list"                   // ol + li
    +	FmtListStart           FormattingFeature = "ordered_list.start"             // 
      + FmtListJumpValue FormattingFeature = "ordered_list.jump_value" //
    1. + FmtCustomEmoji FormattingFeature = "custom_emoji" // + FmtSpoiler FormattingFeature = "spoiler" // + FmtSpoilerReason FormattingFeature = "spoiler.reason" // + FmtTextForegroundColor FormattingFeature = "color.foreground" // + FmtTextBackgroundColor FormattingFeature = "color.background" // + FmtHorizontalLine FormattingFeature = "horizontal_line" // hr + FmtHeaders FormattingFeature = "headers" // h1, h2, h3, h4, h5, h6 + FmtSuperscript FormattingFeature = "superscript" // sup + FmtSubscript FormattingFeature = "subscript" // sub + FmtMath FormattingFeature = "math" // + FmtDetailsSummary FormattingFeature = "details_summary" //
      ......
      + FmtTable FormattingFeature = "table" // table, thead, tbody, tr, th, td +) + +type FileFeatures struct { + // N.B. New fields need to be added to the Hash function to be included in the deduplication hash. + + MimeTypes map[string]CapabilitySupportLevel `json:"mime_types"` + + Caption CapabilitySupportLevel `json:"caption,omitempty"` + MaxCaptionLength int `json:"max_caption_length,omitempty"` + + MaxSize int64 `json:"max_size,omitempty"` + MaxWidth int `json:"max_width,omitempty"` + MaxHeight int `json:"max_height,omitempty"` + MaxDuration *jsontime.Seconds `json:"max_duration,omitempty"` + + ViewOnce bool `json:"view_once,omitempty"` +} + +func (ff *FileFeatures) GetMimeSupport(inputType string) CapabilitySupportLevel { + match, ok := ff.MimeTypes[inputType] + if ok { + return match + } + if strings.IndexByte(inputType, ';') != -1 { + plainMime, _, _ := mime.ParseMediaType(inputType) + if plainMime != "" { + if match, ok = ff.MimeTypes[plainMime]; ok { + return match + } + } + } + if slash := strings.IndexByte(inputType, '/'); slash > 0 { + generalType := fmt.Sprintf("%s/*", inputType[:slash]) + if match, ok = ff.MimeTypes[generalType]; ok { + return match + } + } + match, ok = ff.MimeTypes["*/*"] + if ok { + return match + } + return CapLevelRejected +} + +type hashable interface { + Hash() []byte +} + +func hashMap[Key ~string, Value hashable](w io.Writer, name string, data map[Key]Value) { + keys := maps.Keys(data) + slices.Sort(keys) + exerrors.Must(w.Write([]byte(name))) + for _, key := range keys { + exerrors.Must(w.Write([]byte(key))) + exerrors.Must(w.Write(data[key].Hash())) + exerrors.Must(w.Write([]byte{0})) + } +} + +func hashValue(w io.Writer, name string, data hashable) { + exerrors.Must(w.Write([]byte(name))) + exerrors.Must(w.Write(data.Hash())) +} + +func hashInt[T constraints.Integer](w io.Writer, name string, data T) { + exerrors.Must(w.Write(binary.BigEndian.AppendUint64([]byte(name), uint64(data)))) +} + +func hashBool[T ~bool](w io.Writer, name string, data T) { + exerrors.Must(w.Write([]byte(name))) + if data { + exerrors.Must(w.Write([]byte{1})) + } else { + exerrors.Must(w.Write([]byte{0})) + } +} + +func (csl CapabilitySupportLevel) Hash() []byte { + return []byte{byte(csl + 128)} +} + +func (rf *RoomFeatures) Hash() []byte { + hasher := sha256.New() + + hashMap(hasher, "formatting", rf.Formatting) + hashMap(hasher, "file", rf.File) + hashMap(hasher, "state", rf.State) + hashMap(hasher, "member_actions", rf.MemberActions) + + hashInt(hasher, "max_text_length", rf.MaxTextLength) + + hashValue(hasher, "location_message", rf.LocationMessage) + hashValue(hasher, "poll", rf.Poll) + hashValue(hasher, "thread", rf.Thread) + hashValue(hasher, "reply", rf.Reply) + + hashValue(hasher, "edit", rf.Edit) + hashInt(hasher, "edit_max_count", rf.EditMaxCount) + hashInt(hasher, "edit_max_age", rf.EditMaxAge.Get()) + + hashValue(hasher, "delete", rf.Delete) + hashBool(hasher, "delete_for_me", rf.DeleteForMe) + hashInt(hasher, "delete_max_age", rf.DeleteMaxAge.Get()) + hashValue(hasher, "disappearing_timer", rf.DisappearingTimer) + + hashValue(hasher, "reaction", rf.Reaction) + hashInt(hasher, "reaction_count", rf.ReactionCount) + hasher.Write([]byte("allowed_reactions")) + for _, reaction := range rf.AllowedReactions { + hasher.Write([]byte(reaction)) + } + hashBool(hasher, "custom_emoji_reactions", rf.CustomEmojiReactions) + + hashBool(hasher, "read_receipts", rf.ReadReceipts) + hashBool(hasher, "typing_notifications", rf.TypingNotifications) + hashBool(hasher, "archive", rf.Archive) + hashBool(hasher, "mark_as_unread", rf.MarkAsUnread) + hashBool(hasher, "delete_chat", rf.DeleteChat) + hashBool(hasher, "delete_chat_for_everyone", rf.DeleteChatForEveryone) + hashValue(hasher, "message_request", rf.MessageRequest) + + return hasher.Sum(nil) +} + +func (dtc *DisappearingTimerCapability) Hash() []byte { + if dtc == nil { + return nil + } + hasher := sha256.New() + hasher.Write([]byte("types")) + for _, t := range dtc.Types { + hasher.Write([]byte(t)) + } + hasher.Write([]byte("timers")) + for _, timer := range dtc.Timers { + hashInt(hasher, "", timer.Milliseconds()) + } + return hasher.Sum(nil) +} + +func (ff *FileFeatures) Hash() []byte { + hasher := sha256.New() + hashMap(hasher, "mime_types", ff.MimeTypes) + hashValue(hasher, "caption", ff.Caption) + hashInt(hasher, "max_caption_length", ff.MaxCaptionLength) + hashInt(hasher, "max_size", ff.MaxSize) + hashInt(hasher, "max_width", ff.MaxWidth) + hashInt(hasher, "max_height", ff.MaxHeight) + hashInt(hasher, "max_duration", ff.MaxDuration.Get()) + hashBool(hasher, "view_once", ff.ViewOnce) + return hasher.Sum(nil) +} + +func (ff *FileFeatures) Clone() *FileFeatures { + if ff == nil { + return nil + } + clone := *ff + clone.MimeTypes = maps.Clone(clone.MimeTypes) + clone.MaxDuration = ptr.Clone(clone.MaxDuration) + return &clone +} diff --git a/event/cmdschema/content.go b/event/cmdschema/content.go new file mode 100644 index 00000000..ce07c4c0 --- /dev/null +++ b/event/cmdschema/content.go @@ -0,0 +1,78 @@ +// Copyright (c) 2026 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package cmdschema + +import ( + "crypto/sha256" + "encoding/base64" + "fmt" + "reflect" + "slices" + + "go.mau.fi/util/exsync" + "go.mau.fi/util/ptr" + + "maunium.net/go/mautrix/event" + "maunium.net/go/mautrix/id" +) + +type EventContent struct { + Command string `json:"command"` + Aliases []string `json:"aliases,omitempty"` + Parameters []*Parameter `json:"parameters,omitempty"` + Description *event.ExtensibleTextContainer `json:"description,omitempty"` + TailParam string `json:"fi.mau.tail_parameter,omitempty"` +} + +func (ec *EventContent) Validate() error { + if ec == nil { + return fmt.Errorf("event content is nil") + } else if ec.Command == "" { + return fmt.Errorf("command is empty") + } + var tailFound bool + dupMap := exsync.NewSet[string]() + for i, p := range ec.Parameters { + if err := p.Validate(); err != nil { + return fmt.Errorf("parameter %q (#%d) is invalid: %w", ptr.Val(p).Key, i+1, err) + } else if !dupMap.Add(p.Key) { + return fmt.Errorf("duplicate parameter key %q at #%d", p.Key, i+1) + } else if p.Key == ec.TailParam { + tailFound = true + } else if tailFound && !p.Optional { + return fmt.Errorf("required parameter %q (#%d) is after tail parameter %q", p.Key, i+1, ec.TailParam) + } + } + if ec.TailParam != "" && !tailFound { + return fmt.Errorf("tail parameter %q not found in parameters", ec.TailParam) + } + return nil +} + +func (ec *EventContent) IsValid() bool { + return ec.Validate() == nil +} + +func (ec *EventContent) StateKey(owner id.UserID) string { + hash := sha256.Sum256([]byte(ec.Command + owner.String())) + return base64.StdEncoding.EncodeToString(hash[:]) +} + +func (ec *EventContent) Equals(other *EventContent) bool { + if ec == nil || other == nil { + return ec == other + } + return ec.Command == other.Command && + slices.Equal(ec.Aliases, other.Aliases) && + slices.EqualFunc(ec.Parameters, other.Parameters, (*Parameter).Equals) && + ec.Description.Equals(other.Description) && + ec.TailParam == other.TailParam +} + +func init() { + event.TypeMap[event.StateMSC4391BotCommand] = reflect.TypeOf(EventContent{}) +} diff --git a/event/cmdschema/parameter.go b/event/cmdschema/parameter.go new file mode 100644 index 00000000..4193b297 --- /dev/null +++ b/event/cmdschema/parameter.go @@ -0,0 +1,286 @@ +// Copyright (c) 2026 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package cmdschema + +import ( + "encoding/json" + "fmt" + "slices" + + "go.mau.fi/util/exslices" + + "maunium.net/go/mautrix/event" +) + +type Parameter struct { + Key string `json:"key"` + Schema *ParameterSchema `json:"schema"` + Optional bool `json:"optional,omitempty"` + Description *event.ExtensibleTextContainer `json:"description,omitempty"` + DefaultValue any `json:"fi.mau.default_value,omitempty"` +} + +func (p *Parameter) Equals(other *Parameter) bool { + if p == nil || other == nil { + return p == other + } + return p.Key == other.Key && + p.Schema.Equals(other.Schema) && + p.Optional == other.Optional && + p.Description.Equals(other.Description) && + p.DefaultValue == other.DefaultValue // TODO this won't work for room/event ID values +} + +func (p *Parameter) Validate() error { + if p == nil { + return fmt.Errorf("parameter is nil") + } else if p.Key == "" { + return fmt.Errorf("key is empty") + } + return p.Schema.Validate() +} + +func (p *Parameter) IsValid() bool { + return p.Validate() == nil +} + +func (p *Parameter) GetDefaultValue() any { + if p != nil && p.DefaultValue != nil { + return p.DefaultValue + } else if p == nil || p.Optional { + return nil + } + return p.Schema.GetDefaultValue() +} + +type PrimitiveType string + +const ( + PrimitiveTypeString PrimitiveType = "string" + PrimitiveTypeInteger PrimitiveType = "integer" + PrimitiveTypeBoolean PrimitiveType = "boolean" + PrimitiveTypeServerName PrimitiveType = "server_name" + PrimitiveTypeUserID PrimitiveType = "user_id" + PrimitiveTypeRoomID PrimitiveType = "room_id" + PrimitiveTypeRoomAlias PrimitiveType = "room_alias" + PrimitiveTypeEventID PrimitiveType = "event_id" +) + +func (pt PrimitiveType) Schema() *ParameterSchema { + return &ParameterSchema{ + SchemaType: SchemaTypePrimitive, + Type: pt, + } +} + +func (pt PrimitiveType) IsValid() bool { + switch pt { + case PrimitiveTypeString, + PrimitiveTypeInteger, + PrimitiveTypeBoolean, + PrimitiveTypeServerName, + PrimitiveTypeUserID, + PrimitiveTypeRoomID, + PrimitiveTypeRoomAlias, + PrimitiveTypeEventID: + return true + default: + return false + } +} + +type SchemaType string + +const ( + SchemaTypePrimitive SchemaType = "primitive" + SchemaTypeArray SchemaType = "array" + SchemaTypeUnion SchemaType = "union" + SchemaTypeLiteral SchemaType = "literal" +) + +type ParameterSchema struct { + SchemaType SchemaType `json:"schema_type"` + Type PrimitiveType `json:"type,omitempty"` // Only for primitive + Items *ParameterSchema `json:"items,omitempty"` // Only for array + Variants []*ParameterSchema `json:"variants,omitempty"` // Only for union + Value any `json:"value,omitempty"` // Only for literal +} + +func Literal(value any) *ParameterSchema { + return &ParameterSchema{ + SchemaType: SchemaTypeLiteral, + Value: value, + } +} + +func Enum(values ...any) *ParameterSchema { + return Union(exslices.CastFunc(values, Literal)...) +} + +func flattenUnion(variants []*ParameterSchema) []*ParameterSchema { + var flattened []*ParameterSchema + for _, variant := range variants { + switch variant.SchemaType { + case SchemaTypeArray: + panic(fmt.Errorf("illegal array schema in union")) + case SchemaTypeUnion: + flattened = append(flattened, flattenUnion(variant.Variants)...) + default: + flattened = append(flattened, variant) + } + } + return flattened +} + +func Union(variants ...*ParameterSchema) *ParameterSchema { + needsFlattening := false + for _, variant := range variants { + if variant.SchemaType == SchemaTypeArray { + panic(fmt.Errorf("illegal array schema in union")) + } else if variant.SchemaType == SchemaTypeUnion { + needsFlattening = true + } + } + if needsFlattening { + variants = flattenUnion(variants) + } + return &ParameterSchema{ + SchemaType: SchemaTypeUnion, + Variants: variants, + } +} + +func Array(items *ParameterSchema) *ParameterSchema { + if items.SchemaType == SchemaTypeArray { + panic(fmt.Errorf("illegal array schema in array")) + } + return &ParameterSchema{ + SchemaType: SchemaTypeArray, + Items: items, + } +} + +func (ps *ParameterSchema) GetDefaultValue() any { + if ps == nil { + return nil + } + switch ps.SchemaType { + case SchemaTypePrimitive: + switch ps.Type { + case PrimitiveTypeInteger: + return 0 + case PrimitiveTypeBoolean: + return false + default: + return "" + } + case SchemaTypeArray: + return []any{} + case SchemaTypeUnion: + if len(ps.Variants) > 0 { + return ps.Variants[0].GetDefaultValue() + } + return nil + case SchemaTypeLiteral: + return ps.Value + default: + return nil + } +} + +func (ps *ParameterSchema) IsValid() bool { + return ps.validate("") == nil +} + +func (ps *ParameterSchema) Validate() error { + return ps.validate("") +} + +func (ps *ParameterSchema) validate(parent SchemaType) error { + if ps == nil { + return fmt.Errorf("schema is nil") + } + switch ps.SchemaType { + case SchemaTypePrimitive: + if !ps.Type.IsValid() { + return fmt.Errorf("invalid primitive type %s", ps.Type) + } else if ps.Items != nil || ps.Variants != nil || ps.Value != nil { + return fmt.Errorf("primitive schema has extra fields") + } + return nil + case SchemaTypeArray: + if parent != "" { + return fmt.Errorf("arrays can't be nested in other types") + } else if err := ps.Items.validate(ps.SchemaType); err != nil { + return fmt.Errorf("item schema is invalid: %w", err) + } else if ps.Type != "" || ps.Variants != nil || ps.Value != nil { + return fmt.Errorf("array schema has extra fields") + } + return nil + case SchemaTypeUnion: + if len(ps.Variants) == 0 { + return fmt.Errorf("no variants specified for union") + } else if parent != "" && parent != SchemaTypeArray { + return fmt.Errorf("unions can't be nested in anything other than arrays") + } + for i, v := range ps.Variants { + if err := v.validate(ps.SchemaType); err != nil { + return fmt.Errorf("variant #%d is invalid: %w", i+1, err) + } + } + if ps.Type != "" || ps.Items != nil || ps.Value != nil { + return fmt.Errorf("union schema has extra fields") + } + return nil + case SchemaTypeLiteral: + switch typedVal := ps.Value.(type) { + case string, float64, int, int64, json.Number, bool, RoomIDValue, *RoomIDValue: + // ok + case map[string]any: + if typedVal["type"] != "event_id" && typedVal["type"] != "room_id" { + return fmt.Errorf("literal value has invalid map data") + } + default: + return fmt.Errorf("literal value has unsupported type %T", ps.Value) + } + if ps.Type != "" || ps.Items != nil || ps.Variants != nil { + return fmt.Errorf("literal schema has extra fields") + } + return nil + default: + return fmt.Errorf("invalid schema type %s", ps.SchemaType) + } +} + +func (ps *ParameterSchema) Equals(other *ParameterSchema) bool { + if ps == nil || other == nil { + return ps == other + } + return ps.SchemaType == other.SchemaType && + ps.Type == other.Type && + ps.Items.Equals(other.Items) && + slices.EqualFunc(ps.Variants, other.Variants, (*ParameterSchema).Equals) && + ps.Value == other.Value // TODO this won't work for room/event ID values +} + +func (ps *ParameterSchema) AllowsPrimitive(prim PrimitiveType) bool { + switch ps.SchemaType { + case SchemaTypePrimitive: + return ps.Type == prim + case SchemaTypeUnion: + for _, variant := range ps.Variants { + if variant.AllowsPrimitive(prim) { + return true + } + } + return false + case SchemaTypeArray: + return ps.Items.AllowsPrimitive(prim) + default: + return false + } +} diff --git a/event/cmdschema/parse.go b/event/cmdschema/parse.go new file mode 100644 index 00000000..92e69b60 --- /dev/null +++ b/event/cmdschema/parse.go @@ -0,0 +1,478 @@ +// Copyright (c) 2026 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package cmdschema + +import ( + "encoding/json" + "errors" + "fmt" + "regexp" + "strconv" + "strings" + + "maunium.net/go/mautrix/event" + "maunium.net/go/mautrix/id" +) + +const botArrayOpener = "<" +const botArrayCloser = ">" + +func parseQuoted(val string) (parsed, remaining string, quoted bool) { + if len(val) == 0 { + return + } + if !strings.HasPrefix(val, `"`) { + spaceIdx := strings.IndexByte(val, ' ') + if spaceIdx == -1 { + parsed = val + } else { + parsed = val[:spaceIdx] + remaining = strings.TrimLeft(val[spaceIdx+1:], " ") + } + return + } + val = val[1:] + var buf strings.Builder + for { + quoteIdx := strings.IndexByte(val, '"') + var valUntilQuote string + if quoteIdx == -1 { + valUntilQuote = val + } else { + valUntilQuote = val[:quoteIdx] + } + escapeIdx := strings.IndexByte(valUntilQuote, '\\') + if escapeIdx >= 0 { + buf.WriteString(val[:escapeIdx]) + if len(val) > escapeIdx+1 { + buf.WriteByte(val[escapeIdx+1]) + } + val = val[min(escapeIdx+2, len(val)):] + } else if quoteIdx >= 0 { + buf.WriteString(val[:quoteIdx]) + val = val[quoteIdx+1:] + break + } else if buf.Len() == 0 { + // Unterminated quote, no escape characters, val is the whole input + return val, "", true + } else { + // Unterminated quote, but there were escape characters previously + buf.WriteString(val) + val = "" + break + } + } + return buf.String(), strings.TrimLeft(val, " "), true +} + +// ParseInput tries to parse the given text into a bot command event matching this command definition. +// +// If the prefix doesn't match, this will return a nil content and nil error. +// If the prefix does match, some content is always returned, but there may still be an error if parsing failed. +func (ec *EventContent) ParseInput(owner id.UserID, sigils []string, input string) (content *event.MessageEventContent, err error) { + prefix := ec.parsePrefix(input, sigils, owner.String()) + if prefix == "" { + return nil, nil + } + content = &event.MessageEventContent{ + MsgType: event.MsgText, + Body: input, + Mentions: &event.Mentions{UserIDs: []id.UserID{owner}}, + MSC4391BotCommand: &event.MSC4391BotCommandInput{ + Command: ec.Command, + }, + } + content.MSC4391BotCommand.Arguments, err = ec.ParseArguments(input[len(prefix):]) + return content, err +} + +func (ec *EventContent) ParseArguments(input string) (json.RawMessage, error) { + args := make(map[string]any) + var retErr error + setError := func(err error) { + if err != nil && retErr == nil { + retErr = err + } + } + processParameter := func(param *Parameter, isLast, isTail, isNamed bool) { + origInput := input + var nextVal string + var wasQuoted bool + if param.Schema.SchemaType == SchemaTypeArray { + hasOpener := strings.HasPrefix(input, botArrayOpener) + arrayClosed := false + if hasOpener { + input = input[len(botArrayOpener):] + if strings.HasPrefix(input, botArrayCloser) { + input = strings.TrimLeft(input[len(botArrayCloser):], " ") + arrayClosed = true + } + } + var collector []any + for len(input) > 0 && !arrayClosed { + //origInput = input + nextVal, input, wasQuoted = parseQuoted(input) + if !wasQuoted && hasOpener && strings.HasSuffix(nextVal, botArrayCloser) { + // The value wasn't quoted and has the array delimiter at the end, close the array + nextVal = strings.TrimRight(nextVal, botArrayCloser) + arrayClosed = true + } else if hasOpener && strings.HasPrefix(input, botArrayCloser) { + // The value was quoted or there was a space, and the next character is the + // array delimiter, close the array + input = strings.TrimLeft(input[len(botArrayCloser):], " ") + arrayClosed = true + } else if !hasOpener && !isLast { + // For array arguments in the middle without the <> delimiters, stop after the first item + arrayClosed = true + } + parsedVal, err := param.Schema.Items.ParseString(nextVal) + if err == nil { + collector = append(collector, parsedVal) + } else if hasOpener || isLast { + setError(fmt.Errorf("failed to parse item #%d of array %s: %w", len(collector)+1, param.Key, err)) + } else { + //input = origInput + } + } + args[param.Key] = collector + } else { + nextVal, input, wasQuoted = parseQuoted(input) + if (isLast || isTail) && !wasQuoted && len(input) > 0 { + // If the last argument is not quoted, just treat the rest of the string + // as the argument without escapes (arguments with escapes should be quoted). + nextVal += " " + input + input = "" + } + // Special case for named boolean parameters: if no value is given, treat it as true + if nextVal == "" && !wasQuoted && isNamed && param.Schema.AllowsPrimitive(PrimitiveTypeBoolean) { + args[param.Key] = true + return + } + if nextVal == "" && !wasQuoted && !isNamed && !param.Optional { + setError(fmt.Errorf("missing value for required parameter %s", param.Key)) + } + parsedVal, err := param.Schema.ParseString(nextVal) + if err != nil { + args[param.Key] = param.GetDefaultValue() + // For optional parameters that fail to parse, restore the input and try passing it as the next parameter + if param.Optional && !isLast && !isNamed { + input = strings.TrimLeft(origInput, " ") + } else if !param.Optional || isNamed { + setError(fmt.Errorf("failed to parse %s: %w", param.Key, err)) + } + } else { + args[param.Key] = parsedVal + } + } + } + skipParams := make([]bool, len(ec.Parameters)) + for i, param := range ec.Parameters { + for strings.HasPrefix(input, "--") { + nameEndIdx := strings.IndexAny(input, " =") + if nameEndIdx == -1 { + nameEndIdx = len(input) + } + overrideParam, paramIdx := ec.parameterByName(input[2:nameEndIdx]) + if overrideParam != nil { + // Trim the equals sign, but leave spaces alone to let parseQuoted treat it as empty input + input = strings.TrimPrefix(input[nameEndIdx:], "=") + skipParams[paramIdx] = true + processParameter(overrideParam, false, false, true) + } else { + break + } + } + isTail := param.Key == ec.TailParam + if skipParams[i] || (param.Optional && !isTail) { + continue + } + processParameter(param, i == len(ec.Parameters)-1, isTail, false) + } + jsonArgs, marshalErr := json.Marshal(args) + if marshalErr != nil { + return nil, fmt.Errorf("failed to marshal arguments: %w", marshalErr) + } + return jsonArgs, retErr +} + +func (ec *EventContent) parameterByName(name string) (*Parameter, int) { + for i, param := range ec.Parameters { + if strings.EqualFold(param.Key, name) { + return param, i + } + } + return nil, -1 +} + +func (ec *EventContent) parsePrefix(origInput string, sigils []string, owner string) (prefix string) { + input := origInput + var chosenSigil string + for _, sigil := range sigils { + if strings.HasPrefix(input, sigil) { + chosenSigil = sigil + break + } + } + if chosenSigil == "" { + return "" + } + input = input[len(chosenSigil):] + var chosenAlias string + if !strings.HasPrefix(input, ec.Command) { + for _, alias := range ec.Aliases { + if strings.HasPrefix(input, alias) { + chosenAlias = alias + break + } + } + if chosenAlias == "" { + return "" + } + } else { + chosenAlias = ec.Command + } + input = strings.TrimPrefix(input[len(chosenAlias):], owner) + if input == "" || input[0] == ' ' { + input = strings.TrimLeft(input, " ") + return origInput[:len(origInput)-len(input)] + } + return "" +} + +func (pt PrimitiveType) ValidateValue(value any) bool { + _, err := pt.NormalizeValue(value) + return err == nil +} + +func normalizeNumber(value any) (int, error) { + switch typedValue := value.(type) { + case int: + return typedValue, nil + case int64: + return int(typedValue), nil + case float64: + return int(typedValue), nil + case json.Number: + if i, err := typedValue.Int64(); err != nil { + return 0, fmt.Errorf("failed to parse json.Number: %w", err) + } else { + return int(i), nil + } + default: + return 0, fmt.Errorf("unsupported type %T for integer", value) + } +} + +func (pt PrimitiveType) NormalizeValue(value any) (any, error) { + switch pt { + case PrimitiveTypeInteger: + return normalizeNumber(value) + case PrimitiveTypeBoolean: + bv, ok := value.(bool) + if !ok { + return nil, fmt.Errorf("unsupported type %T for boolean", value) + } + return bv, nil + case PrimitiveTypeString, PrimitiveTypeServerName: + str, ok := value.(string) + if !ok { + return nil, fmt.Errorf("unsupported type %T for string", value) + } + return str, pt.validateStringValue(str) + case PrimitiveTypeUserID, PrimitiveTypeRoomAlias: + str, ok := value.(string) + if !ok { + return nil, fmt.Errorf("unsupported type %T for user ID or room alias", value) + } else if plainErr := pt.validateStringValue(str); plainErr == nil { + return str, nil + } else if parsed, err := id.ParseMatrixURIOrMatrixToURL(str); err != nil { + return nil, fmt.Errorf("couldn't parse %q as plain ID nor matrix URI: %w / %w", value, plainErr, err) + } else if parsed.Sigil1 == '@' && pt == PrimitiveTypeUserID { + return parsed.UserID(), nil + } else if parsed.Sigil1 == '#' && pt == PrimitiveTypeRoomAlias { + return parsed.RoomAlias(), nil + } else { + return nil, fmt.Errorf("unexpected sigil %c for user ID or room alias", parsed.Sigil1) + } + case PrimitiveTypeRoomID, PrimitiveTypeEventID: + riv, err := NormalizeRoomIDValue(value) + if err != nil { + return nil, err + } + return riv, riv.Validate() + default: + return nil, fmt.Errorf("cannot normalize value for argument type %s", pt) + } +} + +func (pt PrimitiveType) validateStringValue(value string) error { + switch pt { + case PrimitiveTypeString: + return nil + case PrimitiveTypeServerName: + if !id.ValidateServerName(value) { + return fmt.Errorf("invalid server name: %q", value) + } + return nil + case PrimitiveTypeUserID: + _, _, err := id.UserID(value).ParseAndValidateRelaxed() + return err + case PrimitiveTypeRoomAlias: + sigil, localpart, serverName := id.ParseCommonIdentifier(value) + if sigil != '#' || localpart == "" || serverName == "" { + return fmt.Errorf("invalid room alias: %q", value) + } else if !id.ValidateServerName(serverName) { + return fmt.Errorf("invalid server name in room alias: %q", serverName) + } + return nil + default: + panic(fmt.Errorf("validateStringValue called with invalid type %s", pt)) + } +} + +func parseBoolean(val string) (bool, error) { + if len(val) == 0 { + return false, fmt.Errorf("cannot parse empty string as boolean") + } + switch strings.ToLower(val) { + case "t", "true", "y", "yes", "1": + return true, nil + case "f", "false", "n", "no", "0": + return false, nil + default: + return false, fmt.Errorf("invalid boolean string: %q", val) + } +} + +var markdownLinkRegex = regexp.MustCompile(`^\[.+]\(([^)]+)\)$`) + +func parseRoomOrEventID(value string) (*RoomIDValue, error) { + if strings.HasPrefix(value, "[") && strings.Contains(value, "](") && strings.HasSuffix(value, ")") { + matches := markdownLinkRegex.FindStringSubmatch(value) + if len(matches) == 2 { + value = matches[1] + } + } + parsed, err := id.ParseMatrixURIOrMatrixToURL(value) + if err != nil && strings.HasPrefix(value, "!") { + return &RoomIDValue{ + Type: PrimitiveTypeRoomID, + RoomID: id.RoomID(value), + }, nil + } + if err != nil { + return nil, err + } else if parsed.Sigil1 != '!' { + return nil, fmt.Errorf("unexpected sigil %c for room ID", parsed.Sigil1) + } else if parsed.MXID2 != "" && parsed.Sigil2 != '$' { + return nil, fmt.Errorf("unexpected sigil %c for event ID", parsed.Sigil2) + } + valType := PrimitiveTypeRoomID + if parsed.MXID2 != "" { + valType = PrimitiveTypeEventID + } + return &RoomIDValue{ + Type: valType, + RoomID: parsed.RoomID(), + Via: parsed.Via, + EventID: parsed.EventID(), + }, nil +} + +func (pt PrimitiveType) ParseString(value string) (any, error) { + switch pt { + case PrimitiveTypeInteger: + return strconv.Atoi(value) + case PrimitiveTypeBoolean: + return parseBoolean(value) + case PrimitiveTypeString, PrimitiveTypeServerName, PrimitiveTypeUserID: + return value, pt.validateStringValue(value) + case PrimitiveTypeRoomAlias: + plainErr := pt.validateStringValue(value) + if plainErr == nil { + return value, nil + } + parsed, err := id.ParseMatrixURIOrMatrixToURL(value) + if err != nil { + return nil, fmt.Errorf("couldn't parse %q as plain room alias nor matrix URI: %w / %w", value, plainErr, err) + } else if parsed.Sigil1 != '#' { + return nil, fmt.Errorf("unexpected sigil %c for room alias", parsed.Sigil1) + } + return parsed.RoomAlias(), nil + case PrimitiveTypeRoomID, PrimitiveTypeEventID: + parsed, err := parseRoomOrEventID(value) + if err != nil { + return nil, err + } else if pt != parsed.Type { + return nil, fmt.Errorf("mismatching argument type: expected %s but got %s", pt, parsed.Type) + } + return parsed, nil + default: + return nil, fmt.Errorf("cannot parse string for argument type %s", pt) + } +} + +func (ps *ParameterSchema) ParseString(value string) (any, error) { + if ps == nil { + return nil, fmt.Errorf("parameter schema is nil") + } + switch ps.SchemaType { + case SchemaTypePrimitive: + return ps.Type.ParseString(value) + case SchemaTypeLiteral: + switch typedValue := ps.Value.(type) { + case string: + if value == typedValue { + return typedValue, nil + } else { + return nil, fmt.Errorf("literal value %q does not match %q", typedValue, value) + } + case int, int64, float64, json.Number: + expectedVal, _ := normalizeNumber(typedValue) + intVal, err := strconv.Atoi(value) + if err != nil { + return nil, fmt.Errorf("failed to parse integer literal: %w", err) + } else if intVal != expectedVal { + return nil, fmt.Errorf("literal value %d does not match %d", expectedVal, intVal) + } + return intVal, nil + case bool: + boolVal, err := parseBoolean(value) + if err != nil { + return nil, fmt.Errorf("failed to parse boolean literal: %w", err) + } else if boolVal != typedValue { + return nil, fmt.Errorf("literal value %t does not match %t", typedValue, boolVal) + } + return boolVal, nil + case RoomIDValue, *RoomIDValue, map[string]any, json.RawMessage: + expectedVal, _ := NormalizeRoomIDValue(typedValue) + parsed, err := parseRoomOrEventID(value) + if err != nil { + return nil, fmt.Errorf("failed to parse room or event ID literal: %w", err) + } else if !parsed.Equals(expectedVal) { + return nil, fmt.Errorf("literal value %s does not match %s", expectedVal, parsed) + } + return parsed, nil + default: + return nil, fmt.Errorf("unsupported literal type %T", ps.Value) + } + case SchemaTypeUnion: + var errs []error + for _, variant := range ps.Variants { + if parsed, err := variant.ParseString(value); err == nil { + return parsed, nil + } else { + errs = append(errs, err) + } + } + return nil, fmt.Errorf("no union variant matched: %w", errors.Join(errs...)) + case SchemaTypeArray: + return nil, fmt.Errorf("cannot parse string for array schema type") + default: + return nil, fmt.Errorf("unknown schema type %s", ps.SchemaType) + } +} diff --git a/event/cmdschema/parse_test.go b/event/cmdschema/parse_test.go new file mode 100644 index 00000000..1e0d1817 --- /dev/null +++ b/event/cmdschema/parse_test.go @@ -0,0 +1,118 @@ +// Copyright (c) 2026 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package cmdschema + +import ( + "bytes" + "encoding/json" + "strings" + "testing" + + "github.com/stretchr/testify/assert" + "go.mau.fi/util/exbytes" + "go.mau.fi/util/exerrors" + + "maunium.net/go/mautrix/event/cmdschema/testdata" +) + +type QuoteParseOutput struct { + Parsed string + Remaining string + Quoted bool +} + +func (qpo *QuoteParseOutput) UnmarshalJSON(data []byte) error { + var arr []any + if err := json.Unmarshal(data, &arr); err != nil { + return err + } + qpo.Parsed = arr[0].(string) + qpo.Remaining = arr[1].(string) + qpo.Quoted = arr[2].(bool) + return nil +} + +type QuoteParseTestData struct { + Name string `json:"name"` + Input string `json:"input"` + Output QuoteParseOutput `json:"output"` +} + +func loadFile[T any](name string) (into T) { + quoteData := exerrors.Must(testdata.FS.ReadFile(name)) + exerrors.PanicIfNotNil(json.Unmarshal(quoteData, &into)) + return +} + +func TestParseQuoted(t *testing.T) { + qptd := loadFile[[]QuoteParseTestData]("parse_quote.json") + for _, test := range qptd { + t.Run(test.Name, func(t *testing.T) { + parsed, remaining, quoted := parseQuoted(test.Input) + assert.Equalf(t, test.Output, QuoteParseOutput{ + Parsed: parsed, + Remaining: remaining, + Quoted: quoted, + }, "Failed with input `%s`", test.Input) + // Note: can't just test that requoted == input, because some inputs + // have unnecessary escapes which won't survive roundtripping + t.Run("roundtrip", func(t *testing.T) { + requoted := quoteString(parsed) + " " + remaining + reparsed, newRemaining, _ := parseQuoted(requoted) + assert.Equal(t, parsed, reparsed) + assert.Equal(t, remaining, newRemaining) + }) + }) + } +} + +type CommandTestData struct { + Spec *EventContent + Tests []*CommandTestUnit +} + +type CommandTestUnit struct { + Name string `json:"name"` + Input string `json:"input"` + Broken string `json:"broken,omitempty"` + Error bool `json:"error"` + Output json.RawMessage `json:"output"` +} + +func compactJSON(input json.RawMessage) json.RawMessage { + var buf bytes.Buffer + exerrors.PanicIfNotNil(json.Compact(&buf, input)) + return buf.Bytes() +} + +func TestMSC4391BotCommandEventContent_ParseInput(t *testing.T) { + for _, cmd := range exerrors.Must(testdata.FS.ReadDir("commands")) { + t.Run(strings.TrimSuffix(cmd.Name(), ".json"), func(t *testing.T) { + ctd := loadFile[CommandTestData]("commands/" + cmd.Name()) + for _, test := range ctd.Tests { + outputStr := exbytes.UnsafeString(compactJSON(test.Output)) + t.Run(test.Name, func(t *testing.T) { + if test.Broken != "" { + t.Skip(test.Broken) + } + output, err := ctd.Spec.ParseInput("@testbot", []string{"/"}, test.Input) + if test.Error { + assert.Error(t, err) + } else { + assert.NoError(t, err) + } + if outputStr == "null" { + assert.Nil(t, output) + } else { + assert.Equal(t, ctd.Spec.Command, output.MSC4391BotCommand.Command) + assert.Equalf(t, outputStr, exbytes.UnsafeString(output.MSC4391BotCommand.Arguments), "Input: %s", test.Input) + } + }) + } + }) + } +} diff --git a/event/cmdschema/roomid.go b/event/cmdschema/roomid.go new file mode 100644 index 00000000..98c421fc --- /dev/null +++ b/event/cmdschema/roomid.go @@ -0,0 +1,135 @@ +// Copyright (c) 2026 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package cmdschema + +import ( + "encoding/json" + "fmt" + "slices" + "strings" + + "maunium.net/go/mautrix/id" +) + +var ParameterSchemaJoinableRoom = Union( + PrimitiveTypeRoomID.Schema(), + PrimitiveTypeRoomAlias.Schema(), +) + +type RoomIDValue struct { + Type PrimitiveType `json:"type"` + RoomID id.RoomID `json:"id"` + Via []string `json:"via,omitempty"` + EventID id.EventID `json:"event_id,omitempty"` +} + +func NormalizeRoomIDValue(input any) (riv *RoomIDValue, err error) { + switch typedValue := input.(type) { + case map[string]any, json.RawMessage: + var raw json.RawMessage + if raw, err = json.Marshal(input); err != nil { + err = fmt.Errorf("failed to roundtrip room ID value: %w", err) + } else if err = json.Unmarshal(raw, &riv); err != nil { + err = fmt.Errorf("failed to roundtrip room ID value: %w", err) + } + case *RoomIDValue: + riv = typedValue + case RoomIDValue: + riv = &typedValue + default: + err = fmt.Errorf("unsupported type %T for room or event ID", input) + } + return +} + +func (riv *RoomIDValue) String() string { + return riv.URI().String() +} + +func (riv *RoomIDValue) URI() *id.MatrixURI { + if riv == nil { + return nil + } + switch riv.Type { + case PrimitiveTypeRoomID: + return riv.RoomID.URI(riv.Via...) + case PrimitiveTypeEventID: + return riv.RoomID.EventURI(riv.EventID, riv.Via...) + default: + return nil + } +} + +func (riv *RoomIDValue) Equals(other *RoomIDValue) bool { + if riv == nil || other == nil { + return riv == other + } + return riv.Type == other.Type && + riv.RoomID == other.RoomID && + riv.EventID == other.EventID && + slices.Equal(riv.Via, other.Via) +} + +func (riv *RoomIDValue) Validate() error { + if riv == nil { + return fmt.Errorf("value is nil") + } + switch riv.Type { + case PrimitiveTypeRoomID: + if riv.EventID != "" { + return fmt.Errorf("event ID must be empty for room ID type") + } + case PrimitiveTypeEventID: + if !strings.HasPrefix(riv.EventID.String(), "$") { + return fmt.Errorf("event ID not valid: %q", riv.EventID) + } + default: + return fmt.Errorf("unexpected type %s for room/event ID value", riv.Type) + } + for _, via := range riv.Via { + if !id.ValidateServerName(via) { + return fmt.Errorf("invalid server name %q in vias", via) + } + } + sigil, localpart, serverName := id.ParseCommonIdentifier(riv.RoomID) + if sigil != '!' { + return fmt.Errorf("room ID does not start with !: %q", riv.RoomID) + } else if localpart == "" && serverName == "" { + return fmt.Errorf("room ID has empty localpart and server name: %q", riv.RoomID) + } else if serverName != "" && !id.ValidateServerName(serverName) { + return fmt.Errorf("invalid server name %q in room ID", serverName) + } + return nil +} + +func (riv *RoomIDValue) IsValid() bool { + return riv.Validate() == nil +} + +type RoomIDOrString string + +func (ros *RoomIDOrString) UnmarshalJSON(data []byte) error { + if len(data) == 0 { + return fmt.Errorf("empty data for room ID or string") + } + if data[0] == '"' { + var str string + if err := json.Unmarshal(data, &str); err != nil { + return err + } + *ros = RoomIDOrString(str) + return nil + } + var riv RoomIDValue + if err := json.Unmarshal(data, &riv); err != nil { + return err + } else if err = riv.Validate(); err != nil { + return err + } + *ros = RoomIDOrString(riv.String()) + return nil +} diff --git a/event/cmdschema/stringify.go b/event/cmdschema/stringify.go new file mode 100644 index 00000000..c5c57c53 --- /dev/null +++ b/event/cmdschema/stringify.go @@ -0,0 +1,122 @@ +// Copyright (c) 2026 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package cmdschema + +import ( + "encoding/json" + "strconv" + "strings" +) + +var quoteEscaper = strings.NewReplacer( + `"`, `\"`, + `\`, `\\`, +) + +const charsToQuote = ` \` + botArrayOpener + botArrayCloser + +func quoteString(val string) string { + if val == "" { + return `""` + } + val = quoteEscaper.Replace(val) + if strings.ContainsAny(val, charsToQuote) { + return `"` + val + `"` + } + return val +} + +func (ec *EventContent) StringifyArgs(args any) string { + var argMap map[string]any + switch typedArgs := args.(type) { + case json.RawMessage: + err := json.Unmarshal(typedArgs, &argMap) + if err != nil { + return "" + } + case map[string]any: + argMap = typedArgs + default: + if b, err := json.Marshal(args); err != nil { + return "" + } else if err = json.Unmarshal(b, &argMap); err != nil { + return "" + } + } + parts := make([]string, 0, len(ec.Parameters)) + for i, param := range ec.Parameters { + isLast := i == len(ec.Parameters)-1 + val := argMap[param.Key] + if val == nil { + val = param.DefaultValue + if val == nil && !param.Optional { + val = param.Schema.GetDefaultValue() + } + } + if val == nil { + continue + } + var stringified string + if param.Schema.SchemaType == SchemaTypeArray { + stringified = arrayArgumentToString(val, isLast) + } else { + stringified = singleArgumentToString(val) + } + if stringified != "" { + parts = append(parts, stringified) + } + } + return strings.Join(parts, " ") +} + +func arrayArgumentToString(val any, isLast bool) string { + valArr, ok := val.([]any) + if !ok { + return "" + } + parts := make([]string, 0, len(valArr)) + for _, elem := range valArr { + stringified := singleArgumentToString(elem) + if stringified != "" { + parts = append(parts, stringified) + } + } + joinedParts := strings.Join(parts, " ") + if isLast && len(parts) > 0 { + return joinedParts + } + return botArrayOpener + joinedParts + botArrayCloser +} + +func singleArgumentToString(val any) string { + switch typedVal := val.(type) { + case string: + return quoteString(typedVal) + case json.Number: + return typedVal.String() + case bool: + return strconv.FormatBool(typedVal) + case int: + return strconv.Itoa(typedVal) + case int64: + return strconv.FormatInt(typedVal, 10) + case float64: + return strconv.FormatInt(int64(typedVal), 10) + case map[string]any, json.RawMessage, RoomIDValue, *RoomIDValue: + normalized, err := NormalizeRoomIDValue(typedVal) + if err != nil { + return "" + } + uri := normalized.URI() + if uri == nil { + return "" + } + return quoteString(uri.String()) + default: + return "" + } +} diff --git a/event/cmdschema/testdata/commands.schema.json b/event/cmdschema/testdata/commands.schema.json new file mode 100644 index 00000000..e53382db --- /dev/null +++ b/event/cmdschema/testdata/commands.schema.json @@ -0,0 +1,281 @@ +{ + "$schema": "https://json-schema.org/draft/2020-12/schema#", + "$id": "commands.schema.json", + "title": "ParseInput test cases", + "description": "JSON schema for test case files containing command specifications and test cases", + "type": "object", + "required": [ + "spec", + "tests" + ], + "additionalProperties": false, + "properties": { + "spec": { + "title": "MSC4391 Command Description", + "description": "JSON schema defining the structure of a bot command event content", + "type": "object", + "required": [ + "command" + ], + "additionalProperties": false, + "properties": { + "command": { + "type": "string", + "description": "The command name that triggers this bot command" + }, + "aliases": { + "type": "array", + "description": "Alternative names/aliases for this command", + "items": { + "type": "string" + } + }, + "parameters": { + "type": "array", + "description": "List of parameters accepted by this command", + "items": { + "$ref": "#/$defs/Parameter" + } + }, + "description": { + "$ref": "#/$defs/ExtensibleTextContainer", + "description": "Human-readable description of the command" + }, + "fi.mau.tail_parameter": { + "type": "string", + "description": "The key of the parameter that accepts remaining arguments as tail text" + }, + "source": { + "type": "string", + "description": "The user ID of the bot that responds to this command" + } + } + }, + "tests": { + "type": "array", + "description": "Array of test cases for the command", + "items": { + "type": "object", + "description": "A single test case for command parsing", + "required": [ + "name", + "input" + ], + "additionalProperties": false, + "properties": { + "name": { + "type": "string", + "description": "The name of the test case" + }, + "input": { + "type": "string", + "description": "The command input string to parse" + }, + "output": { + "description": "The expected parsed parameter values, or null if the parsing is expected to fail", + "oneOf": [ + { + "type": "object", + "additionalProperties": true + }, + { + "type": "null" + } + ] + }, + "error": { + "type": "boolean", + "description": "Whether parsing should result in an error. May still produce output.", + "default": false + } + } + } + } + }, + "$defs": { + "ExtensibleTextContainer": { + "type": "object", + "description": "Container for text that can have multiple representations", + "required": [ + "m.text" + ], + "properties": { + "m.text": { + "type": "array", + "description": "Array of text representations in different formats", + "items": { + "$ref": "#/$defs/ExtensibleText" + } + } + } + }, + "ExtensibleText": { + "type": "object", + "description": "A text representation with a specific MIME type", + "required": [ + "body" + ], + "properties": { + "body": { + "type": "string", + "description": "The text content" + }, + "mimetype": { + "type": "string", + "description": "The MIME type of the text (e.g., text/plain, text/html)", + "default": "text/plain", + "examples": [ + "text/plain", + "text/html" + ] + } + } + }, + "Parameter": { + "type": "object", + "description": "A parameter definition for a command", + "required": [ + "key", + "schema" + ], + "additionalProperties": false, + "properties": { + "key": { + "type": "string", + "description": "The identifier for this parameter" + }, + "schema": { + "$ref": "#/$defs/ParameterSchema", + "description": "The schema defining the type and structure of this parameter" + }, + "optional": { + "type": "boolean", + "description": "Whether this parameter is optional", + "default": false + }, + "description": { + "$ref": "#/$defs/ExtensibleTextContainer", + "description": "Human-readable description of this parameter" + }, + "fi.mau.default_value": { + "description": "Default value for this parameter if not provided" + } + } + }, + "ParameterSchema": { + "type": "object", + "description": "Schema definition for a parameter value", + "required": [ + "schema_type" + ], + "additionalProperties": false, + "properties": { + "schema_type": { + "type": "string", + "enum": [ + "primitive", + "array", + "union", + "literal" + ], + "description": "The type of schema" + } + }, + "allOf": [ + { + "if": { + "properties": { + "schema_type": { + "const": "primitive" + } + } + }, + "then": { + "required": [ + "type" + ], + "properties": { + "type": { + "type": "string", + "enum": [ + "string", + "integer", + "boolean", + "server_name", + "user_id", + "room_id", + "room_alias", + "event_id" + ], + "description": "The primitive type (only for schema_type: primitive)" + } + } + } + }, + { + "if": { + "properties": { + "schema_type": { + "const": "array" + } + } + }, + "then": { + "required": [ + "items" + ], + "properties": { + "items": { + "$ref": "#/$defs/ParameterSchema", + "description": "The schema for array items (only for schema_type: array)" + } + } + } + }, + { + "if": { + "properties": { + "schema_type": { + "const": "union" + } + } + }, + "then": { + "required": [ + "variants" + ], + "properties": { + "variants": { + "type": "array", + "description": "The possible variants (only for schema_type: union)", + "items": { + "$ref": "#/$defs/ParameterSchema" + }, + "minItems": 1 + } + } + } + }, + { + "if": { + "properties": { + "schema_type": { + "const": "literal" + } + } + }, + "then": { + "required": [ + "value" + ], + "properties": { + "value": { + "description": "The literal value (only for schema_type: literal)" + } + } + } + } + ] + } + } +} diff --git a/event/cmdschema/testdata/commands/flags.json b/event/cmdschema/testdata/commands/flags.json new file mode 100644 index 00000000..6ce1f4da --- /dev/null +++ b/event/cmdschema/testdata/commands/flags.json @@ -0,0 +1,126 @@ +{ + "$schema": "../commands.schema.json#", + "spec": { + "command": "flag", + "source": "@testbot", + "parameters": [ + { + "key": "meow", + "schema": { + "schema_type": "primitive", + "type": "string" + } + }, + { + "key": "user", + "schema": { + "schema_type": "primitive", + "type": "user_id" + }, + "optional": true + }, + { + "key": "woof", + "schema": { + "schema_type": "primitive", + "type": "boolean" + }, + "optional": true, + "fi.mau.default_value": false + } + ], + "fi.mau.tail_parameter": "user" + }, + "tests": [ + { + "name": "no flags", + "input": "/flag mrrp", + "output": { + "meow": "mrrp", + "user": null + } + }, + { + "name": "no flags, has tail", + "input": "/flag mrrp @user:example.com", + "output": { + "meow": "mrrp", + "user": "@user:example.com" + } + }, + { + "name": "named flag at start", + "input": "/flag --woof=yes mrrp @user:example.com", + "output": { + "meow": "mrrp", + "user": "@user:example.com", + "woof": true + } + }, + { + "name": "boolean flag without value", + "input": "/flag --woof mrrp @user:example.com", + "output": { + "meow": "mrrp", + "user": "@user:example.com", + "woof": true + } + }, + { + "name": "user id flag without value", + "input": "/flag --user --woof mrrp", + "error": true, + "output": { + "meow": "mrrp", + "user": null, + "woof": true + } + }, + { + "name": "named flag in the middle", + "input": "/flag mrrp --woof=yes @user:example.com", + "output": { + "meow": "mrrp", + "user": "@user:example.com", + "woof": true + } + }, + { + "name": "named flag in the middle with different value", + "input": "/flag mrrp --woof=no @user:example.com", + "output": { + "meow": "mrrp", + "user": "@user:example.com", + "woof": false + } + }, + { + "name": "all variables named", + "input": "/flag --woof=no --meow=mrrp --user=@user:example.com", + "output": { + "meow": "mrrp", + "user": "@user:example.com", + "woof": false + } + }, + { + "name": "all variables named with quotes", + "input": "/flag --woof --meow=\"meow meow mrrp\" --user=\"@user:example.com\"", + "output": { + "meow": "meow meow mrrp", + "user": "@user:example.com", + "woof": true + } + }, + { + "name": "invalid value for named parameter", + "input": "/flag --user=meowings mrrp --woof", + "error": true, + "output": { + "meow": "mrrp", + "user": null, + "woof": true + } + } + ] +} diff --git a/event/cmdschema/testdata/commands/room_id_or_alias.json b/event/cmdschema/testdata/commands/room_id_or_alias.json new file mode 100644 index 00000000..1351c292 --- /dev/null +++ b/event/cmdschema/testdata/commands/room_id_or_alias.json @@ -0,0 +1,85 @@ +{ + "$schema": "../commands.schema.json#", + "spec": { + "command": "test room reference", + "source": "@testbot", + "parameters": [ + { + "key": "room", + "schema": { + "schema_type": "union", + "variants": [ + { + "schema_type": "primitive", + "type": "room_id" + }, + { + "schema_type": "primitive", + "type": "room_alias" + } + ] + } + } + ] + }, + "tests": [ + { + "name": "room alias", + "input": "/test room reference #test:matrix.org", + "output": { + "room": "#test:matrix.org" + } + }, + { + "name": "room id", + "input": "/test room reference !aiwVrNhPwbGBNjqlNu:matrix.org", + "output": { + "room": { + "type": "room_id", + "id": "!aiwVrNhPwbGBNjqlNu:matrix.org" + } + } + }, + { + "name": "room id matrix.to link", + "input": "/test room reference https://matrix.to/#/!aiwVrNhPwbGBNjqlNu:matrix.org?via=example.com", + "output": { + "room": { + "type": "room_id", + "id": "!aiwVrNhPwbGBNjqlNu:matrix.org", + "via": [ + "example.com" + ] + } + } + }, + { + "name": "room id matrix.to link with url encoding", + "input": "/test room reference https://matrix.to/#/!%23test%2Froom%0Aversion%20%3Cu%3E11%3C%2Fu%3E%2C%20with%20%40%F0%9F%90%88%EF%B8%8F%3Amaunium.net?via=maunium.net", + "broken": "Go's url.URL does url decoding on the fragment, which breaks splitting the path segments properly", + "output": { + "room": { + "type": "room_id", + "id": "!#test/room\nversion 11, with @🐈️:maunium.net", + "via": [ + "maunium.net" + ] + } + } + }, + { + "name": "room id matrix: URI", + "input": "/test room reference matrix:roomid/mauT12AzsoqxV7Abvy_ApA-HNPK1LcT4GbP70_AOPyQ?via=maunium.net&via=matrix.org", + "output": { + "room": { + "type": "room_id", + "id": "!mauT12AzsoqxV7Abvy_ApA-HNPK1LcT4GbP70_AOPyQ", + "via": [ + "maunium.net", + "matrix.org" + ] + } + } + } + ] +} diff --git a/event/cmdschema/testdata/commands/room_reference_list.json b/event/cmdschema/testdata/commands/room_reference_list.json new file mode 100644 index 00000000..aa266054 --- /dev/null +++ b/event/cmdschema/testdata/commands/room_reference_list.json @@ -0,0 +1,106 @@ +{ + "$schema": "../commands.schema.json#", + "spec": { + "command": "test room reference", + "source": "@testbot", + "parameters": [ + { + "key": "rooms", + "schema": { + "schema_type": "array", + "items": { + "schema_type": "union", + "variants": [ + { + "schema_type": "primitive", + "type": "room_id" + }, + { + "schema_type": "primitive", + "type": "room_alias" + } + ] + } + } + } + ] + }, + "tests": [ + { + "name": "room alias", + "input": "/test room reference #test:matrix.org", + "output": { + "rooms": [ + "#test:matrix.org" + ] + } + }, + { + "name": "room id", + "input": "/test room reference !aiwVrNhPwbGBNjqlNu:matrix.org", + "output": { + "rooms": [ + { + "type": "room_id", + "id": "!aiwVrNhPwbGBNjqlNu:matrix.org" + } + ] + } + }, + { + "name": "two room ids", + "input": "/test room reference !mauT12AzsoqxV7Abvy_ApA-HNPK1LcT4GbP70_AOPyQ !aiwVrNhPwbGBNjqlNu:matrix.org", + "output": { + "rooms": [ + { + "type": "room_id", + "id": "!mauT12AzsoqxV7Abvy_ApA-HNPK1LcT4GbP70_AOPyQ" + }, + { + "type": "room_id", + "id": "!aiwVrNhPwbGBNjqlNu:matrix.org" + } + ] + } + }, + { + "name": "room id matrix: URI", + "input": "/test room reference matrix:roomid/mauT12AzsoqxV7Abvy_ApA-HNPK1LcT4GbP70_AOPyQ?via=maunium.net&via=matrix.org", + "output": { + "rooms": [ + { + "type": "room_id", + "id": "!mauT12AzsoqxV7Abvy_ApA-HNPK1LcT4GbP70_AOPyQ", + "via": [ + "maunium.net", + "matrix.org" + ] + } + ] + } + }, + { + "name": "room id matrix: URI and matrix.to URL", + "input": "/test room reference https://matrix.to/#/!aiwVrNhPwbGBNjqlNu:matrix.org?via=example.com matrix:roomid/mauT12AzsoqxV7Abvy_ApA-HNPK1LcT4GbP70_AOPyQ?via=maunium.net&via=matrix.org", + "output": { + "rooms": [ + { + "type": "room_id", + "id": "!aiwVrNhPwbGBNjqlNu:matrix.org", + "via": [ + "example.com" + ] + }, + { + "type": "room_id", + "id": "!mauT12AzsoqxV7Abvy_ApA-HNPK1LcT4GbP70_AOPyQ", + "via": [ + "maunium.net", + "matrix.org" + ] + } + ] + } + } + ] +} diff --git a/event/cmdschema/testdata/commands/simple.json b/event/cmdschema/testdata/commands/simple.json new file mode 100644 index 00000000..94667323 --- /dev/null +++ b/event/cmdschema/testdata/commands/simple.json @@ -0,0 +1,46 @@ +{ + "$schema": "../commands.schema.json#", + "spec": { + "command": "test simple", + "source": "@testbot", + "parameters": [ + { + "key": "meow", + "schema": { + "schema_type": "primitive", + "type": "string" + } + } + ] + }, + "tests": [ + { + "name": "success", + "input": "/test simple mrrp", + "output": { + "meow": "mrrp" + } + }, + { + "name": "directed success", + "input": "/test simple@testbot mrrp", + "output": { + "meow": "mrrp" + } + }, + { + "name": "missing parameter", + "input": "/test simple", + "error": true, + "output": { + "meow": "" + } + }, + { + "name": "directed at another bot", + "input": "/test simple@anotherbot mrrp", + "error": false, + "output": null + } + ] +} diff --git a/event/cmdschema/testdata/commands/tail.json b/event/cmdschema/testdata/commands/tail.json new file mode 100644 index 00000000..9782f8ec --- /dev/null +++ b/event/cmdschema/testdata/commands/tail.json @@ -0,0 +1,60 @@ +{ + "$schema": "../commands.schema.json#", + "spec": { + "command": "tail", + "source": "@testbot", + "parameters": [ + { + "key": "meow", + "schema": { + "schema_type": "primitive", + "type": "string" + } + }, + { + "key": "reason", + "schema": { + "schema_type": "primitive", + "type": "string" + }, + "optional": true + }, + { + "key": "woof", + "schema": { + "schema_type": "primitive", + "type": "boolean" + }, + "optional": true + } + ], + "fi.mau.tail_parameter": "reason" + }, + "tests": [ + { + "name": "no tail or flag", + "input": "/tail mrrp", + "output": { + "meow": "mrrp", + "reason": "" + } + }, + { + "name": "tail, no flag", + "input": "/tail mrrp meow meow", + "output": { + "meow": "mrrp", + "reason": "meow meow" + } + }, + { + "name": "flag before tail", + "input": "/tail mrrp --woof meow meow", + "output": { + "meow": "mrrp", + "reason": "meow meow", + "woof": true + } + } + ] +} diff --git a/event/cmdschema/testdata/data.go b/event/cmdschema/testdata/data.go new file mode 100644 index 00000000..eceea3d2 --- /dev/null +++ b/event/cmdschema/testdata/data.go @@ -0,0 +1,14 @@ +// Copyright (c) 2026 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package testdata + +import ( + "embed" +) + +//go:embed * +var FS embed.FS diff --git a/event/cmdschema/testdata/parse_quote.json b/event/cmdschema/testdata/parse_quote.json new file mode 100644 index 00000000..8f52b7f5 --- /dev/null +++ b/event/cmdschema/testdata/parse_quote.json @@ -0,0 +1,30 @@ +[ + {"name": "empty string", "input": "", "output": ["", "", false]}, + {"name": "single word", "input": "meow", "output": ["meow", "", false]}, + {"name": "two words", "input": "meow woof", "output": ["meow", "woof", false]}, + {"name": "many words", "input": "meow meow mrrp", "output": ["meow", "meow mrrp", false]}, + {"name": "extra spaces", "input": "meow meow mrrp", "output": ["meow", "meow mrrp", false]}, + {"name": "trailing space", "input": "meow ", "output": ["meow", "", false]}, + {"name": "only spaces", "input": " ", "output": ["", "", false]}, + {"name": "leading spaces", "input": " meow woof", "output": ["", "meow woof", false]}, + {"name": "backslash at end unquoted", "input": "meow\\ woof", "output": ["meow\\", "woof", false]}, + {"name": "quoted word", "input": "\"meow\" meow mrrp", "output": ["meow", "meow mrrp", true]}, + {"name": "quoted words", "input": "\"meow meow\" mrrp", "output": ["meow meow", "mrrp", true]}, + {"name": "spaces in quotes", "input": "\" meow meow \" mrrp", "output": [" meow meow ", "mrrp", true]}, + {"name": "empty quoted string", "input": "\"\"", "output": ["", "", true]}, + {"name": "empty quoted with trailing", "input": "\"\" meow", "output": ["", "meow", true]}, + {"name": "quote no space before next", "input": "\"meow\"woof", "output": ["meow", "woof", true]}, + {"name": "just opening quote", "input": "\"", "output": ["", "", true]}, + {"name": "quote then space then text", "input": "\" meow", "output": [" meow", "", true]}, + {"name": "quotes after word", "input": "meow \" meow mrrp \"", "output": ["meow", "\" meow mrrp \"", false]}, + {"name": "escaped quote", "input": "\"meow\\\" meow\" mrrp", "output": ["meow\" meow", "mrrp", true]}, + {"name": "missing end quote", "input": "\"meow meow mrrp", "output": ["meow meow mrrp", "", true]}, + {"name": "missing end quote with escaped quote", "input": "\"meow\\\" meow mrrp", "output": ["meow\" meow mrrp", "", true]}, + {"name": "quote in the middle", "input": "me\"ow meow mrrp", "output": ["me\"ow", "meow mrrp", false]}, + {"name": "backslash in the middle", "input": "me\\ow meow mrrp", "output": ["me\\ow", "meow mrrp", false]}, + {"name": "other escaped character", "input": "\"m\\eow\" meow mrrp", "output": ["meow", "meow mrrp", true]}, + {"name": "escaped backslashes", "input": "\"m\\\\e\\\"ow\\\\\" meow mrrp", "output": ["m\\e\"ow\\", "meow mrrp", true]}, + {"name": "just quotes", "input": "\"\\\"\\\"\\\\\\\"\" meow", "output": ["\"\"\\\"", "meow", true]}, + {"name": "escape at eof", "input": "\"meow\\", "output": ["meow", "", true]}, + {"name": "escaped backslash at eof", "input": "\"meow\\\\", "output": ["meow\\", "", true]} +] diff --git a/event/cmdschema/testdata/parse_quote.schema.json b/event/cmdschema/testdata/parse_quote.schema.json new file mode 100644 index 00000000..9f249116 --- /dev/null +++ b/event/cmdschema/testdata/parse_quote.schema.json @@ -0,0 +1,46 @@ +{ + "$schema": "https://json-schema.org/draft/2020-12/schema#", + "$id": "parse_quote.schema.json", + "title": "parseQuote test cases", + "description": "Test cases for the parseQuoted function", + "type": "array", + "items": { + "type": "object", + "required": [ + "name", + "input", + "output" + ], + "properties": { + "name": { + "type": "string", + "description": "Name of the test case" + }, + "input": { + "type": "string", + "description": "Input string to be parsed" + }, + "output": { + "type": "array", + "description": "Expected output of parsing: [first word, remaining text, was quoted]", + "minItems": 3, + "maxItems": 3, + "prefixItems": [ + { + "type": "string", + "description": "First parsed word" + }, + { + "type": "string", + "description": "Remaining text after the first word" + }, + { + "type": "boolean", + "description": "Whether the first word was quoted" + } + ] + } + }, + "additionalProperties": false + } +} diff --git a/event/content.go b/event/content.go index bdb3eeb8..814aeec4 100644 --- a/event/content.go +++ b/event/content.go @@ -18,6 +18,7 @@ import ( // This is used by Content.ParseRaw() for creating the correct type of struct. var TypeMap = map[Type]reflect.Type{ StateMember: reflect.TypeOf(MemberEventContent{}), + StateThirdPartyInvite: reflect.TypeOf(ThirdPartyInviteEventContent{}), StatePowerLevels: reflect.TypeOf(PowerLevelsEventContent{}), StateCanonicalAlias: reflect.TypeOf(CanonicalAliasEventContent{}), StateRoomName: reflect.TypeOf(RoomNameEventContent{}), @@ -38,7 +39,20 @@ var TypeMap = map[Type]reflect.Type{ StateHalfShotBridge: reflect.TypeOf(BridgeEventContent{}), StateSpaceParent: reflect.TypeOf(SpaceParentEventContent{}), StateSpaceChild: reflect.TypeOf(SpaceChildEventContent{}), - StateInsertionMarker: reflect.TypeOf(InsertionMarkerContent{}), + + StateRoomPolicy: reflect.TypeOf(RoomPolicyEventContent{}), + StateUnstableRoomPolicy: reflect.TypeOf(RoomPolicyEventContent{}), + + StateLegacyPolicyRoom: reflect.TypeOf(ModPolicyContent{}), + StateLegacyPolicyServer: reflect.TypeOf(ModPolicyContent{}), + StateLegacyPolicyUser: reflect.TypeOf(ModPolicyContent{}), + StateUnstablePolicyRoom: reflect.TypeOf(ModPolicyContent{}), + StateUnstablePolicyServer: reflect.TypeOf(ModPolicyContent{}), + StateUnstablePolicyUser: reflect.TypeOf(ModPolicyContent{}), + + StateElementFunctionalMembers: reflect.TypeOf(ElementFunctionalMembersContent{}), + StateBeeperRoomFeatures: reflect.TypeOf(RoomFeatures{}), + StateBeeperDisappearingTimer: reflect.TypeOf(BeeperDisappearingTimer{}), EventMessage: reflect.TypeOf(MessageEventContent{}), EventSticker: reflect.TypeOf(MessageEventContent{}), @@ -46,16 +60,27 @@ var TypeMap = map[Type]reflect.Type{ EventRedaction: reflect.TypeOf(RedactionEventContent{}), EventReaction: reflect.TypeOf(ReactionEventContent{}), - BeeperMessageStatus: reflect.TypeOf(BeeperMessageStatusEventContent{}), + EventUnstablePollStart: reflect.TypeOf(PollStartEventContent{}), + EventUnstablePollResponse: reflect.TypeOf(PollResponseEventContent{}), + + BeeperMessageStatus: reflect.TypeOf(BeeperMessageStatusEventContent{}), + BeeperTranscription: reflect.TypeOf(BeeperTranscriptionEventContent{}), + BeeperDeleteChat: reflect.TypeOf(BeeperChatDeleteEventContent{}), + BeeperAcceptMessageRequest: reflect.TypeOf(BeeperAcceptMessageRequestEventContent{}), + BeeperSendState: reflect.TypeOf(BeeperSendStateEventContent{}), AccountDataRoomTags: reflect.TypeOf(TagEventContent{}), AccountDataDirectChats: reflect.TypeOf(DirectChatsEventContent{}), AccountDataFullyRead: reflect.TypeOf(FullyReadEventContent{}), AccountDataIgnoredUserList: reflect.TypeOf(IgnoredUserListEventContent{}), + AccountDataMarkedUnread: reflect.TypeOf(MarkedUnreadEventContent{}), + AccountDataBeeperMute: reflect.TypeOf(BeeperMuteEventContent{}), - EphemeralEventTyping: reflect.TypeOf(TypingEventContent{}), - EphemeralEventReceipt: reflect.TypeOf(ReceiptEventContent{}), - EphemeralEventPresence: reflect.TypeOf(PresenceEventContent{}), + EphemeralEventTyping: reflect.TypeOf(TypingEventContent{}), + EphemeralEventReceipt: reflect.TypeOf(ReceiptEventContent{}), + EphemeralEventPresence: reflect.TypeOf(PresenceEventContent{}), + EphemeralEventEncrypted: reflect.TypeOf(EncryptedEventContent{}), + BeeperEphemeralEventAIStream: reflect.TypeOf(BeeperAIStreamEventContent{}), InRoomVerificationReady: reflect.TypeOf(VerificationReadyEventContent{}), InRoomVerificationStart: reflect.TypeOf(VerificationStartEventContent{}), @@ -107,7 +132,7 @@ var TypeMap = map[Type]reflect.Type{ // When being marshaled into JSON, the data in Parsed will be marshaled first and then recursively merged // with the data in Raw. Values in Raw are preferred, but nested objects will be recursed into before merging, // rather than overriding the whole object with the one in Raw). -// If one of them is nil, the only the other is used. If both (Parsed and Raw) are nil, VeryRaw is used instead. +// If one of them is nil, then only the other is used. If both (Parsed and Raw) are nil, VeryRaw is used instead. type Content struct { VeryRaw json.RawMessage Raw map[string]interface{} @@ -174,6 +199,13 @@ func IsUnsupportedContentType(err error) bool { var ErrContentAlreadyParsed = errors.New("content is already parsed") var ErrUnsupportedContentType = errors.New("unsupported event type") +func (content *Content) GetRaw() map[string]interface{} { + if content.Raw == nil { + content.Raw = make(map[string]interface{}) + } + return content.Raw +} + func (content *Content) ParseRaw(evtType Type) error { if content.Parsed != nil { return ErrContentAlreadyParsed @@ -211,6 +243,7 @@ func init() { gob.Register(&BridgeEventContent{}) gob.Register(&SpaceChildEventContent{}) gob.Register(&SpaceParentEventContent{}) + gob.Register(&ElementFunctionalMembersContent{}) gob.Register(&RoomNameEventContent{}) gob.Register(&RoomAvatarEventContent{}) gob.Register(&TopicEventContent{}) @@ -238,6 +271,15 @@ func init() { gob.Register(&RoomKeyWithheldEventContent{}) } +func CastOrDefault[T any](content *Content) *T { + casted, ok := content.Parsed.(*T) + if ok { + return casted + } + casted2, _ := content.Parsed.(T) + return &casted2 +} + // Helper cast functions below func (content *Content) AsMember() *MemberEventContent { @@ -352,6 +394,13 @@ func (content *Content) AsSpaceParent() *SpaceParentEventContent { } return casted } +func (content *Content) AsElementFunctionalMembers() *ElementFunctionalMembersContent { + casted, ok := content.Parsed.(*ElementFunctionalMembersContent) + if !ok { + return &ElementFunctionalMembersContent{} + } + return casted +} func (content *Content) AsMessage() *MessageEventContent { casted, ok := content.Parsed.(*MessageEventContent) if !ok { @@ -408,6 +457,13 @@ func (content *Content) AsIgnoredUserList() *IgnoredUserListEventContent { } return casted } +func (content *Content) AsMarkedUnread() *MarkedUnreadEventContent { + casted, ok := content.Parsed.(*MarkedUnreadEventContent) + if !ok { + return &MarkedUnreadEventContent{} + } + return casted +} func (content *Content) AsTyping() *TypingEventContent { casted, ok := content.Parsed.(*TypingEventContent) if !ok { diff --git a/event/delayed.go b/event/delayed.go new file mode 100644 index 00000000..fefb62af --- /dev/null +++ b/event/delayed.go @@ -0,0 +1,70 @@ +package event + +import ( + "encoding/json" + + "go.mau.fi/util/jsontime" + + "maunium.net/go/mautrix/id" +) + +type ScheduledDelayedEvent struct { + DelayID id.DelayID `json:"delay_id"` + RoomID id.RoomID `json:"room_id"` + Type Type `json:"type"` + StateKey *string `json:"state_key,omitempty"` + Delay int64 `json:"delay"` + RunningSince jsontime.UnixMilli `json:"running_since"` + Content Content `json:"content"` +} + +func (e ScheduledDelayedEvent) AsEvent(eventID id.EventID, ts jsontime.UnixMilli) (*Event, error) { + evt := &Event{ + ID: eventID, + RoomID: e.RoomID, + Type: e.Type, + StateKey: e.StateKey, + Content: e.Content, + Timestamp: ts.UnixMilli(), + } + return evt, evt.Content.ParseRaw(evt.Type) +} + +type FinalisedDelayedEvent struct { + DelayedEvent *ScheduledDelayedEvent `json:"scheduled_event"` + Outcome DelayOutcome `json:"outcome"` + Reason DelayReason `json:"reason"` + Error json.RawMessage `json:"error,omitempty"` + EventID id.EventID `json:"event_id,omitempty"` + Timestamp jsontime.UnixMilli `json:"origin_server_ts"` +} + +type DelayStatus string + +var ( + DelayStatusScheduled DelayStatus = "scheduled" + DelayStatusFinalised DelayStatus = "finalised" +) + +type DelayAction string + +var ( + DelayActionSend DelayAction = "send" + DelayActionCancel DelayAction = "cancel" + DelayActionRestart DelayAction = "restart" +) + +type DelayOutcome string + +var ( + DelayOutcomeSend DelayOutcome = "send" + DelayOutcomeCancel DelayOutcome = "cancel" +) + +type DelayReason string + +var ( + DelayReasonAction DelayReason = "action" + DelayReasonError DelayReason = "error" + DelayReasonDelay DelayReason = "delay" +) diff --git a/event/encryption.go b/event/encryption.go index cf9c2814..c60cb91a 100644 --- a/event/encryption.go +++ b/event/encryption.go @@ -63,7 +63,7 @@ func (content *EncryptedEventContent) UnmarshalJSON(data []byte) error { return json.Unmarshal(content.Ciphertext, &content.OlmCiphertext) case id.AlgorithmMegolmV1: if len(content.Ciphertext) == 0 || content.Ciphertext[0] != '"' || content.Ciphertext[len(content.Ciphertext)-1] != '"' { - return id.InputNotJSONString + return fmt.Errorf("ciphertext %w", id.ErrInputNotJSONString) } content.MegolmCiphertext = content.Ciphertext[1 : len(content.Ciphertext)-1] } @@ -132,8 +132,9 @@ type RoomKeyRequestEventContent struct { type RequestedKeyInfo struct { Algorithm id.Algorithm `json:"algorithm"` RoomID id.RoomID `json:"room_id"` - SenderKey id.SenderKey `json:"sender_key"` SessionID id.SessionID `json:"session_id"` + // Deprecated: Matrix v1.3 + SenderKey id.SenderKey `json:"sender_key"` } type RoomKeyWithheldCode string diff --git a/event/events.go b/event/events.go index f7b4d4d6..72c1e161 100644 --- a/event/events.go +++ b/event/events.go @@ -118,6 +118,9 @@ type MautrixInfo struct { DecryptionDuration time.Duration CheckpointSent bool + // When using MSC4222 and the state_after field, this field is set + // for timeline events to indicate they shouldn't update room state. + IgnoreState bool } func (evt *Event) GetStateKey() string { @@ -127,27 +130,29 @@ func (evt *Event) GetStateKey() string { return "" } -type StrippedState struct { - Content Content `json:"content"` - Type Type `json:"type"` - StateKey string `json:"state_key"` - Sender id.UserID `json:"sender"` -} - type Unsigned struct { - PrevContent *Content `json:"prev_content,omitempty"` - PrevSender id.UserID `json:"prev_sender,omitempty"` - ReplacesState id.EventID `json:"replaces_state,omitempty"` - Age int64 `json:"age,omitempty"` - TransactionID string `json:"transaction_id,omitempty"` - Relations *Relations `json:"m.relations,omitempty"` - RedactedBecause *Event `json:"redacted_because,omitempty"` - InviteRoomState []StrippedState `json:"invite_room_state,omitempty"` + PrevContent *Content `json:"prev_content,omitempty"` + PrevSender id.UserID `json:"prev_sender,omitempty"` + Membership Membership `json:"membership,omitempty"` + ReplacesState id.EventID `json:"replaces_state,omitempty"` + Age int64 `json:"age,omitempty"` + TransactionID string `json:"transaction_id,omitempty"` + Relations *Relations `json:"m.relations,omitempty"` + RedactedBecause *Event `json:"redacted_because,omitempty"` + InviteRoomState []*Event `json:"invite_room_state,omitempty"` - BeeperHSOrder int64 `json:"com.beeper.hs.order,omitempty"` + BeeperHSOrder int64 `json:"com.beeper.hs.order,omitempty"` + BeeperHSSuborder int16 `json:"com.beeper.hs.suborder,omitempty"` + BeeperHSOrderString *BeeperEncodedOrder `json:"com.beeper.hs.order_string,omitempty"` + BeeperFromBackup bool `json:"com.beeper.from_backup,omitempty"` + + ElementSoftFailed bool `json:"io.element.synapse.soft_failed,omitempty"` + ElementPolicyServerSpammy bool `json:"io.element.synapse.policy_server_spammy,omitempty"` } func (us *Unsigned) IsEmpty() bool { - return us.PrevContent == nil && us.PrevSender == "" && us.ReplacesState == "" && us.Age == 0 && - us.TransactionID == "" && us.RedactedBecause == nil && us.InviteRoomState == nil && us.Relations == nil + return us.PrevContent == nil && us.PrevSender == "" && us.ReplacesState == "" && us.Age == 0 && us.Membership == "" && + us.TransactionID == "" && us.RedactedBecause == nil && us.InviteRoomState == nil && us.Relations == nil && + us.BeeperHSOrder == 0 && us.BeeperHSSuborder == 0 && us.BeeperHSOrderString.IsZero() && + !us.ElementSoftFailed } diff --git a/event/member.go b/event/member.go index ebafdcb7..9956a36b 100644 --- a/event/member.go +++ b/event/member.go @@ -7,8 +7,6 @@ package event import ( - "encoding/json" - "maunium.net/go/mautrix/id" ) @@ -35,19 +33,37 @@ const ( // MemberEventContent represents the content of a m.room.member state event. // https://spec.matrix.org/v1.2/client-server-api/#mroommember type MemberEventContent struct { - Membership Membership `json:"membership"` - AvatarURL id.ContentURIString `json:"avatar_url,omitempty"` - Displayname string `json:"displayname,omitempty"` - IsDirect bool `json:"is_direct,omitempty"` - ThirdPartyInvite *ThirdPartyInvite `json:"third_party_invite,omitempty"` - Reason string `json:"reason,omitempty"` + Membership Membership `json:"membership"` + AvatarURL id.ContentURIString `json:"avatar_url,omitempty"` + Displayname string `json:"displayname,omitempty"` + IsDirect bool `json:"is_direct,omitempty"` + ThirdPartyInvite *ThirdPartyInvite `json:"third_party_invite,omitempty"` + Reason string `json:"reason,omitempty"` + JoinAuthorisedViaUsersServer id.UserID `json:"join_authorised_via_users_server,omitempty"` + MSC3414File *EncryptedFileInfo `json:"org.matrix.msc3414.file,omitempty"` + + MSC4293RedactEvents bool `json:"org.matrix.msc4293.redact_events,omitempty"` +} + +type SignedThirdPartyInvite struct { + Token string `json:"token"` + Signatures map[string]map[id.KeyID]string `json:"signatures,omitempty"` + MXID string `json:"mxid"` } type ThirdPartyInvite struct { - DisplayName string `json:"display_name"` - Signed struct { - Token string `json:"token"` - Signatures json.RawMessage `json:"signatures"` - MXID string `json:"mxid"` - } + DisplayName string `json:"display_name"` + Signed SignedThirdPartyInvite `json:"signed"` +} + +type ThirdPartyInviteEventContent struct { + DisplayName string `json:"display_name"` + KeyValidityURL string `json:"key_validity_url"` + PublicKey id.Ed25519 `json:"public_key"` + PublicKeys []ThirdPartyInviteKey `json:"public_keys,omitempty"` +} + +type ThirdPartyInviteKey struct { + KeyValidityURL string `json:"key_validity_url,omitempty"` + PublicKey id.Ed25519 `json:"public_key"` } diff --git a/event/message.go b/event/message.go index d8b27c3d..3fb3dc82 100644 --- a/event/message.go +++ b/event/message.go @@ -8,11 +8,11 @@ package event import ( "encoding/json" + "html" + "slices" "strconv" "strings" - "golang.org/x/net/html" - "maunium.net/go/mautrix/crypto/attachment" "maunium.net/go/mautrix/id" ) @@ -21,6 +21,24 @@ import ( // https://spec.matrix.org/v1.2/client-server-api/#mroommessage-msgtypes type MessageType string +func (mt MessageType) IsText() bool { + switch mt { + case MsgText, MsgNotice, MsgEmote: + return true + default: + return false + } +} + +func (mt MessageType) IsMedia() bool { + switch mt { + case MsgImage, MsgVideo, MsgAudio, MsgFile, CapMsgSticker: + return true + default: + return false + } +} + // Msgtypes const ( MsgText MessageType = "m.text" @@ -112,12 +130,68 @@ type MessageEventContent struct { replyFallbackRemoved bool - MessageSendRetry *BeeperRetryMetadata `json:"com.beeper.message_send_retry,omitempty"` - BeeperGalleryImages []*MessageEventContent `json:"com.beeper.gallery.images,omitempty"` - BeeperGalleryCaption string `json:"com.beeper.gallery.caption,omitempty"` - BeeperGalleryCaptionHTML string `json:"com.beeper.gallery.caption_html,omitempty"` + MessageSendRetry *BeeperRetryMetadata `json:"com.beeper.message_send_retry,omitempty"` + BeeperGalleryImages []*MessageEventContent `json:"com.beeper.gallery.images,omitempty"` + BeeperGalleryCaption string `json:"com.beeper.gallery.caption,omitempty"` + BeeperGalleryCaptionHTML string `json:"com.beeper.gallery.caption_html,omitempty"` + BeeperPerMessageProfile *BeeperPerMessageProfile `json:"com.beeper.per_message_profile,omitempty"` + BeeperActionMessage *BeeperActionMessage `json:"com.beeper.action_message,omitempty"` BeeperLinkPreviews []*BeeperLinkPreview `json:"com.beeper.linkpreviews,omitempty"` + + BeeperDisappearingTimer *BeeperDisappearingTimer `json:"com.beeper.disappearing_timer,omitempty"` + + MSC1767Audio *MSC1767Audio `json:"org.matrix.msc1767.audio,omitempty"` + MSC3245Voice *MSC3245Voice `json:"org.matrix.msc3245.voice,omitempty"` + + MSC4391BotCommand *MSC4391BotCommandInput `json:"org.matrix.msc4391.command,omitempty"` +} + +func (content *MessageEventContent) GetCapMsgType() CapabilityMsgType { + switch content.MsgType { + case CapMsgSticker: + return CapMsgSticker + case "": + if content.URL != "" || content.File != nil { + return CapMsgSticker + } + case MsgImage: + return MsgImage + case MsgAudio: + if content.MSC3245Voice != nil { + return CapMsgVoice + } + return MsgAudio + case MsgVideo: + if content.Info != nil && content.Info.MauGIF { + return CapMsgGIF + } + return MsgVideo + case MsgFile: + return MsgFile + } + return "" +} + +func (content *MessageEventContent) GetFileName() string { + if content.FileName != "" { + return content.FileName + } + return content.Body +} + +func (content *MessageEventContent) GetCaption() string { + if content.FileName != "" && content.Body != "" && content.Body != content.FileName { + return content.Body + } + return "" +} + +func (content *MessageEventContent) GetFormattedCaption() string { + if content.Format == FormatHTML && content.FormattedBody != "" { + return content.FormattedBody + } + return "" } func (content *MessageEventContent) GetRelatesTo() *RelatesTo { @@ -141,6 +215,7 @@ func (content *MessageEventContent) SetEdit(original id.EventID) { content.RelatesTo = (&RelatesTo{}).SetReplace(original) if content.MsgType == MsgText || content.MsgType == MsgNotice { content.Body = "* " + content.Body + content.Mentions = &Mentions{} if content.Format == FormatHTML && len(content.FormattedBody) > 0 { content.FormattedBody = "* " + content.FormattedBody } @@ -191,24 +266,56 @@ type Mentions struct { Room bool `json:"room,omitempty"` } +func (m *Mentions) Add(userID id.UserID) { + if userID != "" && !slices.Contains(m.UserIDs, userID) { + m.UserIDs = append(m.UserIDs, userID) + } +} + +func (m *Mentions) Has(userID id.UserID) bool { + return m != nil && slices.Contains(m.UserIDs, userID) +} + +func (m *Mentions) Merge(other *Mentions) *Mentions { + if m == nil { + return other + } else if other == nil { + return m + } + return &Mentions{ + UserIDs: slices.Concat(m.UserIDs, other.UserIDs), + Room: m.Room || other.Room, + } +} + +type MSC4391BotCommandInputCustom[T any] struct { + Command string `json:"command"` + Arguments T `json:"arguments,omitempty"` +} + +type MSC4391BotCommandInput = MSC4391BotCommandInputCustom[json.RawMessage] + type EncryptedFileInfo struct { attachment.EncryptedFile URL id.ContentURIString `json:"url"` } type FileInfo struct { - MimeType string `json:"mimetype,omitempty"` - ThumbnailInfo *FileInfo `json:"thumbnail_info,omitempty"` - ThumbnailURL id.ContentURIString `json:"thumbnail_url,omitempty"` - ThumbnailFile *EncryptedFileInfo `json:"thumbnail_file,omitempty"` + MimeType string + ThumbnailInfo *FileInfo + ThumbnailURL id.ContentURIString + ThumbnailFile *EncryptedFileInfo - Blurhash string `json:"blurhash,omitempty"` - AnoaBlurhash string `json:"xyz.amorgan.blurhash,omitempty"` + Blurhash string + AnoaBlurhash string - Width int `json:"-"` - Height int `json:"-"` - Duration int `json:"-"` - Size int `json:"-"` + MauGIF bool + IsAnimated bool + + Width int + Height int + Duration int + Size int } type serializableFileInfo struct { @@ -220,6 +327,9 @@ type serializableFileInfo struct { Blurhash string `json:"blurhash,omitempty"` AnoaBlurhash string `json:"xyz.amorgan.blurhash,omitempty"` + MauGIF bool `json:"fi.mau.gif,omitempty"` + IsAnimated bool `json:"is_animated,omitempty"` + Width json.Number `json:"w,omitempty"` Height json.Number `json:"h,omitempty"` Duration json.Number `json:"duration,omitempty"` @@ -236,6 +346,9 @@ func (sfi *serializableFileInfo) CopyFrom(fileInfo *FileInfo) *serializableFileI ThumbnailInfo: (&serializableFileInfo{}).CopyFrom(fileInfo.ThumbnailInfo), ThumbnailFile: fileInfo.ThumbnailFile, + MauGIF: fileInfo.MauGIF, + IsAnimated: fileInfo.IsAnimated, + Blurhash: fileInfo.Blurhash, AnoaBlurhash: fileInfo.AnoaBlurhash, } @@ -264,6 +377,8 @@ func (sfi *serializableFileInfo) CopyTo(fileInfo *FileInfo) { MimeType: sfi.MimeType, ThumbnailURL: sfi.ThumbnailURL, ThumbnailFile: sfi.ThumbnailFile, + MauGIF: sfi.MauGIF, + IsAnimated: sfi.IsAnimated, Blurhash: sfi.Blurhash, AnoaBlurhash: sfi.AnoaBlurhash, } diff --git a/event/message_test.go b/event/message_test.go index 562a6622..c721df35 100644 --- a/event/message_test.go +++ b/event/message_test.go @@ -33,7 +33,7 @@ const invalidMessageEvent = `{ func TestMessageEventContent__ParseInvalid(t *testing.T) { var evt *event.Event err := json.Unmarshal([]byte(invalidMessageEvent), &evt) - assert.Nil(t, err) + assert.NoError(t, err) assert.Equal(t, id.UserID("@tulir:maunium.net"), evt.Sender) assert.Equal(t, event.EventMessage, evt.Type) @@ -42,7 +42,7 @@ func TestMessageEventContent__ParseInvalid(t *testing.T) { assert.Equal(t, id.RoomID("!bar"), evt.RoomID) err = evt.Content.ParseRaw(evt.Type) - assert.NotNil(t, err) + assert.Error(t, err) } const messageEvent = `{ @@ -68,7 +68,7 @@ const messageEvent = `{ func TestMessageEventContent__ParseEdit(t *testing.T) { var evt *event.Event err := json.Unmarshal([]byte(messageEvent), &evt) - assert.Nil(t, err) + assert.NoError(t, err) assert.Equal(t, id.UserID("@tulir:maunium.net"), evt.Sender) assert.Equal(t, event.EventMessage, evt.Type) @@ -110,7 +110,7 @@ const imageMessageEvent = `{ func TestMessageEventContent__ParseMedia(t *testing.T) { var evt *event.Event err := json.Unmarshal([]byte(imageMessageEvent), &evt) - assert.Nil(t, err) + assert.NoError(t, err) assert.Equal(t, id.UserID("@tulir:maunium.net"), evt.Sender) assert.Equal(t, event.EventMessage, evt.Type) @@ -125,7 +125,7 @@ func TestMessageEventContent__ParseMedia(t *testing.T) { content := evt.Content.Parsed.(*event.MessageEventContent) assert.Equal(t, event.MsgImage, content.MsgType) parsedURL, err := content.URL.Parse() - assert.Nil(t, err) + assert.NoError(t, err) assert.Equal(t, id.ContentURI{Homeserver: "example.com", FileID: "image"}, parsedURL) assert.Nil(t, content.NewContent) assert.Equal(t, "image/png", content.GetInfo().MimeType) @@ -145,7 +145,7 @@ const expectedMarshalResult = `{"msgtype":"m.text","body":"test"}` func TestMessageEventContent__Marshal(t *testing.T) { data, err := json.Marshal(parsedMessage) - assert.Nil(t, err) + assert.NoError(t, err) assert.Equal(t, expectedMarshalResult, string(data)) } @@ -163,6 +163,6 @@ const expectedCustomMarshalResult = `{"body":"test","msgtype":"m.text","net.maun func TestMessageEventContent__Marshal_Custom(t *testing.T) { data, err := json.Marshal(customParsedMessage) - assert.Nil(t, err) + assert.NoError(t, err) assert.Equal(t, expectedCustomMarshalResult, string(data)) } diff --git a/event/poll.go b/event/poll.go new file mode 100644 index 00000000..9082f65e --- /dev/null +++ b/event/poll.go @@ -0,0 +1,64 @@ +// Copyright (c) 2024 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package event + +type PollResponseEventContent struct { + RelatesTo RelatesTo `json:"m.relates_to"` + Response struct { + Answers []string `json:"answers"` + } `json:"org.matrix.msc3381.poll.response"` +} + +func (content *PollResponseEventContent) GetRelatesTo() *RelatesTo { + return &content.RelatesTo +} + +func (content *PollResponseEventContent) OptionalGetRelatesTo() *RelatesTo { + if content.RelatesTo.Type == "" { + return nil + } + return &content.RelatesTo +} + +func (content *PollResponseEventContent) SetRelatesTo(rel *RelatesTo) { + content.RelatesTo = *rel +} + +type MSC1767Message struct { + Text string `json:"org.matrix.msc1767.text,omitempty"` + HTML string `json:"org.matrix.msc1767.html,omitempty"` + Message []ExtensibleText `json:"org.matrix.msc1767.message,omitempty"` +} + +type PollStartEventContent struct { + RelatesTo *RelatesTo `json:"m.relates_to,omitempty"` + Mentions *Mentions `json:"m.mentions,omitempty"` + PollStart struct { + Kind string `json:"kind"` + MaxSelections int `json:"max_selections"` + Question MSC1767Message `json:"question"` + Answers []struct { + ID string `json:"id"` + MSC1767Message + } `json:"answers"` + } `json:"org.matrix.msc3381.poll.start"` +} + +func (content *PollStartEventContent) GetRelatesTo() *RelatesTo { + if content.RelatesTo == nil { + content.RelatesTo = &RelatesTo{} + } + return content.RelatesTo +} + +func (content *PollStartEventContent) OptionalGetRelatesTo() *RelatesTo { + return content.RelatesTo +} + +func (content *PollStartEventContent) SetRelatesTo(rel *RelatesTo) { + content.RelatesTo = rel +} diff --git a/event/powerlevels.go b/event/powerlevels.go index 91d56611..668eb6d3 100644 --- a/event/powerlevels.go +++ b/event/powerlevels.go @@ -7,8 +7,13 @@ package event import ( + "math" + "slices" "sync" + "go.mau.fi/util/ptr" + "golang.org/x/exp/maps" + "maunium.net/go/mautrix/id" ) @@ -23,6 +28,9 @@ type PowerLevelsEventContent struct { Events map[string]int `json:"events,omitempty"` EventsDefault int `json:"events_default,omitempty"` + beeperEphemeralLock sync.RWMutex + BeeperEphemeral map[string]int `json:"com.beeper.ephemeral,omitempty"` + Notifications *NotificationPowerLevels `json:"notifications,omitempty"` StateDefaultPtr *int `json:"state_default,omitempty"` @@ -31,25 +39,12 @@ type PowerLevelsEventContent struct { KickPtr *int `json:"kick,omitempty"` BanPtr *int `json:"ban,omitempty"` RedactPtr *int `json:"redact,omitempty"` -} -func copyPtr(ptr *int) *int { - if ptr == nil { - return nil - } - val := *ptr - return &val -} + BeeperEphemeralDefaultPtr *int `json:"com.beeper.ephemeral_default,omitempty"` -func copyMap[Key comparable](m map[Key]int) map[Key]int { - if m == nil { - return nil - } - copied := make(map[Key]int, len(m)) - for k, v := range m { - copied[k] = v - } - return copied + // This is not a part of power levels, it's added by mautrix-go internally in certain places + // in order to detect creator power accurately. + CreateEvent *Event `json:"-"` } func (pl *PowerLevelsEventContent) Clone() *PowerLevelsEventContent { @@ -57,18 +52,23 @@ func (pl *PowerLevelsEventContent) Clone() *PowerLevelsEventContent { return nil } return &PowerLevelsEventContent{ - Users: copyMap(pl.Users), + Users: maps.Clone(pl.Users), UsersDefault: pl.UsersDefault, - Events: copyMap(pl.Events), + Events: maps.Clone(pl.Events), EventsDefault: pl.EventsDefault, - StateDefaultPtr: copyPtr(pl.StateDefaultPtr), + BeeperEphemeral: maps.Clone(pl.BeeperEphemeral), + StateDefaultPtr: ptr.Clone(pl.StateDefaultPtr), Notifications: pl.Notifications.Clone(), - InvitePtr: copyPtr(pl.InvitePtr), - KickPtr: copyPtr(pl.KickPtr), - BanPtr: copyPtr(pl.BanPtr), - RedactPtr: copyPtr(pl.RedactPtr), + InvitePtr: ptr.Clone(pl.InvitePtr), + KickPtr: ptr.Clone(pl.KickPtr), + BanPtr: ptr.Clone(pl.BanPtr), + RedactPtr: ptr.Clone(pl.RedactPtr), + + BeeperEphemeralDefaultPtr: ptr.Clone(pl.BeeperEphemeralDefaultPtr), + + CreateEvent: pl.CreateEvent, } } @@ -81,7 +81,7 @@ func (npl *NotificationPowerLevels) Clone() *NotificationPowerLevels { return nil } return &NotificationPowerLevels{ - RoomPtr: copyPtr(npl.RoomPtr), + RoomPtr: ptr.Clone(npl.RoomPtr), } } @@ -96,7 +96,7 @@ func (pl *PowerLevelsEventContent) Invite() int { if pl.InvitePtr != nil { return *pl.InvitePtr } - return 50 + return 0 } func (pl *PowerLevelsEventContent) Kick() int { @@ -127,7 +127,17 @@ func (pl *PowerLevelsEventContent) StateDefault() int { return 50 } +func (pl *PowerLevelsEventContent) BeeperEphemeralDefault() int { + if pl.BeeperEphemeralDefaultPtr != nil { + return *pl.BeeperEphemeralDefaultPtr + } + return pl.EventsDefault +} + func (pl *PowerLevelsEventContent) GetUserLevel(userID id.UserID) int { + if pl.isCreator(userID) { + return math.MaxInt + } pl.usersLock.RLock() defer pl.usersLock.RUnlock() level, ok := pl.Users[userID] @@ -137,20 +147,58 @@ func (pl *PowerLevelsEventContent) GetUserLevel(userID id.UserID) int { return level } +const maxPL = 1<<53 - 1 + func (pl *PowerLevelsEventContent) SetUserLevel(userID id.UserID, level int) { pl.usersLock.Lock() defer pl.usersLock.Unlock() + if pl.isCreator(userID) { + return + } + if level == math.MaxInt && maxPL < math.MaxInt { + // Hack to avoid breaking on 32-bit systems (they're only slightly supported) + x := int64(maxPL) + level = int(x) + } if level == pl.UsersDefault { delete(pl.Users, userID) } else { + if pl.Users == nil { + pl.Users = make(map[id.UserID]int) + } pl.Users[userID] = level } } -func (pl *PowerLevelsEventContent) EnsureUserLevel(userID id.UserID, level int) bool { - existingLevel := pl.GetUserLevel(userID) +func (pl *PowerLevelsEventContent) EnsureUserLevel(target id.UserID, level int) bool { + return pl.EnsureUserLevelAs("", target, level) +} + +func (pl *PowerLevelsEventContent) createContent() *CreateEventContent { + if pl.CreateEvent == nil { + return &CreateEventContent{} + } + return pl.CreateEvent.Content.AsCreate() +} + +func (pl *PowerLevelsEventContent) isCreator(userID id.UserID) bool { + cc := pl.createContent() + return cc.SupportsCreatorPower() && (userID == pl.CreateEvent.Sender || slices.Contains(cc.AdditionalCreators, userID)) +} + +func (pl *PowerLevelsEventContent) EnsureUserLevelAs(actor, target id.UserID, level int) bool { + if pl.isCreator(target) { + return false + } + existingLevel := pl.GetUserLevel(target) + if actor != "" && !pl.isCreator(actor) { + actorLevel := pl.GetUserLevel(actor) + if actorLevel <= existingLevel || actorLevel < level { + return false + } + } if existingLevel != level { - pl.SetUserLevel(userID, level) + pl.SetUserLevel(target, level) return true } return false @@ -169,18 +217,54 @@ func (pl *PowerLevelsEventContent) GetEventLevel(eventType Type) int { return level } +func (pl *PowerLevelsEventContent) GetBeeperEphemeralLevel(eventType Type) int { + pl.beeperEphemeralLock.RLock() + defer pl.beeperEphemeralLock.RUnlock() + level, ok := pl.BeeperEphemeral[eventType.String()] + if !ok { + return pl.BeeperEphemeralDefault() + } + return level +} + +func (pl *PowerLevelsEventContent) SetBeeperEphemeralLevel(eventType Type, level int) { + pl.beeperEphemeralLock.Lock() + defer pl.beeperEphemeralLock.Unlock() + if level == pl.BeeperEphemeralDefault() { + delete(pl.BeeperEphemeral, eventType.String()) + } else { + if pl.BeeperEphemeral == nil { + pl.BeeperEphemeral = make(map[string]int) + } + pl.BeeperEphemeral[eventType.String()] = level + } +} + func (pl *PowerLevelsEventContent) SetEventLevel(eventType Type, level int) { pl.eventsLock.Lock() defer pl.eventsLock.Unlock() if (eventType.IsState() && level == pl.StateDefault()) || (!eventType.IsState() && level == pl.EventsDefault) { delete(pl.Events, eventType.String()) } else { + if pl.Events == nil { + pl.Events = make(map[string]int) + } pl.Events[eventType.String()] = level } } func (pl *PowerLevelsEventContent) EnsureEventLevel(eventType Type, level int) bool { + return pl.EnsureEventLevelAs("", eventType, level) +} + +func (pl *PowerLevelsEventContent) EnsureEventLevelAs(actor id.UserID, eventType Type, level int) bool { existingLevel := pl.GetEventLevel(eventType) + if actor != "" && !pl.isCreator(actor) { + actorLevel := pl.GetUserLevel(actor) + if existingLevel > actorLevel || level > actorLevel { + return false + } + } if existingLevel != level { pl.SetEventLevel(eventType, level) return true diff --git a/event/powerlevels_ephemeral_test.go b/event/powerlevels_ephemeral_test.go new file mode 100644 index 00000000..f5861583 --- /dev/null +++ b/event/powerlevels_ephemeral_test.go @@ -0,0 +1,67 @@ +// Copyright (c) 2026 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package event_test + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "maunium.net/go/mautrix/event" +) + +func TestPowerLevelsEventContent_BeeperEphemeralDefaultFallsBackToEventsDefault(t *testing.T) { + pl := &event.PowerLevelsEventContent{ + EventsDefault: 45, + } + + assert.Equal(t, 45, pl.BeeperEphemeralDefault()) + + override := 60 + pl.BeeperEphemeralDefaultPtr = &override + assert.Equal(t, 60, pl.BeeperEphemeralDefault()) +} + +func TestPowerLevelsEventContent_GetSetBeeperEphemeralLevel(t *testing.T) { + pl := &event.PowerLevelsEventContent{ + EventsDefault: 25, + } + evtType := event.Type{Type: "com.example.ephemeral", Class: event.EphemeralEventType} + + assert.Equal(t, 25, pl.GetBeeperEphemeralLevel(evtType)) + + pl.SetBeeperEphemeralLevel(evtType, 50) + assert.Equal(t, 50, pl.GetBeeperEphemeralLevel(evtType)) + require.NotNil(t, pl.BeeperEphemeral) + assert.Equal(t, 50, pl.BeeperEphemeral[evtType.String()]) + + pl.SetBeeperEphemeralLevel(evtType, 25) + _, exists := pl.BeeperEphemeral[evtType.String()] + assert.False(t, exists) +} + +func TestPowerLevelsEventContent_CloneCopiesBeeperEphemeralFields(t *testing.T) { + override := 70 + pl := &event.PowerLevelsEventContent{ + EventsDefault: 35, + BeeperEphemeral: map[string]int{"com.example.ephemeral": 90}, + BeeperEphemeralDefaultPtr: &override, + } + + cloned := pl.Clone() + require.NotNil(t, cloned) + require.NotNil(t, cloned.BeeperEphemeralDefaultPtr) + assert.Equal(t, 70, *cloned.BeeperEphemeralDefaultPtr) + assert.Equal(t, 90, cloned.BeeperEphemeral["com.example.ephemeral"]) + + cloned.BeeperEphemeral["com.example.ephemeral"] = 99 + *cloned.BeeperEphemeralDefaultPtr = 71 + + assert.Equal(t, 90, pl.BeeperEphemeral["com.example.ephemeral"]) + assert.Equal(t, 70, *pl.BeeperEphemeralDefaultPtr) +} diff --git a/event/relations.go b/event/relations.go index ecd7a959..2316cbc7 100644 --- a/event/relations.go +++ b/event/relations.go @@ -15,10 +15,11 @@ import ( type RelationType string const ( - RelReplace RelationType = "m.replace" - RelReference RelationType = "m.reference" - RelAnnotation RelationType = "m.annotation" - RelThread RelationType = "m.thread" + RelReplace RelationType = "m.replace" + RelReference RelationType = "m.reference" + RelAnnotation RelationType = "m.annotation" + RelThread RelationType = "m.thread" + RelBeeperTranscription RelationType = "com.beeper.transcription" ) type RelatesTo struct { @@ -33,7 +34,7 @@ type RelatesTo struct { type InReplyTo struct { EventID id.EventID `json:"event_id,omitempty"` - UnstableRoomID id.RoomID `json:"room_id,omitempty"` + UnstableRoomID id.RoomID `json:"com.beeper.cross_room_id,omitempty"` } func (rel *RelatesTo) Copy() *RelatesTo { @@ -73,7 +74,7 @@ func (rel *RelatesTo) GetReplyTo() id.EventID { } func (rel *RelatesTo) GetNonFallbackReplyTo() id.EventID { - if rel != nil && rel.InReplyTo != nil && !rel.IsFallingBack { + if rel != nil && rel.InReplyTo != nil && (rel.Type != RelThread || !rel.IsFallingBack) { return rel.InReplyTo.EventID } return "" @@ -100,6 +101,10 @@ func (rel *RelatesTo) SetReplace(mxid id.EventID) *RelatesTo { } func (rel *RelatesTo) SetReplyTo(mxid id.EventID) *RelatesTo { + if rel.Type != RelThread { + rel.Type = "" + rel.EventID = "" + } rel.InReplyTo = &InReplyTo{EventID: mxid} rel.IsFallingBack = false return rel diff --git a/event/reply.go b/event/reply.go index 73f8cfc7..5f55bb80 100644 --- a/event/reply.go +++ b/event/reply.go @@ -1,4 +1,4 @@ -// Copyright (c) 2020 Tulir Asokan +// Copyright (c) 2024 Tulir Asokan // // This Source Code Form is subject to the terms of the Mozilla Public // License, v. 2.0. If a copy of the MPL was not distributed with this @@ -7,7 +7,6 @@ package event import ( - "fmt" "regexp" "strings" @@ -33,12 +32,13 @@ func TrimReplyFallbackText(text string) string { } func (content *MessageEventContent) RemoveReplyFallback() { - if len(content.RelatesTo.GetReplyTo()) > 0 && !content.replyFallbackRemoved { - if content.Format == FormatHTML { - content.FormattedBody = TrimReplyFallbackHTML(content.FormattedBody) + if len(content.RelatesTo.GetReplyTo()) > 0 && !content.replyFallbackRemoved && content.Format == FormatHTML { + origHTML := content.FormattedBody + content.FormattedBody = TrimReplyFallbackHTML(content.FormattedBody) + if content.FormattedBody != origHTML { + content.Body = TrimReplyFallbackText(content.Body) + content.replyFallbackRemoved = true } - content.Body = TrimReplyFallbackText(content.Body) - content.replyFallbackRemoved = true } } @@ -47,52 +47,28 @@ func (content *MessageEventContent) GetReplyTo() id.EventID { return content.RelatesTo.GetReplyTo() } -const ReplyFormat = `
      In reply to %s
      %s
      ` - -func (evt *Event) GenerateReplyFallbackHTML() string { - parsedContent, ok := evt.Content.Parsed.(*MessageEventContent) - if !ok { - return "" - } - parsedContent.RemoveReplyFallback() - body := parsedContent.FormattedBody - if len(body) == 0 { - body = TextToHTML(parsedContent.Body) - } - - senderDisplayName := evt.Sender - - return fmt.Sprintf(ReplyFormat, evt.RoomID, evt.ID, evt.Sender, senderDisplayName, body) -} - -func (evt *Event) GenerateReplyFallbackText() string { - parsedContent, ok := evt.Content.Parsed.(*MessageEventContent) - if !ok { - return "" - } - parsedContent.RemoveReplyFallback() - body := parsedContent.Body - lines := strings.Split(strings.TrimSpace(body), "\n") - firstLine, lines := lines[0], lines[1:] - - senderDisplayName := evt.Sender - - var fallbackText strings.Builder - _, _ = fmt.Fprintf(&fallbackText, "> <%s> %s", senderDisplayName, firstLine) - for _, line := range lines { - _, _ = fmt.Fprintf(&fallbackText, "\n> %s", line) - } - fallbackText.WriteString("\n\n") - return fallbackText.String() -} - func (content *MessageEventContent) SetReply(inReplyTo *Event) { - content.RelatesTo = (&RelatesTo{}).SetReplyTo(inReplyTo.ID) - - if content.MsgType == MsgText || content.MsgType == MsgNotice { - content.EnsureHasHTML() - content.FormattedBody = inReplyTo.GenerateReplyFallbackHTML() + content.FormattedBody - content.Body = inReplyTo.GenerateReplyFallbackText() + content.Body - content.replyFallbackRemoved = false + if content.RelatesTo == nil { + content.RelatesTo = &RelatesTo{} } + content.RelatesTo.SetReplyTo(inReplyTo.ID) + if content.Mentions == nil { + content.Mentions = &Mentions{} + } + content.Mentions.Add(inReplyTo.Sender) +} + +func (content *MessageEventContent) SetThread(inReplyTo *Event) { + root := inReplyTo.ID + relatable, ok := inReplyTo.Content.Parsed.(Relatable) + if ok { + targetRoot := relatable.OptionalGetRelatesTo().GetThreadParent() + if targetRoot != "" { + root = targetRoot + } + } + if content.RelatesTo == nil { + content.RelatesTo = &RelatesTo{} + } + content.RelatesTo.SetThread(root, inReplyTo.ID) } diff --git a/event/state.go b/event/state.go index d6b6cf70..ace170a5 100644 --- a/event/state.go +++ b/event/state.go @@ -7,6 +7,12 @@ package event import ( + "encoding/base64" + "encoding/json" + "slices" + + "go.mau.fi/util/jsontime" + "maunium.net/go/mautrix/id" ) @@ -26,8 +32,9 @@ type RoomNameEventContent struct { // RoomAvatarEventContent represents the content of a m.room.avatar state event. // https://spec.matrix.org/v1.2/client-server-api/#mroomavatar type RoomAvatarEventContent struct { - URL id.ContentURI `json:"url"` - Info *FileInfo `json:"info,omitempty"` + URL id.ContentURIString `json:"url,omitempty"` + Info *FileInfo `json:"info,omitempty"` + MSC3414File *EncryptedFileInfo `json:"org.matrix.msc3414.file,omitempty"` } // ServerACLEventContent represents the content of a m.room.server_acl state event. @@ -41,7 +48,52 @@ type ServerACLEventContent struct { // TopicEventContent represents the content of a m.room.topic state event. // https://spec.matrix.org/v1.2/client-server-api/#mroomtopic type TopicEventContent struct { - Topic string `json:"topic"` + Topic string `json:"topic"` + ExtensibleTopic *ExtensibleTopic `json:"m.topic,omitempty"` +} + +// ExtensibleTopic represents the contents of the m.topic field within the +// m.room.topic state event as described in [MSC3765]. +// +// [MSC3765]: https://github.com/matrix-org/matrix-spec-proposals/pull/3765 +type ExtensibleTopic = ExtensibleTextContainer + +type ExtensibleTextContainer struct { + Text []ExtensibleText `json:"m.text"` +} + +func (c *ExtensibleTextContainer) Equals(description *ExtensibleTextContainer) bool { + if c == nil || description == nil { + return c == description + } + return slices.Equal(c.Text, description.Text) +} + +func MakeExtensibleText(text string) *ExtensibleTextContainer { + return &ExtensibleTextContainer{ + Text: []ExtensibleText{{ + Body: text, + MimeType: "text/plain", + }}, + } +} + +func MakeExtensibleFormattedText(plaintext, html string) *ExtensibleTextContainer { + return &ExtensibleTextContainer{ + Text: []ExtensibleText{{ + Body: plaintext, + MimeType: "text/plain", + }, { + Body: html, + MimeType: "text/html", + }}, + } +} + +// ExtensibleText represents the contents of an m.text field. +type ExtensibleText struct { + MimeType string `json:"mimetype,omitempty"` + Body string `json:"body"` } // TombstoneEventContent represents the content of a m.room.tombstone state event. @@ -51,19 +103,64 @@ type TombstoneEventContent struct { ReplacementRoom id.RoomID `json:"replacement_room"` } +func (tec *TombstoneEventContent) GetReplacementRoom() id.RoomID { + if tec == nil { + return "" + } + return tec.ReplacementRoom +} + type Predecessor struct { RoomID id.RoomID `json:"room_id"` EventID id.EventID `json:"event_id"` } +// Deprecated: use id.RoomVersion instead +type RoomVersion = id.RoomVersion + +// Deprecated: use id.RoomVX constants instead +const ( + RoomV1 = id.RoomV1 + RoomV2 = id.RoomV2 + RoomV3 = id.RoomV3 + RoomV4 = id.RoomV4 + RoomV5 = id.RoomV5 + RoomV6 = id.RoomV6 + RoomV7 = id.RoomV7 + RoomV8 = id.RoomV8 + RoomV9 = id.RoomV9 + RoomV10 = id.RoomV10 + RoomV11 = id.RoomV11 + RoomV12 = id.RoomV12 +) + // CreateEventContent represents the content of a m.room.create state event. // https://spec.matrix.org/v1.2/client-server-api/#mroomcreate type CreateEventContent struct { - Type RoomType `json:"type,omitempty"` - Creator id.UserID `json:"creator,omitempty"` - Federate bool `json:"m.federate,omitempty"` - RoomVersion string `json:"room_version,omitempty"` - Predecessor *Predecessor `json:"predecessor,omitempty"` + Type RoomType `json:"type,omitempty"` + Federate *bool `json:"m.federate,omitempty"` + RoomVersion id.RoomVersion `json:"room_version,omitempty"` + Predecessor *Predecessor `json:"predecessor,omitempty"` + + // Room v12+ only + AdditionalCreators []id.UserID `json:"additional_creators,omitempty"` + + // Deprecated: use the event sender instead + Creator id.UserID `json:"creator,omitempty"` +} + +func (cec *CreateEventContent) GetPredecessor() (p Predecessor) { + if cec != nil && cec.Predecessor != nil { + p = *cec.Predecessor + } + return +} + +func (cec *CreateEventContent) SupportsCreatorPower() bool { + if cec == nil { + return false + } + return cec.RoomVersion.PrivilegedRoomCreators() } // JoinRule specifies how open a room is to new members. @@ -71,11 +168,12 @@ type CreateEventContent struct { type JoinRule string const ( - JoinRulePublic JoinRule = "public" - JoinRuleKnock JoinRule = "knock" - JoinRuleInvite JoinRule = "invite" - JoinRuleRestricted JoinRule = "restricted" - JoinRulePrivate JoinRule = "private" + JoinRulePublic JoinRule = "public" + JoinRuleKnock JoinRule = "knock" + JoinRuleInvite JoinRule = "invite" + JoinRuleRestricted JoinRule = "restricted" + JoinRuleKnockRestricted JoinRule = "knock_restricted" + JoinRulePrivate JoinRule = "private" ) // JoinRulesEventContent represents the content of a m.room.join_rules state event. @@ -139,6 +237,9 @@ type BridgeInfoSection struct { DisplayName string `json:"displayname,omitempty"` AvatarURL id.ContentURIString `json:"avatar_url,omitempty"` ExternalURL string `json:"external_url,omitempty"` + + Receiver string `json:"fi.mau.receiver,omitempty"` + MessageRequest bool `json:"com.beeper.message_request,omitempty"` } // BridgeEventContent represents the content of a m.bridge state event. @@ -149,6 +250,35 @@ type BridgeEventContent struct { Protocol BridgeInfoSection `json:"protocol"` Network *BridgeInfoSection `json:"network,omitempty"` Channel BridgeInfoSection `json:"channel"` + + BeeperRoomType string `json:"com.beeper.room_type,omitempty"` + BeeperRoomTypeV2 string `json:"com.beeper.room_type.v2,omitempty"` + + TempSlackRemoteIDMigratedFlag bool `json:"com.beeper.slack_remote_id_migrated,omitempty"` + TempSlackRemoteIDMigratedFlag2 bool `json:"com.beeper.slack_remote_id_really_migrated,omitempty"` +} + +// DisappearingType represents the type of a disappearing message timer. +type DisappearingType string + +const ( + DisappearingTypeNone DisappearingType = "" + DisappearingTypeAfterRead DisappearingType = "after_read" + DisappearingTypeAfterSend DisappearingType = "after_send" +) + +type BeeperDisappearingTimer struct { + Type DisappearingType `json:"type"` + Timer jsontime.Milliseconds `json:"timer"` +} + +type marshalableBeeperDisappearingTimer BeeperDisappearingTimer + +func (bdt *BeeperDisappearingTimer) MarshalJSON() ([]byte, error) { + if bdt == nil || bdt.Type == DisappearingTypeNone { + return []byte("{}"), nil + } + return json.Marshal((*marshalableBeeperDisappearingTimer)(bdt)) } type SpaceChildEventContent struct { @@ -162,16 +292,66 @@ type SpaceParentEventContent struct { Canonical bool `json:"canonical,omitempty"` } +type PolicyRecommendation string + +const ( + PolicyRecommendationBan PolicyRecommendation = "m.ban" + PolicyRecommendationUnstableTakedown PolicyRecommendation = "org.matrix.msc4204.takedown" + PolicyRecommendationUnstableBan PolicyRecommendation = "org.matrix.mjolnir.ban" + PolicyRecommendationUnban PolicyRecommendation = "fi.mau.meowlnir.unban" +) + +type PolicyHashes struct { + SHA256 string `json:"sha256"` +} + +func (ph *PolicyHashes) DecodeSHA256() *[32]byte { + if ph == nil || ph.SHA256 == "" { + return nil + } + decoded, _ := base64.StdEncoding.DecodeString(ph.SHA256) + if len(decoded) == 32 { + return (*[32]byte)(decoded) + } + return nil +} + // ModPolicyContent represents the content of a m.room.rule.user, m.room.rule.room, and m.room.rule.server state event. // https://spec.matrix.org/v1.2/client-server-api/#moderation-policy-lists type ModPolicyContent struct { - Entity string `json:"entity"` - Reason string `json:"reason"` - Recommendation string `json:"recommendation"` + Entity string `json:"entity,omitempty"` + Reason string `json:"reason"` + Recommendation PolicyRecommendation `json:"recommendation"` + UnstableHashes *PolicyHashes `json:"org.matrix.msc4205.hashes,omitempty"` } -// Deprecated: MSC2716 has been abandoned -type InsertionMarkerContent struct { - InsertionID id.EventID `json:"org.matrix.msc2716.marker.insertion"` - Timestamp int64 `json:"com.beeper.timestamp,omitempty"` +func (mpc *ModPolicyContent) EntityOrHash() string { + if mpc.UnstableHashes != nil && mpc.UnstableHashes.SHA256 != "" { + return mpc.UnstableHashes.SHA256 + } + return mpc.Entity +} + +type ElementFunctionalMembersContent struct { + ServiceMembers []id.UserID `json:"service_members"` +} + +func (efmc *ElementFunctionalMembersContent) Add(mxid id.UserID) bool { + if slices.Contains(efmc.ServiceMembers, mxid) { + return false + } + efmc.ServiceMembers = append(efmc.ServiceMembers, mxid) + return true +} + +type PolicyServerPublicKeys struct { + Ed25519 id.Ed25519 `json:"ed25519,omitempty"` +} + +type RoomPolicyEventContent struct { + Via string `json:"via,omitempty"` + PublicKeys *PolicyServerPublicKeys `json:"public_keys,omitempty"` + + // Deprecated, only for legacy use + PublicKey id.Ed25519 `json:"public_key,omitempty"` } diff --git a/event/type.go b/event/type.go index a4b36392..80b86728 100644 --- a/event/type.go +++ b/event/type.go @@ -108,15 +108,17 @@ func (et *Type) IsCustom() bool { func (et *Type) GuessClass() TypeClass { switch et.Type { - case StateAliases.Type, StateCanonicalAlias.Type, StateCreate.Type, StateJoinRules.Type, StateMember.Type, + case StateAliases.Type, StateCanonicalAlias.Type, StateCreate.Type, StateJoinRules.Type, StateMember.Type, StateThirdPartyInvite.Type, StatePowerLevels.Type, StateRoomName.Type, StateRoomAvatar.Type, StateServerACL.Type, StateTopic.Type, StatePinnedEvents.Type, StateTombstone.Type, StateEncryption.Type, StateBridge.Type, StateHalfShotBridge.Type, StateSpaceParent.Type, StateSpaceChild.Type, StatePolicyRoom.Type, StatePolicyServer.Type, StatePolicyUser.Type, - StateInsertionMarker.Type: + StateElementFunctionalMembers.Type, StateBeeperRoomFeatures.Type, StateBeeperDisappearingTimer.Type, + StateMSC4391BotCommand.Type, StateRoomPolicy.Type, StateUnstableRoomPolicy.Type: return StateEventType - case EphemeralEventReceipt.Type, EphemeralEventTyping.Type, EphemeralEventPresence.Type: + case EphemeralEventReceipt.Type, EphemeralEventTyping.Type, EphemeralEventPresence.Type, BeeperEphemeralEventAIStream.Type: return EphemeralEventType case AccountDataDirectChats.Type, AccountDataPushRules.Type, AccountDataRoomTags.Type, + AccountDataFullyRead.Type, AccountDataIgnoredUserList.Type, AccountDataMarkedUnread.Type, AccountDataSecretStorageKey.Type, AccountDataSecretStorageDefaultKey.Type, AccountDataCrossSigningMaster.Type, AccountDataCrossSigningSelf.Type, AccountDataCrossSigningUser.Type, AccountDataFullyRead.Type, AccountDataMegolmBackupKey.Type: @@ -125,7 +127,8 @@ func (et *Type) GuessClass() TypeClass { InRoomVerificationStart.Type, InRoomVerificationReady.Type, InRoomVerificationAccept.Type, InRoomVerificationKey.Type, InRoomVerificationMAC.Type, InRoomVerificationCancel.Type, CallInvite.Type, CallCandidates.Type, CallAnswer.Type, CallReject.Type, CallSelectAnswer.Type, - CallNegotiate.Type, CallHangup.Type, BeeperMessageStatus.Type: + CallNegotiate.Type, CallHangup.Type, BeeperMessageStatus.Type, EventUnstablePollStart.Type, EventUnstablePollResponse.Type, + EventUnstablePollEnd.Type, BeeperTranscription.Type, BeeperDeleteChat.Type, BeeperAcceptMessageRequest.Type: return MessageEventType case ToDeviceRoomKey.Type, ToDeviceRoomKeyRequest.Type, ToDeviceForwardedRoomKey.Type, ToDeviceRoomKeyWithheld.Type, ToDeviceBeeperRoomKeyAck.Type: @@ -148,7 +151,7 @@ func (et *Type) MarshalJSON() ([]byte, error) { return json.Marshal(&et.Type) } -func (et Type) UnmarshalText(data []byte) error { +func (et *Type) UnmarshalText(data []byte) error { et.Type = string(data) et.Class = et.GuessClass() return nil @@ -158,11 +161,11 @@ func (et Type) MarshalText() ([]byte, error) { return []byte(et.Type), nil } -func (et *Type) String() string { +func (et Type) String() string { return et.Type } -func (et *Type) Repr() string { +func (et Type) Repr() string { return fmt.Sprintf("%s (%s)", et.Type, et.Class.Name()) } @@ -175,6 +178,7 @@ var ( StateHistoryVisibility = Type{"m.room.history_visibility", StateEventType} StateGuestAccess = Type{"m.room.guest_access", StateEventType} StateMember = Type{"m.room.member", StateEventType} + StateThirdPartyInvite = Type{"m.room.third_party_invite", StateEventType} StatePowerLevels = Type{"m.room.power_levels", StateEventType} StateRoomName = Type{"m.room.name", StateEventType} StateTopic = Type{"m.room.topic", StateEventType} @@ -191,8 +195,20 @@ var ( StateSpaceChild = Type{"m.space.child", StateEventType} StateSpaceParent = Type{"m.space.parent", StateEventType} - // Deprecated: MSC2716 has been abandoned - StateInsertionMarker = Type{"org.matrix.msc2716.marker", StateEventType} + StateRoomPolicy = Type{"m.room.policy", StateEventType} + StateUnstableRoomPolicy = Type{"org.matrix.msc4284.policy", StateEventType} + + StateLegacyPolicyRoom = Type{"m.room.rule.room", StateEventType} + StateLegacyPolicyServer = Type{"m.room.rule.server", StateEventType} + StateLegacyPolicyUser = Type{"m.room.rule.user", StateEventType} + StateUnstablePolicyRoom = Type{"org.matrix.mjolnir.rule.room", StateEventType} + StateUnstablePolicyServer = Type{"org.matrix.mjolnir.rule.server", StateEventType} + StateUnstablePolicyUser = Type{"org.matrix.mjolnir.rule.user", StateEventType} + + StateElementFunctionalMembers = Type{"io.element.functional_members", StateEventType} + StateBeeperRoomFeatures = Type{"com.beeper.room_features", StateEventType} + StateBeeperDisappearingTimer = Type{"com.beeper.disappearing_timer", StateEventType} + StateMSC4391BotCommand = Type{"org.matrix.msc4391.command_description", StateEventType} ) // Message events @@ -221,14 +237,24 @@ var ( CallNegotiate = Type{"m.call.negotiate", MessageEventType} CallHangup = Type{"m.call.hangup", MessageEventType} - BeeperMessageStatus = Type{"com.beeper.message_send_status", MessageEventType} + BeeperMessageStatus = Type{"com.beeper.message_send_status", MessageEventType} + BeeperTranscription = Type{"com.beeper.transcription", MessageEventType} + BeeperDeleteChat = Type{"com.beeper.delete_chat", MessageEventType} + BeeperAcceptMessageRequest = Type{"com.beeper.accept_message_request", MessageEventType} + BeeperSendState = Type{"com.beeper.send_state", MessageEventType} + + EventUnstablePollStart = Type{Type: "org.matrix.msc3381.poll.start", Class: MessageEventType} + EventUnstablePollResponse = Type{Type: "org.matrix.msc3381.poll.response", Class: MessageEventType} + EventUnstablePollEnd = Type{Type: "org.matrix.msc3381.poll.end", Class: MessageEventType} ) // Ephemeral events var ( - EphemeralEventReceipt = Type{"m.receipt", EphemeralEventType} - EphemeralEventTyping = Type{"m.typing", EphemeralEventType} - EphemeralEventPresence = Type{"m.presence", EphemeralEventType} + EphemeralEventReceipt = Type{"m.receipt", EphemeralEventType} + EphemeralEventTyping = Type{"m.typing", EphemeralEventType} + EphemeralEventPresence = Type{"m.presence", EphemeralEventType} + EphemeralEventEncrypted = Type{"m.room.encrypted", EphemeralEventType} + BeeperEphemeralEventAIStream = Type{"com.beeper.ai.stream_event", EphemeralEventType} ) // Account data events @@ -238,13 +264,15 @@ var ( AccountDataRoomTags = Type{"m.tag", AccountDataEventType} AccountDataFullyRead = Type{"m.fully_read", AccountDataEventType} AccountDataIgnoredUserList = Type{"m.ignored_user_list", AccountDataEventType} + AccountDataMarkedUnread = Type{"m.marked_unread", AccountDataEventType} + AccountDataBeeperMute = Type{"com.beeper.mute", AccountDataEventType} AccountDataSecretStorageDefaultKey = Type{"m.secret_storage.default_key", AccountDataEventType} AccountDataSecretStorageKey = Type{"m.secret_storage.key", AccountDataEventType} AccountDataCrossSigningMaster = Type{string(id.SecretXSMaster), AccountDataEventType} AccountDataCrossSigningUser = Type{string(id.SecretXSUserSigning), AccountDataEventType} AccountDataCrossSigningSelf = Type{string(id.SecretXSSelfSigning), AccountDataEventType} - AccountDataMegolmBackupKey = Type{"m.megolm_backup.v1", AccountDataEventType} + AccountDataMegolmBackupKey = Type{string(id.SecretMegolmBackupV1), AccountDataEventType} ) // Device-to-device events diff --git a/event/verification.go b/event/verification.go index b1851de3..6101896f 100644 --- a/event/verification.go +++ b/event/verification.go @@ -220,6 +220,10 @@ const ( VerificationCancelCodeAccepted VerificationCancelCode = "m.accepted" VerificationCancelCodeSASMismatch VerificationCancelCode = "m.mismatched_sas" VerificationCancelCodeCommitmentMismatch VerificationCancelCode = "m.mismatched_commitment" + + // Non-spec codes + VerificationCancelCodeInternalError VerificationCancelCode = "com.beeper.internal_error" + VerificationCancelCodeMasterKeyNotTrusted VerificationCancelCode = "com.beeper.master_key_not_trusted" // the master key is not trusted by this device, but the QR code that was scanned was from a device that doesn't trust the master key ) // VerificationCancelEventContent represents the content of an diff --git a/event/voip.go b/event/voip.go index 28f56c95..cd8364a1 100644 --- a/event/voip.go +++ b/event/voip.go @@ -76,7 +76,7 @@ func (cv *CallVersion) Int() (int, error) { type BaseCallEventContent struct { CallID string `json:"call_id"` PartyID string `json:"party_id"` - Version CallVersion `json:"version"` + Version CallVersion `json:"version,omitempty"` } type CallInviteEventContent struct { diff --git a/example/go.mod b/example/go.mod deleted file mode 100644 index 60583640..00000000 --- a/example/go.mod +++ /dev/null @@ -1,29 +0,0 @@ -module maunium.net/go/mautrix/example - -go 1.21 - -toolchain go1.22.0 - -require ( - github.com/chzyer/readline v1.5.1 - github.com/mattn/go-sqlite3 v1.14.22 - github.com/rs/zerolog v1.32.0 - go.mau.fi/util v0.3.1-0.20240208085450-32294da153ab - maunium.net/go/mautrix v0.17.1-0.20240208085632-2d1786ced444 -) - -require ( - github.com/mattn/go-colorable v0.1.13 // indirect - github.com/mattn/go-isatty v0.0.19 // indirect - github.com/tidwall/gjson v1.17.0 // indirect - github.com/tidwall/match v1.1.1 // indirect - github.com/tidwall/pretty v1.2.0 // indirect - github.com/tidwall/sjson v1.2.5 // indirect - golang.org/x/crypto v0.19.0 // indirect - golang.org/x/exp v0.0.0-20240205201215-2c58cdc269a3 // indirect - golang.org/x/net v0.21.0 // indirect - golang.org/x/sys v0.17.0 // indirect - maunium.net/go/maulogger/v2 v2.4.1 // indirect -) - -//replace maunium.net/go/mautrix => ../ diff --git a/example/go.sum b/example/go.sum deleted file mode 100644 index f81f31c2..00000000 --- a/example/go.sum +++ /dev/null @@ -1,56 +0,0 @@ -github.com/DATA-DOG/go-sqlmock v1.5.2 h1:OcvFkGmslmlZibjAjaHm3L//6LiuBgolP7OputlJIzU= -github.com/DATA-DOG/go-sqlmock v1.5.2/go.mod h1:88MAG/4G7SMwSE3CeA0ZKzrT5CiOU3OJ+JlNzwDqpNU= -github.com/chzyer/logex v1.2.1 h1:XHDu3E6q+gdHgsdTPH6ImJMIp436vR6MPtH8gP05QzM= -github.com/chzyer/logex v1.2.1/go.mod h1:JLbx6lG2kDbNRFnfkgvh4eRJRPX1QCoOIWomwysCBrQ= -github.com/chzyer/readline v1.5.1 h1:upd/6fQk4src78LMRzh5vItIt361/o4uq553V8B5sGI= -github.com/chzyer/readline v1.5.1/go.mod h1:Eh+b79XXUwfKfcPLepksvw2tcLE/Ct21YObkaSkeBlk= -github.com/chzyer/test v1.0.0 h1:p3BQDXSxOhOG0P9z6/hGnII4LGiEPOYBhs8asl/fC04= -github.com/chzyer/test v1.0.0/go.mod h1:2JlltgoNkt4TW/z9V/IzDdFaMTM2JPIi26O1pF38GC8= -github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc= -github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= -github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA= -github.com/mattn/go-colorable v0.1.13 h1:fFA4WZxdEF4tXPZVKMLwD8oUnCTTo08duU7wxecdEvA= -github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg= -github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM= -github.com/mattn/go-isatty v0.0.19 h1:JITubQf0MOLdlGRuRq+jtsDlekdYPia9ZFsB8h/APPA= -github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= -github.com/mattn/go-sqlite3 v1.14.22 h1:2gZY6PC6kBnID23Tichd1K+Z0oS6nE/XwU+Vz/5o4kU= -github.com/mattn/go-sqlite3 v1.14.22/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y= -github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= -github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= -github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= -github.com/rs/xid v1.5.0/go.mod h1:trrq9SKmegXys3aeAKXMUTdJsYXVwGY3RLcfgqegfbg= -github.com/rs/zerolog v1.32.0 h1:keLypqrlIjaFsbmJOBdB/qvyF8KEtCWHwobLp5l/mQ0= -github.com/rs/zerolog v1.32.0/go.mod h1:/7mN4D5sKwJLZQ2b/znpjC3/GQWY/xaDXUM0kKWRHss= -github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk= -github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= -github.com/tidwall/gjson v1.14.2/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= -github.com/tidwall/gjson v1.17.0 h1:/Jocvlh98kcTfpN2+JzGQWQcqrPQwDrVEMApx/M5ZwM= -github.com/tidwall/gjson v1.17.0/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= -github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA= -github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM= -github.com/tidwall/pretty v1.2.0 h1:RWIZEg2iJ8/g6fDDYzMpobmaoGh5OLl4AXtGUGPcqCs= -github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU= -github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY= -github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28= -go.mau.fi/util v0.3.1-0.20240208085450-32294da153ab h1:XZ8W5vHWlXSGmHn1U+Fvbh+xZr9wuHTvbY+qV7aybDY= -go.mau.fi/util v0.3.1-0.20240208085450-32294da153ab/go.mod h1:rRypwgXVEPILomtFPyQcnbOeuRqf+nRN84vh/CICq4w= -golang.org/x/crypto v0.19.0 h1:ENy+Az/9Y1vSrlrvBSyna3PITt4tiZLf7sgCjZBX7Wo= -golang.org/x/crypto v0.19.0/go.mod h1:Iy9bg/ha4yyC70EfRS8jz+B6ybOBKMaSxLj6P6oBDfU= -golang.org/x/exp v0.0.0-20240205201215-2c58cdc269a3 h1:/RIbNt/Zr7rVhIkQhooTxCxFcdWLGIKnZA4IXNFSrvo= -golang.org/x/exp v0.0.0-20240205201215-2c58cdc269a3/go.mod h1:idGWGoKP1toJGkd5/ig9ZLuPcZBC3ewk7SzmH0uou08= -golang.org/x/net v0.21.0 h1:AQyQV4dYCvJ7vGmJyKki9+PBdyvhkSd8EIx/qb0AYv4= -golang.org/x/net v0.21.0/go.mod h1:bIjVDfnllIU7BJ2DNgfnXvpSvtn8VRwhlsaeUTyUS44= -golang.org/x/sys v0.0.0-20220310020820-b874c991c1a5/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.17.0 h1:25cE3gD+tdBA7lp7QfhuV+rJiE9YXTcS3VG1SqssI/Y= -golang.org/x/sys v0.17.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= -gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= -gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= -maunium.net/go/maulogger/v2 v2.4.1 h1:N7zSdd0mZkB2m2JtFUsiGTQQAdP0YeFWT7YMc80yAL8= -maunium.net/go/maulogger/v2 v2.4.1/go.mod h1:omPuYwYBILeVQobz8uO3XC8DIRuEb5rXYlQSuqrbCho= -maunium.net/go/mautrix v0.17.1-0.20240208085632-2d1786ced444 h1:PkpCzQotFakHkGKAatiQdb+XjP/HLQM40xuiy2JtHes= -maunium.net/go/mautrix v0.17.1-0.20240208085632-2d1786ced444/go.mod h1:tMIBWuMXrtjXAqMtaD1VHiT0B3TCxraYlqtncLIyKF0= diff --git a/example/main.go b/example/main.go index d8006d46..2bf4bef3 100644 --- a/example/main.go +++ b/example/main.go @@ -143,7 +143,7 @@ func main() { if err != nil { log.Error().Err(err).Msg("Failed to send event") } else { - log.Info().Str("event_id", resp.EventID.String()).Msg("Event sent") + log.Info().Stringer("event_id", resp.EventID).Msg("Event sent") } } cancelSync() diff --git a/federation/cache.go b/federation/cache.go new file mode 100644 index 00000000..24154974 --- /dev/null +++ b/federation/cache.go @@ -0,0 +1,153 @@ +// Copyright (c) 2025 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package federation + +import ( + "errors" + "fmt" + "math" + "sync" + "time" +) + +// ResolutionCache is an interface for caching resolved server names. +type ResolutionCache interface { + StoreResolution(*ResolvedServerName) + // LoadResolution loads a resolved server name from the cache. + // Expired entries MUST NOT be returned. + LoadResolution(serverName string) (*ResolvedServerName, error) +} + +type KeyCache interface { + StoreKeys(*ServerKeyResponse) + StoreFetchError(serverName string, err error) + ShouldReQuery(serverName string) bool + LoadKeys(serverName string) (*ServerKeyResponse, error) +} + +type InMemoryCache struct { + MinKeyRefetchDelay time.Duration + + resolutions map[string]*ResolvedServerName + resolutionsLock sync.RWMutex + keys map[string]*ServerKeyResponse + lastReQueryAt map[string]time.Time + lastError map[string]*resolutionErrorCache + keysLock sync.RWMutex +} + +var ( + _ ResolutionCache = (*InMemoryCache)(nil) + _ KeyCache = (*InMemoryCache)(nil) +) + +func NewInMemoryCache() *InMemoryCache { + return &InMemoryCache{ + resolutions: make(map[string]*ResolvedServerName), + keys: make(map[string]*ServerKeyResponse), + lastReQueryAt: make(map[string]time.Time), + lastError: make(map[string]*resolutionErrorCache), + MinKeyRefetchDelay: 1 * time.Hour, + } +} + +func (c *InMemoryCache) StoreResolution(resolution *ResolvedServerName) { + c.resolutionsLock.Lock() + defer c.resolutionsLock.Unlock() + c.resolutions[resolution.ServerName] = resolution +} + +func (c *InMemoryCache) LoadResolution(serverName string) (*ResolvedServerName, error) { + c.resolutionsLock.RLock() + defer c.resolutionsLock.RUnlock() + resolution, ok := c.resolutions[serverName] + if !ok || time.Until(resolution.Expires) < 0 { + return nil, nil + } + return resolution, nil +} + +func (c *InMemoryCache) StoreKeys(keys *ServerKeyResponse) { + c.keysLock.Lock() + defer c.keysLock.Unlock() + c.keys[keys.ServerName] = keys + delete(c.lastError, keys.ServerName) +} + +type resolutionErrorCache struct { + Error error + Time time.Time + Count int +} + +const MaxBackoff = 7 * 24 * time.Hour + +func (rec *resolutionErrorCache) ShouldRetry() bool { + backoff := time.Duration(math.Exp(float64(rec.Count))) * time.Second + return time.Since(rec.Time) > backoff +} + +var ErrRecentKeyQueryFailed = errors.New("last retry was too recent") + +func (c *InMemoryCache) LoadKeys(serverName string) (*ServerKeyResponse, error) { + c.keysLock.RLock() + defer c.keysLock.RUnlock() + keys, ok := c.keys[serverName] + if !ok || time.Until(keys.ValidUntilTS.Time) < 0 { + err, ok := c.lastError[serverName] + if ok && !err.ShouldRetry() { + return nil, fmt.Errorf( + "%w (%s ago) and failed with %w", + ErrRecentKeyQueryFailed, + time.Since(err.Time).String(), + err.Error, + ) + } + return nil, nil + } + return keys, nil +} + +func (c *InMemoryCache) StoreFetchError(serverName string, err error) { + c.keysLock.Lock() + defer c.keysLock.Unlock() + errorCache, ok := c.lastError[serverName] + if ok { + errorCache.Time = time.Now() + errorCache.Error = err + errorCache.Count++ + } else { + c.lastError[serverName] = &resolutionErrorCache{Error: err, Time: time.Now(), Count: 1} + } +} + +func (c *InMemoryCache) ShouldReQuery(serverName string) bool { + c.keysLock.Lock() + defer c.keysLock.Unlock() + lastQuery, ok := c.lastReQueryAt[serverName] + if ok && time.Since(lastQuery) < c.MinKeyRefetchDelay { + return false + } + c.lastReQueryAt[serverName] = time.Now() + return true +} + +type noopCache struct{} + +func (*noopCache) StoreKeys(_ *ServerKeyResponse) {} +func (*noopCache) LoadKeys(_ string) (*ServerKeyResponse, error) { return nil, nil } +func (*noopCache) StoreFetchError(_ string, _ error) {} +func (*noopCache) ShouldReQuery(_ string) bool { return true } +func (*noopCache) StoreResolution(_ *ResolvedServerName) {} +func (*noopCache) LoadResolution(_ string) (*ResolvedServerName, error) { return nil, nil } + +var ( + _ ResolutionCache = (*noopCache)(nil) + _ KeyCache = (*noopCache)(nil) +) + +var NoopCache *noopCache diff --git a/federation/client.go b/federation/client.go new file mode 100644 index 00000000..183fb5d1 --- /dev/null +++ b/federation/client.go @@ -0,0 +1,605 @@ +// Copyright (c) 2024 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package federation + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "net/url" + "strconv" + "time" + + "go.mau.fi/util/exslices" + "go.mau.fi/util/jsontime" + + "maunium.net/go/mautrix" + "maunium.net/go/mautrix/federation/signutil" + "maunium.net/go/mautrix/id" +) + +type Client struct { + HTTP *http.Client + ServerName string + UserAgent string + Key *SigningKey + + ResponseSizeLimit int64 +} + +func NewClient(serverName string, key *SigningKey, cache ResolutionCache) *Client { + return &Client{ + HTTP: &http.Client{ + Transport: NewServerResolvingTransport(cache), + Timeout: 120 * time.Second, + CheckRedirect: func(req *http.Request, via []*http.Request) error { + // Federation requests do not allow redirects. + return http.ErrUseLastResponse + }, + }, + UserAgent: mautrix.DefaultUserAgent, + ServerName: serverName, + Key: key, + + ResponseSizeLimit: mautrix.DefaultResponseSizeLimit, + } +} + +func (c *Client) Version(ctx context.Context, serverName string) (resp *RespServerVersion, err error) { + err = c.MakeRequest(ctx, serverName, false, http.MethodGet, URLPath{"v1", "version"}, nil, &resp) + return +} + +func (c *Client) ServerKeys(ctx context.Context, serverName string) (resp *ServerKeyResponse, err error) { + err = c.MakeRequest(ctx, serverName, false, http.MethodGet, KeyURLPath{"v2", "server"}, nil, &resp) + return +} + +func (c *Client) QueryKeys(ctx context.Context, serverName string, req *ReqQueryKeys) (resp *QueryKeysResponse, err error) { + err = c.MakeRequest(ctx, serverName, false, http.MethodPost, KeyURLPath{"v2", "query"}, req, &resp) + return +} + +type PDU = json.RawMessage +type EDU = json.RawMessage + +type ReqSendTransaction struct { + Destination string `json:"destination"` + TxnID string `json:"-"` + + Origin string `json:"origin"` + OriginServerTS jsontime.UnixMilli `json:"origin_server_ts"` + PDUs []PDU `json:"pdus"` + EDUs []EDU `json:"edus,omitempty"` +} + +type PDUProcessingResult struct { + Error string `json:"error,omitempty"` +} + +type RespSendTransaction struct { + PDUs map[id.EventID]PDUProcessingResult `json:"pdus"` +} + +func (c *Client) SendTransaction(ctx context.Context, req *ReqSendTransaction) (resp *RespSendTransaction, err error) { + err = c.MakeRequest(ctx, req.Destination, true, http.MethodPut, URLPath{"v1", "send", req.TxnID}, req, &resp) + return +} + +type RespGetEventAuthChain struct { + AuthChain []PDU `json:"auth_chain"` +} + +func (c *Client) GetEventAuthChain(ctx context.Context, serverName string, roomID id.RoomID, eventID id.EventID) (resp *RespGetEventAuthChain, err error) { + err = c.MakeRequest(ctx, serverName, true, http.MethodGet, URLPath{"v1", "event_auth", roomID, eventID}, nil, &resp) + return +} + +type ReqBackfill struct { + ServerName string + RoomID id.RoomID + Limit int + BackfillFrom []id.EventID +} + +type RespBackfill struct { + Origin string `json:"origin"` + OriginServerTS jsontime.UnixMilli `json:"origin_server_ts"` + PDUs []PDU `json:"pdus"` +} + +func (c *Client) Backfill(ctx context.Context, req *ReqBackfill) (resp *RespBackfill, err error) { + _, _, err = c.MakeFullRequest(ctx, RequestParams{ + ServerName: req.ServerName, + Method: http.MethodGet, + Path: URLPath{"v1", "backfill", req.RoomID}, + Query: url.Values{ + "limit": {strconv.Itoa(req.Limit)}, + "v": exslices.CastToString[string](req.BackfillFrom), + }, + Authenticate: true, + ResponseJSON: &resp, + }) + return +} + +type ReqGetMissingEvents struct { + ServerName string `json:"-"` + RoomID id.RoomID `json:"-"` + EarliestEvents []id.EventID `json:"earliest_events"` + LatestEvents []id.EventID `json:"latest_events"` + Limit int `json:"limit,omitempty"` + MinDepth int `json:"min_depth,omitempty"` +} + +type RespGetMissingEvents struct { + Events []PDU `json:"events"` +} + +func (c *Client) GetMissingEvents(ctx context.Context, req *ReqGetMissingEvents) (resp *RespGetMissingEvents, err error) { + err = c.MakeRequest(ctx, req.ServerName, true, http.MethodPost, URLPath{"v1", "get_missing_events", req.RoomID}, req, &resp) + return +} + +func (c *Client) GetEvent(ctx context.Context, serverName string, eventID id.EventID) (resp *RespBackfill, err error) { + err = c.MakeRequest(ctx, serverName, true, http.MethodGet, URLPath{"v1", "event", eventID}, nil, &resp) + return +} + +type RespGetState struct { + AuthChain []PDU `json:"auth_chain"` + PDUs []PDU `json:"pdus"` +} + +func (c *Client) GetState(ctx context.Context, serverName string, roomID id.RoomID, eventID id.EventID) (resp *RespGetState, err error) { + _, _, err = c.MakeFullRequest(ctx, RequestParams{ + ServerName: serverName, + Method: http.MethodGet, + Path: URLPath{"v1", "state", roomID}, + Query: url.Values{ + "event_id": {string(eventID)}, + }, + Authenticate: true, + ResponseJSON: &resp, + }) + return +} + +type RespGetStateIDs struct { + AuthChain []id.EventID `json:"auth_chain_ids"` + PDUs []id.EventID `json:"pdu_ids"` +} + +func (c *Client) GetStateIDs(ctx context.Context, serverName string, roomID id.RoomID, eventID id.EventID) (resp *RespGetStateIDs, err error) { + _, _, err = c.MakeFullRequest(ctx, RequestParams{ + ServerName: serverName, + Method: http.MethodGet, + Path: URLPath{"v1", "state_ids", roomID}, + Query: url.Values{ + "event_id": {string(eventID)}, + }, + Authenticate: true, + ResponseJSON: &resp, + }) + return +} + +func (c *Client) TimestampToEvent(ctx context.Context, serverName string, roomID id.RoomID, timestamp time.Time, dir mautrix.Direction) (resp *mautrix.RespTimestampToEvent, err error) { + _, _, err = c.MakeFullRequest(ctx, RequestParams{ + ServerName: serverName, + Method: http.MethodGet, + Path: URLPath{"v1", "timestamp_to_event", roomID}, + Query: url.Values{ + "dir": {string(dir)}, + "ts": {strconv.FormatInt(timestamp.UnixMilli(), 10)}, + }, + Authenticate: true, + ResponseJSON: &resp, + }) + return +} + +func (c *Client) QueryProfile(ctx context.Context, serverName string, userID id.UserID) (resp *mautrix.RespUserProfile, err error) { + err = c.Query(ctx, serverName, "profile", url.Values{"user_id": {userID.String()}}, &resp) + return +} + +func (c *Client) QueryDirectory(ctx context.Context, serverName string, roomAlias id.RoomAlias) (resp *mautrix.RespAliasResolve, err error) { + err = c.Query(ctx, serverName, "directory", url.Values{"room_alias": {roomAlias.String()}}, &resp) + return +} + +func (c *Client) Query(ctx context.Context, serverName, queryType string, queryParams url.Values, respStruct any) (err error) { + _, _, err = c.MakeFullRequest(ctx, RequestParams{ + ServerName: serverName, + Method: http.MethodGet, + Path: URLPath{"v1", "query", queryType}, + Query: queryParams, + Authenticate: true, + ResponseJSON: respStruct, + }) + return +} + +func queryToValues(query map[string]string) url.Values { + values := make(url.Values, len(query)) + for k, v := range query { + values[k] = []string{v} + } + return values +} + +func (c *Client) PublicRooms(ctx context.Context, serverName string, req *mautrix.ReqPublicRooms) (resp *mautrix.RespPublicRooms, err error) { + _, _, err = c.MakeFullRequest(ctx, RequestParams{ + ServerName: serverName, + Method: http.MethodGet, + Path: URLPath{"v1", "publicRooms"}, + Query: queryToValues(req.Query()), + Authenticate: true, + ResponseJSON: &resp, + }) + return +} + +type RespOpenIDUserInfo struct { + Sub id.UserID `json:"sub"` +} + +func (c *Client) GetOpenIDUserInfo(ctx context.Context, serverName, accessToken string) (resp *RespOpenIDUserInfo, err error) { + _, _, err = c.MakeFullRequest(ctx, RequestParams{ + ServerName: serverName, + Method: http.MethodGet, + Path: URLPath{"v1", "openid", "userinfo"}, + Query: url.Values{"access_token": {accessToken}}, + ResponseJSON: &resp, + }) + return +} + +type ReqMakeJoin struct { + RoomID id.RoomID + UserID id.UserID + Via string + SupportedVersions []id.RoomVersion +} + +type RespMakeJoin struct { + RoomVersion id.RoomVersion `json:"room_version"` + Event PDU `json:"event"` +} + +type ReqSendJoin struct { + RoomID id.RoomID + EventID id.EventID + OmitMembers bool + Event PDU + Via string +} + +type ReqSendKnock struct { + RoomID id.RoomID + EventID id.EventID + Event PDU + Via string +} + +type RespSendJoin struct { + AuthChain []PDU `json:"auth_chain"` + Event PDU `json:"event"` + MembersOmitted bool `json:"members_omitted"` + ServersInRoom []string `json:"servers_in_room"` + State []PDU `json:"state"` +} + +type RespSendKnock struct { + KnockRoomState []PDU `json:"knock_room_state"` +} + +type ReqSendInvite struct { + RoomID id.RoomID `json:"-"` + UserID id.UserID `json:"-"` + Event PDU `json:"event"` + InviteRoomState []PDU `json:"invite_room_state"` + RoomVersion id.RoomVersion `json:"room_version"` +} + +type RespSendInvite struct { + Event PDU `json:"event"` +} + +type ReqMakeLeave struct { + RoomID id.RoomID + UserID id.UserID + Via string +} + +type ReqSendLeave struct { + RoomID id.RoomID + EventID id.EventID + Event PDU + Via string +} + +type ( + ReqMakeKnock = ReqMakeJoin + RespMakeKnock = RespMakeJoin + RespMakeLeave = RespMakeJoin +) + +func (c *Client) MakeJoin(ctx context.Context, req *ReqMakeJoin) (resp *RespMakeJoin, err error) { + versions := make([]string, len(req.SupportedVersions)) + for i, v := range req.SupportedVersions { + versions[i] = string(v) + } + _, _, err = c.MakeFullRequest(ctx, RequestParams{ + ServerName: req.Via, + Method: http.MethodGet, + Path: URLPath{"v1", "make_join", req.RoomID, req.UserID}, + Query: url.Values{"ver": versions}, + Authenticate: true, + ResponseJSON: &resp, + }) + return +} + +func (c *Client) MakeKnock(ctx context.Context, req *ReqMakeKnock) (resp *RespMakeKnock, err error) { + versions := make([]string, len(req.SupportedVersions)) + for i, v := range req.SupportedVersions { + versions[i] = string(v) + } + _, _, err = c.MakeFullRequest(ctx, RequestParams{ + ServerName: req.Via, + Method: http.MethodGet, + Path: URLPath{"v1", "make_knock", req.RoomID, req.UserID}, + Query: url.Values{"ver": versions}, + Authenticate: true, + ResponseJSON: &resp, + }) + return +} + +func (c *Client) SendJoin(ctx context.Context, req *ReqSendJoin) (resp *RespSendJoin, err error) { + _, _, err = c.MakeFullRequest(ctx, RequestParams{ + ServerName: req.Via, + Method: http.MethodPut, + Path: URLPath{"v2", "send_join", req.RoomID, req.EventID}, + Query: url.Values{ + "omit_members": {strconv.FormatBool(req.OmitMembers)}, + }, + Authenticate: true, + RequestJSON: req.Event, + ResponseJSON: &resp, + }) + return +} + +func (c *Client) SendKnock(ctx context.Context, req *ReqSendKnock) (resp *RespSendKnock, err error) { + _, _, err = c.MakeFullRequest(ctx, RequestParams{ + ServerName: req.Via, + Method: http.MethodPut, + Path: URLPath{"v1", "send_knock", req.RoomID, req.EventID}, + Authenticate: true, + RequestJSON: req.Event, + ResponseJSON: &resp, + }) + return +} + +func (c *Client) SendInvite(ctx context.Context, req *ReqSendInvite) (resp *RespSendInvite, err error) { + _, _, err = c.MakeFullRequest(ctx, RequestParams{ + ServerName: req.UserID.Homeserver(), + Method: http.MethodPut, + Path: URLPath{"v2", "invite", req.RoomID, req.UserID}, + Authenticate: true, + RequestJSON: req, + ResponseJSON: &resp, + }) + return +} + +func (c *Client) MakeLeave(ctx context.Context, req *ReqMakeLeave) (resp *RespMakeLeave, err error) { + _, _, err = c.MakeFullRequest(ctx, RequestParams{ + ServerName: req.Via, + Method: http.MethodGet, + Path: URLPath{"v1", "make_leave", req.RoomID, req.UserID}, + Authenticate: true, + ResponseJSON: &resp, + }) + return +} + +func (c *Client) SendLeave(ctx context.Context, req *ReqSendLeave) (err error) { + _, _, err = c.MakeFullRequest(ctx, RequestParams{ + ServerName: req.Via, + Method: http.MethodPut, + Path: URLPath{"v2", "send_leave", req.RoomID, req.EventID}, + Authenticate: true, + RequestJSON: req.Event, + }) + return +} + +type URLPath []any + +func (fup URLPath) FullPath() []any { + return append([]any{"_matrix", "federation"}, []any(fup)...) +} + +type KeyURLPath []any + +func (fkup KeyURLPath) FullPath() []any { + return append([]any{"_matrix", "key"}, []any(fkup)...) +} + +type RequestParams struct { + ServerName string + Method string + Path mautrix.PrefixableURLPath + Query url.Values + Authenticate bool + RequestJSON any + + ResponseJSON any + DontReadBody bool +} + +func (c *Client) MakeRequest(ctx context.Context, serverName string, authenticate bool, method string, path mautrix.PrefixableURLPath, reqJSON, respJSON any) error { + _, _, err := c.MakeFullRequest(ctx, RequestParams{ + ServerName: serverName, + Method: method, + Path: path, + Authenticate: authenticate, + RequestJSON: reqJSON, + ResponseJSON: respJSON, + }) + return err +} + +func (c *Client) MakeFullRequest(ctx context.Context, params RequestParams) ([]byte, *http.Response, error) { + req, err := c.compileRequest(ctx, params) + if err != nil { + return nil, nil, err + } + resp, err := c.HTTP.Do(req) + if err != nil { + return nil, nil, mautrix.HTTPError{ + Request: req, + Response: resp, + + Message: "request error", + WrappedError: err, + } + } + if !params.DontReadBody { + defer resp.Body.Close() + } + var body []byte + if resp.StatusCode >= 300 { + body, err = mautrix.ParseErrorResponse(req, resp) + return body, resp, err + } else if params.ResponseJSON != nil || !params.DontReadBody { + if resp.ContentLength > c.ResponseSizeLimit { + return body, resp, mautrix.HTTPError{ + Request: req, + Response: resp, + + Message: "not reading response", + WrappedError: fmt.Errorf("%w (%.2f MiB)", mautrix.ErrResponseTooLong, float64(resp.ContentLength)/1024/1024), + } + } + body, err = io.ReadAll(io.LimitReader(resp.Body, c.ResponseSizeLimit+1)) + if err == nil && len(body) > int(c.ResponseSizeLimit) { + err = mautrix.ErrBodyReadReachedLimit + } + if err != nil { + return body, resp, mautrix.HTTPError{ + Request: req, + Response: resp, + + Message: "failed to read response body", + WrappedError: err, + } + } + if params.ResponseJSON != nil { + err = json.Unmarshal(body, params.ResponseJSON) + if err != nil { + return body, resp, mautrix.HTTPError{ + Request: req, + Response: resp, + + Message: "failed to unmarshal response JSON", + ResponseBody: string(body), + WrappedError: err, + } + } + } + } + return body, resp, nil +} + +func (c *Client) compileRequest(ctx context.Context, params RequestParams) (*http.Request, error) { + reqURL := mautrix.BuildURL(&url.URL{ + Scheme: "matrix-federation", + Host: params.ServerName, + }, params.Path.FullPath()...) + reqURL.RawQuery = params.Query.Encode() + var reqJSON json.RawMessage + var reqBody io.Reader + if params.RequestJSON != nil { + var err error + reqJSON, err = json.Marshal(params.RequestJSON) + if err != nil { + return nil, mautrix.HTTPError{ + Message: "failed to marshal JSON", + WrappedError: err, + } + } + reqBody = bytes.NewReader(reqJSON) + } + req, err := http.NewRequestWithContext(ctx, params.Method, reqURL.String(), reqBody) + if err != nil { + return nil, mautrix.HTTPError{ + Message: "failed to create request", + WrappedError: err, + } + } + req.Header.Set("User-Agent", c.UserAgent) + if params.Authenticate { + if c.ServerName == "" || c.Key == nil { + return nil, mautrix.HTTPError{ + Message: "client not configured for authentication", + } + } + auth, err := (&signableRequest{ + Method: req.Method, + URI: reqURL.RequestURI(), + Origin: c.ServerName, + Destination: params.ServerName, + Content: reqJSON, + }).Sign(c.Key) + if err != nil { + return nil, mautrix.HTTPError{ + Message: "failed to sign request", + WrappedError: err, + } + } + req.Header.Set("Authorization", auth) + } + return req, nil +} + +type signableRequest struct { + Method string `json:"method"` + URI string `json:"uri"` + Origin string `json:"origin"` + Destination string `json:"destination"` + Content json.RawMessage `json:"content,omitempty"` +} + +func (r *signableRequest) Verify(key id.SigningKey, sig string) error { + message, err := json.Marshal(r) + if err != nil { + return fmt.Errorf("failed to marshal data: %w", err) + } + return signutil.VerifyJSONRaw(key, sig, message) +} + +func (r *signableRequest) Sign(key *SigningKey) (string, error) { + sig, err := key.SignJSON(r) + if err != nil { + return "", err + } + return XMatrixAuth{ + Origin: r.Origin, + Destination: r.Destination, + KeyID: key.ID, + Signature: sig, + }.String(), nil +} diff --git a/federation/client_test.go b/federation/client_test.go new file mode 100644 index 00000000..ece399ea --- /dev/null +++ b/federation/client_test.go @@ -0,0 +1,23 @@ +// Copyright (c) 2024 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package federation_test + +import ( + "context" + "testing" + + "github.com/stretchr/testify/require" + + "maunium.net/go/mautrix/federation" +) + +func TestClient_Version(t *testing.T) { + cli := federation.NewClient("", nil, nil) + resp, err := cli.Version(context.TODO(), "maunium.net") + require.NoError(t, err) + require.Equal(t, "Synapse", resp.Server.Name) +} diff --git a/federation/context.go b/federation/context.go new file mode 100644 index 00000000..eedb2dc1 --- /dev/null +++ b/federation/context.go @@ -0,0 +1,42 @@ +// Copyright (c) 2025 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package federation + +import ( + "context" + "net/http" +) + +type contextKey int + +const ( + contextKeyIPPort contextKey = iota + contextKeyDestinationServer + contextKeyOriginServer +) + +func DestinationServerNameFromRequest(r *http.Request) string { + return DestinationServerName(r.Context()) +} + +func DestinationServerName(ctx context.Context) string { + if dest, ok := ctx.Value(contextKeyDestinationServer).(string); ok { + return dest + } + return "" +} + +func OriginServerNameFromRequest(r *http.Request) string { + return OriginServerName(r.Context()) +} + +func OriginServerName(ctx context.Context) string { + if origin, ok := ctx.Value(contextKeyOriginServer).(string); ok { + return origin + } + return "" +} diff --git a/federation/eventauth/eventauth.go b/federation/eventauth/eventauth.go new file mode 100644 index 00000000..c72933c2 --- /dev/null +++ b/federation/eventauth/eventauth.go @@ -0,0 +1,851 @@ +// Copyright (c) 2025 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +//go:build goexperiment.jsonv2 + +package eventauth + +import ( + "encoding/json" + "encoding/json/jsontext" + "errors" + "fmt" + "slices" + "strconv" + "strings" + + "github.com/tidwall/gjson" + "go.mau.fi/util/exgjson" + "go.mau.fi/util/exstrings" + "go.mau.fi/util/ptr" + + "maunium.net/go/mautrix/event" + "maunium.net/go/mautrix/federation/pdu" + "maunium.net/go/mautrix/federation/signutil" + "maunium.net/go/mautrix/id" +) + +type AuthFailError struct { + Index string + Message string + Wrapped error +} + +func (afe AuthFailError) Error() string { + if afe.Message != "" { + return fmt.Sprintf("fail %s: %s", afe.Index, afe.Message) + } else if afe.Wrapped != nil { + return fmt.Sprintf("fail %s: %s", afe.Index, afe.Wrapped.Error()) + } + return fmt.Sprintf("fail %s", afe.Index) +} + +func (afe AuthFailError) Unwrap() error { + return afe.Wrapped +} + +var mFederatePath = exgjson.Path("m.federate") + +var ( + ErrCreateHasPrevEvents = AuthFailError{Index: "1.1", Message: "m.room.create event has prev_events"} + ErrCreateHasRoomID = AuthFailError{Index: "1.2", Message: "m.room.create event has room_id set"} + ErrRoomIDDoesntMatchSender = AuthFailError{Index: "1.2", Message: "room ID server doesn't match sender server"} + ErrUnknownRoomVersion = AuthFailError{Index: "1.3", Wrapped: id.ErrUnknownRoomVersion} + ErrInvalidAdditionalCreators = AuthFailError{Index: "1.4", Message: "m.room.create event has invalid additional_creators"} + ErrMissingCreator = AuthFailError{Index: "1.4", Message: "m.room.create event is missing creator field"} + + ErrInvalidRoomIDLength = AuthFailError{Index: "2", Message: "room ID length is invalid"} + ErrFailedToGetCreateEvent = AuthFailError{Index: "2", Message: "failed to get m.room.create event"} + ErrCreateEventNotFound = AuthFailError{Index: "2", Message: "m.room.create event not found using room ID as event ID"} + ErrRejectedCreateEvent = AuthFailError{Index: "2", Message: "m.room.create event was rejected"} + + ErrFailedToGetAuthEvents = AuthFailError{Index: "3", Message: "failed to get auth events"} + ErrFailedToParsePowerLevels = AuthFailError{Index: "?", Message: "failed to parse power levels"} + ErrDuplicateAuthEvent = AuthFailError{Index: "3.1", Message: "duplicate type/state key pair in auth events"} + ErrNonStateAuthEvent = AuthFailError{Index: "3.2", Message: "non-state event in auth events"} + ErrMissingAuthEvent = AuthFailError{Index: "3.2", Message: "missing auth event"} + ErrUnexpectedAuthEvent = AuthFailError{Index: "3.2", Message: "unexpected type/state key pair in auth events"} + ErrNoCreateEvent = AuthFailError{Index: "3.2", Message: "no m.room.create event found in auth events"} + ErrRejectedAuthEvent = AuthFailError{Index: "3.3", Message: "auth event was rejected"} + ErrMismatchingRoomIDInAuthEvent = AuthFailError{Index: "3.4", Message: "auth event room ID does not match event room ID"} + + ErrFederationDisabled = AuthFailError{Index: "4", Message: "federation is disabled for this room"} + + ErrMemberNotState = AuthFailError{Index: "5.1", Message: "m.room.member event is not a state event"} + ErrNotSignedByAuthoriser = AuthFailError{Index: "5.2", Message: "m.room.member event is not signed by server of join_authorised_via_users_server"} + ErrCantJoinOtherUser = AuthFailError{Index: "5.3.2", Message: "can't send join event with different state key"} + ErrCantJoinBanned = AuthFailError{Index: "5.3.3", Message: "user is banned from the room"} + ErrAuthoriserCantInvite = AuthFailError{Index: "5.3.5.2", Message: "authoriser doesn't have sufficient power level to invite"} + ErrAuthoriserNotInRoom = AuthFailError{Index: "5.3.5.2", Message: "authoriser isn't a member of the room"} + ErrCantJoinWithoutInvite = AuthFailError{Index: "5.3.7", Message: "can't join invite-only room without invite"} + ErrInvalidJoinRule = AuthFailError{Index: "5.3.7", Message: "invalid join rule in room"} + ErrThirdPartyInviteBanned = AuthFailError{Index: "5.4.1.1", Message: "third party invite target user is banned"} + ErrThirdPartyInviteMissingFields = AuthFailError{Index: "5.4.1.3", Message: "third party invite is missing mxid or token fields"} + ErrThirdPartyInviteMXIDMismatch = AuthFailError{Index: "5.4.1.4", Message: "mxid in signed third party invite doesn't match event state key"} + ErrThirdPartyInviteNotFound = AuthFailError{Index: "5.4.1.5", Message: "matching m.room.third_party_invite event not found in auth events"} + ErrThirdPartyInviteSenderMismatch = AuthFailError{Index: "5.4.1.6", Message: "sender of third party invite doesn't match sender of member event"} + ErrThirdPartyInviteNotSigned = AuthFailError{Index: "5.4.1.8", Message: "no valid signatures found for third party invite"} + ErrInviterNotInRoom = AuthFailError{Index: "5.4.2", Message: "inviter's membership is not join"} + ErrInviteTargetAlreadyInRoom = AuthFailError{Index: "5.4.3", Message: "invite target user is already in the room"} + ErrInviteTargetBanned = AuthFailError{Index: "5.4.3", Message: "invite target user is banned"} + ErrInsufficientPermissionForInvite = AuthFailError{Index: "5.4.5", Message: "inviter does not have sufficient permission to send invites"} + ErrCantLeaveWithoutBeingInRoom = AuthFailError{Index: "5.5.1", Message: "can't leave room without being in it"} + ErrCantKickWithoutBeingInRoom = AuthFailError{Index: "5.5.2", Message: "can't kick another user without being in the room"} + ErrInsufficientPermissionForUnban = AuthFailError{Index: "5.5.3", Message: "sender does not have sufficient permission to unban users"} + ErrInsufficientPermissionForKick = AuthFailError{Index: "5.5.5", Message: "sender does not have sufficient permission to kick the user"} + ErrCantBanWithoutBeingInRoom = AuthFailError{Index: "5.6.1", Message: "can't ban another user without being in the room"} + ErrInsufficientPermissionForBan = AuthFailError{Index: "5.6.3", Message: "sender does not have sufficient permission to ban the user"} + ErrNotKnockableRoom = AuthFailError{Index: "5.7.1", Message: "join rule doesn't allow knocking"} + ErrCantKnockOtherUser = AuthFailError{Index: "5.7.1", Message: "can't send knock event with different state key"} + ErrCantKnockWhileInRoom = AuthFailError{Index: "5.7.2", Message: "can't knock while joined, invited or banned"} + ErrUnknownMembership = AuthFailError{Index: "5.8", Message: "unknown membership in m.room.member event"} + + ErrNotInRoom = AuthFailError{Index: "6", Message: "sender is not a member of the room"} + + ErrInsufficientPowerForThirdPartyInvite = AuthFailError{Index: "7.1", Message: "sender does not have sufficient power level to send third party invite"} + + ErrInsufficientPowerLevel = AuthFailError{Index: "8", Message: "sender does not have sufficient power level to send event"} + + ErrMismatchingPrivateStateKey = AuthFailError{Index: "9", Message: "state keys starting with @ must match sender user ID"} + + ErrTopLevelPLNotInteger = AuthFailError{Index: "10.1", Message: "invalid type for top-level power level field"} + ErrPLNotInteger = AuthFailError{Index: "10.2", Message: "invalid type for power level"} + ErrInvalidUserIDInPL = AuthFailError{Index: "10.3", Message: "invalid user ID in power levels"} + ErrUserPLNotInteger = AuthFailError{Index: "10.3", Message: "invalid type for user power level"} + ErrCreatorInPowerLevels = AuthFailError{Index: "10.4", Message: "room creators must not be specified in power levels"} + ErrInvalidPowerChange = AuthFailError{Index: "10.x", Message: "illegal power level change"} + ErrInvalidUserPowerChange = AuthFailError{Index: "10.9", Message: "illegal power level change"} +) + +func isRejected(evt *pdu.PDU) bool { + return evt.InternalMeta.Rejected +} + +type GetEventsFunc = func(ids []id.EventID) ([]*pdu.PDU, error) + +func Authorize(roomVersion id.RoomVersion, evt *pdu.PDU, getEvents GetEventsFunc, getKey pdu.GetKeyFunc) error { + if evt.Type == event.StateCreate.Type { + // 1. If type is m.room.create: + return authorizeCreate(roomVersion, evt) + } + var createEvt *pdu.PDU + if roomVersion.RoomIDIsCreateEventID() { + // 2. If the event’s room_id is not an event ID for an accepted (not rejected) m.room.create event, + // with the sigil ! instead of $, reject. + if len(evt.RoomID) != 44 { + return fmt.Errorf("%w (%d)", ErrInvalidRoomIDLength, len(evt.RoomID)) + } else if createEvts, err := getEvents([]id.EventID{id.EventID("$" + evt.RoomID[1:])}); err != nil { + return fmt.Errorf("%w: %w", ErrFailedToGetCreateEvent, err) + } else if len(createEvts) != 1 { + return fmt.Errorf("%w (%s)", ErrCreateEventNotFound, evt.RoomID) + } else if isRejected(createEvts[0]) { + return ErrRejectedCreateEvent + } else { + createEvt = createEvts[0] + } + } + authEvents, err := getEvents(evt.AuthEvents) + if err != nil { + return fmt.Errorf("%w: %w", ErrFailedToGetAuthEvents, err) + } + expectedAuthEvents := evt.AuthEventSelection(roomVersion) + deduplicator := make(map[pdu.StateKey]id.EventID, len(expectedAuthEvents)) + // 3. Considering the event’s auth_events: + for i, ae := range authEvents { + authEvtID := evt.AuthEvents[i] + if ae == nil { + return fmt.Errorf("%w (%s)", ErrMissingAuthEvent, authEvtID) + } else if ae.StateKey == nil { + // This approximately falls under rule 3.2. + return fmt.Errorf("%w (%s)", ErrNonStateAuthEvent, authEvtID) + } + key := pdu.StateKey{Type: ae.Type, StateKey: *ae.StateKey} + if prevEvtID, alreadyFound := deduplicator[key]; alreadyFound { + // 3.1. If there are duplicate entries for a given type and state_key pair, reject. + return fmt.Errorf("%w for %s/%s: found %s and %s", ErrDuplicateAuthEvent, ae.Type, *ae.StateKey, prevEvtID, authEvtID) + } else if !expectedAuthEvents.Has(key) { + // 3.2. If there are entries whose type and state_key don’t match those specified by + // the auth events selection algorithm described in the server specification, reject. + return fmt.Errorf("%w: found %s with key %s/%s", ErrUnexpectedAuthEvent, authEvtID, ae.Type, *ae.StateKey) + } else if isRejected(ae) { + // 3.3. If there are entries which were themselves rejected under the checks performed on receipt of a PDU, reject. + return fmt.Errorf("%w (%s)", ErrRejectedAuthEvent, authEvtID) + } else if ae.RoomID != evt.RoomID { + // 3.4. If any event in auth_events has a room_id which does not match that of the event being authorised, reject. + return fmt.Errorf("%w (%s)", ErrMismatchingRoomIDInAuthEvent, authEvtID) + } else { + deduplicator[key] = authEvtID + } + if ae.Type == event.StateCreate.Type { + if createEvt == nil { + createEvt = ae + } else { + // Duplicates are prevented by deduplicator, AuthEventSelection also won't allow a create event at all for v12+ + panic(fmt.Errorf("impossible case: multiple create events found in auth events")) + } + } + } + if createEvt == nil { + // This comes either from auth_events or room_id depending on the room version. + // The checks above make sure it's from the right source. + return ErrNoCreateEvent + } + if federateVal := gjson.GetBytes(createEvt.Content, mFederatePath); federateVal.Type == gjson.False && createEvt.Sender.Homeserver() != evt.Sender.Homeserver() { + // 4. If the content of the m.room.create event in the room state has the property m.federate set to false, + // and the sender domain of the event does not match the sender domain of the create event, reject. + return ErrFederationDisabled + } + if evt.Type == event.StateMember.Type { + // 5. If type is m.room.member: + return authorizeMember(roomVersion, evt, createEvt, authEvents, getKey) + } + senderMembership := event.Membership(findEventAndReadString(authEvents, event.StateMember.Type, evt.Sender.String(), "membership", "leave")) + if senderMembership != event.MembershipJoin { + // 6. If the sender’s current membership state is not join, reject. + return ErrNotInRoom + } + powerLevels, err := getPowerLevels(roomVersion, authEvents, createEvt) + if err != nil { + return err + } + senderPL := powerLevels.GetUserLevel(evt.Sender) + if evt.Type == event.StateThirdPartyInvite.Type { + // 7.1. Allow if and only if sender’s current power level is greater than or equal to the invite level. + if senderPL >= powerLevels.Invite() { + return nil + } + return ErrInsufficientPowerForThirdPartyInvite + } + typeClass := event.MessageEventType + if evt.StateKey != nil { + typeClass = event.StateEventType + } + evtLevel := powerLevels.GetEventLevel(event.Type{Type: evt.Type, Class: typeClass}) + if evtLevel > senderPL { + // 8. If the event type’s required power level is greater than the sender’s power level, reject. + return fmt.Errorf("%w (%d > %d)", ErrInsufficientPowerLevel, evtLevel, senderPL) + } + + if evt.StateKey != nil && strings.HasPrefix(*evt.StateKey, "@") && *evt.StateKey != evt.Sender.String() { + // 9. If the event has a state_key that starts with an @ and does not match the sender, reject. + return ErrMismatchingPrivateStateKey + } + + if evt.Type == event.StatePowerLevels.Type { + // 10. If type is m.room.power_levels: + return authorizePowerLevels(roomVersion, evt, createEvt, authEvents) + } + + // 11. Otherwise, allow. + return nil +} + +var ErrUserIDNotAString = errors.New("not a string") +var ErrUserIDNotValid = errors.New("not a valid user ID") + +func isValidUserID(roomVersion id.RoomVersion, userID gjson.Result) error { + if userID.Type != gjson.String { + return ErrUserIDNotAString + } + // In a future room version, user IDs will have stricter validation + _, _, err := id.UserID(userID.Str).Parse() + if err != nil { + return ErrUserIDNotValid + } + return nil +} + +func authorizeCreate(roomVersion id.RoomVersion, evt *pdu.PDU) error { + if len(evt.PrevEvents) > 0 { + // 1.1. If it has any prev_events, reject. + return ErrCreateHasPrevEvents + } + if roomVersion.RoomIDIsCreateEventID() { + if evt.RoomID != "" { + // 1.2. If the event has a room_id, reject. + return ErrCreateHasRoomID + } + } else { + _, _, server := id.ParseCommonIdentifier(evt.RoomID) + if server == "" || server != evt.Sender.Homeserver() { + // 1.2. (v11 and below) If the domain of the room_id does not match the domain of the sender, reject. + return ErrRoomIDDoesntMatchSender + } + } + if !roomVersion.IsKnown() { + // 1.3. If content.room_version is present and is not a recognised version, reject. + return fmt.Errorf("%w %s", ErrUnknownRoomVersion, roomVersion) + } + if roomVersion.PrivilegedRoomCreators() { + additionalCreators := gjson.GetBytes(evt.Content, "additional_creators") + if additionalCreators.Exists() { + if !additionalCreators.IsArray() { + return fmt.Errorf("%w: not an array", ErrInvalidAdditionalCreators) + } + for i, item := range additionalCreators.Array() { + // 1.4. If additional_creators is present in content and is not an array of strings + // where each string passes the same user ID validation applied to sender, reject. + if err := isValidUserID(roomVersion, item); err != nil { + return fmt.Errorf("%w: item #%d %w", ErrInvalidAdditionalCreators, i+1, err) + } + } + } + } + if roomVersion.CreatorInContent() { + // 1.4. (v10 and below) If content has no creator property, reject. + if !gjson.GetBytes(evt.Content, "creator").Exists() { + return ErrMissingCreator + } + } + // 1.5. Otherwise, allow. + return nil +} + +func authorizeMember(roomVersion id.RoomVersion, evt, createEvt *pdu.PDU, authEvents []*pdu.PDU, getKey pdu.GetKeyFunc) error { + membership := event.Membership(gjson.GetBytes(evt.Content, "membership").Str) + if evt.StateKey == nil { + // 5.1. If there is no state_key property, or no membership property in content, reject. + return ErrMemberNotState + } + authorizedVia := id.UserID(gjson.GetBytes(evt.Content, "authorised_via_users_server").Str) + if authorizedVia != "" { + homeserver := authorizedVia.Homeserver() + err := evt.VerifySignature(roomVersion, homeserver, getKey) + if err != nil { + // 5.2. If content has a join_authorised_via_users_server key: + // 5.2.1. If the event is not validly signed by the homeserver of the user ID denoted by the key, reject. + return fmt.Errorf("%w: %w", ErrNotSignedByAuthoriser, err) + } + } + targetPrevMembership := event.Membership(findEventAndReadString(authEvents, event.StateMember.Type, *evt.StateKey, "membership", "leave")) + senderMembership := event.Membership(findEventAndReadString(authEvents, event.StateMember.Type, evt.Sender.String(), "membership", "leave")) + switch membership { + case event.MembershipJoin: + createEvtID, err := createEvt.GetEventID(roomVersion) + if err != nil { + return fmt.Errorf("failed to get create event ID: %w", err) + } + creator := createEvt.Sender.String() + if roomVersion.CreatorInContent() { + creator = gjson.GetBytes(evt.Content, "creator").Str + } + if len(evt.PrevEvents) == 1 && + len(evt.AuthEvents) <= 1 && + evt.PrevEvents[0] == createEvtID && + *evt.StateKey == creator { + // 5.3.1. If the only previous event is an m.room.create and the state_key is the sender of the m.room.create, allow. + return nil + } + // Spec wart: this would make more sense before the check above. + // Now you can set anyone as the sender of the first join. + if evt.Sender.String() != *evt.StateKey { + // 5.3.2. If the sender does not match state_key, reject. + return ErrCantJoinOtherUser + } + + if senderMembership == event.MembershipBan { + // 5.3.3. If the sender is banned, reject. + return ErrCantJoinBanned + } + + joinRule := event.JoinRule(findEventAndReadString(authEvents, event.StateJoinRules.Type, "", "join_rule", "invite")) + switch joinRule { + case event.JoinRuleKnock: + if !roomVersion.Knocks() { + return ErrInvalidJoinRule + } + fallthrough + case event.JoinRuleInvite: + // 5.3.4. If the join_rule is invite or knock then allow if membership state is invite or join. + if targetPrevMembership == event.MembershipJoin || targetPrevMembership == event.MembershipInvite { + return nil + } + return ErrCantJoinWithoutInvite + case event.JoinRuleKnockRestricted: + if !roomVersion.KnockRestricted() { + return ErrInvalidJoinRule + } + fallthrough + case event.JoinRuleRestricted: + if joinRule == event.JoinRuleRestricted && !roomVersion.RestrictedJoins() { + return ErrInvalidJoinRule + } + if targetPrevMembership == event.MembershipJoin || targetPrevMembership == event.MembershipInvite { + // 5.3.5.1. If membership state is join or invite, allow. + return nil + } + powerLevels, err := getPowerLevels(roomVersion, authEvents, createEvt) + if err != nil { + return err + } + if powerLevels.GetUserLevel(authorizedVia) < powerLevels.Invite() { + // 5.3.5.2. If the join_authorised_via_users_server key in content is not a user with sufficient permission to invite other users, reject. + return ErrAuthoriserCantInvite + } + authorizerMembership := event.Membership(findEventAndReadString(authEvents, event.StateMember.Type, authorizedVia.String(), "membership", string(event.MembershipLeave))) + if authorizerMembership != event.MembershipJoin { + return ErrAuthoriserNotInRoom + } + // 5.3.5.3. Otherwise, allow. + return nil + case event.JoinRulePublic: + // 5.3.6. If the join_rule is public, allow. + return nil + default: + // 5.3.7. Otherwise, reject. + return ErrInvalidJoinRule + } + case event.MembershipInvite: + tpiVal := gjson.GetBytes(evt.Content, "third_party_invite") + if tpiVal.Exists() { + if targetPrevMembership == event.MembershipBan { + return ErrThirdPartyInviteBanned + } + signed := tpiVal.Get("signed") + mxid := signed.Get("mxid").Str + token := signed.Get("token").Str + if mxid == "" || token == "" { + // 5.4.1.2. If content.third_party_invite does not have a signed property, reject. + // 5.4.1.3. If signed does not have mxid and token properties, reject. + return ErrThirdPartyInviteMissingFields + } + if mxid != *evt.StateKey { + // 5.4.1.4. If mxid does not match state_key, reject. + return ErrThirdPartyInviteMXIDMismatch + } + tpiEvt := findEvent(authEvents, event.StateThirdPartyInvite.Type, token) + if tpiEvt == nil { + // 5.4.1.5. If there is no m.room.third_party_invite event in the current room state with state_key matching token, reject. + return ErrThirdPartyInviteNotFound + } + if tpiEvt.Sender != evt.Sender { + // 5.4.1.6. If sender does not match sender of the m.room.third_party_invite, reject. + return ErrThirdPartyInviteSenderMismatch + } + var keys []id.Ed25519 + const ed25519Base64Len = 43 + oldPubKey := gjson.GetBytes(evt.Content, "public_key.token") + if oldPubKey.Type == gjson.String && len(oldPubKey.Str) == ed25519Base64Len { + keys = append(keys, id.Ed25519(oldPubKey.Str)) + } + gjson.GetBytes(evt.Content, "public_keys").ForEach(func(key, value gjson.Result) bool { + if key.Type != gjson.Number { + return false + } + if value.Type == gjson.String && len(value.Str) == ed25519Base64Len { + keys = append(keys, id.Ed25519(value.Str)) + } + return true + }) + rawSigned := jsontext.Value(exstrings.UnsafeBytes(signed.Str)) + var validated bool + for _, key := range keys { + if signutil.VerifyJSONAny(key, rawSigned) == nil { + validated = true + } + } + if validated { + // 4.4.1.7. If any signature in signed matches any public key in the m.room.third_party_invite event, allow. + return nil + } + // 4.4.1.8. Otherwise, reject. + return ErrThirdPartyInviteNotSigned + } + if senderMembership != event.MembershipJoin { + // 5.4.2. If the sender’s current membership state is not join, reject. + return ErrInviterNotInRoom + } + // 5.4.3. If target user’s current membership state is join or ban, reject. + if targetPrevMembership == event.MembershipJoin { + return ErrInviteTargetAlreadyInRoom + } else if targetPrevMembership == event.MembershipBan { + return ErrInviteTargetBanned + } + powerLevels, err := getPowerLevels(roomVersion, authEvents, createEvt) + if err != nil { + return err + } + if powerLevels.GetUserLevel(evt.Sender) >= powerLevels.Invite() { + // 5.4.4. If the sender’s power level is greater than or equal to the invite level, allow. + return nil + } + // 5.4.5. Otherwise, reject. + return ErrInsufficientPermissionForInvite + case event.MembershipLeave: + if evt.Sender.String() == *evt.StateKey { + // 5.5.1. If the sender matches state_key, allow if and only if that user’s current membership state is invite, join, or knock. + if senderMembership == event.MembershipInvite || + senderMembership == event.MembershipJoin || + (senderMembership == event.MembershipKnock && roomVersion.Knocks()) { + return nil + } + return ErrCantLeaveWithoutBeingInRoom + } + if senderMembership != event.MembershipJoin { + // 5.5.2. If the sender’s current membership state is not join, reject. + return ErrCantKickWithoutBeingInRoom + } + powerLevels, err := getPowerLevels(roomVersion, authEvents, createEvt) + if err != nil { + return err + } + senderLevel := powerLevels.GetUserLevel(evt.Sender) + if targetPrevMembership == event.MembershipBan && senderLevel < powerLevels.Ban() { + // 5.5.3. If the target user’s current membership state is ban, and the sender’s power level is less than the ban level, reject. + return ErrInsufficientPermissionForUnban + } + if senderLevel >= powerLevels.Kick() && powerLevels.GetUserLevel(id.UserID(*evt.StateKey)) < senderLevel { + // 5.5.4. If the sender’s power level is greater than or equal to the kick level, and the target user’s power level is less than the sender’s power level, allow. + return nil + } + // TODO separate errors for < kick and < target user level? + // 5.5.5. Otherwise, reject. + return ErrInsufficientPermissionForKick + case event.MembershipBan: + if senderMembership != event.MembershipJoin { + // 5.6.1. If the sender’s current membership state is not join, reject. + return ErrCantBanWithoutBeingInRoom + } + powerLevels, err := getPowerLevels(roomVersion, authEvents, createEvt) + if err != nil { + return err + } + senderLevel := powerLevels.GetUserLevel(evt.Sender) + if senderLevel >= powerLevels.Ban() && powerLevels.GetUserLevel(id.UserID(*evt.StateKey)) < senderLevel { + // 5.6.2. If the sender’s power level is greater than or equal to the ban level, and the target user’s power level is less than the sender’s power level, allow. + return nil + } + // 5.6.3. Otherwise, reject. + return ErrInsufficientPermissionForBan + case event.MembershipKnock: + joinRule := event.JoinRule(findEventAndReadString(authEvents, event.StateJoinRules.Type, "", "join_rule", "invite")) + validKnockRule := roomVersion.Knocks() && joinRule == event.JoinRuleKnock + validKnockRestrictedRule := roomVersion.KnockRestricted() && joinRule == event.JoinRuleKnockRestricted + if !validKnockRule && !validKnockRestrictedRule { + // 5.7.1. If the join_rule is anything other than knock or knock_restricted, reject. + return ErrNotKnockableRoom + } + if evt.Sender.String() != *evt.StateKey { + // 5.7.2. If the sender does not match state_key, reject. + return ErrCantKnockOtherUser + } + if senderMembership != event.MembershipBan && senderMembership != event.MembershipInvite && senderMembership != event.MembershipJoin { + // 5.7.3. If the sender’s current membership is not ban, invite, or join, allow. + return nil + } + // 5.7.4. Otherwise, reject. + return ErrCantKnockWhileInRoom + default: + // 5.8. Otherwise, the membership is unknown. Reject. + return ErrUnknownMembership + } +} + +func authorizePowerLevels(roomVersion id.RoomVersion, evt, createEvt *pdu.PDU, authEvents []*pdu.PDU) error { + if roomVersion.ValidatePowerLevelInts() { + for _, key := range []string{"users_default", "events_default", "state_default", "ban", "redact", "kick", "invite"} { + res := gjson.GetBytes(evt.Content, key) + if !res.Exists() { + continue + } + if parseIntWithVersion(roomVersion, res) == nil { + // 10.1. If any of the properties users_default, events_default, state_default, ban, redact, kick, or invite in content are present and not an integer, reject. + return fmt.Errorf("%w %s", ErrTopLevelPLNotInteger, key) + } + } + for _, key := range []string{"events", "notifications"} { + obj := gjson.GetBytes(evt.Content, key) + if !obj.Exists() { + continue + } + // 10.2. If either of the properties events or notifications in content are present and not an object [...], reject. + if !obj.IsObject() { + return fmt.Errorf("%w %s", ErrTopLevelPLNotInteger, key) + } + var err error + // 10.2. [...] are not an object with values that are integers, reject. + obj.ForEach(func(innerKey, value gjson.Result) bool { + if parseIntWithVersion(roomVersion, value) == nil { + err = fmt.Errorf("%w %s.%s", ErrPLNotInteger, key, innerKey.Str) + return false + } + return true + }) + if err != nil { + return err + } + } + } + var creators []id.UserID + if roomVersion.PrivilegedRoomCreators() { + creators = append(creators, createEvt.Sender) + gjson.GetBytes(createEvt.Content, "additional_creators").ForEach(func(key, value gjson.Result) bool { + creators = append(creators, id.UserID(value.Str)) + return true + }) + } + users := gjson.GetBytes(evt.Content, "users") + if users.Exists() { + if !users.IsObject() { + // 10.3. If the users property in content is not an object [...], reject. + return fmt.Errorf("%w users", ErrTopLevelPLNotInteger) + } + var err error + users.ForEach(func(key, value gjson.Result) bool { + if validatorErr := isValidUserID(roomVersion, key); validatorErr != nil { + // 10.3. [...] is not an object with keys that are valid user IDs [...], reject. + err = fmt.Errorf("%w: %q %w", ErrInvalidUserIDInPL, key.Str, validatorErr) + return false + } + if parseIntWithVersion(roomVersion, value) == nil { + // 10.3. [...] is not an object [...] with values that are integers, reject. + err = fmt.Errorf("%w %q", ErrUserPLNotInteger, key.Str) + return false + } + // creators is only filled if the room version has privileged room creators + if slices.Contains(creators, id.UserID(key.Str)) { + // 10.4. If the users property in content contains the sender of the m.room.create event or any of + // the additional_creators array (if present) from the content of the m.room.create event, reject. + err = fmt.Errorf("%w: %q", ErrCreatorInPowerLevels, key.Str) + return false + } + return true + }) + if err != nil { + return err + } + } + oldPL := findEvent(authEvents, event.StatePowerLevels.Type, "") + if oldPL == nil { + // 10.5. If there is no previous m.room.power_levels event in the room, allow. + return nil + } + if slices.Contains(creators, evt.Sender) { + // Skip remaining checks for creators + return nil + } + senderPLPtr := parsePythonInt(gjson.GetBytes(oldPL.Content, exgjson.Path("users", evt.Sender.String()))) + if senderPLPtr == nil { + senderPLPtr = parsePythonInt(gjson.GetBytes(oldPL.Content, "users_default")) + if senderPLPtr == nil { + senderPLPtr = ptr.Ptr(0) + } + } + for _, key := range []string{"users_default", "events_default", "state_default", "ban", "redact", "kick", "invite"} { + oldVal := gjson.GetBytes(oldPL.Content, key) + newVal := gjson.GetBytes(evt.Content, key) + if err := allowPowerChange(roomVersion, *senderPLPtr, key, oldVal, newVal); err != nil { + return err + } + } + if err := allowPowerChangeMap( + roomVersion, *senderPLPtr, "events", "", + gjson.GetBytes(oldPL.Content, "events"), + gjson.GetBytes(evt.Content, "events"), + ); err != nil { + return err + } + if err := allowPowerChangeMap( + roomVersion, *senderPLPtr, "notifications", "", + gjson.GetBytes(oldPL.Content, "notifications"), + gjson.GetBytes(evt.Content, "notifications"), + ); err != nil { + return err + } + if err := allowPowerChangeMap( + roomVersion, *senderPLPtr, "users", evt.Sender.String(), + gjson.GetBytes(oldPL.Content, "users"), + gjson.GetBytes(evt.Content, "users"), + ); err != nil { + return err + } + return nil +} + +func allowPowerChangeMap(roomVersion id.RoomVersion, maxVal int, path, ownID string, old, new gjson.Result) (err error) { + old.ForEach(func(key, value gjson.Result) bool { + newVal := new.Get(exgjson.Path(key.Str)) + err = allowPowerChange(roomVersion, maxVal, path+"."+key.Str, value, newVal) + if err == nil && ownID != "" && key.Str != ownID { + parsedOldVal := parseIntWithVersion(roomVersion, value) + parsedNewVal := parseIntWithVersion(roomVersion, newVal) + if *parsedOldVal >= maxVal && *parsedOldVal != *parsedNewVal { + err = fmt.Errorf("%w: can't change users.%s from %s to %s with sender level %d", ErrInvalidUserPowerChange, key.Str, stringifyForError(value), stringifyForError(newVal), maxVal) + } + } + return err == nil + }) + if err != nil { + return + } + new.ForEach(func(key, value gjson.Result) bool { + err = allowPowerChange(roomVersion, maxVal, path+"."+key.Str, old.Get(exgjson.Path(key.Str)), value) + return err == nil + }) + return +} + +func allowPowerChange(roomVersion id.RoomVersion, maxVal int, path string, old, new gjson.Result) error { + oldVal := parseIntWithVersion(roomVersion, old) + newVal := parseIntWithVersion(roomVersion, new) + if oldVal == nil { + if newVal == nil || *newVal <= maxVal { + return nil + } + } else if newVal == nil { + if *oldVal <= maxVal { + return nil + } + } else if *oldVal == *newVal || (*oldVal <= maxVal && *newVal <= maxVal) { + return nil + } + return fmt.Errorf("%w can't change %s from %s to %s with sender level %d", ErrInvalidPowerChange, path, stringifyForError(old), stringifyForError(new), maxVal) +} + +func stringifyForError(val gjson.Result) string { + if !val.Exists() { + return "null" + } + return val.Raw +} + +func findEvent(events []*pdu.PDU, evtType, stateKey string) *pdu.PDU { + for _, evt := range events { + if evt.Type == evtType && *evt.StateKey == stateKey { + return evt + } + } + return nil +} + +func findEventAndReadData[T any](events []*pdu.PDU, evtType, stateKey string, reader func(evt *pdu.PDU) T) T { + return reader(findEvent(events, evtType, stateKey)) +} + +func findEventAndReadString(events []*pdu.PDU, evtType, stateKey, fieldPath, defVal string) string { + return findEventAndReadData(events, evtType, stateKey, func(evt *pdu.PDU) string { + if evt == nil { + return defVal + } + res := gjson.GetBytes(evt.Content, fieldPath) + if res.Type != gjson.String { + return defVal + } + return res.Str + }) +} + +func getPowerLevels(roomVersion id.RoomVersion, authEvents []*pdu.PDU, createEvt *pdu.PDU) (*event.PowerLevelsEventContent, error) { + var err error + powerLevels := findEventAndReadData(authEvents, event.StatePowerLevels.Type, "", func(evt *pdu.PDU) *event.PowerLevelsEventContent { + if evt == nil { + return nil + } + content := evt.Content + out := &event.PowerLevelsEventContent{} + if !roomVersion.ValidatePowerLevelInts() { + safeParsePowerLevels(content, out) + } else { + err = json.Unmarshal(content, out) + } + return out + }) + if err != nil { + // This should never happen thanks to safeParsePowerLevels for v1-9 and strict validation in v10+ + return nil, fmt.Errorf("%w: %w", ErrFailedToParsePowerLevels, err) + } + if roomVersion.PrivilegedRoomCreators() { + if powerLevels == nil { + powerLevels = &event.PowerLevelsEventContent{} + } + powerLevels.CreateEvent, err = createEvt.ToClientEvent(roomVersion) + if err != nil { + return nil, fmt.Errorf("%w: %w", ErrFailedToParsePowerLevels, err) + } + err = powerLevels.CreateEvent.Content.ParseRaw(powerLevels.CreateEvent.Type) + if err != nil { + return nil, fmt.Errorf("%w: %w", ErrFailedToParsePowerLevels, err) + } + } else if powerLevels == nil { + powerLevels = &event.PowerLevelsEventContent{ + Users: map[id.UserID]int{ + createEvt.Sender: 100, + }, + } + } + return powerLevels, nil +} + +func parseIntWithVersion(roomVersion id.RoomVersion, val gjson.Result) *int { + if roomVersion.ValidatePowerLevelInts() { + if val.Type != gjson.Number { + return nil + } + return ptr.Ptr(int(val.Int())) + } + return parsePythonInt(val) +} + +func parsePythonInt(val gjson.Result) *int { + switch val.Type { + case gjson.True: + return ptr.Ptr(1) + case gjson.False: + return ptr.Ptr(0) + case gjson.Number: + return ptr.Ptr(int(val.Int())) + case gjson.String: + // strconv.Atoi accepts signs as well as leading zeroes, so we just need to trim spaces beforehand + num, err := strconv.Atoi(strings.TrimSpace(val.Str)) + if err != nil { + return nil + } + return &num + default: + // Python int() doesn't accept nulls, arrays or dicts + return nil + } +} + +func safeParsePowerLevels(content jsontext.Value, into *event.PowerLevelsEventContent) { + *into = event.PowerLevelsEventContent{ + Users: make(map[id.UserID]int), + UsersDefault: ptr.Val(parsePythonInt(gjson.GetBytes(content, "users_default"))), + Events: make(map[string]int), + EventsDefault: ptr.Val(parsePythonInt(gjson.GetBytes(content, "events_default"))), + Notifications: nil, // irrelevant for event auth + StateDefaultPtr: parsePythonInt(gjson.GetBytes(content, "state_default")), + InvitePtr: parsePythonInt(gjson.GetBytes(content, "invite")), + KickPtr: parsePythonInt(gjson.GetBytes(content, "kick")), + BanPtr: parsePythonInt(gjson.GetBytes(content, "ban")), + RedactPtr: parsePythonInt(gjson.GetBytes(content, "redact")), + } + gjson.GetBytes(content, "events").ForEach(func(key, value gjson.Result) bool { + if key.Type != gjson.String { + return false + } + val := parsePythonInt(value) + if val != nil { + into.Events[key.Str] = *val + } + return true + }) + gjson.GetBytes(content, "users").ForEach(func(key, value gjson.Result) bool { + if key.Type != gjson.String { + return false + } + val := parsePythonInt(value) + if val == nil { + return false + } + userID := id.UserID(key.Str) + if _, _, err := userID.Parse(); err != nil { + return false + } + into.Users[userID] = *val + return true + }) +} diff --git a/federation/eventauth/eventauth_internal_test.go b/federation/eventauth/eventauth_internal_test.go new file mode 100644 index 00000000..d316f3c8 --- /dev/null +++ b/federation/eventauth/eventauth_internal_test.go @@ -0,0 +1,66 @@ +// Copyright (c) 2026 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +//go:build goexperiment.jsonv2 + +package eventauth + +import ( + "strings" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/tidwall/gjson" +) + +type pythonIntTest struct { + Name string + Input string + Expected int64 +} + +var pythonIntTests = []pythonIntTest{ + {"True", `true`, 1}, + {"False", `false`, 0}, + {"SmallFloat", `3.1415`, 3}, + {"SmallFloatRoundDown", `10.999999999999999`, 10}, + {"SmallFloatRoundUp", `10.9999999999999999`, 11}, + {"BigFloatRoundDown", `1000000.9999999999`, 1000000}, + {"BigFloatRoundUp", `1000000.99999999999`, 1000001}, + {"BigFloatPrecisionError", `9007199254740993.0`, 9007199254740992}, + {"BigFloatPrecisionError2", `9007199254740993.123`, 9007199254740994}, + {"Int64", `9223372036854775807`, 9223372036854775807}, + {"Int64String", `"9223372036854775807"`, 9223372036854775807}, + {"String", `"123"`, 123}, + {"InvalidFloatInString", `"123.456"`, 0}, + {"StringWithPlusSign", `"+123"`, 123}, + {"StringWithMinusSign", `"-123"`, -123}, + {"StringWithSpaces", `" 123 "`, 123}, + {"StringWithSpacesAndSign", `" -123 "`, -123}, + //{"StringWithUnderscores", `"123_456"`, 123456}, + //{"StringWithUnderscores", `"123_456"`, 123456}, + {"InvalidStringWithTrailingUnderscore", `"123_456_"`, 0}, + {"InvalidStringWithMultipleUnderscores", `"123__456"`, 0}, + {"InvalidStringWithLeadingUnderscore", `"_123_456"`, 0}, + {"InvalidStringWithUnderscoreAfterSign", `"+_123_456"`, 0}, + {"InvalidStringWithUnderscoreAfterSpace", `" _123_456"`, 0}, + //{"StringWithUnderscoresAndSpaces", `" +1_2_3_4_5_6 "`, 123456}, +} + +func TestParsePythonInt(t *testing.T) { + for _, test := range pythonIntTests { + t.Run(test.Name, func(t *testing.T) { + output := parsePythonInt(gjson.Parse(test.Input)) + if strings.HasPrefix(test.Name, "Invalid") { + assert.Nil(t, output) + } else { + require.NotNil(t, output) + assert.Equal(t, int(test.Expected), *output) + } + }) + } +} diff --git a/federation/eventauth/eventauth_test.go b/federation/eventauth/eventauth_test.go new file mode 100644 index 00000000..e3c5cd76 --- /dev/null +++ b/federation/eventauth/eventauth_test.go @@ -0,0 +1,85 @@ +// Copyright (c) 2025 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +//go:build goexperiment.jsonv2 + +package eventauth_test + +import ( + "embed" + "encoding/json/jsontext" + "encoding/json/v2" + "errors" + "io" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/tidwall/gjson" + "go.mau.fi/util/exerrors" + "go.mau.fi/util/ptr" + + "maunium.net/go/mautrix/federation/eventauth" + "maunium.net/go/mautrix/federation/pdu" + "maunium.net/go/mautrix/id" +) + +//go:embed *.jsonl +var data embed.FS + +type eventMap map[id.EventID]*pdu.PDU + +func (em eventMap) Get(ids []id.EventID) ([]*pdu.PDU, error) { + output := make([]*pdu.PDU, len(ids)) + for i, evtID := range ids { + output[i] = em[evtID] + } + return output, nil +} + +func GetKey(serverName string, keyID id.KeyID, validUntilTS time.Time) (id.SigningKey, time.Time, error) { + return "", time.Time{}, nil +} + +func TestAuthorize(t *testing.T) { + files := exerrors.Must(data.ReadDir(".")) + for _, file := range files { + t.Run(file.Name(), func(t *testing.T) { + decoder := jsontext.NewDecoder(exerrors.Must(data.Open(file.Name()))) + events := make(eventMap) + var roomVersion *id.RoomVersion + for i := 1; ; i++ { + var evt *pdu.PDU + err := json.UnmarshalDecode(decoder, &evt) + if errors.Is(err, io.EOF) { + break + } + require.NoError(t, err) + if roomVersion == nil { + require.Equal(t, evt.Type, "m.room.create") + roomVersion = ptr.Ptr(id.RoomVersion(gjson.GetBytes(evt.Content, "room_version").Str)) + } + expectedEventID := gjson.GetBytes(evt.Unsigned, "event_id").Str + evtID, err := evt.GetEventID(*roomVersion) + require.NoError(t, err) + require.Equalf(t, id.EventID(expectedEventID), evtID, "Event ID mismatch for event #%d", i) + + // TODO allow redacted events + assert.True(t, evt.VerifyContentHash(), i) + + events[evtID] = evt + err = eventauth.Authorize(*roomVersion, evt, events.Get, GetKey) + if err != nil { + evt.InternalMeta.Rejected = true + } + // TODO allow testing intentionally rejected events + assert.NoErrorf(t, err, "Failed to authorize event #%d / %s of type %s", i, evtID, evt.Type) + } + }) + } + +} diff --git a/federation/eventauth/testroom-v12-success.jsonl b/federation/eventauth/testroom-v12-success.jsonl new file mode 100644 index 00000000..2b751de3 --- /dev/null +++ b/federation/eventauth/testroom-v12-success.jsonl @@ -0,0 +1,21 @@ +{"auth_events":[],"content":{"room_version":"12"},"depth":1,"hashes":{"sha256":"qJYytb+EqWPiiZ0ogDODcLeA8XYw/2hVTaLHihcVBZQ"},"origin_server_ts":1756071567186,"prev_events":[],"sender":"@tulir:maunium.net","signatures":{"maunium.net":{"ed25519:a_xxeS":"/9pp+2tkLo6XcZ3opqLeIpa3D96fh3QLpR2PQrZ6Z6j7wyRAvBrcgCpAeMtuyDCzW8Wh1QFEPG4FSsGvVaEFBg"}},"state_key":"","type":"m.room.create","unsigned":{"age_ts":1756071567186,"event_id":"$lVEL38waGAf4ggmWC3OVk_bbx8kZx-iOcTBKXTBnM54"}} +{"auth_events":[],"content":{"avatar_url":"mxc://maunium.net/jdlSfvudiMSmcRrleeiYjjFO","displayname":"tulir","membership":"join"},"depth":2,"hashes":{"sha256":"MXmgq0e4J9CdIP0IVKVvueFhOb+ndlsXpeyI+6l/2FI"},"origin_server_ts":1756071567259,"prev_events":["$lVEL38waGAf4ggmWC3OVk_bbx8kZx-iOcTBKXTBnM54"],"room_id":"!lVEL38waGAf4ggmWC3OVk_bbx8kZx-iOcTBKXTBnM54","sender":"@tulir:maunium.net","signatures":{"maunium.net":{"ed25519:a_xxeS":"xMgRzyRg9VM9XCKpfFJA+MrYoI68b8PIddKpMTcxz/fDzmGSHEy6Ta2b59VxiX3NoJe2CigkDZ3+jVsQoZYIBA"}},"state_key":"@tulir:maunium.net","type":"m.room.member","unsigned":{"age_ts":1756071567259,"event_id":"$mmqm2KS4UExkNL65c6CIhKofn_L9fzF2OhghVqajksU"}} +{"auth_events":["$mmqm2KS4UExkNL65c6CIhKofn_L9fzF2OhghVqajksU"],"content":{"ban":50,"events":{"m.room.avatar":50,"m.room.canonical_alias":50,"m.room.encryption":100,"m.room.history_visibility":100,"m.room.name":50,"m.room.power_levels":100,"m.room.server_acl":100,"m.room.tombstone":150},"events_default":0,"historical":100,"invite":0,"kick":50,"redact":50,"state_default":50,"users":{"@tulir:envs.net":9001},"users_default":0},"depth":3,"hashes":{"sha256":"/JzQNBNqJ/i8vwj6xESDaD5EDdOqB4l/LmKlvAVl5jY"},"origin_server_ts":1756071567319,"prev_events":["$mmqm2KS4UExkNL65c6CIhKofn_L9fzF2OhghVqajksU"],"room_id":"!lVEL38waGAf4ggmWC3OVk_bbx8kZx-iOcTBKXTBnM54","sender":"@tulir:maunium.net","signatures":{"maunium.net":{"ed25519:a_xxeS":"W3N3X/enja+lumXw3uz66/wT9oczoxrmHbAD5/RF069cX4wkCtqtDd61VWPkSGmKxdV1jurgbCqSX6+Q9/t3AA"}},"state_key":"","type":"m.room.power_levels","unsigned":{"age_ts":1756071567319,"event_id":"$v3gylw64IK4PohOe0M8XO1PZthibpBCKVBI3x_8xiUU"}} +{"auth_events":["$v3gylw64IK4PohOe0M8XO1PZthibpBCKVBI3x_8xiUU","$mmqm2KS4UExkNL65c6CIhKofn_L9fzF2OhghVqajksU"],"content":{"join_rule":"invite"},"depth":4,"hashes":{"sha256":"GBu5AySj75ZXlOLd65mB03KueFKOHNgvtg2o/LUnLyI"},"origin_server_ts":1756071567320,"prev_events":["$v3gylw64IK4PohOe0M8XO1PZthibpBCKVBI3x_8xiUU"],"room_id":"!lVEL38waGAf4ggmWC3OVk_bbx8kZx-iOcTBKXTBnM54","sender":"@tulir:maunium.net","signatures":{"maunium.net":{"ed25519:a_xxeS":"XqWEnFREo2PhRnaebGjNzdHdtD691BtCQKkLnpKd8P3lVDewDt8OkCbDSk/Uzh9rDtzwWEsbsIoKSYuOm+G6CA"}},"state_key":"","type":"m.room.join_rules","unsigned":{"age_ts":1756071567320,"event_id":"$deNVGs6Ef7OKVrvewhtPv7DCCqSip112cEJYp-jkP6M"}} +{"auth_events":["$v3gylw64IK4PohOe0M8XO1PZthibpBCKVBI3x_8xiUU","$mmqm2KS4UExkNL65c6CIhKofn_L9fzF2OhghVqajksU"],"content":{"history_visibility":"shared"},"depth":5,"hashes":{"sha256":"niDi5vG2akQm0f5pm0aoCYXqmWjXRfmP1ulr/ZEPm/k"},"origin_server_ts":1756071567320,"prev_events":["$deNVGs6Ef7OKVrvewhtPv7DCCqSip112cEJYp-jkP6M"],"room_id":"!lVEL38waGAf4ggmWC3OVk_bbx8kZx-iOcTBKXTBnM54","sender":"@tulir:maunium.net","signatures":{"maunium.net":{"ed25519:a_xxeS":"PTIrNke/fc9+ObKAl/K0PGZfmpe8dwREyoA5rXffOXWdRHSaBifn9UIiJUqd68Bzvrv4RcADTR/ci7lUquFBBw"}},"state_key":"","type":"m.room.history_visibility","unsigned":{"age_ts":1756071567320,"event_id":"$Wmy3G9yxl9ArVg5ZsdeIDPxBsNAdgseuvHoqHTZ2vug"}} +{"auth_events":["$v3gylw64IK4PohOe0M8XO1PZthibpBCKVBI3x_8xiUU","$mmqm2KS4UExkNL65c6CIhKofn_L9fzF2OhghVqajksU"],"content":{"guest_access":"can_join"},"depth":6,"hashes":{"sha256":"sZ9QqsId4oarFF724esTohXuRxDNnaXPl+QmTDG60dw"},"origin_server_ts":1756071567321,"prev_events":["$Wmy3G9yxl9ArVg5ZsdeIDPxBsNAdgseuvHoqHTZ2vug"],"room_id":"!lVEL38waGAf4ggmWC3OVk_bbx8kZx-iOcTBKXTBnM54","sender":"@tulir:maunium.net","signatures":{"maunium.net":{"ed25519:a_xxeS":"Eh2P9/hl38wfZx2AQbeS5VCD4wldXPfeP2sQsJsLtfmdwFV74jrlGVBaKIkaYcXY4eA08iDp8HW5jqttZqKKDg"}},"state_key":"","type":"m.room.guest_access","unsigned":{"age_ts":1756071567321,"event_id":"$hYVRH7F4P5mB5IqvBDDU5aXY7pYGG0ApstrryiVPKmQ"}} +{"auth_events":["$v3gylw64IK4PohOe0M8XO1PZthibpBCKVBI3x_8xiUU","$mmqm2KS4UExkNL65c6CIhKofn_L9fzF2OhghVqajksU"],"content":{"name":"event auth test v12"},"depth":7,"hashes":{"sha256":"tjwPo38yR+23Was6SbxLvPMhNx44DaXLhF3rKgngepU"},"origin_server_ts":1756071567321,"prev_events":["$hYVRH7F4P5mB5IqvBDDU5aXY7pYGG0ApstrryiVPKmQ"],"room_id":"!lVEL38waGAf4ggmWC3OVk_bbx8kZx-iOcTBKXTBnM54","sender":"@tulir:maunium.net","signatures":{"maunium.net":{"ed25519:a_xxeS":"q1rk0c5m8TJYE9tePsMaLeaigatNNbvaLRom0X8KiZY0EH+itujfA+/UnksvmPmMmThfAXWlFLx5u8tcuSVyCQ"}},"state_key":"","type":"m.room.name","unsigned":{"age_ts":1756071567321,"event_id":"$fFDwIavLTEIfcnggWuryB6JwfS-L2KT6vP1ap3P6ctE"}} +{"auth_events":["$v3gylw64IK4PohOe0M8XO1PZthibpBCKVBI3x_8xiUU","$mmqm2KS4UExkNL65c6CIhKofn_L9fzF2OhghVqajksU","$deNVGs6Ef7OKVrvewhtPv7DCCqSip112cEJYp-jkP6M"],"content":{"avatar_url":"mxc://envs.net/000cf1510b7c61018f9c72ca4cc63668370782c81725865933316030464","displayname":"tulir[e]","membership":"invite"},"depth":8,"hashes":{"sha256":"r5EBUZN/4LbVcMYwuffDcVV9G4OMHzAQuNbnjigL+OE"},"origin_server_ts":1756071567548,"prev_events":["$fFDwIavLTEIfcnggWuryB6JwfS-L2KT6vP1ap3P6ctE"],"room_id":"!lVEL38waGAf4ggmWC3OVk_bbx8kZx-iOcTBKXTBnM54","sender":"@tulir:maunium.net","signatures":{"envs.net":{"ed25519:wuJyKT":"svB+uW4Tsj8/I+SYbLl+LPPjBlqxGNXE4wGyAxlP7vfyJtFf7Kn/19jx65wT9ebeCq5sTGlEDV4Fabwma9LhDA"},"maunium.net":{"ed25519:a_xxeS":"LBYMcdJVSNsLd6SmOgx5oOU/0xOeCl03o4g83VwJfHWlRuTT5l9+qlpNED28wY07uxoU9MgLgXXICJ0EezMBCg"}},"state_key":"@tulir:envs.net","type":"m.room.member","unsigned":{"age_ts":1756071567548,"event_id":"$qYZqSKiKMCNjzH6Trhr6nBSvbfuwr8Sh2bC4USSAxok","invite_room_state":[{"content":{"join_rule":"invite"},"sender":"@tulir:maunium.net","state_key":"","type":"m.room.join_rules"},{"content":{"name":"event auth test v12"},"sender":"@tulir:maunium.net","state_key":"","type":"m.room.name"},{"auth_events":[],"content":{"room_version":"12"},"depth":1,"hashes":{"sha256":"qJYytb+EqWPiiZ0ogDODcLeA8XYw/2hVTaLHihcVBZQ"},"origin_server_ts":1756071567186,"prev_events":[],"sender":"@tulir:maunium.net","signatures":{"maunium.net":{"ed25519:a_xxeS":"/9pp+2tkLo6XcZ3opqLeIpa3D96fh3QLpR2PQrZ6Z6j7wyRAvBrcgCpAeMtuyDCzW8Wh1QFEPG4FSsGvVaEFBg"}},"state_key":"","type":"m.room.create","unsigned":{"age_ts":1756071567186}},{"content":{"avatar_url":"mxc://maunium.net/jdlSfvudiMSmcRrleeiYjjFO","displayname":"tulir","membership":"join"},"sender":"@tulir:maunium.net","state_key":"@tulir:maunium.net","type":"m.room.member"}]}} +{"auth_events":["$v3gylw64IK4PohOe0M8XO1PZthibpBCKVBI3x_8xiUU","$mmqm2KS4UExkNL65c6CIhKofn_L9fzF2OhghVqajksU"],"content":{"body":"meow","com.beeper.linkpreviews":[],"m.mentions":{},"msgtype":"m.text"},"depth":9,"hashes":{"sha256":"23rgMf7EGJcYt3Aj0qAFnmBWCxuU9Uk+ReidqtIJDKQ"},"origin_server_ts":1756071575986,"prev_events":["$qYZqSKiKMCNjzH6Trhr6nBSvbfuwr8Sh2bC4USSAxok"],"room_id":"!lVEL38waGAf4ggmWC3OVk_bbx8kZx-iOcTBKXTBnM54","sender":"@tulir:maunium.net","signatures":{"maunium.net":{"ed25519:a_xxeS":"p+Fm/uWO8VXJdCYvN/dVb8HF8W3t1sssNCBiOWbzAeuS3QqYjoMKHyixLuN1mOdnCyATv7SsHHmA4+cELRGdAA"}},"type":"m.room.message","unsigned":{"age_ts":1756071576002,"event_id":"$eZDCydRWSRnR5od0c7ahz2qSZQDHbl5g5PITT0OMC3E"}} +{"auth_events":["$qYZqSKiKMCNjzH6Trhr6nBSvbfuwr8Sh2bC4USSAxok","$v3gylw64IK4PohOe0M8XO1PZthibpBCKVBI3x_8xiUU","$deNVGs6Ef7OKVrvewhtPv7DCCqSip112cEJYp-jkP6M"],"content":{"avatar_url":"mxc://envs.net/000cf1510b7c61018f9c72ca4cc63668370782c81725865933316030464","displayname":"tulir[e]","membership":"join"},"depth":10,"hashes":{"sha256":"2kJPx2UsysNzTH8QGYHUKTO/05yetxKRlI0nKFeGbts"},"origin_server_ts":1756071578631,"prev_events":["$eZDCydRWSRnR5od0c7ahz2qSZQDHbl5g5PITT0OMC3E"],"room_id":"!lVEL38waGAf4ggmWC3OVk_bbx8kZx-iOcTBKXTBnM54","sender":"@tulir:envs.net","signatures":{"envs.net":{"ed25519:wuJyKT":"Wuzxkh8nEEX6mdJzph6Bt5ku+odFkEg2RIpFAAirOqxgcrwRaz42PsJni3YbfzH1qneF+iWQ/neA+up6jLXFBw"}},"state_key":"@tulir:envs.net","type":"m.room.member","unsigned":{"age":6,"event_id":"$Bz2lxsbUYkeBDE7eMAsOm_TK_iuSuHNvQdrHnc-T1PE","replaces_state":"$qYZqSKiKMCNjzH6Trhr6nBSvbfuwr8Sh2bC4USSAxok"}} +{"auth_events":["$Bz2lxsbUYkeBDE7eMAsOm_TK_iuSuHNvQdrHnc-T1PE","$deNVGs6Ef7OKVrvewhtPv7DCCqSip112cEJYp-jkP6M","$v3gylw64IK4PohOe0M8XO1PZthibpBCKVBI3x_8xiUU"],"content":{"avatar_url":"mxc://matrix.org/BDYVQFSLvZHMaKHDGiRkvhVg","displayname":"tulir[m]","membership":"invite"},"depth":11,"hashes":{"sha256":"dRE11R2hBfFalQ5tIJdyaElUIiSE5aCKMddjek4wR3c"},"origin_server_ts":1756071591449,"prev_events":["$Bz2lxsbUYkeBDE7eMAsOm_TK_iuSuHNvQdrHnc-T1PE"],"room_id":"!lVEL38waGAf4ggmWC3OVk_bbx8kZx-iOcTBKXTBnM54","sender":"@tulir:envs.net","signatures":{"envs.net":{"ed25519:wuJyKT":"/Mi4kX40fbR+V3DCJJGI/9L3Uuf8y5Un8LHlCQv1T0O5gnFZGQ3qN6rRNaZ1Kdh3QJBU6H4NTfnd+SVj3wt3CQ"},"matrix.org":{"ed25519:a_RXGa":"ZeLm/oxP3/Cds/uCL2FaZpgjUp0vTDBlGG6YVFNl76yIVlyIKKQKR6BSVw2u5KC5Mu9M1f+0lDmLGQujR5NkBg"}},"state_key":"@tulir:matrix.org","type":"m.room.member","unsigned":{"event_id":"$g4eBtA9EFNGLkHOofvQ4U87GNt4W8NmfmNRyR0wOUO4","invite_room_state":[{"content":{"join_rule":"invite"},"sender":"@tulir:maunium.net","state_key":"","type":"m.room.join_rules"},{"content":{"name":"event auth test v12"},"sender":"@tulir:maunium.net","state_key":"","type":"m.room.name"},{"auth_events":[],"content":{"room_version":"12"},"depth":1,"hashes":{"sha256":"qJYytb+EqWPiiZ0ogDODcLeA8XYw/2hVTaLHihcVBZQ"},"origin_server_ts":1756071567186,"prev_events":[],"sender":"@tulir:maunium.net","signatures":{"maunium.net":{"ed25519:a_xxeS":"/9pp+2tkLo6XcZ3opqLeIpa3D96fh3QLpR2PQrZ6Z6j7wyRAvBrcgCpAeMtuyDCzW8Wh1QFEPG4FSsGvVaEFBg"}},"state_key":"","type":"m.room.create","unsigned":{"age":11553}},{"content":{"avatar_url":"mxc://envs.net/000cf1510b7c61018f9c72ca4cc63668370782c81725865933316030464","displayname":"tulir[e]","membership":"join"},"sender":"@tulir:envs.net","state_key":"@tulir:envs.net","type":"m.room.member"}]}} +{"auth_events":["$g4eBtA9EFNGLkHOofvQ4U87GNt4W8NmfmNRyR0wOUO4","$deNVGs6Ef7OKVrvewhtPv7DCCqSip112cEJYp-jkP6M","$v3gylw64IK4PohOe0M8XO1PZthibpBCKVBI3x_8xiUU"],"content":{"avatar_url":"mxc://matrix.org/BDYVQFSLvZHMaKHDGiRkvhVg","displayname":"tulir[m]","membership":"join"},"depth":12,"hashes":{"sha256":"hR/fRIyFkxKnA1XNxIB+NKC0VR0vHs82EDgydhmmZXU"},"origin_server_ts":1756071609205,"prev_events":["$g4eBtA9EFNGLkHOofvQ4U87GNt4W8NmfmNRyR0wOUO4"],"room_id":"!lVEL38waGAf4ggmWC3OVk_bbx8kZx-iOcTBKXTBnM54","sender":"@tulir:matrix.org","signatures":{"matrix.org":{"ed25519:a_RXGa":"keWbZHm+LPW22XWxb14Att4Ae4GVc6XAKAnxFRr3hxhrgEhsnMcxUx7fjqlA1dk3As6kjLKdekcyCef+AQCXCA"}},"state_key":"@tulir:matrix.org","type":"m.room.member","unsigned":{"age":19,"event_id":"$_gYjNODWJdo5-S1IN0bmAk3rzIeXzr5W5cmXZSmUsNw","replaces_state":"$g4eBtA9EFNGLkHOofvQ4U87GNt4W8NmfmNRyR0wOUO4"}} +{"auth_events":["$Bz2lxsbUYkeBDE7eMAsOm_TK_iuSuHNvQdrHnc-T1PE","$v3gylw64IK4PohOe0M8XO1PZthibpBCKVBI3x_8xiUU"],"content":{"ban":50,"events":{"m.room.avatar":50,"m.room.canonical_alias":50,"m.room.encryption":100,"m.room.history_visibility":100,"m.room.name":50,"m.room.power_levels":100,"m.room.server_acl":100,"m.room.tombstone":150},"events_default":0,"historical":100,"invite":0,"kick":50,"redact":50,"state_default":50,"users":{"@tulir:envs.net":9001,"@tulir:matrix.org":9000},"users_default":0},"depth":13,"hashes":{"sha256":"30Wuw3xIbA8+eXQBa4nFDKcyHtMbKPBYhLW1zft9/fE"},"origin_server_ts":1756071643928,"prev_events":["$_gYjNODWJdo5-S1IN0bmAk3rzIeXzr5W5cmXZSmUsNw"],"room_id":"!lVEL38waGAf4ggmWC3OVk_bbx8kZx-iOcTBKXTBnM54","sender":"@tulir:envs.net","signatures":{"envs.net":{"ed25519:wuJyKT":"x6Y4uViq4nK8LVPqtMLdCuvNET2bnjxYTgiKuEe1JYfwB4jPBnPuqvrt1O9oaanMpcRWbnuiZjckq4bUlRZ7Cw"}},"state_key":"","type":"m.room.power_levels","unsigned":{"event_id":"$Qg1xRB8nL8lGykGvt9_agu_WCWq8Y3rl_p_LKa6D2Hg","replaces_state":"$v3gylw64IK4PohOe0M8XO1PZthibpBCKVBI3x_8xiUU"}} +{"auth_events":["$Qg1xRB8nL8lGykGvt9_agu_WCWq8Y3rl_p_LKa6D2Hg","$_gYjNODWJdo5-S1IN0bmAk3rzIeXzr5W5cmXZSmUsNw"],"content":{"name":"event auth test v12!"},"depth":14,"hashes":{"sha256":"WT0gz7KYXvbdNruRavqIi9Hhul3rxCdZ+YY9yMGN+Fw"},"origin_server_ts":1756071656988,"prev_events":["$Qg1xRB8nL8lGykGvt9_agu_WCWq8Y3rl_p_LKa6D2Hg"],"room_id":"!lVEL38waGAf4ggmWC3OVk_bbx8kZx-iOcTBKXTBnM54","sender":"@tulir:matrix.org","signatures":{"matrix.org":{"ed25519:a_RXGa":"bSplmqtXVhO2Z3hJ8JMQ/u7G2Wmg6yt7SwhYXObRQJfthekddJN152ME4YJIwy7YD8WFq7EkyB/NMyQoliYyCg"}},"state_key":"","type":"m.room.name","unsigned":{"event_id":"$p4xvOczrhzQMtRW3-Tf86LYUb5aqpGFIgjwHBuxWIcI","replaces_state":"$fFDwIavLTEIfcnggWuryB6JwfS-L2KT6vP1ap3P6ctE"}} +{"auth_events":["$Bz2lxsbUYkeBDE7eMAsOm_TK_iuSuHNvQdrHnc-T1PE","$Qg1xRB8nL8lGykGvt9_agu_WCWq8Y3rl_p_LKa6D2Hg"],"content":{"ban":50,"events":{"m.room.avatar":50,"m.room.canonical_alias":50,"m.room.encryption":100,"m.room.history_visibility":100,"m.room.name":50,"m.room.power_levels":100,"m.room.server_acl":100,"m.room.tombstone":9001},"events_default":0,"historical":12345,"invite":0,"kick":50,"redact":50,"state_default":50,"users":{"@tulir:envs.net":9001,"@tulir:matrix.org":9000},"users_default":0},"depth":15,"hashes":{"sha256":"FnGzbcXc8YOiB1TY33QunGA17Axoyuu3sdVOj5Z408o"},"origin_server_ts":1756071804931,"prev_events":["$p4xvOczrhzQMtRW3-Tf86LYUb5aqpGFIgjwHBuxWIcI"],"room_id":"!lVEL38waGAf4ggmWC3OVk_bbx8kZx-iOcTBKXTBnM54","sender":"@tulir:envs.net","signatures":{"envs.net":{"ed25519:wuJyKT":"uyTUsPR+CzCtlevzB5+sNXvmfbPSp6u7RZC4E4TLVsj45+pjmMRswAvuHP9PT2+Tkl6Hu8ZPigsXgbKZtR35Aw"}},"state_key":"","type":"m.room.power_levels","unsigned":{"event_id":"$uZ4OOtkM8RcbEkhjNp-YlEH0zBqgsRx1eI8b2YP7ovw","replaces_state":"$Qg1xRB8nL8lGykGvt9_agu_WCWq8Y3rl_p_LKa6D2Hg"}} +{"auth_events":["$Bz2lxsbUYkeBDE7eMAsOm_TK_iuSuHNvQdrHnc-T1PE","$uZ4OOtkM8RcbEkhjNp-YlEH0zBqgsRx1eI8b2YP7ovw"],"content":{"ban":50,"events":{"m.room.avatar":50,"m.room.canonical_alias":50,"m.room.encryption":100,"m.room.history_visibility":100,"m.room.name":50,"m.room.power_levels":100,"m.room.server_acl":100,"m.room.tombstone":100},"events_default":0,"historical":12345,"invite":0,"kick":50,"redact":50,"state_default":50,"users":{"@tulir:envs.net":9001,"@tulir:matrix.org":9000},"users_default":0},"depth":16,"hashes":{"sha256":"KcivsiLesdnUnKX23Akk3OJEJFGRSY0g4H+p7XIThnw"},"origin_server_ts":1756071812688,"prev_events":["$uZ4OOtkM8RcbEkhjNp-YlEH0zBqgsRx1eI8b2YP7ovw"],"room_id":"!lVEL38waGAf4ggmWC3OVk_bbx8kZx-iOcTBKXTBnM54","sender":"@tulir:envs.net","signatures":{"envs.net":{"ed25519:wuJyKT":"cAK8dO2AVZklY9te5aVKbF1jR/eB5rzeNOXfYPjBLf+aSAS4Z6R2aMKW6hJB9PqRS4S+UZc24DTrjUjnvMzeBA"}},"state_key":"","type":"m.room.power_levels","unsigned":{"event_id":"$iwqRXQc2cx8K4AclTjU1Se-BMJpUl4DxrLm3nfUgeQU","replaces_state":"$uZ4OOtkM8RcbEkhjNp-YlEH0zBqgsRx1eI8b2YP7ovw"}} +{"auth_events":["$Bz2lxsbUYkeBDE7eMAsOm_TK_iuSuHNvQdrHnc-T1PE","$iwqRXQc2cx8K4AclTjU1Se-BMJpUl4DxrLm3nfUgeQU"],"content":{"body":"meow #2","com.beeper.linkpreviews":[],"m.mentions":{},"msgtype":"m.text"},"depth":17,"hashes":{"sha256":"SgH9fOXGdbdqpRfYmoz1t29+gX8Ze4ThSoj6klZs3Og"},"origin_server_ts":1756247476706,"prev_events":["$iwqRXQc2cx8K4AclTjU1Se-BMJpUl4DxrLm3nfUgeQU"],"room_id":"!lVEL38waGAf4ggmWC3OVk_bbx8kZx-iOcTBKXTBnM54","sender":"@tulir:envs.net","signatures":{"envs.net":{"ed25519:wuJyKT":"SMYK7zP3SaQOKhzZUKUBVCKwffYqi3PFAlPM34kRJtmfGU3KZXNBT0zi+veXDMmxkMunqhF2RTHBD6joa0kBAQ"}},"type":"m.room.message","unsigned":{"event_id":"$KFHLO0-ENYOGQXogp84C-ISSu1xtKUzIMaZ6LiBcR_w"}} +{"auth_events":["$_gYjNODWJdo5-S1IN0bmAk3rzIeXzr5W5cmXZSmUsNw","$iwqRXQc2cx8K4AclTjU1Se-BMJpUl4DxrLm3nfUgeQU"],"content":{"ban":50,"events":{"m.room.avatar":50,"m.room.canonical_alias":50,"m.room.encryption":100,"m.room.history_visibility":100,"m.room.name":50,"m.room.power_levels":100,"m.room.server_acl":100,"m.room.tombstone":100},"events_default":0,"historical":12345,"invite":0,"kick":50,"redact":50,"state_default":50,"users":{"@tulir:beeper.com":8999,"@tulir:envs.net":9001,"@tulir:matrix.org":9000},"users_default":0},"depth":18,"hashes":{"sha256":"l8Mw3VKn/Bvntg7bZ8uh5J8M2IBZM93Xg7hsdaSci8s"},"origin_server_ts":1758918656341,"prev_events":["$KFHLO0-ENYOGQXogp84C-ISSu1xtKUzIMaZ6LiBcR_w"],"room_id":"!lVEL38waGAf4ggmWC3OVk_bbx8kZx-iOcTBKXTBnM54","sender":"@tulir:matrix.org","signatures":{"matrix.org":{"ed25519:a_RXGa":"cg5LP0WuTnVB5jFhNERLLU5b+EhmyACiOq6cp3gKJnZsTAb1yajcgJybLWKrc8QQqxPa7hPnskRBgt4OBTFNAA"}},"state_key":"","type":"m.room.power_levels","unsigned":{"event_id":"$x-CCUewbWOHQXqfcUsywOmbHvNnOSwNM1RyOu-c8SB0","replaces_state":"$iwqRXQc2cx8K4AclTjU1Se-BMJpUl4DxrLm3nfUgeQU"}} +{"auth_events":["$deNVGs6Ef7OKVrvewhtPv7DCCqSip112cEJYp-jkP6M","$_gYjNODWJdo5-S1IN0bmAk3rzIeXzr5W5cmXZSmUsNw","$x-CCUewbWOHQXqfcUsywOmbHvNnOSwNM1RyOu-c8SB0"],"content":{"avatar_url":"mxc://beeper.com/eBdwbHbllONoAySQkXLjbfFM","displayname":"tulir[b]","membership":"invite"},"depth":19,"hashes":{"sha256":"KpmaRUQnJju8TIDMPzakitUIKOWJxTvULpFB3a1CGgc"},"origin_server_ts":1758918665952,"prev_events":["$x-CCUewbWOHQXqfcUsywOmbHvNnOSwNM1RyOu-c8SB0"],"room_id":"!lVEL38waGAf4ggmWC3OVk_bbx8kZx-iOcTBKXTBnM54","sender":"@tulir:matrix.org","signatures":{"beeper.com":{"ed25519:a_zgvp":"mzI9rPkQ1xHl2/G5Yrn0qmIRt5OyjPNqRwilPfH4jmr1tP+vv3vC0m4mph/MCOq8S1c/DQaCWSpdOX1uWfchBQ"},"matrix.org":{"ed25519:a_RXGa":"kEdfr8DjxC/bdvGYxnniFI/pxDWeyG73OjG/Gu1uoHLhjdtAT/vEQ6lotJJs214/KX5eAaQWobE9qtMvtPwMDw"}},"state_key":"@tulir:beeper.com","type":"m.room.member","unsigned":{"event_id":"$PZJZoUwNySl0jY16DkHBHR0HyAppLdxc0rkSuYp5Mro","invite_room_state":[{"auth_events":[],"content":{"room_version":"12"},"depth":1,"hashes":{"sha256":"qJYytb+EqWPiiZ0ogDODcLeA8XYw/2hVTaLHihcVBZQ"},"origin_server_ts":1756071567186,"prev_events":[],"sender":"@tulir:maunium.net","signatures":{"maunium.net":{"ed25519:a_xxeS":"/9pp+2tkLo6XcZ3opqLeIpa3D96fh3QLpR2PQrZ6Z6j7wyRAvBrcgCpAeMtuyDCzW8Wh1QFEPG4FSsGvVaEFBg"}},"state_key":"","type":"m.room.create","unsigned":{"age":11553}},{"content":{"avatar_url":"mxc://matrix.org/BDYVQFSLvZHMaKHDGiRkvhVg","displayname":"tulir[m]","membership":"join"},"sender":"@tulir:matrix.org","state_key":"@tulir:matrix.org","type":"m.room.member"},{"content":{"name":"event auth test v12!"},"sender":"@tulir:matrix.org","state_key":"","type":"m.room.name"},{"content":{"join_rule":"invite"},"sender":"@tulir:maunium.net","state_key":"","type":"m.room.join_rules"}]}} +{"auth_events":["$deNVGs6Ef7OKVrvewhtPv7DCCqSip112cEJYp-jkP6M","$PZJZoUwNySl0jY16DkHBHR0HyAppLdxc0rkSuYp5Mro","$x-CCUewbWOHQXqfcUsywOmbHvNnOSwNM1RyOu-c8SB0"],"content":{"avatar_url":"mxc://beeper.com/eBdwbHbllONoAySQkXLjbfFM","displayname":"tulir[b]","membership":"join"},"depth":20,"hashes":{"sha256":"bmaHSm4mYPNBNlUfFsauSTxLrUH4CUSAKYvr1v76qkk"},"origin_server_ts":1758918670276,"prev_events":["$PZJZoUwNySl0jY16DkHBHR0HyAppLdxc0rkSuYp5Mro"],"room_id":"!lVEL38waGAf4ggmWC3OVk_bbx8kZx-iOcTBKXTBnM54","sender":"@tulir:beeper.com","signatures":{"beeper.com":{"ed25519:a_zgvp":"D3cz3m15m89a3G4c5yWOBCjhtSeI5IxBfQKt5XOr9a44QHyc3nwjjvIJaRrKNcS5tLUJwZ2IpVzjlrpbPHpxDA"}},"state_key":"@tulir:beeper.com","type":"m.room.member","unsigned":{"age":6,"event_id":"$_hayW1Y0HRWp3VEGZZbsMf0Ncg9x6n0ikveD0lbCwMw","replaces_state":"$PZJZoUwNySl0jY16DkHBHR0HyAppLdxc0rkSuYp5Mro"}} +{"auth_events":["$x-CCUewbWOHQXqfcUsywOmbHvNnOSwNM1RyOu-c8SB0","$Bz2lxsbUYkeBDE7eMAsOm_TK_iuSuHNvQdrHnc-T1PE"],"content":{"ban":50,"events":{"m.room.avatar":50,"m.room.canonical_alias":50,"m.room.encryption":100,"m.room.history_visibility":100,"m.room.name":50,"m.room.power_levels":100,"m.room.server_acl":100,"m.room.tombstone":100},"events_default":0,"historical":12345,"invite":0,"kick":50,"redact":50,"state_default":50,"users":{"@tulir:beeper.com":9000,"@tulir:envs.net":9001,"@tulir:matrix.org":8999},"users_default":0},"depth":21,"hashes":{"sha256":"xCj9vszChHiXba9DaPzhtF79Tphek3pRViMp36DOurU"},"origin_server_ts":1758918689485,"prev_events":["$_hayW1Y0HRWp3VEGZZbsMf0Ncg9x6n0ikveD0lbCwMw"],"room_id":"!lVEL38waGAf4ggmWC3OVk_bbx8kZx-iOcTBKXTBnM54","sender":"@tulir:envs.net","signatures":{"envs.net":{"ed25519:wuJyKT":"odkrWD30+ObeYtagULtECB/QmGae7qNy66nmJMWYXiQMYUJw/GMzSmgAiLAWfVYlfD3aEvMb/CBdrhL07tfSBw"}},"state_key":"","type":"m.room.power_levels","unsigned":{"event_id":"$di6cI89-GxX8-Wbx-0T69l4wg6TUWITRkjWXzG7EBqo","replaces_state":"$x-CCUewbWOHQXqfcUsywOmbHvNnOSwNM1RyOu-c8SB0"}} diff --git a/federation/httpclient.go b/federation/httpclient.go new file mode 100644 index 00000000..2f8dbb4f --- /dev/null +++ b/federation/httpclient.go @@ -0,0 +1,92 @@ +// Copyright (c) 2024 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package federation + +import ( + "context" + "fmt" + "net" + "net/http" + "sync" +) + +// ServerResolvingTransport is an http.RoundTripper that resolves Matrix server names before sending requests. +// It only allows requests using the "matrix-federation" scheme. +type ServerResolvingTransport struct { + ResolveOpts *ResolveServerNameOpts + Transport *http.Transport + Dialer *net.Dialer + + cache ResolutionCache + + resolveLocks map[string]*sync.Mutex + resolveLocksLock sync.Mutex +} + +func NewServerResolvingTransport(cache ResolutionCache) *ServerResolvingTransport { + if cache == nil { + cache = NewInMemoryCache() + } + srt := &ServerResolvingTransport{ + resolveLocks: make(map[string]*sync.Mutex), + cache: cache, + Dialer: &net.Dialer{}, + } + srt.Transport = &http.Transport{ + DialContext: srt.DialContext, + } + return srt +} + +var _ http.RoundTripper = (*ServerResolvingTransport)(nil) + +func (srt *ServerResolvingTransport) DialContext(ctx context.Context, network, addr string) (net.Conn, error) { + addrs, ok := ctx.Value(contextKeyIPPort).([]string) + if !ok { + return nil, fmt.Errorf("no IP:port in context") + } + return srt.Dialer.DialContext(ctx, network, addrs[0]) +} + +func (srt *ServerResolvingTransport) RoundTrip(request *http.Request) (*http.Response, error) { + if request.URL.Scheme != "matrix-federation" { + return nil, fmt.Errorf("unsupported scheme: %s", request.URL.Scheme) + } + resolved, err := srt.resolve(request.Context(), request.URL.Host) + if err != nil { + return nil, fmt.Errorf("failed to resolve server name: %w", err) + } + request = request.WithContext(context.WithValue(request.Context(), contextKeyIPPort, resolved.IPPort)) + request.URL.Scheme = "https" + request.URL.Host = resolved.HostHeader + request.Host = resolved.HostHeader + return srt.Transport.RoundTrip(request) +} + +func (srt *ServerResolvingTransport) resolve(ctx context.Context, serverName string) (*ResolvedServerName, error) { + srt.resolveLocksLock.Lock() + lock, ok := srt.resolveLocks[serverName] + if !ok { + lock = &sync.Mutex{} + srt.resolveLocks[serverName] = lock + } + srt.resolveLocksLock.Unlock() + + lock.Lock() + defer lock.Unlock() + res, err := srt.cache.LoadResolution(serverName) + if err != nil { + return nil, fmt.Errorf("failed to read cache: %w", err) + } else if res != nil { + return res, nil + } else if res, err = ResolveServerName(ctx, serverName, srt.ResolveOpts); err != nil { + return nil, err + } else { + srt.cache.StoreResolution(res) + return res, nil + } +} diff --git a/federation/keyserver.go b/federation/keyserver.go index 3e74bfdf..d32ba5cf 100644 --- a/federation/keyserver.go +++ b/federation/keyserver.go @@ -1,4 +1,4 @@ -// Copyright (c) 2024 Tulir Asokan +// Copyright (c) 2025 Tulir Asokan // // This Source Code Form is subject to the terms of the Mozilla Public // License, v. 2.0. If a copy of the MPL was not distributed with this @@ -8,13 +8,17 @@ package federation import ( "encoding/json" - "fmt" "net/http" "strconv" "time" - "github.com/gorilla/mux" + "github.com/rs/zerolog" + "github.com/rs/zerolog/hlog" + "go.mau.fi/util/exerrors" + "go.mau.fi/util/exhttp" "go.mau.fi/util/jsontime" + "go.mau.fi/util/ptr" + "go.mau.fi/util/requestlog" "maunium.net/go/mautrix" "maunium.net/go/mautrix/id" @@ -47,34 +51,29 @@ type KeyServer struct { KeyProvider ServerKeyProvider Version ServerVersion WellKnownTarget string + OtherKeys KeyCache } // Register registers the key server endpoints to the given router. -func (ks *KeyServer) Register(r *mux.Router) { - r.HandleFunc("/.well-known/matrix/server", ks.GetWellKnown).Methods(http.MethodGet) - r.HandleFunc("/_matrix/federation/v1/version", ks.GetServerVersion).Methods(http.MethodGet) - keyRouter := r.PathPrefix("/_matrix/key").Subrouter() - keyRouter.HandleFunc("/v2/server", ks.GetServerKey).Methods(http.MethodGet) - keyRouter.HandleFunc("/v2/query/{serverName}", ks.GetQueryKeys).Methods(http.MethodGet) - keyRouter.HandleFunc("/v2/query", ks.PostQueryKeys).Methods(http.MethodPost) - keyRouter.NotFoundHandler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - jsonResponse(w, http.StatusNotFound, &mautrix.RespError{ - ErrCode: mautrix.MUnrecognized.ErrCode, - Err: "Unrecognized endpoint", - }) - }) - keyRouter.MethodNotAllowedHandler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - jsonResponse(w, http.StatusMethodNotAllowed, &mautrix.RespError{ - ErrCode: mautrix.MUnrecognized.ErrCode, - Err: "Invalid method for endpoint", - }) - }) -} - -func jsonResponse(w http.ResponseWriter, code int, data any) { - w.Header().Add("Content-Type", "application/json") - w.WriteHeader(code) - _ = json.NewEncoder(w).Encode(data) +func (ks *KeyServer) Register(r *http.ServeMux, log zerolog.Logger) { + r.HandleFunc("GET /.well-known/matrix/server", ks.GetWellKnown) + r.HandleFunc("GET /_matrix/federation/v1/version", ks.GetServerVersion) + keyRouter := http.NewServeMux() + keyRouter.HandleFunc("GET /v2/server", ks.GetServerKey) + keyRouter.HandleFunc("GET /v2/query/{serverName}", ks.GetQueryKeys) + keyRouter.HandleFunc("POST /v2/query", ks.PostQueryKeys) + errorBodies := exhttp.ErrorBodies{ + NotFound: exerrors.Must(ptr.Ptr(mautrix.MUnrecognized.WithMessage("Unrecognized endpoint")).MarshalJSON()), + MethodNotAllowed: exerrors.Must(ptr.Ptr(mautrix.MUnrecognized.WithMessage("Invalid method for endpoint")).MarshalJSON()), + } + r.Handle("/_matrix/key/", exhttp.ApplyMiddleware( + keyRouter, + exhttp.StripPrefix("/_matrix/key"), + hlog.NewHandler(log), + hlog.RequestIDHandler("request_id", "Request-Id"), + requestlog.AccessLogger(requestlog.Options{TrustXForwardedFor: true}), + exhttp.HandleErrors(errorBodies), + )) } // RespWellKnown is the response body for the `GET /.well-known/matrix/server` endpoint. @@ -87,12 +86,9 @@ type RespWellKnown struct { // https://spec.matrix.org/v1.9/server-server-api/#get_well-knownmatrixserver func (ks *KeyServer) GetWellKnown(w http.ResponseWriter, r *http.Request) { if ks.WellKnownTarget == "" { - jsonResponse(w, http.StatusNotFound, &mautrix.RespError{ - ErrCode: mautrix.MNotFound.ErrCode, - Err: "No well-known target set", - }) + mautrix.MNotFound.WithMessage("No well-known target set").Write(w) } else { - jsonResponse(w, http.StatusOK, &RespWellKnown{Server: ks.WellKnownTarget}) + exhttp.WriteJSONResponse(w, http.StatusOK, &RespWellKnown{Server: ks.WellKnownTarget}) } } @@ -105,7 +101,7 @@ type RespServerVersion struct { // // https://spec.matrix.org/v1.9/server-server-api/#get_matrixfederationv1version func (ks *KeyServer) GetServerVersion(w http.ResponseWriter, r *http.Request) { - jsonResponse(w, http.StatusOK, &RespServerVersion{Server: ks.Version}) + exhttp.WriteJSONResponse(w, http.StatusOK, &RespServerVersion{Server: ks.Version}) } // GetServerKey implements the `GET /_matrix/key/v2/server` endpoint. @@ -114,12 +110,9 @@ func (ks *KeyServer) GetServerVersion(w http.ResponseWriter, r *http.Request) { func (ks *KeyServer) GetServerKey(w http.ResponseWriter, r *http.Request) { domain, key := ks.KeyProvider.Get(r) if key == nil { - jsonResponse(w, http.StatusNotFound, &mautrix.RespError{ - ErrCode: mautrix.MNotFound.ErrCode, - Err: fmt.Sprintf("No signing key found for %q", r.Host), - }) + mautrix.MNotFound.WithMessage("No signing key found for %q", r.Host).Write(w) } else { - jsonResponse(w, http.StatusOK, key.GenerateKeyResponse(domain, nil)) + exhttp.WriteJSONResponse(w, http.StatusOK, key.GenerateKeyResponse(domain, nil)) } } @@ -144,10 +137,7 @@ func (ks *KeyServer) PostQueryKeys(w http.ResponseWriter, r *http.Request) { var req ReqQueryKeys err := json.NewDecoder(r.Body).Decode(&req) if err != nil { - jsonResponse(w, http.StatusBadRequest, &mautrix.RespError{ - ErrCode: mautrix.MBadJSON.ErrCode, - Err: fmt.Sprintf("failed to parse request: %v", err), - }) + mautrix.MBadJSON.WithMessage("failed to parse request: %v", err).Write(w) return } @@ -165,7 +155,7 @@ func (ks *KeyServer) PostQueryKeys(w http.ResponseWriter, r *http.Request) { } } } - jsonResponse(w, http.StatusOK, resp) + exhttp.WriteJSONResponse(w, http.StatusOK, resp) } // GetQueryKeysResponse is the response body for the `GET /_matrix/key/v2/query/{serverName}` endpoint @@ -177,27 +167,39 @@ type GetQueryKeysResponse struct { // // https://spec.matrix.org/v1.9/server-server-api/#get_matrixkeyv2queryservername func (ks *KeyServer) GetQueryKeys(w http.ResponseWriter, r *http.Request) { - serverName := mux.Vars(r)["serverName"] + serverName := r.PathValue("serverName") minimumValidUntilTSString := r.URL.Query().Get("minimum_valid_until_ts") minimumValidUntilTS, err := strconv.ParseInt(minimumValidUntilTSString, 10, 64) if err != nil && minimumValidUntilTSString != "" { - jsonResponse(w, http.StatusBadRequest, &mautrix.RespError{ - ErrCode: mautrix.MInvalidParam.ErrCode, - Err: fmt.Sprintf("failed to parse ?minimum_valid_until_ts: %v", err), - }) + mautrix.MInvalidParam.WithMessage("failed to parse ?minimum_valid_until_ts: %v", err).Write(w) return } else if time.UnixMilli(minimumValidUntilTS).After(time.Now().Add(24 * time.Hour)) { - jsonResponse(w, http.StatusBadRequest, &mautrix.RespError{ - ErrCode: mautrix.MInvalidParam.ErrCode, - Err: "minimum_valid_until_ts may not be more than 24 hours in the future", - }) + mautrix.MInvalidParam.WithMessage("minimum_valid_until_ts may not be more than 24 hours in the future").Write(w) return } resp := &GetQueryKeysResponse{ ServerKeys: []*ServerKeyResponse{}, } - if domain, key := ks.KeyProvider.Get(r); key != nil && domain == serverName { - resp.ServerKeys = append(resp.ServerKeys, key.GenerateKeyResponse(serverName, nil)) + domain, key := ks.KeyProvider.Get(r) + if domain == serverName { + if key != nil { + resp.ServerKeys = append(resp.ServerKeys, key.GenerateKeyResponse(serverName, nil)) + } + } else if ks.OtherKeys != nil { + otherKey, err := ks.OtherKeys.LoadKeys(serverName) + if err != nil { + mautrix.MUnknown.WithMessage("Failed to load keys from cache").Write(w) + return + } + if key != nil && domain != "" { + signature, err := key.SignJSON(otherKey) + if err == nil { + otherKey.Signatures[domain] = map[id.KeyID]string{ + key.ID: signature, + } + } + } + resp.ServerKeys = append(resp.ServerKeys, otherKey) } - jsonResponse(w, http.StatusOK, resp) + exhttp.WriteJSONResponse(w, http.StatusOK, resp) } diff --git a/federation/pdu/auth.go b/federation/pdu/auth.go new file mode 100644 index 00000000..16706fe5 --- /dev/null +++ b/federation/pdu/auth.go @@ -0,0 +1,71 @@ +// Copyright (c) 2025 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +//go:build goexperiment.jsonv2 + +package pdu + +import ( + "slices" + + "github.com/tidwall/gjson" + "go.mau.fi/util/exgjson" + + "maunium.net/go/mautrix/event" + "maunium.net/go/mautrix/id" +) + +type StateKey struct { + Type string + StateKey string +} + +var thirdPartyInviteTokenPath = exgjson.Path("third_party_invite", "signed", "token") + +type AuthEventSelection []StateKey + +func (aes *AuthEventSelection) Add(evtType, stateKey string) { + key := StateKey{Type: evtType, StateKey: stateKey} + if !aes.Has(key) { + *aes = append(*aes, key) + } +} + +func (aes *AuthEventSelection) Has(key StateKey) bool { + return slices.Contains(*aes, key) +} + +func (pdu *PDU) AuthEventSelection(roomVersion id.RoomVersion) (keys AuthEventSelection) { + if pdu.Type == event.StateCreate.Type && pdu.StateKey != nil { + return AuthEventSelection{} + } + keys = make(AuthEventSelection, 0, 3) + if !roomVersion.RoomIDIsCreateEventID() { + keys.Add(event.StateCreate.Type, "") + } + keys.Add(event.StatePowerLevels.Type, "") + keys.Add(event.StateMember.Type, pdu.Sender.String()) + if pdu.Type == event.StateMember.Type && pdu.StateKey != nil { + keys.Add(event.StateMember.Type, *pdu.StateKey) + membership := event.Membership(gjson.GetBytes(pdu.Content, "membership").Str) + if membership == event.MembershipJoin || membership == event.MembershipInvite || membership == event.MembershipKnock { + keys.Add(event.StateJoinRules.Type, "") + } + if membership == event.MembershipInvite { + thirdPartyInviteToken := gjson.GetBytes(pdu.Content, thirdPartyInviteTokenPath).Str + if thirdPartyInviteToken != "" { + keys.Add(event.StateThirdPartyInvite.Type, thirdPartyInviteToken) + } + } + if membership == event.MembershipJoin && roomVersion.RestrictedJoins() { + authorizedVia := gjson.GetBytes(pdu.Content, "authorised_via_users_server").Str + if authorizedVia != "" { + keys.Add(event.StateMember.Type, authorizedVia) + } + } + } + return +} diff --git a/federation/pdu/hash.go b/federation/pdu/hash.go new file mode 100644 index 00000000..38ef83e9 --- /dev/null +++ b/federation/pdu/hash.go @@ -0,0 +1,118 @@ +// Copyright (c) 2025 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +//go:build goexperiment.jsonv2 + +package pdu + +import ( + "crypto/hmac" + "crypto/sha256" + "encoding/base64" + "fmt" + + "github.com/tidwall/gjson" + + "maunium.net/go/mautrix/id" +) + +func (pdu *PDU) CalculateContentHash() ([32]byte, error) { + if pdu == nil { + return [32]byte{}, ErrPDUIsNil + } + pduClone := pdu.Clone() + pduClone.Signatures = nil + pduClone.Unsigned = nil + pduClone.Hashes = nil + rawJSON, err := marshalCanonical(pduClone) + if err != nil { + return [32]byte{}, fmt.Errorf("failed to marshal PDU to calculate content hash: %w", err) + } + return sha256.Sum256(rawJSON), nil +} + +func (pdu *PDU) FillContentHash() error { + if pdu == nil { + return ErrPDUIsNil + } else if pdu.Hashes != nil { + return nil + } else if hash, err := pdu.CalculateContentHash(); err != nil { + return err + } else { + pdu.Hashes = &Hashes{SHA256: hash[:]} + return nil + } +} + +func (pdu *PDU) VerifyContentHash() bool { + if pdu == nil || pdu.Hashes == nil { + return false + } + calculatedHash, err := pdu.CalculateContentHash() + if err != nil { + return false + } + return hmac.Equal(calculatedHash[:], pdu.Hashes.SHA256) +} + +func (pdu *PDU) GetRoomID() (id.RoomID, error) { + if pdu == nil { + return "", ErrPDUIsNil + } else if pdu.Type != "m.room.create" { + return "", fmt.Errorf("room ID can only be calculated for m.room.create events") + } else if roomVersion := id.RoomVersion(gjson.GetBytes(pdu.Content, "room_version").Str); !roomVersion.RoomIDIsCreateEventID() { + return "", fmt.Errorf("room version %s does not use m.room.create event ID as room ID", roomVersion) + } else if evtID, err := pdu.calculateEventID(roomVersion, '!'); err != nil { + return "", fmt.Errorf("failed to calculate event ID: %w", err) + } else { + return id.RoomID(evtID), nil + } +} + +var UseInternalMetaForGetEventID = false + +func (pdu *PDU) GetEventID(roomVersion id.RoomVersion) (id.EventID, error) { + if UseInternalMetaForGetEventID && pdu.InternalMeta.EventID != "" { + return pdu.InternalMeta.EventID, nil + } + return pdu.calculateEventID(roomVersion, '$') +} + +func (pdu *PDU) GetReferenceHash(roomVersion id.RoomVersion) ([32]byte, error) { + if pdu == nil { + return [32]byte{}, ErrPDUIsNil + } + if pdu.Hashes == nil || pdu.Hashes.SHA256 == nil { + if err := pdu.FillContentHash(); err != nil { + return [32]byte{}, err + } + } + rawJSON, err := marshalCanonical(pdu.Clone().RedactForSignature(roomVersion)) + if err != nil { + return [32]byte{}, fmt.Errorf("failed to marshal redacted PDU to calculate event ID: %w", err) + } + return sha256.Sum256(rawJSON), nil +} + +func (pdu *PDU) calculateEventID(roomVersion id.RoomVersion, prefix byte) (id.EventID, error) { + referenceHash, err := pdu.GetReferenceHash(roomVersion) + if err != nil { + return "", err + } + eventID := make([]byte, 44) + eventID[0] = prefix + switch roomVersion.EventIDFormat() { + case id.EventIDFormatCustom: + return "", fmt.Errorf("*pdu.PDU can only be used for room v3+") + case id.EventIDFormatBase64: + base64.RawStdEncoding.Encode(eventID[1:], referenceHash[:]) + case id.EventIDFormatURLSafeBase64: + base64.RawURLEncoding.Encode(eventID[1:], referenceHash[:]) + default: + return "", fmt.Errorf("unknown event ID format %v", roomVersion.EventIDFormat()) + } + return id.EventID(eventID), nil +} diff --git a/federation/pdu/hash_test.go b/federation/pdu/hash_test.go new file mode 100644 index 00000000..17417e12 --- /dev/null +++ b/federation/pdu/hash_test.go @@ -0,0 +1,55 @@ +// Copyright (c) 2025 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +//go:build goexperiment.jsonv2 + +package pdu_test + +import ( + "encoding/base64" + "testing" + + "github.com/stretchr/testify/assert" + "go.mau.fi/util/exerrors" +) + +func TestPDU_CalculateContentHash(t *testing.T) { + for _, test := range testPDUs { + if test.redacted { + continue + } + t.Run(test.name, func(t *testing.T) { + parsed := parsePDU(test.pdu) + contentHash := exerrors.Must(parsed.CalculateContentHash()) + assert.Equal( + t, + base64.RawStdEncoding.EncodeToString(parsed.Hashes.SHA256), + base64.RawStdEncoding.EncodeToString(contentHash[:]), + ) + }) + } +} + +func TestPDU_VerifyContentHash(t *testing.T) { + for _, test := range testPDUs { + if test.redacted { + continue + } + t.Run(test.name, func(t *testing.T) { + parsed := parsePDU(test.pdu) + assert.True(t, parsed.VerifyContentHash()) + }) + } +} + +func TestPDU_GetEventID(t *testing.T) { + for _, test := range testPDUs { + t.Run(test.name, func(t *testing.T) { + gotEventID := exerrors.Must(parsePDU(test.pdu).GetEventID(test.roomVersion)) + assert.Equal(t, test.eventID, gotEventID) + }) + } +} diff --git a/federation/pdu/pdu.go b/federation/pdu/pdu.go new file mode 100644 index 00000000..17db6995 --- /dev/null +++ b/federation/pdu/pdu.go @@ -0,0 +1,156 @@ +// Copyright (c) 2025 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +//go:build goexperiment.jsonv2 + +package pdu + +import ( + "bytes" + "crypto/ed25519" + "encoding/json/jsontext" + "encoding/json/v2" + "errors" + "fmt" + "strings" + "time" + + "github.com/tidwall/gjson" + "go.mau.fi/util/jsonbytes" + "go.mau.fi/util/ptr" + + "maunium.net/go/mautrix/crypto/canonicaljson" + "maunium.net/go/mautrix/event" + "maunium.net/go/mautrix/id" +) + +// GetKeyFunc is a callback for retrieving the key corresponding to a given key ID when verifying the signature of a PDU. +// +// The input time is the timestamp of the event. The function should attempt to fetch a key that is +// valid at or after this time, but if that is not possible, the latest available key should be +// returned without an error. The verify function will do its own validity checking based on the +// returned valid until timestamp. +type GetKeyFunc = func(serverName string, keyID id.KeyID, minValidUntil time.Time) (key id.SigningKey, validUntil time.Time, err error) + +type AnyPDU interface { + GetRoomID() (id.RoomID, error) + GetEventID(roomVersion id.RoomVersion) (id.EventID, error) + GetReferenceHash(roomVersion id.RoomVersion) ([32]byte, error) + CalculateContentHash() ([32]byte, error) + FillContentHash() error + VerifyContentHash() bool + Sign(roomVersion id.RoomVersion, serverName string, keyID id.KeyID, privateKey ed25519.PrivateKey) error + VerifySignature(roomVersion id.RoomVersion, serverName string, getKey GetKeyFunc) error + ToClientEvent(roomVersion id.RoomVersion) (*event.Event, error) + AuthEventSelection(roomVersion id.RoomVersion) (keys AuthEventSelection) +} + +var ( + _ AnyPDU = (*PDU)(nil) + _ AnyPDU = (*RoomV1PDU)(nil) +) + +type InternalMeta struct { + EventID id.EventID `json:"event_id,omitempty"` + Rejected bool `json:"rejected,omitempty"` + Extra map[string]any `json:",unknown"` +} + +type PDU struct { + AuthEvents []id.EventID `json:"auth_events"` + Content jsontext.Value `json:"content"` + Depth int64 `json:"depth"` + Hashes *Hashes `json:"hashes,omitzero"` + OriginServerTS int64 `json:"origin_server_ts"` + PrevEvents []id.EventID `json:"prev_events"` + Redacts *id.EventID `json:"redacts,omitzero"` + RoomID id.RoomID `json:"room_id,omitzero"` // not present for room v12+ create events + Sender id.UserID `json:"sender"` + Signatures map[string]map[id.KeyID]string `json:"signatures,omitzero"` + StateKey *string `json:"state_key,omitzero"` + Type string `json:"type"` + Unsigned jsontext.Value `json:"unsigned,omitzero"` + InternalMeta InternalMeta `json:"-"` + + Unknown jsontext.Value `json:",unknown"` + + // Deprecated legacy fields + DeprecatedPrevState jsontext.Value `json:"prev_state,omitzero"` + DeprecatedOrigin jsontext.Value `json:"origin,omitzero"` + DeprecatedMembership jsontext.Value `json:"membership,omitzero"` +} + +var ErrPDUIsNil = errors.New("PDU is nil") + +type Hashes struct { + SHA256 jsonbytes.UnpaddedBytes `json:"sha256"` + + Unknown jsontext.Value `json:",unknown"` +} + +func (pdu *PDU) ToClientEvent(roomVersion id.RoomVersion) (*event.Event, error) { + if pdu.Type == "m.room.create" && roomVersion == "" { + roomVersion = id.RoomVersion(gjson.GetBytes(pdu.Content, "room_version").Str) + } + evtType := event.Type{Type: pdu.Type, Class: event.MessageEventType} + if pdu.StateKey != nil { + evtType.Class = event.StateEventType + } + eventID, err := pdu.GetEventID(roomVersion) + if err != nil { + return nil, err + } + roomID := pdu.RoomID + if pdu.Type == "m.room.create" && roomVersion.RoomIDIsCreateEventID() { + roomID = id.RoomID(strings.Replace(string(eventID), "$", "!", 1)) + } + evt := &event.Event{ + StateKey: pdu.StateKey, + Sender: pdu.Sender, + Type: evtType, + Timestamp: pdu.OriginServerTS, + ID: eventID, + RoomID: roomID, + Redacts: ptr.Val(pdu.Redacts), + } + err = json.Unmarshal(pdu.Content, &evt.Content) + if err != nil { + return nil, fmt.Errorf("failed to unmarshal content: %w", err) + } + return evt, nil +} + +func (pdu *PDU) AddSignature(serverName string, keyID id.KeyID, signature string) { + if signature == "" { + return + } + if pdu.Signatures == nil { + pdu.Signatures = make(map[string]map[id.KeyID]string) + } + if _, ok := pdu.Signatures[serverName]; !ok { + pdu.Signatures[serverName] = make(map[id.KeyID]string) + } + pdu.Signatures[serverName][keyID] = signature +} + +func marshalCanonical(data any) (jsontext.Value, error) { + marshaledBytes, err := json.Marshal(data) + if err != nil { + return nil, err + } + marshaled := jsontext.Value(marshaledBytes) + err = marshaled.Canonicalize() + if err != nil { + return nil, err + } + check := canonicaljson.CanonicalJSONAssumeValid(marshaled) + if !bytes.Equal(marshaled, check) { + fmt.Println(string(marshaled)) + fmt.Println(string(check)) + return nil, fmt.Errorf("canonical JSON mismatch for %s", string(marshaled)) + } + return marshaled, nil +} diff --git a/federation/pdu/pdu_test.go b/federation/pdu/pdu_test.go new file mode 100644 index 00000000..59d7c3a6 --- /dev/null +++ b/federation/pdu/pdu_test.go @@ -0,0 +1,193 @@ +// Copyright (c) 2025 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +//go:build goexperiment.jsonv2 + +package pdu_test + +import ( + "encoding/json/v2" + "time" + + "go.mau.fi/util/exerrors" + + "maunium.net/go/mautrix/federation/pdu" + "maunium.net/go/mautrix/id" +) + +type serverKey struct { + key id.SigningKey + validUntilTS time.Time +} + +type serverDetails struct { + serverName string + keys map[id.KeyID]serverKey +} + +func (sd serverDetails) getKey(serverName string, keyID id.KeyID, _ time.Time) (id.SigningKey, time.Time, error) { + if serverName != sd.serverName { + return "", time.Time{}, nil + } + key, ok := sd.keys[keyID] + if ok { + return key.key, key.validUntilTS, nil + } + return "", time.Time{}, nil +} + +var mauniumNet = serverDetails{ + serverName: "maunium.net", + keys: map[id.KeyID]serverKey{ + "ed25519:a_xxeS": { + key: "lVt/CC3tv74OH6xTph2JrUmeRj/j+1q0HVa0Xf4QlCg", + validUntilTS: time.Now(), + }, + }, +} +var envsNet = serverDetails{ + serverName: "envs.net", + keys: map[id.KeyID]serverKey{ + "ed25519:a_zIqy": { + key: "vCUcZpt9hUn0aabfh/9GP/6sZvXcydww8DUstPHdJm0", + validUntilTS: time.UnixMilli(1722360538068), + }, + "ed25519:wuJyKT": { + key: "xbE1QssgomL4wCSlyMYF5/7KxVyM4HPwAbNa+nFFnx0", + validUntilTS: time.Now(), + }, + }, +} +var matrixOrg = serverDetails{ + serverName: "matrix.org", + keys: map[id.KeyID]serverKey{ + "ed25519:auto": { + key: "Noi6WqcDj0QmPxCNQqgezwTlBKrfqehY1u2FyWP9uYw", + validUntilTS: time.UnixMilli(1576767829750), + }, + "ed25519:a_RXGa": { + key: "l8Hft5qXKn1vfHrg3p4+W8gELQVo8N13JkluMfmn2sQ", + validUntilTS: time.Now(), + }, + }, +} +var continuwuityOrg = serverDetails{ + serverName: "continuwuity.org", + keys: map[id.KeyID]serverKey{ + "ed25519:PwHlNsFu": { + key: "8eNx2s0zWW+heKAmOH5zKv/nCPkEpraDJfGHxDu6hFI", + validUntilTS: time.Now(), + }, + }, +} +var novaAstraltechOrg = serverDetails{ + serverName: "nova.astraltech.org", + keys: map[id.KeyID]serverKey{ + "ed25519:a_afpo": { + key: "O1Y9GWuKo9xkuzuQef6gROxtTgxxAbS3WPNghPYXF3o", + validUntilTS: time.Now(), + }, + }, +} + +type testPDU struct { + name string + pdu string + eventID id.EventID + roomVersion id.RoomVersion + redacted bool + serverDetails +} + +var roomV4MessageTestPDU = testPDU{ + name: "m.room.message in v4 room", + pdu: `{"auth_events":["$OB87jNemaIVDHAfu0-pa_cP7OPFXUXCbFpjYVi8gll4","$RaWbTF9wQfGQgUpe1S13wzICtGTB2PNKRHUNHu9IO1c","$ZmEWOXw6cC4Rd1wTdY5OzeLJVzjhrkxFPwwKE4gguGk"],"content":{"body":"the last one is saying it shouldn't have effects","com.beeper.linkpreviews":[],"m.mentions":{},"msgtype":"m.text"},"depth":13103,"hashes":{"sha256":"c2wb8qMlvzIPCP1Wd+eYZ4BRgnGYxS97dR1UlJjVMeg"},"origin_server_ts":1752875275263,"prev_events":["$-7_BMI3BXwj3ayoxiJvraJxYWTKwjiQ6sh7CW_Brvj0"],"room_id":"!JiiOHXrIUCtcOJsZCa:matrix.org","sender":"@tulir:maunium.net","type":"m.room.message","signatures":{"maunium.net":{"ed25519:a_xxeS":"99TAqHpBkUEtgCraXsVXogmf/hnijPbgbG9eACtA+mbix3Y6gURI4QGQgcX/NhcE3pJQZ/YDjmbuvCnKvEccAA"}},"unsigned":{"age_ts":1752875275281}}`, + eventID: "$Jo_lmFR-e6lzrimzCA7DevIn2OwhuQYmd9xkcJBoqAA", + roomVersion: id.RoomV4, + serverDetails: mauniumNet, +} + +var roomV12MessageTestPDU = testPDU{ + name: "m.room.message in v12 room", + pdu: `{"auth_events":["$gCzdJUVV93Qory0x7p_PLG5UUiDjPJNe1H12qbHTuFA","$hyeL_nU_L3tsZ2dtZZpAHk0Skv-PqFQIipuII_By584"],"content":{"body":"meow","com.beeper.linkpreviews":[],"m.mentions":{},"msgtype":"m.text"},"depth":122,"hashes":{"sha256":"IQ0zlc+PXeEs6R3JvRkW3xTPV3zlGKSSd3x07KXGjzs"},"origin_server_ts":1755384351627,"prev_events":["$gCzdJUVV93Qory0x7p_PLG5UUiDjPJNe1H12qbHTuFA"],"room_id":"!mauT12AzsoqxV7Abvy_ApA-HNPK1LcT4GbP70_AOPyQ","sender":"@tulir_test:maunium.net","type":"m.room.message","signatures":{"maunium.net":{"ed25519:a_xxeS":"0GDMddL2k7gF4V1VU8sL3wTfhAIzAu5iVH5jeavZ2VEg3J9/tHLWXAOn2tzkLaMRWl0/XpINT2YlH/rd2U21Ag"}},"unsigned":{"age_ts":1755384351627}}`, + eventID: "$xmP-wZfpannuHG-Akogi6c4YvqxChMtdyYbUMGOrMWc", + roomVersion: id.RoomV12, + serverDetails: mauniumNet, +} + +var testPDUs = []testPDU{roomV4MessageTestPDU, { + name: "m.room.message in v5 room", + pdu: `{"auth_events":["$hp0ImHqYgHTRbLeWKPeTeFmxdb5SdMJN9cfmTrTk7d0","$KAj7X7tnJbR9qYYMWJSw-1g414_KlPptbbkZm7_kUtg","$V-2ShOwZYhA_nxMijaf3lqFgIJgzE2UMeFPtOLnoBYM"],"content":{"body":"meow","com.beeper.linkpreviews":[],"m.mentions":{},"msgtype":"m.text"},"depth":2248,"hashes":{"sha256":"kV+JuLbWXJ2r6PjHT3wt8bFc/TfI1nTaSN3Lamg/xHs"},"origin_server_ts":1755422945654,"prev_events":["$49lFLem2Nk4dxHk9RDXxTdaq9InIJpmkHpzVnjKcYwg"],"room_id":"!vzBgJsjNzgHSdWsmki:mozilla.org","sender":"@tulir:maunium.net","type":"m.room.message","signatures":{"maunium.net":{"ed25519:a_xxeS":"JIl60uVgfCLBZLPoSiE7wVkJ9U5cNEPVPuv1sCCYUOq5yOW56WD1adgpBUdX2UFpYkCHvkRnyQGxU0+6HBp5BA"}},"unsigned":{"age_ts":1755422945673}}`, + eventID: "$Qn4tHfuAe6PlnKXPZnygAU9wd6RXqMKtt_ZzstHTSgA", + roomVersion: id.RoomV5, + serverDetails: mauniumNet, +}, { + name: "m.room.message in v10 room", + pdu: `{"auth_events":["$--ilpwnsHaEdHrwiMrZNu5xHP6TthWG0FIXMHnlHCcs","$tn1FZUI_YUpfTr_a3Y_r8kC3inliIZZratzg0UsNdCQ","$Z-qMWmiMvm-aIEffcfSO6lN7TyjyTOsIcHIymfzoo20"],"content":{"body":"meow","com.beeper.linkpreviews":[],"m.mentions":{},"msgtype":"m.text"},"depth":100885,"hashes":{"sha256":"jc9272JPpPIVreJC3UEAm3BNVnLX8sm3U/TZs23wsHo"},"origin_server_ts":1755422792518,"prev_events":["$HDtbzpSys36Hk-F2NsiXfp9slsGXBH0b58qyddj_q5E"],"room_id":"!UzZHbJYcgggctGnlzr:envs.net","sender":"@tulir:maunium.net","type":"m.room.message","signatures":{"maunium.net":{"ed25519:a_xxeS":"sAMLo9jPtNB0Jq67IQm06siEBx82qZa2edu56IDQ4tDylEV4Mq7iFO23gCghqXA7B/MqBsjXotGBxv6AvlJ2Dw"}},"unsigned":{"age_ts":1755422792540}}`, + eventID: "$4ZFr_ypfp4DyZQP4zyxM_cvuOMFkl07doJmwi106YFY", + roomVersion: id.RoomV10, + serverDetails: mauniumNet, +}, { + name: "m.room.message in v11 room", + pdu: `{"auth_events":["$L8Ak6A939llTRIsZrytMlLDXQhI4uLEjx-wb1zSg-Bw","$QJmr7mmGeXGD4Tof0ZYSPW2oRGklseyHTKtZXnF-YNM","$7bkKK_Z-cGQ6Ae4HXWGBwXyZi3YjC6rIcQzGfVyl3Eo"],"content":{"body":"meow","com.beeper.linkpreviews":[],"m.mentions":{},"msgtype":"m.text"},"depth":3212,"hashes":{"sha256":"K549YdTnv62Jn84Y7sS5ZN3+AdmhleZHbenbhUpR2R8"},"origin_server_ts":1754242687127,"prev_events":["$DAhJg4jVsqk5FRatE2hbT1dSA8D2ASy5DbjEHIMSHwY"],"room_id":"!offtopic-2:continuwuity.org","sender":"@tulir:maunium.net","type":"m.room.message","signatures":{"maunium.net":{"ed25519:a_xxeS":"SkzZdZ+rH22kzCBBIAErTdB0Vg6vkFmzvwjlOarGul72EnufgtE/tJcd3a8szAdK7f1ZovRyQxDgVm/Ib2u0Aw"}},"unsigned":{"age_ts":1754242687146}}`, + eventID: `$qkWfTL7_l3oRZO2CItW8-Q0yAmi_l_1ua629ZDqponE`, + roomVersion: id.RoomV11, + serverDetails: mauniumNet, +}, roomV12MessageTestPDU, { + name: "m.room.create in v4 room", + pdu: `{"auth_events": [], "prev_events": [], "type": "m.room.create", "room_id": "!jxlRxnrZCsjpjDubDX:matrix.org", "sender": "@neilj:matrix.org", "content": {"room_version": "4", "predecessor": {"room_id": "!DYgXKezaHgMbiPMzjX:matrix.org", "event_id": "$156171636353XwPJT:matrix.org"}, "creator": "@neilj:matrix.org"}, "depth": 1, "prev_state": [], "state_key": "", "origin": "matrix.org", "origin_server_ts": 1561716363993, "hashes": {"sha256": "9tj8GpXjTAJvdNAbnuKLemZZk+Tjv2LAbGodSX6nJAo"}, "signatures": {"matrix.org": {"ed25519:auto": "2+sNt8uJUhzU4GPxnFVYtU2ZRgFdtVLT1vEZGUdJYN40zBpwYEGJy+kyb5matA+8/yLeYD9gu1O98lhleH0aCA"}}, "unsigned": {"age": 104769}}`, + eventID: "$ay_9_nPilrTpb3UxIwHHBBfFjTJb6hBAE_JzQwSjqeY", + roomVersion: id.RoomV4, + serverDetails: matrixOrg, +}, { + name: "m.room.create in v10 room", + pdu: `{"auth_events":[],"content":{"creator":"@creme:envs.net","predecessor":{"event_id":"$BxYNisKcyBDhPLiVC06t18qhv7wsT72MzMCqn5vRhfY","room_id":"!tEyFYiMHhwJlDXTxwf:envs.net"},"room_version":"10"},"depth":1,"hashes":{"sha256":"us3TrsIjBWpwbm+k3F9fUVnz9GIuhnb+LcaY47fWwUI"},"origin":"envs.net","origin_server_ts":1664394769527,"prev_events":[],"room_id":"!UzZHbJYcgggctGnlzr:envs.net","sender":"@creme:envs.net","state_key":"","type":"m.room.create","signatures":{"envs.net":{"ed25519:a_zIqy":"0g3FDaD1e5BekJYW2sR7dgxuKoZshrf8P067c9+jmH6frsWr2Ua86Ax08CFa/n46L8uvV2SGofP8iiVYgXCRBg"}},"unsigned":{"age":2060}}`, + eventID: "$tn1FZUI_YUpfTr_a3Y_r8kC3inliIZZratzg0UsNdCQ", + roomVersion: id.RoomV10, + serverDetails: envsNet, +}, { + name: "m.room.create in v12 room", + pdu: `{"auth_events":[],"content":{"fi.mau.randomness":"AAXZ6aIc","predecessor":{"room_id":"!#test/room\nversion 11, with @\ud83d\udc08\ufe0f:maunium.net"},"room_version":"12"},"depth":1,"hashes":{"sha256":"d3L1M3KUdyIKWcShyW6grUoJ8GOjCdSIEvQrDVHSpE8"},"origin_server_ts":1754940000000,"prev_events":[],"sender":"@tulir:maunium.net","state_key":"","type":"m.room.create","signatures":{"maunium.net":{"ed25519:a_xxeS":"ebjIRpzToc82cjb/RGY+VUzZic0yeRZrjctgx0SUTJxkprXn3/i1KdiYULfl/aD0cUJ5eL8gLakOSk2glm+sBw"}},"unsigned":{"age_ts":1754939139045}}`, + eventID: "$mauT12AzsoqxV7Abvy_ApA-HNPK1LcT4GbP70_AOPyQ", + roomVersion: id.RoomV12, + serverDetails: mauniumNet, +}, { + name: "m.room.member in v4 room", + pdu: `{"auth_events":["$ay_9_nPilrTpb3UxIwHHBBfFjTJb6hBAE_JzQwSjqeY","$jg2AgCfnwnjR-osoyM0lVYS21QrtfmZxhGO90PRkmO4","$wMGMP4Ucij2_d4h_fVDgIT2xooLZAgMcBruT9oo3Jio","$yyDgV8w0_e8qslmn0nh9OeSq_fO0zjpjTjSEdKFxDso"],"prev_events":["$zSjNuTXhUe3Rq6NpKD3sNyl8a_asMnBhGC5IbacHlJ4"],"type":"m.room.member","room_id":"!jxlRxnrZCsjpjDubDX:matrix.org","sender":"@tulir:maunium.net","content":{"membership":"join","displayname":"tulir","avatar_url":"mxc://maunium.net/jdlSfvudiMSmcRrleeiYjjFO","clicked \"send membership event with no changes\"":true},"depth":14370,"prev_state":[],"state_key":"@tulir:maunium.net","origin":"maunium.net","origin_server_ts":1600871136259,"hashes":{"sha256":"Ga6bG9Mk0887ruzM9TAAfa1O3DbNssb+qSFtE9oeRL4"},"signatures":{"maunium.net":{"ed25519:a_xxeS":"fzOyDG3G3pEzixtWPttkRA1DfnHETiKbiG8SEBQe2qycQbZWPky7xX8WujSrUJH/+bxTABpQwEH49d+RakxtBw"}},"unsigned":{"age_ts":1600871136259,"replaces_state":"$jg2AgCfnwnjR-osoyM0lVYS21QrtfmZxhGO90PRkmO4"}}`, + eventID: "$VtuCNOfAWGow-cxy0ajeK3fvONcC8QzF2yWa43g0Gwo", + roomVersion: id.RoomV4, + serverDetails: mauniumNet, +}, { + name: "m.room.member in v10 room", + pdu: `{"auth_events":["$HQC4hWaioLKVbMH94qKbfb3UnL4ocql2vi-VdUYI48I","$R9FUDgNAp9ms7b6ASunZOIkpqmsIRq_ROrNEznu62fs","$kEPF8Aj87EzRmFPriu2zdyEY0rY15XSqywTYVLUUlCA","$tn1FZUI_YUpfTr_a3Y_r8kC3inliIZZratzg0UsNdCQ"],"content":{"avatar_url":"mxc://maunium.net/jdlSfvudiMSmcRrleeiYjjFO","displayname":"tulir","membership":"join"},"depth":182,"hashes":{"sha256":"0HscBc921QV2dxK2qY7qrnyoAgfxBM7kKvqAXlEk+GE"},"origin":"maunium.net","origin_server_ts":1665402609039,"prev_events":["$R9FUDgNAp9ms7b6ASunZOIkpqmsIRq_ROrNEznu62fs"],"room_id":"!UzZHbJYcgggctGnlzr:envs.net","sender":"@tulir:maunium.net","state_key":"@tulir:maunium.net","type":"m.room.member","signatures":{"maunium.net":{"ed25519:a_xxeS":"lkOW0FSJ8MJ0wZpdwLH1Uf6FSl2q9/u6KthRIlM0CwHDJG4sIZ9DrMA8BdU8L/PWoDS/CoDUlLanDh99SplgBw"}},"unsigned":{"age_ts":1665402609039,"replaces_state":"$R9FUDgNAp9ms7b6ASunZOIkpqmsIRq_ROrNEznu62fs"}}`, + eventID: "$--ilpwnsHaEdHrwiMrZNu5xHP6TthWG0FIXMHnlHCcs", + roomVersion: id.RoomV10, + serverDetails: mauniumNet, +}, { + name: "m.room.member of creator in v12 room", + pdu: `{"auth_events":[],"content":{"avatar_url":"mxc://maunium.net/jdlSfvudiMSmcRrleeiYjjFO","displayname":"tulir","membership":"join"},"depth":2,"hashes":{"sha256":"IebdOBYaaWYIx2zq/lkVCnjWIXTLk1g+vgFpJMgd2/E"},"origin_server_ts":1754939139117,"prev_events":["$mauT12AzsoqxV7Abvy_ApA-HNPK1LcT4GbP70_AOPyQ"],"room_id":"!mauT12AzsoqxV7Abvy_ApA-HNPK1LcT4GbP70_AOPyQ","sender":"@tulir:maunium.net","state_key":"@tulir:maunium.net","type":"m.room.member","signatures":{"maunium.net":{"ed25519:a_xxeS":"rFCgF2hmavdm6+P6/f7rmuOdoSOmELFaH3JdWjgBLZXS2z51Ma7fa2v2+BkAH1FvBo9FLhvEoFVM4WbNQLXtAA"}},"unsigned":{"age_ts":1754939139117}}`, + eventID: "$accqGxfvhBvMP4Sf6P7t3WgnaJK6UbonO2ZmwqSE5Sg", + roomVersion: id.RoomV12, + serverDetails: mauniumNet, +}, { + name: "custom message event in v4 room", + pdu: `{"auth_events":["$VtuCNOfAWGow-cxy0ajeK3fvONcC8QzF2yWa43g0Gwo","$ay_9_nPilrTpb3UxIwHHBBfFjTJb6hBAE_JzQwSjqeY","$Gau_XwziYsr-rt3SouhbKN14twgmbKjcZZc_hz-nOgU"],"content":{"\ud83d\udc08\ufe0f":true,"\ud83d\udc15\ufe0f":false},"depth":69645,"hashes":{"sha256":"VHtWyCt+15ZesNnStU3FOkxrjzHJYZfd3JUgO9JWe0s"},"origin_server_ts":1755423939146,"prev_events":["$exmp4cj0OKOFSxuqBYiOYwQi5j_0XRc78d6EavAkhy0"],"room_id":"!jxlRxnrZCsjpjDubDX:matrix.org","sender":"@tulir:maunium.net","type":"\ud83d\udc08\ufe0f","signatures":{"maunium.net":{"ed25519:a_xxeS":"wfmP1XN4JBkKVkqrQnwysyEUslXt8hQRFwN9NC9vJaIeDMd0OJ6uqCas75808DuG71p23fzqbzhRnHckst6FCQ"}},"unsigned":{"age_ts":1755423939164}}`, + eventID: "$kAagtZAIEeZaLVCUSl74tAxQbdKbE22GU7FM-iAJBc0", + roomVersion: id.RoomV4, + serverDetails: mauniumNet, +}, { + name: "redacted m.room.member event in v11 room with 2 signatures", + pdu: `{"auth_events":["$9f12-_stoY07BOTmyguE1QlqvghLBh9Rk6PWRLoZn_M","$IP8hyjBkIDREVadyv0fPCGAW9IXGNllaZyxqQwiY_tA","$7dN5J8EveliaPkX6_QSejl4GQtem4oieavgALMeWZyE"],"content":{"membership":"join"},"depth":96978,"hashes":{"sha256":"APYA/aj3u+P0EwNaEofuSIlfqY3cK3lBz6RkwHX+Zak"},"origin_server_ts":1755664164485,"prev_events":["$XBN9W5Ll8VEH3eYqJaemxCBTDdy0hZB0sWpmyoUp93c"],"room_id":"!main-1:continuwuity.org","sender":"@6a19abdd4766:nova.astraltech.org","state_key":"@6a19abdd4766:nova.astraltech.org","type":"m.room.member","signatures":{"continuwuity.org":{"ed25519:PwHlNsFu":"+b/Fp2vWnC+Z2lI3GnCu7ZHdo3iWNDZ2AJqMoU9owMtLBPMxs4dVIsJXvaFq0ryawsgwDwKZ7f4xaFUNARJSDg"},"nova.astraltech.org":{"ed25519:a_afpo":"pXIngyxKukCPR7WOIIy8FTZxQ5L2dLiou5Oc8XS4WyY4YzJuckQzOaToigLLZxamfbN/jXbO+XUizpRpYccDAA"}},"unsigned":{}}`, + eventID: "$r6d9m125YWG28-Tln47bWtm6Jlv4mcSUWJTHijBlXLQ", + roomVersion: id.RoomV11, + serverDetails: novaAstraltechOrg, + redacted: true, +}} + +func parsePDU(pdu string) (out *pdu.PDU) { + exerrors.PanicIfNotNil(json.Unmarshal([]byte(pdu), &out)) + return +} diff --git a/federation/pdu/redact.go b/federation/pdu/redact.go new file mode 100644 index 00000000..d7ee0c15 --- /dev/null +++ b/federation/pdu/redact.go @@ -0,0 +1,111 @@ +// Copyright (c) 2025 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +//go:build goexperiment.jsonv2 + +package pdu + +import ( + "encoding/json/jsontext" + + "github.com/tidwall/gjson" + "github.com/tidwall/sjson" + "go.mau.fi/util/exgjson" + "go.mau.fi/util/ptr" + + "maunium.net/go/mautrix/id" +) + +func filteredObject(object jsontext.Value, allowedPaths ...string) jsontext.Value { + filtered := jsontext.Value("{}") + var err error + for _, path := range allowedPaths { + res := gjson.GetBytes(object, path) + if res.Exists() { + var raw jsontext.Value + if res.Index > 0 { + raw = object[res.Index : res.Index+len(res.Raw)] + } else { + raw = jsontext.Value(res.Raw) + } + filtered, err = sjson.SetRawBytes(filtered, path, raw) + if err != nil { + panic(err) + } + } + } + return filtered +} + +func (pdu *PDU) Clone() *PDU { + return ptr.Clone(pdu) +} + +func (pdu *PDU) RedactForSignature(roomVersion id.RoomVersion) *PDU { + pdu.Signatures = nil + return pdu.Redact(roomVersion) +} + +var emptyObject = jsontext.Value("{}") + +func RedactContent(eventType string, content jsontext.Value, roomVersion id.RoomVersion) jsontext.Value { + switch eventType { + case "m.room.member": + allowedPaths := []string{"membership"} + if roomVersion.RestrictedJoinsFix() { + allowedPaths = append(allowedPaths, "join_authorised_via_users_server") + } + if roomVersion.UpdatedRedactionRules() { + allowedPaths = append(allowedPaths, exgjson.Path("third_party_invite", "signed")) + } + return filteredObject(content, allowedPaths...) + case "m.room.create": + if !roomVersion.UpdatedRedactionRules() { + return filteredObject(content, "creator") + } + return content + case "m.room.join_rules": + if roomVersion.RestrictedJoins() { + return filteredObject(content, "join_rule", "allow") + } + return filteredObject(content, "join_rule") + case "m.room.power_levels": + allowedKeys := []string{"ban", "events", "events_default", "kick", "redact", "state_default", "users", "users_default"} + if roomVersion.UpdatedRedactionRules() { + allowedKeys = append(allowedKeys, "invite") + } + return filteredObject(content, allowedKeys...) + case "m.room.history_visibility": + return filteredObject(content, "history_visibility") + case "m.room.redaction": + if roomVersion.RedactsInContent() { + return filteredObject(content, "redacts") + } + return emptyObject + case "m.room.aliases": + if roomVersion.SpecialCasedAliasesAuth() { + return filteredObject(content, "aliases") + } + return emptyObject + default: + return emptyObject + } +} + +func (pdu *PDU) Redact(roomVersion id.RoomVersion) *PDU { + pdu.Unknown = nil + pdu.Unsigned = nil + if roomVersion.UpdatedRedactionRules() { + pdu.DeprecatedPrevState = nil + pdu.DeprecatedOrigin = nil + pdu.DeprecatedMembership = nil + } + if pdu.Type != "m.room.redaction" || roomVersion.RedactsInContent() { + pdu.Redacts = nil + } + pdu.Content = RedactContent(pdu.Type, pdu.Content, roomVersion) + return pdu +} diff --git a/federation/pdu/signature.go b/federation/pdu/signature.go new file mode 100644 index 00000000..04e7c5ef --- /dev/null +++ b/federation/pdu/signature.go @@ -0,0 +1,60 @@ +// Copyright (c) 2025 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +//go:build goexperiment.jsonv2 + +package pdu + +import ( + "crypto/ed25519" + "encoding/base64" + "fmt" + "time" + + "maunium.net/go/mautrix/federation/signutil" + "maunium.net/go/mautrix/id" +) + +func (pdu *PDU) Sign(roomVersion id.RoomVersion, serverName string, keyID id.KeyID, privateKey ed25519.PrivateKey) error { + err := pdu.FillContentHash() + if err != nil { + return err + } + rawJSON, err := marshalCanonical(pdu.Clone().RedactForSignature(roomVersion)) + if err != nil { + return fmt.Errorf("failed to marshal redacted PDU to sign: %w", err) + } + signature := ed25519.Sign(privateKey, rawJSON) + pdu.AddSignature(serverName, keyID, base64.RawStdEncoding.EncodeToString(signature)) + return nil +} + +func (pdu *PDU) VerifySignature(roomVersion id.RoomVersion, serverName string, getKey GetKeyFunc) error { + rawJSON, err := marshalCanonical(pdu.Clone().RedactForSignature(roomVersion)) + if err != nil { + return fmt.Errorf("failed to marshal redacted PDU to verify signature: %w", err) + } + verified := false + for keyID, sig := range pdu.Signatures[serverName] { + originServerTS := time.UnixMilli(pdu.OriginServerTS) + key, validUntil, err := getKey(serverName, keyID, originServerTS) + if err != nil { + return fmt.Errorf("failed to get key %s for %s: %w", keyID, serverName, err) + } else if key == "" { + return fmt.Errorf("key %s not found for %s", keyID, serverName) + } else if validUntil.Before(originServerTS) && roomVersion.EnforceSigningKeyValidity() { + return fmt.Errorf("key %s for %s is only valid until %s, but event is from %s", keyID, serverName, validUntil, originServerTS) + } else if err = signutil.VerifyJSONRaw(key, sig, rawJSON); err != nil { + return fmt.Errorf("failed to verify signature from key %s: %w", keyID, err) + } else { + verified = true + } + } + if !verified { + return fmt.Errorf("no verifiable signatures found for server %s", serverName) + } + return nil +} diff --git a/federation/pdu/signature_test.go b/federation/pdu/signature_test.go new file mode 100644 index 00000000..01df5076 --- /dev/null +++ b/federation/pdu/signature_test.go @@ -0,0 +1,102 @@ +// Copyright (c) 2025 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +//go:build goexperiment.jsonv2 + +package pdu_test + +import ( + "crypto/ed25519" + "encoding/base64" + "encoding/json/jsontext" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "go.mau.fi/util/exerrors" + + "maunium.net/go/mautrix/federation/pdu" + "maunium.net/go/mautrix/id" +) + +func TestPDU_VerifySignature(t *testing.T) { + for _, test := range testPDUs { + t.Run(test.name, func(t *testing.T) { + parsed := parsePDU(test.pdu) + err := parsed.VerifySignature(test.roomVersion, test.serverName, test.getKey) + assert.NoError(t, err) + }) + } +} + +func TestPDU_VerifySignature_Fail_NoKey(t *testing.T) { + test := roomV12MessageTestPDU + parsed := parsePDU(test.pdu) + err := parsed.VerifySignature(test.roomVersion, test.serverName, func(serverName string, keyID id.KeyID, minValidUntil time.Time) (key id.SigningKey, validUntil time.Time, err error) { + return + }) + assert.Error(t, err) +} + +func TestPDU_VerifySignature_V4ExpiredKey(t *testing.T) { + test := roomV4MessageTestPDU + parsed := parsePDU(test.pdu) + err := parsed.VerifySignature(test.roomVersion, test.serverName, func(serverName string, keyID id.KeyID, minValidUntil time.Time) (key id.SigningKey, validUntil time.Time, err error) { + key = test.keys[keyID].key + validUntil = time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC) + return + }) + assert.NoError(t, err) +} + +func TestPDU_VerifySignature_V12ExpiredKey(t *testing.T) { + test := roomV12MessageTestPDU + parsed := parsePDU(test.pdu) + err := parsed.VerifySignature(test.roomVersion, test.serverName, func(serverName string, keyID id.KeyID, minValidUntil time.Time) (key id.SigningKey, validUntil time.Time, err error) { + key = test.keys[keyID].key + validUntil = time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC) + return + }) + assert.Error(t, err) +} + +func TestPDU_VerifySignature_V12InvalidSignature(t *testing.T) { + test := roomV12MessageTestPDU + parsed := parsePDU(test.pdu) + for _, sigs := range parsed.Signatures { + for key := range sigs { + sigs[key] = sigs[key][:len(sigs[key])-3] + "ABC" + } + } + err := parsed.VerifySignature(test.roomVersion, test.serverName, test.getKey) + assert.Error(t, err) +} + +func TestPDU_Sign(t *testing.T) { + pubKey, privKey := exerrors.Must2(ed25519.GenerateKey(nil)) + evt := &pdu.PDU{ + AuthEvents: []id.EventID{"$gCzdJUVV93Qory0x7p_PLG5UUiDjPJNe1H12qbHTuFA", "$hyeL_nU_L3tsZ2dtZZpAHk0Skv-PqFQIipuII_By584"}, + Content: jsontext.Value(`{"msgtype":"m.text","body":"Hello, world!"}`), + Depth: 123, + OriginServerTS: 1755384351627, + PrevEvents: []id.EventID{"$gCzdJUVV93Qory0x7p_PLG5UUiDjPJNe1H12qbHTuFA"}, + RoomID: "!mauT12AzsoqxV7Abvy_ApA-HNPK1LcT4GbP70_AOPyQ", + Sender: "@tulir:example.com", + Type: "m.room.message", + } + err := evt.Sign(id.RoomV12, "example.com", "ed25519:rand", privKey) + require.NoError(t, err) + err = evt.VerifySignature(id.RoomV11, "example.com", func(serverName string, keyID id.KeyID, minValidUntil time.Time) (key id.SigningKey, validUntil time.Time, err error) { + if serverName == "example.com" && keyID == "ed25519:rand" { + key = id.SigningKey(base64.RawStdEncoding.EncodeToString(pubKey)) + validUntil = time.Now() + } + return + }) + require.NoError(t, err) + +} diff --git a/federation/pdu/v1.go b/federation/pdu/v1.go new file mode 100644 index 00000000..9557f8ab --- /dev/null +++ b/federation/pdu/v1.go @@ -0,0 +1,277 @@ +// Copyright (c) 2025 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +//go:build goexperiment.jsonv2 + +package pdu + +import ( + "crypto/ed25519" + "crypto/hmac" + "crypto/sha256" + "encoding/base64" + "encoding/json/jsontext" + "encoding/json/v2" + "fmt" + "time" + + "github.com/tidwall/gjson" + "go.mau.fi/util/ptr" + + "maunium.net/go/mautrix/event" + "maunium.net/go/mautrix/federation/signutil" + "maunium.net/go/mautrix/id" +) + +type V1EventReference struct { + ID id.EventID + Hashes Hashes +} + +var ( + _ json.UnmarshalerFrom = (*V1EventReference)(nil) + _ json.MarshalerTo = (*V1EventReference)(nil) +) + +func (er *V1EventReference) MarshalJSONTo(enc *jsontext.Encoder) error { + return json.MarshalEncode(enc, []any{er.ID, er.Hashes}) +} + +func (er *V1EventReference) UnmarshalJSONFrom(dec *jsontext.Decoder) error { + var ref V1EventReference + var data []jsontext.Value + if err := json.UnmarshalDecode(dec, &data); err != nil { + return err + } else if len(data) != 2 { + return fmt.Errorf("V1EventReference.UnmarshalJSONFrom: expected array with 2 elements, got %d", len(data)) + } else if err = json.Unmarshal(data[0], &ref.ID); err != nil { + return fmt.Errorf("V1EventReference.UnmarshalJSONFrom: failed to unmarshal event ID: %w", err) + } else if err = json.Unmarshal(data[1], &ref.Hashes); err != nil { + return fmt.Errorf("V1EventReference.UnmarshalJSONFrom: failed to unmarshal hashes: %w", err) + } + *er = ref + return nil +} + +type RoomV1PDU struct { + AuthEvents []V1EventReference `json:"auth_events"` + Content jsontext.Value `json:"content"` + Depth int64 `json:"depth"` + EventID id.EventID `json:"event_id"` + Hashes *Hashes `json:"hashes,omitzero"` + OriginServerTS int64 `json:"origin_server_ts"` + PrevEvents []V1EventReference `json:"prev_events"` + Redacts *id.EventID `json:"redacts,omitzero"` + RoomID id.RoomID `json:"room_id"` + Sender id.UserID `json:"sender"` + Signatures map[string]map[id.KeyID]string `json:"signatures,omitzero"` + StateKey *string `json:"state_key,omitzero"` + Type string `json:"type"` + Unsigned jsontext.Value `json:"unsigned,omitzero"` + + Unknown jsontext.Value `json:",unknown"` + + // Deprecated legacy fields + DeprecatedPrevState jsontext.Value `json:"prev_state,omitzero"` + DeprecatedOrigin jsontext.Value `json:"origin,omitzero"` + DeprecatedMembership jsontext.Value `json:"membership,omitzero"` +} + +func (pdu *RoomV1PDU) GetRoomID() (id.RoomID, error) { + return pdu.RoomID, nil +} + +func (pdu *RoomV1PDU) GetEventID(roomVersion id.RoomVersion) (id.EventID, error) { + if !pdu.SupportsRoomVersion(roomVersion) { + return "", fmt.Errorf("RoomV1PDU.GetEventID: unsupported room version %s", roomVersion) + } + return pdu.EventID, nil +} + +func (pdu *RoomV1PDU) RedactForSignature(roomVersion id.RoomVersion) *RoomV1PDU { + pdu.Signatures = nil + return pdu.Redact(roomVersion) +} + +func (pdu *RoomV1PDU) Redact(roomVersion id.RoomVersion) *RoomV1PDU { + pdu.Unknown = nil + pdu.Unsigned = nil + if pdu.Type != "m.room.redaction" { + pdu.Redacts = nil + } + pdu.Content = RedactContent(pdu.Type, pdu.Content, roomVersion) + return pdu +} + +func (pdu *RoomV1PDU) GetReferenceHash(roomVersion id.RoomVersion) ([32]byte, error) { + if !pdu.SupportsRoomVersion(roomVersion) { + return [32]byte{}, fmt.Errorf("RoomV1PDU.GetReferenceHash: unsupported room version %s", roomVersion) + } + if pdu == nil { + return [32]byte{}, ErrPDUIsNil + } + if pdu.Hashes == nil || pdu.Hashes.SHA256 == nil { + if err := pdu.FillContentHash(); err != nil { + return [32]byte{}, err + } + } + rawJSON, err := marshalCanonical(pdu.Clone().RedactForSignature(roomVersion)) + if err != nil { + return [32]byte{}, fmt.Errorf("failed to marshal redacted PDU to calculate event ID: %w", err) + } + return sha256.Sum256(rawJSON), nil +} + +func (pdu *RoomV1PDU) CalculateContentHash() ([32]byte, error) { + if pdu == nil { + return [32]byte{}, ErrPDUIsNil + } + pduClone := pdu.Clone() + pduClone.Signatures = nil + pduClone.Unsigned = nil + pduClone.Hashes = nil + rawJSON, err := marshalCanonical(pduClone) + if err != nil { + return [32]byte{}, fmt.Errorf("failed to marshal PDU to calculate content hash: %w", err) + } + return sha256.Sum256(rawJSON), nil +} + +func (pdu *RoomV1PDU) FillContentHash() error { + if pdu == nil { + return ErrPDUIsNil + } else if pdu.Hashes != nil { + return nil + } else if hash, err := pdu.CalculateContentHash(); err != nil { + return err + } else { + pdu.Hashes = &Hashes{SHA256: hash[:]} + return nil + } +} + +func (pdu *RoomV1PDU) VerifyContentHash() bool { + if pdu == nil || pdu.Hashes == nil { + return false + } + calculatedHash, err := pdu.CalculateContentHash() + if err != nil { + return false + } + return hmac.Equal(calculatedHash[:], pdu.Hashes.SHA256) +} + +func (pdu *RoomV1PDU) Clone() *RoomV1PDU { + return ptr.Clone(pdu) +} + +func (pdu *RoomV1PDU) Sign(roomVersion id.RoomVersion, serverName string, keyID id.KeyID, privateKey ed25519.PrivateKey) error { + if !pdu.SupportsRoomVersion(roomVersion) { + return fmt.Errorf("RoomV1PDU.Sign: unsupported room version %s", roomVersion) + } + err := pdu.FillContentHash() + if err != nil { + return err + } + rawJSON, err := marshalCanonical(pdu.Clone().RedactForSignature(roomVersion)) + if err != nil { + return fmt.Errorf("failed to marshal redacted PDU to sign: %w", err) + } + signature := ed25519.Sign(privateKey, rawJSON) + if pdu.Signatures == nil { + pdu.Signatures = make(map[string]map[id.KeyID]string) + } + if _, ok := pdu.Signatures[serverName]; !ok { + pdu.Signatures[serverName] = make(map[id.KeyID]string) + } + pdu.Signatures[serverName][keyID] = base64.RawStdEncoding.EncodeToString(signature) + return nil +} + +func (pdu *RoomV1PDU) VerifySignature(roomVersion id.RoomVersion, serverName string, getKey GetKeyFunc) error { + if !pdu.SupportsRoomVersion(roomVersion) { + return fmt.Errorf("RoomV1PDU.VerifySignature: unsupported room version %s", roomVersion) + } + rawJSON, err := marshalCanonical(pdu.Clone().RedactForSignature(roomVersion)) + if err != nil { + return fmt.Errorf("failed to marshal redacted PDU to verify signature: %w", err) + } + verified := false + for keyID, sig := range pdu.Signatures[serverName] { + originServerTS := time.UnixMilli(pdu.OriginServerTS) + key, _, err := getKey(serverName, keyID, originServerTS) + if err != nil { + return fmt.Errorf("failed to get key %s for %s: %w", keyID, serverName, err) + } else if key == "" { + return fmt.Errorf("key %s not found for %s", keyID, serverName) + } else if err = signutil.VerifyJSONRaw(key, sig, rawJSON); err != nil { + return fmt.Errorf("failed to verify signature from key %s: %w", keyID, err) + } else { + verified = true + } + } + if !verified { + return fmt.Errorf("no verifiable signatures found for server %s", serverName) + } + return nil +} + +func (pdu *RoomV1PDU) SupportsRoomVersion(roomVersion id.RoomVersion) bool { + switch roomVersion { + case id.RoomV0, id.RoomV1, id.RoomV2: + return true + default: + return false + } +} + +func (pdu *RoomV1PDU) ToClientEvent(roomVersion id.RoomVersion) (*event.Event, error) { + if !pdu.SupportsRoomVersion(roomVersion) { + return nil, fmt.Errorf("RoomV1PDU.ToClientEvent: unsupported room version %s", roomVersion) + } + evtType := event.Type{Type: pdu.Type, Class: event.MessageEventType} + if pdu.StateKey != nil { + evtType.Class = event.StateEventType + } + evt := &event.Event{ + StateKey: pdu.StateKey, + Sender: pdu.Sender, + Type: evtType, + Timestamp: pdu.OriginServerTS, + ID: pdu.EventID, + RoomID: pdu.RoomID, + Redacts: ptr.Val(pdu.Redacts), + } + err := json.Unmarshal(pdu.Content, &evt.Content) + if err != nil { + return nil, fmt.Errorf("failed to unmarshal content: %w", err) + } + return evt, nil +} + +func (pdu *RoomV1PDU) AuthEventSelection(_ id.RoomVersion) (keys AuthEventSelection) { + if pdu.Type == event.StateCreate.Type && pdu.StateKey != nil { + return AuthEventSelection{} + } + keys = make(AuthEventSelection, 0, 3) + keys.Add(event.StateCreate.Type, "") + keys.Add(event.StatePowerLevels.Type, "") + keys.Add(event.StateMember.Type, pdu.Sender.String()) + if pdu.Type == event.StateMember.Type && pdu.StateKey != nil { + keys.Add(event.StateMember.Type, *pdu.StateKey) + membership := event.Membership(gjson.GetBytes(pdu.Content, "membership").Str) + if membership == event.MembershipJoin || membership == event.MembershipInvite || membership == event.MembershipKnock { + keys.Add(event.StateJoinRules.Type, "") + } + if membership == event.MembershipInvite { + thirdPartyInviteToken := gjson.GetBytes(pdu.Content, thirdPartyInviteTokenPath).Str + if thirdPartyInviteToken != "" { + keys.Add(event.StateThirdPartyInvite.Type, thirdPartyInviteToken) + } + } + } + return +} diff --git a/federation/pdu/v1_test.go b/federation/pdu/v1_test.go new file mode 100644 index 00000000..ecf2dbd2 --- /dev/null +++ b/federation/pdu/v1_test.go @@ -0,0 +1,86 @@ +// Copyright (c) 2025 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +//go:build goexperiment.jsonv2 + +package pdu_test + +import ( + "encoding/base64" + "encoding/json/v2" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "go.mau.fi/util/exerrors" + + "maunium.net/go/mautrix/federation/pdu" + "maunium.net/go/mautrix/id" +) + +var testV1PDUs = []testPDU{{ + name: "m.room.message in v1 room", + pdu: `{"auth_events":[["$159234730483190eXavq:matrix.org",{"sha256":"VprZrhMqOQyKbfF3UE26JXE8D27ih4R/FGGc8GZ0Whs"}],["$143454825711DhCxH:matrix.org",{"sha256":"3sJh/5GOB094OKuhbjL634Gt69YIcge9GD55ciJa9ok"}],["$156837651426789wiPdh:maunium.net",{"sha256":"FGyR3sxJ/VxYabDkO/5qtwrPR3hLwGknJ0KX0w3GUHE"}]],"content":{"body":"photo-1526336024174-e58f5cdd8e13.jpg","info":{"h":1620,"mimetype":"image/jpeg","size":208053,"w":1080},"msgtype":"m.image","url":"mxc://maunium.net/aEqEghIjFPAerIhCxJCYpQeC"},"depth":16669,"event_id":"$16738169022163bokdi:maunium.net","hashes":{"sha256":"XYB47Gf2vAci3BTguIJaC75ZYGMuVY65jcvoUVgpcLA"},"origin":"maunium.net","origin_server_ts":1673816902100,"prev_events":[["$1673816901121325UMCjA:matrix.org",{"sha256":"t7e0IYHLI3ydIPoIU8a8E/pIWXH9cNLlQBEtGyGtHwc"}]],"room_id":"!jhpZBTbckszblMYjMK:matrix.org","sender":"@cat:maunium.net","type":"m.room.message","signatures":{"maunium.net":{"ed25519:a_xxeS":"uRZbEm+P+Y1ZVgwBn5I6SlaUZdzlH1bB4nv81yt5EIQ0b1fZ8YgM4UWMijrrXp3+NmqRFl0cakSM3MneJOtFCw"}},"unsigned":{"age_ts":1673816902100}}`, + eventID: "$16738169022163bokdi:maunium.net", + roomVersion: id.RoomV1, + serverDetails: mauniumNet, +}, { + name: "m.room.create in v1 room", + pdu: `{"origin": "matrix.org", "signatures": {"matrix.org": {"ed25519:auto": "XTejpXn5REoHrZWgCpJglGX7MfOWS2zUjYwJRLrwW2PQPbFdqtL+JnprBXwIP2C1NmgWSKG+am1QdApu0KoHCQ"}}, "origin_server_ts": 1434548257426, "sender": "@appservice-irc:matrix.org", "event_id": "$143454825711DhCxH:matrix.org", "prev_events": [], "unsigned": {"age": 12872287834}, "state_key": "", "content": {"creator": "@appservice-irc:matrix.org"}, "depth": 1, "prev_state": [], "room_id": "!jhpZBTbckszblMYjMK:matrix.org", "auth_events": [], "hashes": {"sha256": "+SSdmeeoKI/6yK6sY4XAFljWFiugSlCiXQf0QMCZjTs"}, "type": "m.room.create"}`, + eventID: "$143454825711DhCxH:matrix.org", + roomVersion: id.RoomV1, + serverDetails: matrixOrg, +}, { + name: "m.room.member in v1 room", + pdu: `{"auth_events": [["$1536447669931522zlyWe:matrix.org", {"sha256": "UkzPGd7cPAGvC0FVx3Yy2/Q0GZhA2kcgj8MGp5pjYV8"}], ["$143454825711DhCxH:matrix.org", {"sha256": "3sJh/5GOB094OKuhbjL634Gt69YIcge9GD55ciJa9ok"}], ["$143454825714nUEqZ:matrix.org", {"sha256": "NjuZXu8EDMfIfejPcNlC/IdnKQAGpPIcQjHaf0BZaHk"}]], "prev_events": [["$15660585503271JRRMm:maunium.net", {"sha256": "/Sm7uSLkYMHapp6I3NuEVJlk2JucW2HqjsQy9vzhciA"}]], "type": "m.room.member", "room_id": "!jhpZBTbckszblMYjMK:matrix.org", "sender": "@tulir:maunium.net", "content": {"membership": "join", "avatar_url": "mxc://maunium.net/jdlSfvudiMSmcRrleeiYjjFO", "displayname": "tulir"}, "depth": 10485, "prev_state": [], "state_key": "@tulir:maunium.net", "event_id": "$15660585693272iEryv:maunium.net", "origin": "maunium.net", "origin_server_ts": 1566058569201, "hashes": {"sha256": "1D6fdDzKsMGCxSqlXPA7I9wGQNTutVuJke1enGHoWK8"}, "signatures": {"maunium.net": {"ed25519:a_xxeS": "Lj/zDK6ozr4vgsxyL8jY56wTGWoA4jnlvkTs5paCX1w3nNKHnQnSMi+wuaqI6yv5vYh9usGWco2LLMuMzYXcBg"}}, "unsigned": {"age_ts": 1566058569201, "replaces_state": "$15660585383268liyBc:maunium.net"}}`, + eventID: "$15660585693272iEryv:maunium.net", + roomVersion: id.RoomV1, + serverDetails: mauniumNet, +}} + +func parseV1PDU(pdu string) (out *pdu.RoomV1PDU) { + exerrors.PanicIfNotNil(json.Unmarshal([]byte(pdu), &out)) + return +} + +func TestRoomV1PDU_CalculateContentHash(t *testing.T) { + for _, test := range testV1PDUs { + t.Run(test.name, func(t *testing.T) { + parsed := parseV1PDU(test.pdu) + contentHash := exerrors.Must(parsed.CalculateContentHash()) + assert.Equal( + t, + base64.RawStdEncoding.EncodeToString(parsed.Hashes.SHA256), + base64.RawStdEncoding.EncodeToString(contentHash[:]), + ) + }) + } +} + +func TestRoomV1PDU_VerifyContentHash(t *testing.T) { + for _, test := range testV1PDUs { + t.Run(test.name, func(t *testing.T) { + parsed := parseV1PDU(test.pdu) + assert.True(t, parsed.VerifyContentHash()) + }) + } +} + +func TestRoomV1PDU_VerifySignature(t *testing.T) { + for _, test := range testV1PDUs { + t.Run(test.name, func(t *testing.T) { + parsed := parseV1PDU(test.pdu) + err := parsed.VerifySignature(test.roomVersion, test.serverName, func(serverName string, keyID id.KeyID, _ time.Time) (id.SigningKey, time.Time, error) { + key, ok := test.keys[keyID] + if ok { + return key.key, key.validUntilTS, nil + } + return "", time.Time{}, nil + }) + assert.NoError(t, err) + }) + } +} diff --git a/federation/resolution.go b/federation/resolution.go new file mode 100644 index 00000000..a3188266 --- /dev/null +++ b/federation/resolution.go @@ -0,0 +1,198 @@ +// Copyright (c) 2024 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package federation + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "io" + "net" + "net/http" + "net/url" + "strconv" + "strings" + "time" + + "github.com/rs/zerolog" + + "maunium.net/go/mautrix" +) + +type ResolvedServerName struct { + ServerName string `json:"server_name"` + HostHeader string `json:"host_header"` + IPPort []string `json:"ip_port"` + Expires time.Time `json:"expires"` +} + +type ResolveServerNameOpts struct { + HTTPClient *http.Client + DNSClient *net.Resolver +} + +var ( + ErrInvalidServerName = errors.New("invalid server name") +) + +// ResolveServerName implements the full server discovery algorithm as specified in https://spec.matrix.org/v1.11/server-server-api/#resolving-server-names +func ResolveServerName(ctx context.Context, serverName string, opts ...*ResolveServerNameOpts) (*ResolvedServerName, error) { + var opt ResolveServerNameOpts + if len(opts) > 0 && opts[0] != nil { + opt = *opts[0] + } + if opt.HTTPClient == nil { + opt.HTTPClient = http.DefaultClient + } + if opt.DNSClient == nil { + opt.DNSClient = net.DefaultResolver + } + output := ResolvedServerName{ + ServerName: serverName, + HostHeader: serverName, + IPPort: []string{serverName}, + Expires: time.Now().Add(24 * time.Hour), + } + hostname, port, ok := ParseServerName(serverName) + if !ok { + return nil, ErrInvalidServerName + } + // Steps 1 and 2: handle IP literals and hostnames with port + if net.ParseIP(hostname) != nil || port != 0 { + if port == 0 { + port = 8448 + } + output.IPPort = []string{net.JoinHostPort(hostname, strconv.Itoa(int(port)))} + return &output, nil + } + // Step 3: resolve .well-known + wellKnown, expiry, err := RequestWellKnown(ctx, opt.HTTPClient, hostname) + if err != nil { + zerolog.Ctx(ctx).Trace(). + Str("server_name", serverName). + Err(err). + Msg("Failed to get well-known data") + } else if wellKnown != nil { + output.Expires = expiry + output.HostHeader = wellKnown.Server + wkHost, wkPort, ok := ParseServerName(wellKnown.Server) + if ok { + hostname, port = wkHost, wkPort + } + // Step 3.1 and 3.2: IP literals and hostnames with port inside .well-known + if net.ParseIP(hostname) != nil || port != 0 { + if port == 0 { + port = 8448 + } + output.IPPort = []string{net.JoinHostPort(hostname, strconv.Itoa(int(port)))} + return &output, nil + } + } + // Step 3.3, 3.4, 4 and 5: resolve SRV records + srv, err := RequestSRV(ctx, opt.DNSClient, hostname) + if err != nil { + // TODO log more noisily for abnormal errors? + zerolog.Ctx(ctx).Trace(). + Str("server_name", serverName). + Str("hostname", hostname). + Err(err). + Msg("Failed to get SRV record") + } else if len(srv) > 0 { + output.IPPort = make([]string, len(srv)) + for i, record := range srv { + output.IPPort[i] = net.JoinHostPort(strings.TrimRight(record.Target, "."), strconv.Itoa(int(record.Port))) + } + return &output, nil + } + // Step 6 or 3.5: no SRV records were found, so default to port 8448 + output.IPPort = []string{net.JoinHostPort(hostname, "8448")} + return &output, nil +} + +// RequestSRV resolves the `_matrix-fed._tcp` SRV record for the given hostname. +// If the new matrix-fed record is not found, it falls back to the old `_matrix._tcp` record. +func RequestSRV(ctx context.Context, cli *net.Resolver, hostname string) ([]*net.SRV, error) { + _, target, err := cli.LookupSRV(ctx, "matrix-fed", "tcp", hostname) + var dnsErr *net.DNSError + if err != nil && errors.As(err, &dnsErr) && dnsErr.IsNotFound { + _, target, err = cli.LookupSRV(ctx, "matrix", "tcp", hostname) + } + return target, err +} + +func parseCacheControl(resp *http.Response) time.Duration { + cc := resp.Header.Get("Cache-Control") + if cc == "" { + return 0 + } + parts := strings.Split(cc, ",") + for _, part := range parts { + kv := strings.SplitN(strings.TrimSpace(part), "=", 1) + switch kv[0] { + case "no-cache", "no-store": + return 0 + case "max-age": + if len(kv) < 2 { + continue + } + maxAge, err := strconv.Atoi(kv[1]) + if err != nil || maxAge < 0 { + continue + } + age, _ := strconv.Atoi(resp.Header.Get("Age")) + return time.Duration(maxAge-age) * time.Second + } + } + return 0 +} + +const ( + MinCacheDuration = 1 * time.Hour + MaxCacheDuration = 72 * time.Hour + DefaultCacheDuration = 24 * time.Hour +) + +// RequestWellKnown sends a request to the well-known endpoint of a server and returns the response, +// plus the time when the cache should expire. +func RequestWellKnown(ctx context.Context, cli *http.Client, hostname string) (*RespWellKnown, time.Time, error) { + wellKnownURL := url.URL{ + Scheme: "https", + Host: hostname, + Path: "/.well-known/matrix/server", + } + req, err := http.NewRequestWithContext(ctx, http.MethodGet, wellKnownURL.String(), nil) + if err != nil { + return nil, time.Time{}, fmt.Errorf("failed to prepare request: %w", err) + } + resp, err := cli.Do(req) + if err != nil { + return nil, time.Time{}, fmt.Errorf("failed to send request: %w", err) + } + defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { + return nil, time.Time{}, fmt.Errorf("unexpected status code %d", resp.StatusCode) + } else if resp.ContentLength > mautrix.WellKnownMaxSize { + return nil, time.Time{}, fmt.Errorf("response too large: %d bytes", resp.ContentLength) + } + var respData RespWellKnown + err = json.NewDecoder(io.LimitReader(resp.Body, mautrix.WellKnownMaxSize)).Decode(&respData) + if err != nil { + return nil, time.Time{}, fmt.Errorf("failed to decode response: %w", err) + } else if respData.Server == "" { + return nil, time.Time{}, errors.New("server name not found in response") + } + cacheDuration := parseCacheControl(resp) + if cacheDuration <= 0 { + cacheDuration = DefaultCacheDuration + } else if cacheDuration < MinCacheDuration { + cacheDuration = MinCacheDuration + } else if cacheDuration > MaxCacheDuration { + cacheDuration = MaxCacheDuration + } + return &respData, time.Now().Add(24 * time.Hour), nil +} diff --git a/federation/resolution_test.go b/federation/resolution_test.go new file mode 100644 index 00000000..62200454 --- /dev/null +++ b/federation/resolution_test.go @@ -0,0 +1,115 @@ +// Copyright (c) 2024 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package federation_test + +import ( + "context" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "maunium.net/go/mautrix/federation" +) + +type resolveTestCase struct { + name string + serverName string + expected federation.ResolvedServerName +} + +func TestResolveServerName(t *testing.T) { + // See https://t2bot.io/docs/resolvematrix/ for more info on the RM test cases + testCases := []resolveTestCase{{ + "maunium", + "maunium.net", + federation.ResolvedServerName{ + HostHeader: "federation.mau.chat", + IPPort: []string{"meow.host.mau.fi:443"}, + }, + }, { + "IP literal", + "135.181.208.158", + federation.ResolvedServerName{ + HostHeader: "135.181.208.158", + IPPort: []string{"135.181.208.158:8448"}, + }, + }, { + "IP literal with port", + "135.181.208.158:8447", + federation.ResolvedServerName{ + HostHeader: "135.181.208.158:8447", + IPPort: []string{"135.181.208.158:8447"}, + }, + }, { + "RM Step 2", + "2.s.resolvematrix.dev:7652", + federation.ResolvedServerName{ + HostHeader: "2.s.resolvematrix.dev:7652", + IPPort: []string{"2.s.resolvematrix.dev:7652"}, + }, + }, { + "RM Step 3B", + "3b.s.resolvematrix.dev", + federation.ResolvedServerName{ + HostHeader: "wk.3b.s.resolvematrix.dev:7753", + IPPort: []string{"wk.3b.s.resolvematrix.dev:7753"}, + }, + }, { + "RM Step 3C", + "3c.s.resolvematrix.dev", + federation.ResolvedServerName{ + HostHeader: "wk.3c.s.resolvematrix.dev", + IPPort: []string{"srv.wk.3c.s.resolvematrix.dev:7754"}, + }, + }, { + "RM Step 3C MSC4040", + "3c.msc4040.s.resolvematrix.dev", + federation.ResolvedServerName{ + HostHeader: "wk.3c.msc4040.s.resolvematrix.dev", + IPPort: []string{"srv.wk.3c.msc4040.s.resolvematrix.dev:7053"}, + }, + }, { + "RM Step 3D", + "3d.s.resolvematrix.dev", + federation.ResolvedServerName{ + HostHeader: "wk.3d.s.resolvematrix.dev", + IPPort: []string{"wk.3d.s.resolvematrix.dev:8448"}, + }, + }, { + "RM Step 4", + "4.s.resolvematrix.dev", + federation.ResolvedServerName{ + HostHeader: "4.s.resolvematrix.dev", + IPPort: []string{"srv.4.s.resolvematrix.dev:7855"}, + }, + }, { + "RM Step 4 MSC4040", + "4.msc4040.s.resolvematrix.dev", + federation.ResolvedServerName{ + HostHeader: "4.msc4040.s.resolvematrix.dev", + IPPort: []string{"srv.4.msc4040.s.resolvematrix.dev:7054"}, + }, + }, { + "RM Step 5", + "5.s.resolvematrix.dev", + federation.ResolvedServerName{ + HostHeader: "5.s.resolvematrix.dev", + IPPort: []string{"5.s.resolvematrix.dev:8448"}, + }, + }} + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + tc.expected.ServerName = tc.serverName + resp, err := federation.ResolveServerName(context.TODO(), tc.serverName) + require.NoError(t, err) + resp.Expires = time.Time{} + assert.Equal(t, tc.expected, *resp) + }) + } +} diff --git a/federation/serverauth.go b/federation/serverauth.go new file mode 100644 index 00000000..cd300341 --- /dev/null +++ b/federation/serverauth.go @@ -0,0 +1,264 @@ +// Copyright (c) 2025 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package federation + +import ( + "bytes" + "context" + "encoding/json" + "errors" + "fmt" + "io" + "maps" + "net/http" + "slices" + "strings" + "sync" + + "github.com/rs/zerolog" + "go.mau.fi/util/ptr" + + "maunium.net/go/mautrix" + "maunium.net/go/mautrix/id" +) + +type ServerAuth struct { + Keys KeyCache + Client *Client + GetDestination func(XMatrixAuth) string + MaxBodySize int64 + + keyFetchLocks map[string]*sync.Mutex + keyFetchLocksLock sync.Mutex +} + +func NewServerAuth(client *Client, keyCache KeyCache, getDestination func(auth XMatrixAuth) string) *ServerAuth { + return &ServerAuth{ + Keys: keyCache, + Client: client, + GetDestination: getDestination, + MaxBodySize: 50 * 1024 * 1024, + keyFetchLocks: make(map[string]*sync.Mutex), + } +} + +var MUnauthorized = mautrix.RespError{ErrCode: "M_UNAUTHORIZED", StatusCode: http.StatusUnauthorized} + +var ( + errMissingAuthHeader = MUnauthorized.WithMessage("Missing Authorization header") + errInvalidAuthHeader = MUnauthorized.WithMessage("Authorization header does not start with X-Matrix") + errMalformedAuthHeader = MUnauthorized.WithMessage("X-Matrix value is missing required components") + errInvalidDestination = MUnauthorized.WithMessage("Invalid destination in X-Matrix header") + errFailedToQueryKeys = MUnauthorized.WithMessage("Failed to query server keys") + errInvalidSelfSignatures = MUnauthorized.WithMessage("Server keys don't have valid self-signatures") + errRequestBodyTooLarge = mautrix.MTooLarge.WithMessage("Request body too large") + errInvalidJSONBody = mautrix.MBadJSON.WithMessage("Request body is not valid JSON") + errBodyReadFailed = mautrix.MUnknown.WithMessage("Failed to read request body") + errInvalidRequestSignature = MUnauthorized.WithMessage("Failed to verify request signature") +) + +type XMatrixAuth struct { + Origin string + Destination string + KeyID id.KeyID + Signature string +} + +func (xma XMatrixAuth) String() string { + return fmt.Sprintf( + `X-Matrix origin="%s",destination="%s",key="%s",sig="%s"`, + xma.Origin, + xma.Destination, + xma.KeyID, + xma.Signature, + ) +} + +func ParseXMatrixAuth(auth string) (xma XMatrixAuth) { + auth = strings.TrimPrefix(auth, "X-Matrix ") + // TODO upgrade to strings.SplitSeq after Go 1.24 is the minimum + for _, part := range strings.Split(auth, ",") { + part = strings.TrimSpace(part) + eqIdx := strings.Index(part, "=") + if eqIdx == -1 || strings.Count(part, "=") > 1 { + continue + } + val := strings.Trim(part[eqIdx+1:], "\"") + switch strings.ToLower(part[:eqIdx]) { + case "origin": + xma.Origin = val + case "destination": + xma.Destination = val + case "key": + xma.KeyID = id.KeyID(val) + case "sig": + xma.Signature = val + } + } + return +} + +func (sa *ServerAuth) GetKeysWithCache(ctx context.Context, serverName string, keyID id.KeyID) (*ServerKeyResponse, error) { + res, err := sa.Keys.LoadKeys(serverName) + if err != nil { + return nil, fmt.Errorf("failed to read cache: %w", err) + } else if res.HasKey(keyID) { + return res, nil + } + + sa.keyFetchLocksLock.Lock() + lock, ok := sa.keyFetchLocks[serverName] + if !ok { + lock = &sync.Mutex{} + sa.keyFetchLocks[serverName] = lock + } + sa.keyFetchLocksLock.Unlock() + + lock.Lock() + defer lock.Unlock() + res, err = sa.Keys.LoadKeys(serverName) + if err != nil { + return nil, fmt.Errorf("failed to read cache: %w", err) + } else if res != nil { + if res.HasKey(keyID) { + return res, nil + } else if !sa.Keys.ShouldReQuery(serverName) { + zerolog.Ctx(ctx).Trace(). + Str("server_name", serverName). + Stringer("key_id", keyID). + Msg("Not sending key request for missing key ID, last query was too recent") + return res, nil + } + } + res, err = sa.Client.ServerKeys(ctx, serverName) + if err != nil { + sa.Keys.StoreFetchError(serverName, err) + return nil, err + } + sa.Keys.StoreKeys(res) + return res, nil +} + +type fixedLimitedReader struct { + R io.Reader + N int64 + Err error +} + +func (l *fixedLimitedReader) Read(p []byte) (n int, err error) { + if l.N <= 0 { + return 0, l.Err + } + if int64(len(p)) > l.N { + p = p[0:l.N] + } + n, err = l.R.Read(p) + l.N -= int64(n) + return +} + +func (sa *ServerAuth) Authenticate(r *http.Request) (*http.Request, *mautrix.RespError) { + defer func() { + _ = r.Body.Close() + }() + log := zerolog.Ctx(r.Context()) + if r.ContentLength > sa.MaxBodySize { + return nil, &errRequestBodyTooLarge + } + auth := r.Header.Get("Authorization") + if auth == "" { + return nil, &errMissingAuthHeader + } else if !strings.HasPrefix(auth, "X-Matrix ") { + return nil, &errInvalidAuthHeader + } + parsed := ParseXMatrixAuth(auth) + if parsed.Origin == "" || parsed.KeyID == "" || parsed.Signature == "" { + log.Trace().Str("auth_header", auth).Msg("Malformed X-Matrix header") + return nil, &errMalformedAuthHeader + } + destination := sa.GetDestination(parsed) + if destination == "" || (parsed.Destination != "" && parsed.Destination != destination) { + log.Trace(). + Str("got_destination", parsed.Destination). + Str("expected_destination", destination). + Msg("Invalid destination in X-Matrix header") + return nil, &errInvalidDestination + } + resp, err := sa.GetKeysWithCache(r.Context(), parsed.Origin, parsed.KeyID) + if err != nil { + if !errors.Is(err, ErrRecentKeyQueryFailed) { + log.Err(err). + Str("server_name", parsed.Origin). + Msg("Failed to query keys to authenticate request") + } else { + log.Trace().Err(err). + Str("server_name", parsed.Origin). + Msg("Failed to query keys to authenticate request (cached error)") + } + return nil, &errFailedToQueryKeys + } else if err := resp.VerifySelfSignature(); err != nil { + log.Trace().Err(err). + Str("server_name", parsed.Origin). + Msg("Failed to validate self-signatures of server keys") + return nil, &errInvalidSelfSignatures + } + key, ok := resp.VerifyKeys[parsed.KeyID] + if !ok { + keys := slices.Collect(maps.Keys(resp.VerifyKeys)) + log.Trace(). + Stringer("expected_key_id", parsed.KeyID). + Any("found_key_ids", keys). + Msg("Didn't find expected key ID to verify request") + return nil, ptr.Ptr(MUnauthorized.WithMessage("Key ID %q not found (got %v)", parsed.KeyID, keys)) + } + var reqBody []byte + if r.ContentLength != 0 && r.Method != http.MethodGet && r.Method != http.MethodHead { + reqBody, err = io.ReadAll(&fixedLimitedReader{R: r.Body, N: sa.MaxBodySize, Err: errRequestBodyTooLarge}) + if errors.Is(err, errRequestBodyTooLarge) { + return nil, &errRequestBodyTooLarge + } else if err != nil { + log.Err(err). + Str("server_name", parsed.Origin). + Msg("Failed to read request body to authenticate") + return nil, &errBodyReadFailed + } else if !json.Valid(reqBody) { + return nil, &errInvalidJSONBody + } + } + err = (&signableRequest{ + Method: r.Method, + URI: r.URL.RequestURI(), + Origin: parsed.Origin, + Destination: destination, + Content: reqBody, + }).Verify(key.Key, parsed.Signature) + if err != nil { + log.Trace().Err(err).Msg("Request has invalid signature") + return nil, &errInvalidRequestSignature + } + ctx := context.WithValue(r.Context(), contextKeyDestinationServer, destination) + ctx = context.WithValue(ctx, contextKeyOriginServer, parsed.Origin) + ctx = log.With(). + Str("origin_server_name", parsed.Origin). + Str("destination_server_name", destination). + Logger().WithContext(ctx) + modifiedReq := r.WithContext(ctx) + if reqBody != nil { + modifiedReq.Body = io.NopCloser(bytes.NewReader(reqBody)) + } + return modifiedReq, nil +} + +func (sa *ServerAuth) AuthenticateMiddleware(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if modifiedReq, err := sa.Authenticate(r); err != nil { + err.Write(w) + } else { + next.ServeHTTP(w, modifiedReq) + } + }) +} diff --git a/federation/serverauth_test.go b/federation/serverauth_test.go new file mode 100644 index 00000000..f99fc6cf --- /dev/null +++ b/federation/serverauth_test.go @@ -0,0 +1,29 @@ +// Copyright (c) 2025 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package federation_test + +import ( + "context" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "maunium.net/go/mautrix/federation" +) + +func TestServerKeyResponse_VerifySelfSignature(t *testing.T) { + cli := federation.NewClient("", nil, nil) + ctx := context.Background() + for _, name := range []string{"matrix.org", "maunium.net", "cd.mau.dev", "uwu.mau.dev"} { + t.Run(name, func(t *testing.T) { + resp, err := cli.ServerKeys(ctx, name) + require.NoError(t, err) + assert.NoError(t, resp.VerifySelfSignature()) + }) + } +} diff --git a/federation/servername.go b/federation/servername.go new file mode 100644 index 00000000..33590712 --- /dev/null +++ b/federation/servername.go @@ -0,0 +1,95 @@ +// Copyright (c) 2024 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package federation + +import ( + "net" + "strconv" + "strings" +) + +func isSpecCompliantIPv6(host string) bool { + // IPv6address = 2*45IPv6char + // IPv6char = DIGIT / %x41-46 / %x61-66 / ":" / "." + // ; 0-9, A-F, a-f, :, . + if len(host) < 2 || len(host) > 45 { + return false + } + for _, ch := range host { + if (ch < '0' || ch > '9') && (ch < 'a' || ch > 'f') && (ch < 'A' || ch > 'F') && ch != ':' && ch != '.' { + return false + } + } + return true +} + +func isValidIPv4Chunk(str string) bool { + if len(str) == 0 || len(str) > 3 { + return false + } + for _, ch := range str { + if ch < '0' || ch > '9' { + return false + } + } + return true + +} + +func isSpecCompliantIPv4(host string) bool { + // IPv4address = 1*3DIGIT "." 1*3DIGIT "." 1*3DIGIT "." 1*3DIGIT + if len(host) < 7 || len(host) > 15 { + return false + } + parts := strings.Split(host, ".") + return len(parts) == 4 && + isValidIPv4Chunk(parts[0]) && + isValidIPv4Chunk(parts[1]) && + isValidIPv4Chunk(parts[2]) && + isValidIPv4Chunk(parts[3]) +} + +func isSpecCompliantDNSName(host string) bool { + // dns-name = 1*255dns-char + // dns-char = DIGIT / ALPHA / "-" / "." + if len(host) == 0 || len(host) > 255 { + return false + } + for _, ch := range host { + if (ch < '0' || ch > '9') && (ch < 'a' || ch > 'z') && (ch < 'A' || ch > 'Z') && ch != '-' && ch != '.' { + return false + } + } + return true +} + +// ParseServerName parses the port and hostname from a Matrix server name and validates that +// it matches the grammar specified in https://spec.matrix.org/v1.11/appendices/#server-name +func ParseServerName(serverName string) (host string, port uint16, ok bool) { + if len(serverName) == 0 || len(serverName) > 255 { + return + } + colonIdx := strings.LastIndexByte(serverName, ':') + if colonIdx > 0 { + u64Port, err := strconv.ParseUint(serverName[colonIdx+1:], 10, 16) + if err == nil { + port = uint16(u64Port) + serverName = serverName[:colonIdx] + } + } + if serverName[0] == '[' { + if serverName[len(serverName)-1] != ']' { + return + } + host = serverName[1 : len(serverName)-1] + ok = isSpecCompliantIPv6(host) && net.ParseIP(host) != nil + } else { + host = serverName + ok = isSpecCompliantDNSName(host) || isSpecCompliantIPv4(host) + } + return +} diff --git a/federation/servername_test.go b/federation/servername_test.go new file mode 100644 index 00000000..156d692f --- /dev/null +++ b/federation/servername_test.go @@ -0,0 +1,64 @@ +// Copyright (c) 2024 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package federation_test + +import ( + "testing" + + "github.com/stretchr/testify/assert" + + "maunium.net/go/mautrix/federation" +) + +type parseTestCase struct { + name string + serverName string + hostname string + port uint16 +} + +func TestParseServerName(t *testing.T) { + testCases := []parseTestCase{{ + "Domain", + "matrix.org", + "matrix.org", + 0, + }, { + "Domain with port", + "matrix.org:8448", + "matrix.org", + 8448, + }, { + "IPv4 literal", + "1.2.3.4", + "1.2.3.4", + 0, + }, { + "IPv4 literal with port", + "1.2.3.4:8448", + "1.2.3.4", + 8448, + }, { + "IPv6 literal", + "[1234:5678::abcd]", + "1234:5678::abcd", + 0, + }, { + "IPv6 literal with port", + "[1234:5678::abcd]:8448", + "1234:5678::abcd", + 8448, + }} + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + hostname, port, ok := federation.ParseServerName(tc.serverName) + assert.True(t, ok) + assert.Equal(t, tc.hostname, hostname) + assert.Equal(t, tc.port, port) + }) + } +} diff --git a/federation/signingkey.go b/federation/signingkey.go index 3d118233..a4ad9679 100644 --- a/federation/signingkey.go +++ b/federation/signingkey.go @@ -14,9 +14,11 @@ import ( "strings" "time" + "github.com/tidwall/sjson" "go.mau.fi/util/jsontime" "maunium.net/go/mautrix/crypto/canonicaljson" + "maunium.net/go/mautrix/federation/signutil" "maunium.net/go/mautrix/id" ) @@ -31,8 +33,8 @@ type SigningKey struct { // // The output of this function can be parsed back into a [SigningKey] using the [ParseSynapseKey] function. func (sk *SigningKey) SynapseString() string { - alg, id := sk.ID.Parse() - return fmt.Sprintf("%s %s %s", alg, id, base64.RawStdEncoding.EncodeToString(sk.Priv.Seed())) + alg, keyID := sk.ID.Parse() + return fmt.Sprintf("%s %s %s", alg, keyID, base64.RawStdEncoding.EncodeToString(sk.Priv.Seed())) } // ParseSynapseKey parses a Synapse-compatible private key string into a SigningKey. @@ -77,23 +79,62 @@ type ServerKeyResponse struct { OldVerifyKeys map[id.KeyID]OldVerifyKey `json:"old_verify_keys,omitempty"` Signatures map[string]map[id.KeyID]string `json:"signatures,omitempty"` ValidUntilTS jsontime.UnixMilli `json:"valid_until_ts"` + + Raw json.RawMessage `json:"-"` +} + +type QueryKeysResponse struct { + ServerKeys []*ServerKeyResponse `json:"server_keys"` +} + +func (skr *ServerKeyResponse) HasKey(keyID id.KeyID) bool { + if skr == nil { + return false + } else if _, ok := skr.VerifyKeys[keyID]; ok { + return true + } + return false +} + +func (skr *ServerKeyResponse) VerifySelfSignature() error { + for keyID, key := range skr.VerifyKeys { + if err := signutil.VerifyJSON(skr.ServerName, keyID, key.Key, skr.Raw); err != nil { + return fmt.Errorf("failed to verify self signature for key %s: %w", keyID, err) + } + } + return nil +} + +type marshalableSKR ServerKeyResponse + +func (skr *ServerKeyResponse) UnmarshalJSON(data []byte) error { + skr.Raw = data + return json.Unmarshal(data, (*marshalableSKR)(skr)) } type ServerVerifyKey struct { Key id.SigningKey `json:"key"` } +func (svk *ServerVerifyKey) Decode() (ed25519.PublicKey, error) { + return base64.RawStdEncoding.DecodeString(string(svk.Key)) +} + type OldVerifyKey struct { Key id.SigningKey `json:"key"` ExpiredTS jsontime.UnixMilli `json:"expired_ts"` } -func (sk *SigningKey) SignJSON(data any) ([]byte, error) { +func (sk *SigningKey) SignJSON(data any) (string, error) { marshaled, err := json.Marshal(data) if err != nil { - return nil, err + return "", err } - return sk.SignRawJSON(marshaled), nil + marshaled, err = sjson.DeleteBytes(marshaled, "signatures") + if err != nil { + return "", err + } + return base64.RawStdEncoding.EncodeToString(sk.SignRawJSON(marshaled)), nil } func (sk *SigningKey) SignRawJSON(data json.RawMessage) []byte { @@ -116,7 +157,7 @@ func (sk *SigningKey) GenerateKeyResponse(serverName string, oldVerifyKeys map[i } skr.Signatures = map[string]map[id.KeyID]string{ serverName: { - sk.ID: base64.RawURLEncoding.EncodeToString(signature), + sk.ID: signature, }, } return skr diff --git a/federation/signutil/verify.go b/federation/signutil/verify.go new file mode 100644 index 00000000..ea0e7886 --- /dev/null +++ b/federation/signutil/verify.go @@ -0,0 +1,106 @@ +// Copyright (c) 2025 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package signutil + +import ( + "crypto/ed25519" + "encoding/base64" + "encoding/json" + "errors" + "fmt" + + "github.com/tidwall/gjson" + "github.com/tidwall/sjson" + "go.mau.fi/util/exgjson" + + "maunium.net/go/mautrix/crypto/canonicaljson" + "maunium.net/go/mautrix/id" +) + +var ErrSignatureNotFound = errors.New("signature not found") +var ErrInvalidSignature = errors.New("invalid signature") + +func VerifyJSON(serverName string, keyID id.KeyID, key id.SigningKey, data any) error { + var err error + message, ok := data.(json.RawMessage) + if !ok { + message, err = json.Marshal(data) + if err != nil { + return fmt.Errorf("failed to marshal data: %w", err) + } + } + sigVal := gjson.GetBytes(message, exgjson.Path("signatures", serverName, string(keyID))) + if sigVal.Type != gjson.String { + return ErrSignatureNotFound + } + message, err = sjson.DeleteBytes(message, "signatures") + if err != nil { + return fmt.Errorf("failed to delete signatures: %w", err) + } + message, err = sjson.DeleteBytes(message, "unsigned") + if err != nil { + return fmt.Errorf("failed to delete unsigned: %w", err) + } + return VerifyJSONRaw(key, sigVal.Str, message) +} + +func VerifyJSONAny(key id.SigningKey, data any) error { + var err error + message, ok := data.(json.RawMessage) + if !ok { + message, err = json.Marshal(data) + if err != nil { + return fmt.Errorf("failed to marshal data: %w", err) + } + } + sigs := gjson.GetBytes(message, "signatures") + if !sigs.IsObject() { + return ErrSignatureNotFound + } + message, err = sjson.DeleteBytes(message, "signatures") + if err != nil { + return fmt.Errorf("failed to delete signatures: %w", err) + } + message, err = sjson.DeleteBytes(message, "unsigned") + if err != nil { + return fmt.Errorf("failed to delete unsigned: %w", err) + } + var validated bool + sigs.ForEach(func(_, value gjson.Result) bool { + if !value.IsObject() { + return true + } + value.ForEach(func(_, value gjson.Result) bool { + if value.Type != gjson.String { + return true + } + validated = VerifyJSONRaw(key, value.Str, message) == nil + return !validated + }) + return !validated + }) + if !validated { + return ErrInvalidSignature + } + return nil +} + +func VerifyJSONRaw(key id.SigningKey, sig string, message json.RawMessage) error { + sigBytes, err := base64.RawStdEncoding.DecodeString(sig) + if err != nil { + return fmt.Errorf("failed to decode signature: %w", err) + } + keyBytes, err := base64.RawStdEncoding.DecodeString(string(key)) + if err != nil { + return fmt.Errorf("failed to decode key: %w", err) + } + message = canonicaljson.CanonicalJSONAssumeValid(message) + if !ed25519.Verify(keyBytes, message, sigBytes) { + return ErrInvalidSignature + } + return nil +} diff --git a/filter.go b/filter.go index fd6de7a0..54973dab 100644 --- a/filter.go +++ b/filter.go @@ -19,43 +19,45 @@ const ( // Filter is used by clients to specify how the server should filter responses to e.g. sync requests // Specified by: https://spec.matrix.org/v1.2/client-server-api/#filtering type Filter struct { - AccountData FilterPart `json:"account_data,omitempty"` + AccountData *FilterPart `json:"account_data,omitempty"` EventFields []string `json:"event_fields,omitempty"` EventFormat EventFormat `json:"event_format,omitempty"` - Presence FilterPart `json:"presence,omitempty"` - Room RoomFilter `json:"room,omitempty"` + Presence *FilterPart `json:"presence,omitempty"` + Room *RoomFilter `json:"room,omitempty"` + + BeeperToDevice *FilterPart `json:"com.beeper.to_device,omitempty"` } // RoomFilter is used to define filtering rules for room events type RoomFilter struct { - AccountData FilterPart `json:"account_data,omitempty"` - Ephemeral FilterPart `json:"ephemeral,omitempty"` + AccountData *FilterPart `json:"account_data,omitempty"` + Ephemeral *FilterPart `json:"ephemeral,omitempty"` IncludeLeave bool `json:"include_leave,omitempty"` NotRooms []id.RoomID `json:"not_rooms,omitempty"` Rooms []id.RoomID `json:"rooms,omitempty"` - State FilterPart `json:"state,omitempty"` - Timeline FilterPart `json:"timeline,omitempty"` + State *FilterPart `json:"state,omitempty"` + Timeline *FilterPart `json:"timeline,omitempty"` } // FilterPart is used to define filtering rules for specific categories of events type FilterPart struct { - NotRooms []id.RoomID `json:"not_rooms,omitempty"` - Rooms []id.RoomID `json:"rooms,omitempty"` - Limit int `json:"limit,omitempty"` - NotSenders []id.UserID `json:"not_senders,omitempty"` - NotTypes []event.Type `json:"not_types,omitempty"` - Senders []id.UserID `json:"senders,omitempty"` - Types []event.Type `json:"types,omitempty"` - ContainsURL *bool `json:"contains_url,omitempty"` - - LazyLoadMembers bool `json:"lazy_load_members,omitempty"` - IncludeRedundantMembers bool `json:"include_redundant_members,omitempty"` + NotRooms []id.RoomID `json:"not_rooms,omitempty"` + Rooms []id.RoomID `json:"rooms,omitempty"` + Limit int `json:"limit,omitempty"` + NotSenders []id.UserID `json:"not_senders,omitempty"` + NotTypes []event.Type `json:"not_types,omitempty"` + Senders []id.UserID `json:"senders,omitempty"` + Types []event.Type `json:"types,omitempty"` + ContainsURL *bool `json:"contains_url,omitempty"` + LazyLoadMembers bool `json:"lazy_load_members,omitempty"` + IncludeRedundantMembers bool `json:"include_redundant_members,omitempty"` + UnreadThreadNotifications bool `json:"unread_thread_notifications,omitempty"` } // Validate checks if the filter contains valid property values func (filter *Filter) Validate() error { if filter.EventFormat != EventFormatClient && filter.EventFormat != EventFormatFederation { - return errors.New("Bad event_format value. Must be one of [\"client\", \"federation\"]") + return errors.New("bad event_format value") } return nil } @@ -67,7 +69,7 @@ func DefaultFilter() Filter { EventFields: nil, EventFormat: "client", Presence: DefaultFilterPart(), - Room: RoomFilter{ + Room: &RoomFilter{ AccountData: DefaultFilterPart(), Ephemeral: DefaultFilterPart(), IncludeLeave: false, @@ -80,8 +82,8 @@ func DefaultFilter() Filter { } // DefaultFilterPart returns the default filter part used by the Matrix server if no filter is provided in the request -func DefaultFilterPart() FilterPart { - return FilterPart{ +func DefaultFilterPart() *FilterPart { + return &FilterPart{ NotRooms: nil, Rooms: nil, Limit: 20, diff --git a/format/htmlparser.go b/format/htmlparser.go index eb2a662b..e0507d93 100644 --- a/format/htmlparser.go +++ b/format/htmlparser.go @@ -7,13 +7,16 @@ package format import ( + "context" "fmt" "math" "strconv" "strings" + "go.mau.fi/util/exstrings" "golang.org/x/net/html" + "maunium.net/go/mautrix/event" "maunium.net/go/mautrix/id" ) @@ -33,14 +36,16 @@ func (ts TagStack) Has(tag string) bool { } type Context struct { + Ctx context.Context ReturnData map[string]any TagStack TagStack PreserveWhitespace bool } -func NewContext() Context { +func NewContext(ctx context.Context) Context { return Context{ + Ctx: ctx, ReturnData: map[string]any{}, TagStack: make(TagStack, 0, 4), } @@ -62,10 +67,15 @@ type LinkConverter func(text, href string, ctx Context) string type ColorConverter func(text, fg, bg string, ctx Context) string type CodeBlockConverter func(code, language string, ctx Context) string type PillConverter func(displayname, mxid, eventID string, ctx Context) string +type ImageConverter func(src, alt, title, width, height string, isEmoji bool) string -func DefaultPillConverter(displayname, mxid, eventID string, _ Context) string { +const ContextKeyMentions = "_mentions" + +func DefaultPillConverter(displayname, mxid, eventID string, ctx Context) string { switch { case len(mxid) == 0, mxid[0] == '@': + existingMentions, _ := ctx.ReturnData[ContextKeyMentions].([]id.UserID) + ctx.ReturnData[ContextKeyMentions] = append(existingMentions, id.UserID(mxid)) // User link, always just show the displayname return displayname case len(eventID) > 0: @@ -83,6 +93,30 @@ func DefaultPillConverter(displayname, mxid, eventID string, _ Context) string { } } +func onlyBacktickCount(line string) (count int) { + for i := 0; i < len(line); i++ { + if line[i] != '`' { + return -1 + } + count++ + } + return +} + +func DefaultMonospaceBlockConverter(code, language string, ctx Context) string { + if len(code) == 0 || code[len(code)-1] != '\n' { + code += "\n" + } + fence := "```" + for line := range strings.SplitSeq(code, "\n") { + count := onlyBacktickCount(strings.TrimSpace(line)) + if count >= len(fence) { + fence = strings.Repeat("`", count+1) + } + } + return fmt.Sprintf("%s%s\n%s%s", fence, language, code, fence) +} + // HTMLParser is a somewhat customizable Matrix HTML parser. type HTMLParser struct { PillConverter PillConverter @@ -93,12 +127,15 @@ type HTMLParser struct { ItalicConverter TextConverter StrikethroughConverter TextConverter UnderlineConverter TextConverter + MathConverter TextConverter + MathBlockConverter TextConverter LinkConverter LinkConverter SpoilerConverter SpoilerConverter ColorConverter ColorConverter MonospaceBlockConverter CodeBlockConverter MonospaceConverter TextConverter TextConverter TextConverter + ImageConverter ImageConverter } // TaggedString is a string that also contains a HTML tag. @@ -175,25 +212,6 @@ func (parser *HTMLParser) listToString(node *html.Node, ctx Context) string { return strings.Join(children, "\n") } -func LongestSequence(in string, of rune) int { - currentSeq := 0 - maxSeq := 0 - for _, chr := range in { - if chr == of { - currentSeq++ - } else { - if currentSeq > maxSeq { - maxSeq = currentSeq - } - currentSeq = 0 - } - } - if currentSeq > maxSeq { - maxSeq = currentSeq - } - return maxSeq -} - func (parser *HTMLParser) basicFormatToString(node *html.Node, ctx Context) string { str := parser.nodeToTagAwareString(node.FirstChild, ctx) switch node.Data { @@ -220,14 +238,23 @@ func (parser *HTMLParser) basicFormatToString(node *html.Node, ctx Context) stri if parser.MonospaceConverter != nil { return parser.MonospaceConverter(str, ctx) } - surround := strings.Repeat("`", LongestSequence(str, '`')+1) - return fmt.Sprintf("%s%s%s", surround, str, surround) + return SafeMarkdownCode(str) } return str } func (parser *HTMLParser) spanToString(node *html.Node, ctx Context) string { str := parser.nodeToTagAwareString(node.FirstChild, ctx) + if node.Data == "span" || node.Data == "div" { + math, _ := parser.maybeGetAttribute(node, "data-mx-maths") + if math != "" && parser.MathConverter != nil { + if node.Data == "div" && parser.MathBlockConverter != nil { + str = parser.MathBlockConverter(math, ctx) + } else { + str = parser.MathConverter(math, ctx) + } + } + } if node.Data == "span" { reason, isSpoiler := parser.maybeGetAttribute(node, "data-mx-spoiler") if isSpoiler { @@ -284,12 +311,28 @@ func (parser *HTMLParser) linkToString(node *html.Node, ctx Context) string { } if parser.LinkConverter != nil { return parser.LinkConverter(str, href, ctx) - } else if str == href { + } else if str == href || + str == strings.TrimPrefix(href, "mailto:") || + str == strings.TrimPrefix(href, "http://") || + str == strings.TrimPrefix(href, "https://") { return str } return fmt.Sprintf("%s (%s)", str, href) } +func (parser *HTMLParser) imgToString(node *html.Node, ctx Context) string { + src := parser.getAttribute(node, "src") + alt := parser.getAttribute(node, "alt") + title := parser.getAttribute(node, "title") + width := parser.getAttribute(node, "width") + height := parser.getAttribute(node, "height") + _, isEmoji := parser.maybeGetAttribute(node, "data-mx-emoticon") + if parser.ImageConverter != nil { + return parser.ImageConverter(src, alt, title, width, height, isEmoji) + } + return alt +} + func (parser *HTMLParser) tagToString(node *html.Node, ctx Context) string { ctx = ctx.WithTag(node.Data) switch node.Data { @@ -309,8 +352,12 @@ func (parser *HTMLParser) tagToString(node *html.Node, ctx Context) string { return parser.linkToString(node, ctx) case "p": return parser.nodeToTagAwareString(node.FirstChild, ctx) + case "img": + return parser.imgToString(node, ctx) case "hr": return parser.HorizontalLine + case "input": + return parser.inputToString(node, ctx) case "pre": var preStr, language string if node.FirstChild != nil && node.FirstChild.Type == html.ElementNode && node.FirstChild.Data == "code" { @@ -325,20 +372,28 @@ func (parser *HTMLParser) tagToString(node *html.Node, ctx Context) string { if parser.MonospaceBlockConverter != nil { return parser.MonospaceBlockConverter(preStr, language, ctx) } - if len(preStr) == 0 || preStr[len(preStr)-1] != '\n' { - preStr += "\n" - } - return fmt.Sprintf("```%s\n%s```", language, preStr) + return DefaultMonospaceBlockConverter(preStr, language, ctx) default: return parser.nodeToTagAwareString(node.FirstChild, ctx) } } +func (parser *HTMLParser) inputToString(node *html.Node, ctx Context) string { + if len(ctx.TagStack) > 1 && ctx.TagStack[len(ctx.TagStack)-2] == "li" { + _, checked := parser.maybeGetAttribute(node, "checked") + if checked { + return "[x]" + } + return "[ ]" + } + return parser.nodeToTagAwareString(node.FirstChild, ctx) +} + func (parser *HTMLParser) singleNodeToString(node *html.Node, ctx Context) TaggedString { switch node.Type { case html.TextNode: if !ctx.PreserveWhitespace { - node.Data = strings.Replace(node.Data, "\n", "", -1) + node.Data = exstrings.CollapseSpaces(strings.ReplaceAll(node.Data, "\n", "")) } if parser.TextConverter != nil { node.Data = parser.TextConverter(node.Data, ctx) @@ -404,6 +459,35 @@ func (parser *HTMLParser) Parse(htmlData string, ctx Context) string { return parser.nodeToTagAwareString(node, ctx) } +var TextHTMLParser = &HTMLParser{ + TabsToSpaces: 4, + Newline: "\n", + HorizontalLine: "\n---\n", + PillConverter: DefaultPillConverter, +} + +var MarkdownHTMLParser = &HTMLParser{ + TabsToSpaces: 4, + Newline: "\n", + HorizontalLine: "\n---\n", + PillConverter: DefaultPillConverter, + LinkConverter: func(text, href string, ctx Context) string { + if text == href { + return fmt.Sprintf("<%s>", href) + } + return fmt.Sprintf("[%s](%s)", text, href) + }, + MathConverter: func(s string, c Context) string { + return fmt.Sprintf("$%s$", s) + }, + MathBlockConverter: func(s string, c Context) string { + return fmt.Sprintf("$$\n%s\n$$", s) + }, + UnderlineConverter: func(s string, c Context) string { + return fmt.Sprintf("%s", s) + }, +} + // HTMLToText converts Matrix HTML into text with the default settings. func HTMLToText(html string) string { return (&HTMLParser{ @@ -411,23 +495,26 @@ func HTMLToText(html string) string { Newline: "\n", HorizontalLine: "\n---\n", PillConverter: DefaultPillConverter, - }).Parse(html, NewContext()) + }).Parse(html, NewContext(context.TODO())) +} + +func HTMLToMarkdownFull(parser *HTMLParser, html string) (parsed string, mentions *event.Mentions) { + if parser == nil { + parser = MarkdownHTMLParser + } + ctx := NewContext(context.TODO()) + parsed = parser.Parse(html, ctx) + mentionList, _ := ctx.ReturnData[ContextKeyMentions].([]id.UserID) + mentions = &event.Mentions{ + UserIDs: mentionList, + } + return } // HTMLToMarkdown converts Matrix HTML into markdown with the default settings. // // Currently, the only difference to HTMLToText is how links are formatted. func HTMLToMarkdown(html string) string { - return (&HTMLParser{ - TabsToSpaces: 4, - Newline: "\n", - HorizontalLine: "\n---\n", - PillConverter: DefaultPillConverter, - LinkConverter: func(text, href string, ctx Context) string { - if text == href { - return text - } - return fmt.Sprintf("[%s](%s)", text, href) - }, - }).Parse(html, NewContext()) + parsed, _ := HTMLToMarkdownFull(nil, html) + return parsed } diff --git a/format/markdown.go b/format/markdown.go index fa2a8e8a..77ced0dc 100644 --- a/format/markdown.go +++ b/format/markdown.go @@ -8,14 +8,17 @@ package format import ( "fmt" + "regexp" "strings" "github.com/yuin/goldmark" "github.com/yuin/goldmark/extension" "github.com/yuin/goldmark/renderer/html" + "go.mau.fi/util/exstrings" "maunium.net/go/mautrix/event" "maunium.net/go/mautrix/format/mdext" + "maunium.net/go/mautrix/id" ) const paragraphStart = "

      " @@ -39,6 +42,55 @@ func UnwrapSingleParagraph(html string) string { return html } +var mdEscapeRegex = regexp.MustCompile("([\\\\`*_[\\]()])") + +func EscapeMarkdown(text string) string { + text = mdEscapeRegex.ReplaceAllString(text, "\\$1") + text = strings.ReplaceAll(text, ">", ">") + text = strings.ReplaceAll(text, "<", "<") + return text +} + +type uriAble interface { + String() string + URI() *id.MatrixURI +} + +func MarkdownMention(id uriAble) string { + return MarkdownMentionWithName(id.String(), id) +} + +func MarkdownMentionWithName(name string, id uriAble) string { + return MarkdownLink(name, id.URI().MatrixToURL()) +} + +func MarkdownMentionRoomID(name string, id id.RoomID, via ...string) string { + if name == "" { + name = id.String() + } + return MarkdownLink(name, id.URI(via...).MatrixToURL()) +} + +func MarkdownLink(name string, url string) string { + return fmt.Sprintf("[%s](%s)", EscapeMarkdown(name), EscapeMarkdown(url)) +} + +func SafeMarkdownCode[T ~string](textInput T) string { + if textInput == "" { + return "` `" + } + text := strings.ReplaceAll(string(textInput), "\n", " ") + backtickCount := exstrings.LongestSequenceOf(text, '`') + if backtickCount == 0 { + return fmt.Sprintf("`%s`", text) + } + quotes := strings.Repeat("`", backtickCount+1) + if text[0] == '`' || text[len(text)-1] == '`' { + return fmt.Sprintf("%s %s %s", quotes, text, quotes) + } + return fmt.Sprintf("%s%s%s", quotes, text, quotes) +} + func RenderMarkdownCustom(text string, renderer goldmark.Markdown) event.MessageEventContent { var buf strings.Builder err := renderer.Convert([]byte(text), &buf) @@ -49,20 +101,30 @@ func RenderMarkdownCustom(text string, renderer goldmark.Markdown) event.Message return HTMLToContent(htmlBody) } -func HTMLToContent(html string) event.MessageEventContent { - text := HTMLToMarkdown(html) +func TextToContent(text string) event.MessageEventContent { + return event.MessageEventContent{ + MsgType: event.MsgText, + Body: text, + Mentions: &event.Mentions{}, + } +} + +func HTMLToContentFull(renderer *HTMLParser, html string) event.MessageEventContent { + text, mentions := HTMLToMarkdownFull(renderer, html) if html != text { return event.MessageEventContent{ FormattedBody: html, Format: event.FormatHTML, MsgType: event.MsgText, Body: text, + Mentions: mentions, } } - return event.MessageEventContent{ - MsgType: event.MsgText, - Body: text, - } + return TextToContent(text) +} + +func HTMLToContent(html string) event.MessageEventContent { + return HTMLToContentFull(nil, html) } func RenderMarkdown(text string, allowMarkdown, allowHTML bool) event.MessageEventContent { @@ -78,9 +140,6 @@ func RenderMarkdown(text string, allowMarkdown, allowHTML bool) event.MessageEve htmlBody = strings.Replace(text, "\n", "
      ", -1) return HTMLToContent(htmlBody) } else { - return event.MessageEventContent{ - MsgType: event.MsgText, - Body: text, - } + return TextToContent(text) } } diff --git a/format/markdown_test.go b/format/markdown_test.go index 179de6b6..46ea4886 100644 --- a/format/markdown_test.go +++ b/format/markdown_test.go @@ -17,17 +17,20 @@ import ( "maunium.net/go/mautrix/event" "maunium.net/go/mautrix/format" "maunium.net/go/mautrix/format/mdext" + "maunium.net/go/mautrix/id" ) func TestRenderMarkdown_PlainText(t *testing.T) { content := format.RenderMarkdown("hello world", true, true) - assert.Equal(t, event.MessageEventContent{MsgType: event.MsgText, Body: "hello world"}, content) + assert.Equal(t, event.MessageEventContent{MsgType: event.MsgText, Body: "hello world", Mentions: &event.Mentions{}}, content) content = format.RenderMarkdown("hello world", true, false) - assert.Equal(t, event.MessageEventContent{MsgType: event.MsgText, Body: "hello world"}, content) + assert.Equal(t, event.MessageEventContent{MsgType: event.MsgText, Body: "hello world", Mentions: &event.Mentions{}}, content) content = format.RenderMarkdown("hello world", false, true) - assert.Equal(t, event.MessageEventContent{MsgType: event.MsgText, Body: "hello world"}, content) + assert.Equal(t, event.MessageEventContent{MsgType: event.MsgText, Body: "hello world", Mentions: &event.Mentions{}}, content) content = format.RenderMarkdown("hello world", false, false) - assert.Equal(t, event.MessageEventContent{MsgType: event.MsgText, Body: "hello world"}, content) + assert.Equal(t, event.MessageEventContent{MsgType: event.MsgText, Body: "hello world", Mentions: &event.Mentions{}}, content) + content = format.RenderMarkdown(`mention`, false, false) + assert.Equal(t, event.MessageEventContent{MsgType: event.MsgText, Body: "mention", Mentions: &event.Mentions{}}, content) } func TestRenderMarkdown_EscapeHTML(t *testing.T) { @@ -37,6 +40,7 @@ func TestRenderMarkdown_EscapeHTML(t *testing.T) { Body: "hello world", Format: event.FormatHTML, FormattedBody: "<b>hello world</b>", + Mentions: &event.Mentions{}, }, content) } @@ -47,6 +51,7 @@ func TestRenderMarkdown_HTML(t *testing.T) { Body: "**hello world**", Format: event.FormatHTML, FormattedBody: "hello world", + Mentions: &event.Mentions{}, }, content) content = format.RenderMarkdown("hello world", true, true) @@ -55,6 +60,18 @@ func TestRenderMarkdown_HTML(t *testing.T) { Body: "**hello world**", Format: event.FormatHTML, FormattedBody: "hello world", + Mentions: &event.Mentions{}, + }, content) + + content = format.RenderMarkdown(`[mention](https://matrix.to/#/@user:example.com)`, true, false) + assert.Equal(t, event.MessageEventContent{ + MsgType: event.MsgText, + Body: "mention", + Format: event.FormatHTML, + FormattedBody: `mention`, + Mentions: &event.Mentions{ + UserIDs: []id.UserID{"@user:example.com"}, + }, }, content) } @@ -141,3 +158,56 @@ func TestRenderMarkdown_DiscordUnderline(t *testing.T) { assert.Equal(t, html, strings.ReplaceAll(rendered, "\n", "")) } } + +var mathTests = map[string]string{ + "$foo$": `foo`, + "hello $foo$ world": `hello foo world`, + "$$\nfoo\nbar\n$$": `

      foo
      bar
      `, + "`$foo$`": `$foo$`, + "```\n$foo$\n```": `
      $foo$\n
      `, + "~~meow $foo$ asd~~": `meow foo asd`, + "$5 or $10": `$5 or $10`, + "5$ or 10$": `5$ or 10$`, + "$5 or 10$": `5 or 10`, + "$*500*$": `*500*`, + "$$\n*500*\n$$": `
      *500*
      `, + + // TODO: This doesn't work :( + // Maybe same reason as the spoiler wrapping not working? + //"~~$foo$~~": `foo`, +} + +func TestRenderMarkdown_Math(t *testing.T) { + renderer := goldmark.New(goldmark.WithExtensions(extension.Strikethrough, mdext.Math, mdext.EscapeHTML), format.HTMLOptions) + for markdown, html := range mathTests { + rendered := format.UnwrapSingleParagraph(render(renderer, markdown)) + assert.Equal(t, html, strings.ReplaceAll(rendered, "\n", "\\n"), "with input %q", markdown) + } +} + +var customEmojiTests = map[string]string{ + `![:meow:](mxc://example.com/emoji.png "Emoji: meow")`: `:meow:`, +} + +func TestRenderMarkdown_CustomEmoji(t *testing.T) { + renderer := goldmark.New(goldmark.WithExtensions(mdext.CustomEmoji), format.HTMLOptions) + for markdown, html := range customEmojiTests { + rendered := format.UnwrapSingleParagraph(render(renderer, markdown)) + assert.Equal(t, html, rendered, "with input %q", markdown) + } +} + +var codeTests = map[string]string{ + "meow": "`meow`", + "me`ow": "``me`ow``", + "`me`ow": "`` `me`ow ``", + "me`ow`": "`` me`ow` ``", + "`meow`": "`` `meow` ``", + "`````````": "`````````` ````````` ``````````", +} + +func TestSafeMarkdownCode(t *testing.T) { + for input, expected := range codeTests { + assert.Equal(t, expected, format.SafeMarkdownCode(input), "with input %q", input) + } +} diff --git a/format/mdext/customemoji.go b/format/mdext/customemoji.go new file mode 100644 index 00000000..2884a5ea --- /dev/null +++ b/format/mdext/customemoji.go @@ -0,0 +1,73 @@ +// Copyright (c) 2024 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package mdext + +import ( + "bytes" + + "github.com/yuin/goldmark" + "github.com/yuin/goldmark/ast" + "github.com/yuin/goldmark/renderer" + "github.com/yuin/goldmark/util" +) + +type extCustomEmoji struct{} +type customEmojiRenderer struct { + funcs functionCapturer +} + +// CustomEmoji is an extension that converts certain markdown images into Matrix custom emojis. +var CustomEmoji = &extCustomEmoji{} + +type functionCapturer struct { + renderImage renderer.NodeRendererFunc + renderText renderer.NodeRendererFunc + renderString renderer.NodeRendererFunc +} + +func (fc *functionCapturer) Register(kind ast.NodeKind, rendererFunc renderer.NodeRendererFunc) { + switch kind { + case ast.KindImage: + fc.renderImage = rendererFunc + case ast.KindText: + fc.renderText = rendererFunc + case ast.KindString: + fc.renderString = rendererFunc + } +} + +var ( + _ renderer.NodeRendererFuncRegisterer = (*functionCapturer)(nil) + _ renderer.Option = (*functionCapturer)(nil) +) + +func (fc *functionCapturer) SetConfig(cfg *renderer.Config) { + cfg.NodeRenderers[0].Value.(renderer.NodeRenderer).RegisterFuncs(fc) +} + +func (eeh *extCustomEmoji) Extend(m goldmark.Markdown) { + var fc functionCapturer + m.Renderer().AddOptions(&fc) + m.Renderer().AddOptions(renderer.WithNodeRenderers(util.Prioritized(&customEmojiRenderer{fc}, 0))) +} + +func (cer *customEmojiRenderer) RegisterFuncs(reg renderer.NodeRendererFuncRegisterer) { + reg.Register(ast.KindImage, cer.renderImage) +} + +var emojiPrefix = []byte("Emoji: ") +var mxcPrefix = []byte("mxc://") + +func (cer *customEmojiRenderer) renderImage(w util.BufWriter, source []byte, node ast.Node, entering bool) (ast.WalkStatus, error) { + n, ok := node.(*ast.Image) + if ok && entering && bytes.HasPrefix(n.Title, emojiPrefix) && bytes.HasPrefix(n.Destination, mxcPrefix) { + n.Title = bytes.TrimPrefix(n.Title, emojiPrefix) + n.SetAttributeString("data-mx-emoticon", nil) + n.SetAttributeString("height", "32") + } + return cer.funcs.renderImage(w, source, node, entering) +} diff --git a/format/mdext/indentableparagraph.go b/format/mdext/indentableparagraph.go new file mode 100644 index 00000000..a6ebd6c0 --- /dev/null +++ b/format/mdext/indentableparagraph.go @@ -0,0 +1,28 @@ +// Copyright (c) 2024 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package mdext + +import ( + "github.com/yuin/goldmark" + "github.com/yuin/goldmark/parser" + "github.com/yuin/goldmark/util" +) + +// indentableParagraphParser is the default paragraph parser with CanAcceptIndentedLine. +// Used when disabling CodeBlockParser (as disabling it without a replacement will make indented blocks disappear). +type indentableParagraphParser struct { + parser.BlockParser +} + +var defaultIndentableParagraphParser = &indentableParagraphParser{BlockParser: parser.NewParagraphParser()} + +func (b *indentableParagraphParser) CanAcceptIndentedLine() bool { + return true +} + +// FixIndentedParagraphs is a goldmark option which fixes indented paragraphs when disabling CodeBlockParser. +var FixIndentedParagraphs = goldmark.WithParserOptions(parser.WithBlockParsers(util.Prioritized(defaultIndentableParagraphParser, 500))) diff --git a/format/mdext/math.go b/format/mdext/math.go new file mode 100644 index 00000000..e6a6ecc5 --- /dev/null +++ b/format/mdext/math.go @@ -0,0 +1,240 @@ +// Copyright (c) 2024 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package mdext + +import ( + "bytes" + "fmt" + stdhtml "html" + "regexp" + "strings" + "unicode" + + "github.com/yuin/goldmark" + "github.com/yuin/goldmark/ast" + "github.com/yuin/goldmark/parser" + "github.com/yuin/goldmark/renderer" + "github.com/yuin/goldmark/renderer/html" + "github.com/yuin/goldmark/text" + "github.com/yuin/goldmark/util" +) + +var astKindMath = ast.NewNodeKind("Math") + +type astMath struct { + ast.BaseInline + value []byte +} + +func (n *astMath) Dump(source []byte, level int) { + ast.DumpHelper(n, source, level, nil, nil) +} + +func (n *astMath) Kind() ast.NodeKind { + return astKindMath +} + +type astMathBlock struct { + ast.BaseBlock +} + +func (n *astMathBlock) Dump(source []byte, level int) { + ast.DumpHelper(n, source, level, nil, nil) +} + +func (n *astMathBlock) Kind() ast.NodeKind { + return astKindMath +} + +type inlineMathParser struct{} + +var defaultInlineMathParser = &inlineMathParser{} + +func NewInlineMathParser() parser.InlineParser { + return defaultInlineMathParser +} + +const mathDelimiter = '$' + +func (s *inlineMathParser) Trigger() []byte { + return []byte{mathDelimiter} +} + +// This ignores lines where there's no space after the closing $ to avoid false positives +var latexInlineRegexp = regexp.MustCompile(`^(\$[^$]*\$)(?:$|\s)`) + +func (s *inlineMathParser) Parse(parent ast.Node, block text.Reader, pc parser.Context) ast.Node { + before := block.PrecendingCharacter() + // Ignore lines where the opening $ comes after a letter or number to avoid false positives + if unicode.IsLetter(before) || unicode.IsNumber(before) { + return nil + } + line, segment := block.PeekLine() + idx := latexInlineRegexp.FindSubmatchIndex(line) + if idx == nil { + return nil + } + block.Advance(idx[3]) + return &astMath{ + value: block.Value(text.NewSegment(segment.Start+1, segment.Start+idx[3]-1)), + } +} + +func (s *inlineMathParser) CloseBlock(parent ast.Node, pc parser.Context) { + // nothing to do +} + +type blockMathParser struct{} + +var defaultBlockMathParser = &blockMathParser{} + +func NewBlockMathParser() parser.BlockParser { + return defaultBlockMathParser +} + +var mathBlockInfoKey = parser.NewContextKey() + +type mathBlockData struct { + indent int + length int + node ast.Node +} + +func (b *blockMathParser) Trigger() []byte { + return []byte{'$'} +} + +func (b *blockMathParser) Open(parent ast.Node, reader text.Reader, pc parser.Context) (ast.Node, parser.State) { + line, _ := reader.PeekLine() + pos := pc.BlockOffset() + if pos < 0 || (line[pos] != mathDelimiter) { + return nil, parser.NoChildren + } + findent := pos + i := pos + for ; i < len(line) && line[i] == mathDelimiter; i++ { + } + oFenceLength := i - pos + if oFenceLength < 2 { + return nil, parser.NoChildren + } + if i < len(line)-1 { + rest := line[i:] + left := util.TrimLeftSpaceLength(rest) + right := util.TrimRightSpaceLength(rest) + if left < len(rest)-right { + value := rest[left : len(rest)-right] + if bytes.IndexByte(value, mathDelimiter) > -1 { + return nil, parser.NoChildren + } + } + } + node := &astMathBlock{} + pc.Set(mathBlockInfoKey, &mathBlockData{findent, oFenceLength, node}) + return node, parser.NoChildren + +} + +func (b *blockMathParser) Continue(node ast.Node, reader text.Reader, pc parser.Context) parser.State { + line, segment := reader.PeekLine() + fdata := pc.Get(mathBlockInfoKey).(*mathBlockData) + + w, pos := util.IndentWidth(line, reader.LineOffset()) + if w < 4 { + i := pos + for ; i < len(line) && line[i] == mathDelimiter; i++ { + } + length := i - pos + if length >= fdata.length && util.IsBlank(line[i:]) { + newline := 1 + if line[len(line)-1] != '\n' { + newline = 0 + } + reader.Advance(segment.Stop - segment.Start - newline + segment.Padding) + return parser.Close + } + } + pos, padding := util.IndentPositionPadding(line, reader.LineOffset(), segment.Padding, fdata.indent) + if pos < 0 { + pos = util.FirstNonSpacePosition(line) + if pos < 0 { + pos = 0 + } + padding = 0 + } + seg := text.NewSegmentPadding(segment.Start+pos, segment.Stop, padding) + seg.ForceNewline = true // EOF as newline + node.Lines().Append(seg) + reader.AdvanceAndSetPadding(segment.Stop-segment.Start-pos-1, padding) + return parser.Continue | parser.NoChildren +} + +func (b *blockMathParser) Close(node ast.Node, reader text.Reader, pc parser.Context) { + fdata := pc.Get(mathBlockInfoKey).(*mathBlockData) + if fdata.node == node { + pc.Set(mathBlockInfoKey, nil) + } +} + +func (b *blockMathParser) CanInterruptParagraph() bool { + return true +} + +func (b *blockMathParser) CanAcceptIndentedLine() bool { + return false +} + +type mathHTMLRenderer struct { + html.Config +} + +func NewMathHTMLRenderer(opts ...html.Option) renderer.NodeRenderer { + r := &mathHTMLRenderer{ + Config: html.NewConfig(), + } + for _, opt := range opts { + opt.SetHTMLOption(&r.Config) + } + return r +} + +func (r *mathHTMLRenderer) RegisterFuncs(reg renderer.NodeRendererFuncRegisterer) { + reg.Register(astKindMath, r.renderMath) +} + +func (r *mathHTMLRenderer) renderMath(w util.BufWriter, source []byte, n ast.Node, entering bool) (ast.WalkStatus, error) { + if entering { + tag := "span" + var tex string + switch typed := n.(type) { + case *astMathBlock: + tag = "div" + tex = string(n.Lines().Value(source)) + case *astMath: + tex = string(typed.value) + } + tex = stdhtml.EscapeString(strings.TrimSpace(tex)) + _, _ = fmt.Fprintf(w, `<%s data-mx-maths="%s">%s`, tag, tex, strings.ReplaceAll(tex, "\n", "
      "), tag) + } + return ast.WalkSkipChildren, nil +} + +type math struct{} + +// Math is an extension that allow you to use math like '$$text$$'. +var Math = &math{} + +func (e *math) Extend(m goldmark.Markdown) { + m.Parser().AddOptions(parser.WithInlineParsers( + util.Prioritized(NewInlineMathParser(), 500), + ), parser.WithBlockParsers( + util.Prioritized(NewBlockMathParser(), 850), + )) + m.Renderer().AddOptions(renderer.WithNodeRenderers( + util.Prioritized(NewMathHTMLRenderer(), 500), + )) +} diff --git a/format/mdext/shortemphasis.go b/format/mdext/shortemphasis.go new file mode 100644 index 00000000..62190326 --- /dev/null +++ b/format/mdext/shortemphasis.go @@ -0,0 +1,96 @@ +// Copyright (c) 2024 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package mdext + +import ( + "github.com/yuin/goldmark" + "github.com/yuin/goldmark/ast" + "github.com/yuin/goldmark/parser" + "github.com/yuin/goldmark/text" + "github.com/yuin/goldmark/util" +) + +var ShortEmphasis goldmark.Extender = &shortEmphasisExtender{} + +type shortEmphasisExtender struct{} + +func (s *shortEmphasisExtender) Extend(m goldmark.Markdown) { + m.Parser().AddOptions(parser.WithInlineParsers( + util.Prioritized(&italicsParser{}, 500), + util.Prioritized(&boldParser{}, 500), + )) +} + +type italicsDelimiterProcessor struct{} + +func (p *italicsDelimiterProcessor) IsDelimiter(b byte) bool { + return b == '_' +} + +func (p *italicsDelimiterProcessor) CanOpenCloser(opener, closer *parser.Delimiter) bool { + return opener.Char == closer.Char +} + +func (p *italicsDelimiterProcessor) OnMatch(consumes int) ast.Node { + return ast.NewEmphasis(1) +} + +var defaultItalicsDelimiterProcessor = &italicsDelimiterProcessor{} + +type italicsParser struct{} + +func (s *italicsParser) Trigger() []byte { + return []byte{'_'} +} + +func (s *italicsParser) Parse(parent ast.Node, block text.Reader, pc parser.Context) ast.Node { + before := block.PrecendingCharacter() + line, segment := block.PeekLine() + node := parser.ScanDelimiter(line, before, 1, defaultItalicsDelimiterProcessor) + if node == nil || node.OriginalLength > 1 || before == '_' { + return nil + } + node.Segment = segment.WithStop(segment.Start + node.OriginalLength) + block.Advance(node.OriginalLength) + pc.PushDelimiter(node) + return node +} + +type boldDelimiterProcessor struct{} + +func (p *boldDelimiterProcessor) IsDelimiter(b byte) bool { + return b == '*' +} + +func (p *boldDelimiterProcessor) CanOpenCloser(opener, closer *parser.Delimiter) bool { + return opener.Char == closer.Char +} + +func (p *boldDelimiterProcessor) OnMatch(consumes int) ast.Node { + return ast.NewEmphasis(2) +} + +var defaultBoldDelimiterProcessor = &boldDelimiterProcessor{} + +type boldParser struct{} + +func (s *boldParser) Trigger() []byte { + return []byte{'*'} +} + +func (s *boldParser) Parse(parent ast.Node, block text.Reader, pc parser.Context) ast.Node { + before := block.PrecendingCharacter() + line, segment := block.PeekLine() + node := parser.ScanDelimiter(line, before, 1, defaultBoldDelimiterProcessor) + if node == nil || node.OriginalLength > 1 || before == '*' { + return nil + } + node.Segment = segment.WithStop(segment.Start + node.OriginalLength) + block.Advance(node.OriginalLength) + pc.PushDelimiter(node) + return node +} diff --git a/format/mdext/shortstrike.go b/format/mdext/shortstrike.go new file mode 100644 index 00000000..00328f22 --- /dev/null +++ b/format/mdext/shortstrike.go @@ -0,0 +1,76 @@ +// Copyright (c) 2024 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package mdext + +import ( + "github.com/yuin/goldmark" + gast "github.com/yuin/goldmark/ast" + "github.com/yuin/goldmark/extension" + "github.com/yuin/goldmark/extension/ast" + "github.com/yuin/goldmark/parser" + "github.com/yuin/goldmark/renderer" + "github.com/yuin/goldmark/text" + "github.com/yuin/goldmark/util" +) + +var ShortStrike goldmark.Extender = &shortStrikeExtender{length: 1} +var LongStrike goldmark.Extender = &shortStrikeExtender{length: 2} + +type shortStrikeExtender struct { + length int +} + +func (s *shortStrikeExtender) Extend(m goldmark.Markdown) { + m.Parser().AddOptions(parser.WithInlineParsers( + util.Prioritized(&strikethroughParser{length: s.length}, 500), + )) + m.Renderer().AddOptions(renderer.WithNodeRenderers( + util.Prioritized(extension.NewStrikethroughHTMLRenderer(), 500), + )) +} + +type strikethroughDelimiterProcessor struct{} + +func (p *strikethroughDelimiterProcessor) IsDelimiter(b byte) bool { + return b == '~' +} + +func (p *strikethroughDelimiterProcessor) CanOpenCloser(opener, closer *parser.Delimiter) bool { + return opener.Char == closer.Char +} + +func (p *strikethroughDelimiterProcessor) OnMatch(consumes int) gast.Node { + return ast.NewStrikethrough() +} + +var defaultStrikethroughDelimiterProcessor = &strikethroughDelimiterProcessor{} + +type strikethroughParser struct { + length int +} + +func (s *strikethroughParser) Trigger() []byte { + return []byte{'~'} +} + +func (s *strikethroughParser) Parse(parent gast.Node, block text.Reader, pc parser.Context) gast.Node { + before := block.PrecendingCharacter() + line, segment := block.PeekLine() + node := parser.ScanDelimiter(line, before, 1, defaultStrikethroughDelimiterProcessor) + if node == nil || node.OriginalLength != s.length || before == '~' { + return nil + } + + node.Segment = segment.WithStop(segment.Start + node.OriginalLength) + block.Advance(node.OriginalLength) + pc.PushDelimiter(node) + return node +} + +func (s *strikethroughParser) CloseBlock(parent gast.Node, pc parser.Context) { + // nothing to do +} diff --git a/go.mod b/go.mod index 3e349953..49a1d4e4 100644 --- a/go.mod +++ b/go.mod @@ -1,34 +1,42 @@ module maunium.net/go/mautrix -go 1.21 +go 1.25.0 + +toolchain go1.26.0 require ( - github.com/gorilla/mux v1.8.0 - github.com/gorilla/websocket v1.5.0 - github.com/lib/pq v1.10.9 - github.com/mattn/go-sqlite3 v1.14.22 - github.com/rs/zerolog v1.32.0 - github.com/stretchr/testify v1.9.0 - github.com/tidwall/gjson v1.17.1 + filippo.io/edwards25519 v1.2.0 + github.com/chzyer/readline v1.5.1 + github.com/coder/websocket v1.8.14 + github.com/lib/pq v1.11.2 + github.com/mattn/go-sqlite3 v1.14.34 + github.com/rs/xid v1.6.0 + github.com/rs/zerolog v1.34.0 + github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e + github.com/stretchr/testify v1.11.1 + github.com/tidwall/gjson v1.18.0 github.com/tidwall/sjson v1.2.5 - github.com/yuin/goldmark v1.7.0 - go.mau.fi/util v0.4.1 - go.mau.fi/zeroconfig v0.1.2 - golang.org/x/crypto v0.21.0 - golang.org/x/exp v0.0.0-20240314144324-c7f7c6466f7f - golang.org/x/net v0.22.0 + github.com/yuin/goldmark v1.7.16 + go.mau.fi/util v0.9.6 + go.mau.fi/zeroconfig v0.2.0 + golang.org/x/crypto v0.48.0 + golang.org/x/exp v0.0.0-20260218203240-3dfff04db8fa + golang.org/x/net v0.50.0 + golang.org/x/sync v0.19.0 gopkg.in/yaml.v3 v3.0.1 maunium.net/go/mauflag v1.0.0 ) require ( - github.com/coreos/go-systemd/v22 v22.5.0 // indirect + github.com/coreos/go-systemd/v22 v22.6.0 // indirect github.com/davecgh/go-spew v1.1.1 // indirect - github.com/mattn/go-colorable v0.1.13 // indirect - github.com/mattn/go-isatty v0.0.19 // indirect + github.com/mattn/go-colorable v0.1.14 // indirect + github.com/mattn/go-isatty v0.0.20 // indirect + github.com/petermattis/goid v0.0.0-20260113132338-7c7de50cc741 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect github.com/tidwall/match v1.1.1 // indirect - github.com/tidwall/pretty v1.2.0 // indirect - golang.org/x/sys v0.18.0 // indirect + github.com/tidwall/pretty v1.2.1 // indirect + golang.org/x/sys v0.41.0 // indirect + golang.org/x/text v0.34.0 // indirect gopkg.in/natefinch/lumberjack.v2 v2.2.1 // indirect ) diff --git a/go.sum b/go.sum index 705282d4..871a5156 100644 --- a/go.sum +++ b/go.sum @@ -1,57 +1,77 @@ +filippo.io/edwards25519 v1.2.0 h1:crnVqOiS4jqYleHd9vaKZ+HKtHfllngJIiOpNpoJsjo= +filippo.io/edwards25519 v1.2.0/go.mod h1:xzAOLCNug/yB62zG1bQ8uziwrIqIuxhctzJT18Q77mc= github.com/DATA-DOG/go-sqlmock v1.5.2 h1:OcvFkGmslmlZibjAjaHm3L//6LiuBgolP7OputlJIzU= github.com/DATA-DOG/go-sqlmock v1.5.2/go.mod h1:88MAG/4G7SMwSE3CeA0ZKzrT5CiOU3OJ+JlNzwDqpNU= -github.com/coreos/go-systemd/v22 v22.5.0 h1:RrqgGjYQKalulkV8NGVIfkXQf6YYmOyiJKk8iXXhfZs= +github.com/chzyer/logex v1.2.1 h1:XHDu3E6q+gdHgsdTPH6ImJMIp436vR6MPtH8gP05QzM= +github.com/chzyer/logex v1.2.1/go.mod h1:JLbx6lG2kDbNRFnfkgvh4eRJRPX1QCoOIWomwysCBrQ= +github.com/chzyer/readline v1.5.1 h1:upd/6fQk4src78LMRzh5vItIt361/o4uq553V8B5sGI= +github.com/chzyer/readline v1.5.1/go.mod h1:Eh+b79XXUwfKfcPLepksvw2tcLE/Ct21YObkaSkeBlk= +github.com/chzyer/test v1.0.0 h1:p3BQDXSxOhOG0P9z6/hGnII4LGiEPOYBhs8asl/fC04= +github.com/chzyer/test v1.0.0/go.mod h1:2JlltgoNkt4TW/z9V/IzDdFaMTM2JPIi26O1pF38GC8= +github.com/coder/websocket v1.8.14 h1:9L0p0iKiNOibykf283eHkKUHHrpG7f65OE3BhhO7v9g= +github.com/coder/websocket v1.8.14/go.mod h1:NX3SzP+inril6yawo5CQXx8+fk145lPDC6pumgx0mVg= github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc= +github.com/coreos/go-systemd/v22 v22.6.0 h1:aGVa/v8B7hpb0TKl0MWoAavPDmHvobFe5R5zn0bCJWo= +github.com/coreos/go-systemd/v22 v22.6.0/go.mod h1:iG+pp635Fo7ZmV/j14KUcmEyWF+0X7Lua8rrTWzYgWU= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA= -github.com/gorilla/mux v1.8.0 h1:i40aqfkR1h2SlN9hojwV5ZA91wcXFOvkdNIeFDP5koI= -github.com/gorilla/mux v1.8.0/go.mod h1:DVbg23sWSpFRCP0SfiEN6jmj59UnW/n46BH5rLB71So= -github.com/gorilla/websocket v1.5.0 h1:PPwGk2jz7EePpoHN/+ClbZu8SPxiqlu12wZP/3sWmnc= -github.com/gorilla/websocket v1.5.0/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= -github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw= -github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= -github.com/mattn/go-colorable v0.1.13 h1:fFA4WZxdEF4tXPZVKMLwD8oUnCTTo08duU7wxecdEvA= +github.com/lib/pq v1.11.2 h1:x6gxUeu39V0BHZiugWe8LXZYZ+Utk7hSJGThs8sdzfs= +github.com/lib/pq v1.11.2/go.mod h1:/p+8NSbOcwzAEI7wiMXFlgydTwcgTr3OSKMsD2BitpA= github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg= +github.com/mattn/go-colorable v0.1.14 h1:9A9LHSqF/7dyVVX6g0U9cwm9pG3kP9gSzcuIPHPsaIE= +github.com/mattn/go-colorable v0.1.14/go.mod h1:6LmQG8QLFO4G5z1gPvYEzlUgJ2wF+stgPZH1UqBm1s8= github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM= -github.com/mattn/go-isatty v0.0.19 h1:JITubQf0MOLdlGRuRq+jtsDlekdYPia9ZFsB8h/APPA= github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= -github.com/mattn/go-sqlite3 v1.14.22 h1:2gZY6PC6kBnID23Tichd1K+Z0oS6nE/XwU+Vz/5o4kU= -github.com/mattn/go-sqlite3 v1.14.22/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y= +github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= +github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= +github.com/mattn/go-sqlite3 v1.14.34 h1:3NtcvcUnFBPsuRcno8pUtupspG/GM+9nZ88zgJcp6Zk= +github.com/mattn/go-sqlite3 v1.14.34/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y= +github.com/petermattis/goid v0.0.0-20260113132338-7c7de50cc741 h1:KPpdlQLZcHfTMQRi6bFQ7ogNO0ltFT4PmtwTLW4W+14= +github.com/petermattis/goid v0.0.0-20260113132338-7c7de50cc741/go.mod h1:pxMtw7cyUw6B2bRH0ZBANSPg+AoSud1I1iyJHI69jH4= github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= -github.com/rs/xid v1.5.0/go.mod h1:trrq9SKmegXys3aeAKXMUTdJsYXVwGY3RLcfgqegfbg= -github.com/rs/zerolog v1.32.0 h1:keLypqrlIjaFsbmJOBdB/qvyF8KEtCWHwobLp5l/mQ0= -github.com/rs/zerolog v1.32.0/go.mod h1:/7mN4D5sKwJLZQ2b/znpjC3/GQWY/xaDXUM0kKWRHss= -github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= -github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= +github.com/rs/xid v1.6.0 h1:fV591PaemRlL6JfRxGDEPl69wICngIQ3shQtzfy2gxU= +github.com/rs/xid v1.6.0/go.mod h1:7XoLgs4eV+QndskICGsho+ADou8ySMSjJKDIan90Nz0= +github.com/rs/zerolog v1.34.0 h1:k43nTLIwcTVQAncfCw4KZ2VY6ukYoZaBPNOE8txlOeY= +github.com/rs/zerolog v1.34.0/go.mod h1:bJsvje4Z08ROH4Nhs5iH600c3IkWhwp44iRc54W6wYQ= +github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e h1:MRM5ITcdelLK2j1vwZ3Je0FKVCfqOLp5zO6trqMLYs0= +github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e/go.mod h1:XV66xRDqSt+GTGFMVlhk3ULuV0y9ZmzeVGR4mloJI3M= +github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= +github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= github.com/tidwall/gjson v1.14.2/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= -github.com/tidwall/gjson v1.17.1 h1:wlYEnwqAHgzmhNUFfw7Xalt2JzQvsMx2Se4PcoFCT/U= -github.com/tidwall/gjson v1.17.1/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= +github.com/tidwall/gjson v1.18.0 h1:FIDeeyB800efLX89e5a8Y0BNH+LOngJyGrIWxG2FKQY= +github.com/tidwall/gjson v1.18.0/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA= github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM= -github.com/tidwall/pretty v1.2.0 h1:RWIZEg2iJ8/g6fDDYzMpobmaoGh5OLl4AXtGUGPcqCs= github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU= +github.com/tidwall/pretty v1.2.1 h1:qjsOFOWWQl+N3RsoF5/ssm1pHmJJwhjlSbZ51I6wMl4= +github.com/tidwall/pretty v1.2.1/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU= github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY= github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28= -github.com/yuin/goldmark v1.7.0 h1:EfOIvIMZIzHdB/R/zVrikYLPPwJlfMcNczJFMs1m6sA= -github.com/yuin/goldmark v1.7.0/go.mod h1:uzxRWxtg69N339t3louHJ7+O03ezfj6PlliRlaOzY1E= -go.mau.fi/util v0.4.1 h1:3EC9KxIXo5+h869zDGf5OOZklRd/FjeVnimTwtm3owg= -go.mau.fi/util v0.4.1/go.mod h1:GjkTEBsehYZbSh2LlE6cWEn+6ZIZTGrTMM/5DMNlmFY= -go.mau.fi/zeroconfig v0.1.2 h1:DKOydWnhPMn65GbXZOafgkPm11BvFashZWLct0dGFto= -go.mau.fi/zeroconfig v0.1.2/go.mod h1:NcSJkf180JT+1IId76PcMuLTNa1CzsFFZ0nBygIQM70= -golang.org/x/crypto v0.21.0 h1:X31++rzVUdKhX5sWmSOFZxx8UW/ldWx55cbf08iNAMA= -golang.org/x/crypto v0.21.0/go.mod h1:0BP7YvVV9gBbVKyeTG0Gyn+gZm94bibOW5BjDEYAOMs= -golang.org/x/exp v0.0.0-20240314144324-c7f7c6466f7f h1:3CW0unweImhOzd5FmYuRsD4Y4oQFKZIjAnKbjV4WIrw= -golang.org/x/exp v0.0.0-20240314144324-c7f7c6466f7f/go.mod h1:CxmFvTBINI24O/j8iY7H1xHzx2i4OsyguNBmN/uPtqc= -golang.org/x/net v0.22.0 h1:9sGLhx7iRIHEiX0oAJ3MRZMUCElJgy7Br1nO+AMN3Tc= -golang.org/x/net v0.22.0/go.mod h1:JKghWKKOSdJwpW2GEx0Ja7fmaKnMsbu+MWVZTokSYmg= +github.com/yuin/goldmark v1.7.16 h1:n+CJdUxaFMiDUNnWC3dMWCIQJSkxH4uz3ZwQBkAlVNE= +github.com/yuin/goldmark v1.7.16/go.mod h1:ip/1k0VRfGynBgxOz0yCqHrbZXhcjxyuS66Brc7iBKg= +go.mau.fi/util v0.9.6 h1:2nsvxm49KhI3wrFltr0+wSUBlnQ4CMtykuELjpIU+ts= +go.mau.fi/util v0.9.6/go.mod h1:sIJpRH7Iy5Ad1SBuxQoatxtIeErgzxCtjd/2hCMkYMI= +go.mau.fi/zeroconfig v0.2.0 h1:e/OGEERqVRRKlgaro7E6bh8xXiKFSXB3eNNIud7FUjU= +go.mau.fi/zeroconfig v0.2.0/go.mod h1:J0Vn0prHNOm493oZoQ84kq83ZaNCYZnq+noI1b1eN8w= +golang.org/x/crypto v0.48.0 h1:/VRzVqiRSggnhY7gNRxPauEQ5Drw9haKdM0jqfcCFts= +golang.org/x/crypto v0.48.0/go.mod h1:r0kV5h3qnFPlQnBSrULhlsRfryS2pmewsg+XfMgkVos= +golang.org/x/exp v0.0.0-20260218203240-3dfff04db8fa h1:Zt3DZoOFFYkKhDT3v7Lm9FDMEV06GpzjG2jrqW+QTE0= +golang.org/x/exp v0.0.0-20260218203240-3dfff04db8fa/go.mod h1:K79w1Vqn7PoiZn+TkNpx3BUWUQksGO3JcVX6qIjytmA= +golang.org/x/net v0.50.0 h1:ucWh9eiCGyDR3vtzso0WMQinm2Dnt8cFMuQa9K33J60= +golang.org/x/net v0.50.0/go.mod h1:UgoSli3F/pBgdJBHCTc+tp3gmrU4XswgGRgtnwWTfyM= +golang.org/x/sync v0.19.0 h1:vV+1eWNmZ5geRlYjzm2adRgW2/mcpevXNg50YZtPCE4= +golang.org/x/sync v0.19.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI= +golang.org/x/sys v0.0.0-20220310020820-b874c991c1a5/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.18.0 h1:DBdB3niSjOA/O0blCZBqDefyWNYveAYMNF1Wum0DYQ4= -golang.org/x/sys v0.18.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/sys v0.41.0 h1:Ivj+2Cp/ylzLiEU89QhWblYnOE9zerudt9Ftecq2C6k= +golang.org/x/sys v0.41.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= +golang.org/x/text v0.34.0 h1:oL/Qq0Kdaqxa1KbNeMKwQq0reLCCaFtqu2eNuSeNHbk= +golang.org/x/text v0.34.0/go.mod h1:homfLqTYRFyVYemLBFl5GgL/DWEiH5wcsQ5gSh1yziA= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/natefinch/lumberjack.v2 v2.2.1 h1:bBRl1b0OH9s/DuPhuXpNl+VtCaJXFZ5/uEFST95x9zc= diff --git a/id/contenturi.go b/id/contenturi.go index cfd00c3e..67127b6c 100644 --- a/id/contenturi.go +++ b/id/contenturi.go @@ -12,12 +12,19 @@ import ( "encoding/json" "errors" "fmt" + "regexp" "strings" ) var ( - InvalidContentURI = errors.New("invalid Matrix content URI") - InputNotJSONString = errors.New("input doesn't look like a JSON string") + ErrInvalidContentURI = errors.New("invalid Matrix content URI") + ErrInputNotJSONString = errors.New("input doesn't look like a JSON string") +) + +// Deprecated: use variables prefixed with Err +var ( + InvalidContentURI = ErrInvalidContentURI + InputNotJSONString = ErrInputNotJSONString ) // ContentURIString is a string that's expected to be a Matrix content URI. @@ -54,9 +61,9 @@ func ParseContentURI(uri string) (parsed ContentURI, err error) { if len(uri) == 0 { return } else if !strings.HasPrefix(uri, "mxc://") { - err = InvalidContentURI + err = ErrInvalidContentURI } else if index := strings.IndexRune(uri[6:], '/'); index == -1 || index == len(uri)-7 { - err = InvalidContentURI + err = ErrInvalidContentURI } else { parsed.Homeserver = uri[6 : 6+index] parsed.FileID = uri[6+index+1:] @@ -70,9 +77,9 @@ func ParseContentURIBytes(uri []byte) (parsed ContentURI, err error) { if len(uri) == 0 { return } else if !bytes.HasPrefix(uri, mxcBytes) { - err = InvalidContentURI + err = ErrInvalidContentURI } else if index := bytes.IndexRune(uri[6:], '/'); index == -1 || index == len(uri)-7 { - err = InvalidContentURI + err = ErrInvalidContentURI } else { parsed.Homeserver = string(uri[6 : 6+index]) parsed.FileID = string(uri[6+index+1:]) @@ -85,7 +92,7 @@ func (uri *ContentURI) UnmarshalJSON(raw []byte) (err error) { *uri = ContentURI{} return nil } else if len(raw) < 2 || raw[0] != '"' || raw[len(raw)-1] != '"' { - return InputNotJSONString + return fmt.Errorf("ContentURI: %w", ErrInputNotJSONString) } parsed, err := ParseContentURIBytes(raw[1 : len(raw)-1]) if err != nil { @@ -156,3 +163,21 @@ func (uri ContentURI) CUString() ContentURIString { func (uri ContentURI) IsEmpty() bool { return len(uri.Homeserver) == 0 || len(uri.FileID) == 0 } + +var simpleHomeserverRegex = regexp.MustCompile(`^[a-zA-Z0-9.:-]+$`) + +func (uri ContentURI) IsValid() bool { + return IsValidMediaID(uri.FileID) && uri.Homeserver != "" && simpleHomeserverRegex.MatchString(uri.Homeserver) +} + +func IsValidMediaID(mediaID string) bool { + if len(mediaID) == 0 { + return false + } + for _, char := range mediaID { + if (char < 'A' || char > 'Z') && (char < 'a' || char > 'z') && (char < '0' || char > '9') && char != '-' && char != '_' { + return false + } + } + return true +} diff --git a/id/crypto.go b/id/crypto.go index 355a84a8..ee857f78 100644 --- a/id/crypto.go +++ b/id/crypto.go @@ -53,6 +53,34 @@ const ( KeyBackupAlgorithmMegolmBackupV1 KeyBackupAlgorithm = "m.megolm_backup.v1.curve25519-aes-sha2" ) +type KeySource string + +func (source KeySource) String() string { + return string(source) +} + +func (source KeySource) Int() int { + switch source { + case KeySourceDirect: + return 100 + case KeySourceBackup: + return 90 + case KeySourceImport: + return 80 + case KeySourceForward: + return 50 + default: + return 0 + } +} + +const ( + KeySourceDirect KeySource = "direct" + KeySourceBackup KeySource = "backup" + KeySourceImport KeySource = "import" + KeySourceForward KeySource = "forward" +) + // BackupVersion is an arbitrary string that identifies a server side key backup. type KeyBackupVersion string diff --git a/id/matrixuri.go b/id/matrixuri.go index 5ec403e9..d5c78bc7 100644 --- a/id/matrixuri.go +++ b/id/matrixuri.go @@ -54,7 +54,7 @@ var SigilToPathSegment = map[rune]string{ func (uri *MatrixURI) getQuery() url.Values { q := make(url.Values) - if uri.Via != nil && len(uri.Via) > 0 { + if len(uri.Via) > 0 { q["via"] = uri.Via } if len(uri.Action) > 0 { @@ -65,6 +65,9 @@ func (uri *MatrixURI) getQuery() url.Values { // String converts the parsed matrix: URI back into the string representation. func (uri *MatrixURI) String() string { + if uri == nil { + return "" + } parts := []string{ SigilToPathSegment[uri.Sigil1], url.PathEscape(uri.MXID1), @@ -81,6 +84,9 @@ func (uri *MatrixURI) String() string { // MatrixToURL converts to parsed matrix: URI into a matrix.to URL func (uri *MatrixURI) MatrixToURL() string { + if uri == nil { + return "" + } fragment := fmt.Sprintf("#/%s", url.PathEscape(uri.PrimaryIdentifier())) if uri.Sigil2 != 0 { fragment = fmt.Sprintf("%s/%s", fragment, url.PathEscape(uri.SecondaryIdentifier())) @@ -96,13 +102,16 @@ func (uri *MatrixURI) MatrixToURL() string { // PrimaryIdentifier returns the first Matrix identifier in the URI. // Currently room IDs, room aliases and user IDs can be in the primary identifier slot. func (uri *MatrixURI) PrimaryIdentifier() string { + if uri == nil { + return "" + } return fmt.Sprintf("%c%s", uri.Sigil1, uri.MXID1) } // SecondaryIdentifier returns the second Matrix identifier in the URI. // Currently only event IDs can be in the secondary identifier slot. func (uri *MatrixURI) SecondaryIdentifier() string { - if uri.Sigil2 == 0 { + if uri == nil || uri.Sigil2 == 0 { return "" } return fmt.Sprintf("%c%s", uri.Sigil2, uri.MXID2) @@ -110,7 +119,7 @@ func (uri *MatrixURI) SecondaryIdentifier() string { // UserID returns the user ID from the URI if the primary identifier is a user ID. func (uri *MatrixURI) UserID() UserID { - if uri.Sigil1 == '@' { + if uri != nil && uri.Sigil1 == '@' { return UserID(uri.PrimaryIdentifier()) } return "" @@ -118,7 +127,7 @@ func (uri *MatrixURI) UserID() UserID { // RoomID returns the room ID from the URI if the primary identifier is a room ID. func (uri *MatrixURI) RoomID() RoomID { - if uri.Sigil1 == '!' { + if uri != nil && uri.Sigil1 == '!' { return RoomID(uri.PrimaryIdentifier()) } return "" @@ -126,7 +135,7 @@ func (uri *MatrixURI) RoomID() RoomID { // RoomAlias returns the room alias from the URI if the primary identifier is a room alias. func (uri *MatrixURI) RoomAlias() RoomAlias { - if uri.Sigil1 == '#' { + if uri != nil && uri.Sigil1 == '#' { return RoomAlias(uri.PrimaryIdentifier()) } return "" @@ -134,7 +143,7 @@ func (uri *MatrixURI) RoomAlias() RoomAlias { // EventID returns the event ID from the URI if the primary identifier is a room ID or alias and the secondary identifier is an event ID. func (uri *MatrixURI) EventID() EventID { - if (uri.Sigil1 == '!' || uri.Sigil1 == '#') && uri.Sigil2 == '$' { + if uri != nil && (uri.Sigil1 == '!' || uri.Sigil1 == '#') && uri.Sigil2 == '$' { return EventID(uri.SecondaryIdentifier()) } return "" @@ -201,10 +210,14 @@ func ProcessMatrixURI(uri *url.URL) (*MatrixURI, error) { if len(parts[1]) == 0 { return nil, ErrEmptySecondSegment } - parsed.MXID1 = parts[1] + var err error + parsed.MXID1, err = url.PathUnescape(parts[1]) + if err != nil { + return nil, fmt.Errorf("failed to url decode second segment %q: %w", parts[1], err) + } // Step 6: if the first part is a room and the URI has 4 segments, construct a second level identifier - if (parsed.Sigil1 == '!' || parsed.Sigil1 == '#') && len(parts) == 4 { + if parsed.Sigil1 == '!' && len(parts) == 4 { // a: find the sigil from the third segment switch parts[2] { case "e", "event": @@ -217,7 +230,10 @@ func ProcessMatrixURI(uri *url.URL) (*MatrixURI, error) { if len(parts[3]) == 0 { return nil, ErrEmptyFourthSegment } - parsed.MXID2 = parts[3] + parsed.MXID2, err = url.PathUnescape(parts[3]) + if err != nil { + return nil, fmt.Errorf("failed to url decode fourth segment %q: %w", parts[3], err) + } } // Step 7: parse the query and extract via and action items diff --git a/id/matrixuri_test.go b/id/matrixuri_test.go index d26d4bfd..90a0754d 100644 --- a/id/matrixuri_test.go +++ b/id/matrixuri_test.go @@ -16,12 +16,11 @@ import ( ) var ( - roomIDLink = id.MatrixURI{Sigil1: '!', MXID1: "7NdBVvkd4aLSbgKt9RXl:example.org"} - roomIDViaLink = id.MatrixURI{Sigil1: '!', MXID1: "7NdBVvkd4aLSbgKt9RXl:example.org", Via: []string{"maunium.net", "matrix.org"}} - roomAliasLink = id.MatrixURI{Sigil1: '#', MXID1: "someroom:example.org"} - roomIDEventLink = id.MatrixURI{Sigil1: '!', MXID1: "7NdBVvkd4aLSbgKt9RXl:example.org", Sigil2: '$', MXID2: "uOH4C9cK4HhMeFWkUXMbdF_dtndJ0j9je-kIK3XpV1s"} - roomAliasEventLink = id.MatrixURI{Sigil1: '#', MXID1: "someroom:example.org", Sigil2: '$', MXID2: "uOH4C9cK4HhMeFWkUXMbdF_dtndJ0j9je-kIK3XpV1s"} - userLink = id.MatrixURI{Sigil1: '@', MXID1: "user:example.org"} + roomIDLink = id.MatrixURI{Sigil1: '!', MXID1: "7NdBVvkd4aLSbgKt9RXl:example.org"} + roomIDViaLink = id.MatrixURI{Sigil1: '!', MXID1: "7NdBVvkd4aLSbgKt9RXl:example.org", Via: []string{"maunium.net", "matrix.org"}} + roomAliasLink = id.MatrixURI{Sigil1: '#', MXID1: "someroom:example.org"} + roomIDEventLink = id.MatrixURI{Sigil1: '!', MXID1: "7NdBVvkd4aLSbgKt9RXl:example.org", Sigil2: '$', MXID2: "uOH4C9cK4HhMeFWkUXMbdF_dtndJ0j9je-kIK3XpV1s"} + userLink = id.MatrixURI{Sigil1: '@', MXID1: "user:example.org"} escapeRoomIDEventLink = id.MatrixURI{Sigil1: '!', MXID1: "meow & 🐈️:example.org", Sigil2: '$', MXID2: "uOH4C9cK4HhMeFWkUXMbdF/dtndJ0j9je+kIK3XpV1s"} ) @@ -31,7 +30,6 @@ func TestMatrixURI_MatrixToURL(t *testing.T) { assert.Equal(t, "https://matrix.to/#/%217NdBVvkd4aLSbgKt9RXl:example.org?via=maunium.net&via=matrix.org", roomIDViaLink.MatrixToURL()) assert.Equal(t, "https://matrix.to/#/%23someroom:example.org", roomAliasLink.MatrixToURL()) assert.Equal(t, "https://matrix.to/#/%217NdBVvkd4aLSbgKt9RXl:example.org/$uOH4C9cK4HhMeFWkUXMbdF_dtndJ0j9je-kIK3XpV1s", roomIDEventLink.MatrixToURL()) - assert.Equal(t, "https://matrix.to/#/%23someroom:example.org/$uOH4C9cK4HhMeFWkUXMbdF_dtndJ0j9je-kIK3XpV1s", roomAliasEventLink.MatrixToURL()) assert.Equal(t, "https://matrix.to/#/@user:example.org", userLink.MatrixToURL()) assert.Equal(t, "https://matrix.to/#/%21meow%20&%20%F0%9F%90%88%EF%B8%8F:example.org/$uOH4C9cK4HhMeFWkUXMbdF%2FdtndJ0j9je+kIK3XpV1s", escapeRoomIDEventLink.MatrixToURL()) } @@ -41,7 +39,6 @@ func TestMatrixURI_String(t *testing.T) { assert.Equal(t, "matrix:roomid/7NdBVvkd4aLSbgKt9RXl:example.org?via=maunium.net&via=matrix.org", roomIDViaLink.String()) assert.Equal(t, "matrix:r/someroom:example.org", roomAliasLink.String()) assert.Equal(t, "matrix:roomid/7NdBVvkd4aLSbgKt9RXl:example.org/e/uOH4C9cK4HhMeFWkUXMbdF_dtndJ0j9je-kIK3XpV1s", roomIDEventLink.String()) - assert.Equal(t, "matrix:r/someroom:example.org/e/uOH4C9cK4HhMeFWkUXMbdF_dtndJ0j9je-kIK3XpV1s", roomAliasEventLink.String()) assert.Equal(t, "matrix:u/user:example.org", userLink.String()) assert.Equal(t, "matrix:roomid/meow%20&%20%F0%9F%90%88%EF%B8%8F:example.org/e/uOH4C9cK4HhMeFWkUXMbdF%2FdtndJ0j9je+kIK3XpV1s", escapeRoomIDEventLink.String()) } @@ -80,8 +77,12 @@ func TestParseMatrixURI_RoomID(t *testing.T) { parsedVia, err := id.ParseMatrixURI("matrix:roomid/7NdBVvkd4aLSbgKt9RXl:example.org?via=maunium.net&via=matrix.org") require.NoError(t, err) require.NotNil(t, parsedVia) + parsedEncoded, err := id.ParseMatrixURI("matrix:roomid/7NdBVvkd4aLSbgKt9RXl%3Aexample.org") + require.NoError(t, err) + require.NotNil(t, parsedEncoded) assert.Equal(t, roomIDLink, *parsed) + assert.Equal(t, roomIDLink, *parsedEncoded) assert.Equal(t, roomIDViaLink, *parsedVia) } @@ -98,19 +99,11 @@ func TestParseMatrixURI_UserID(t *testing.T) { } func TestParseMatrixURI_EventID(t *testing.T) { - parsed1, err := id.ParseMatrixURI("matrix:r/someroom:example.org/e/uOH4C9cK4HhMeFWkUXMbdF_dtndJ0j9je-kIK3XpV1s") + parsed, err := id.ParseMatrixURI("matrix:roomid/7NdBVvkd4aLSbgKt9RXl:example.org/e/uOH4C9cK4HhMeFWkUXMbdF_dtndJ0j9je-kIK3XpV1s") require.NoError(t, err) - require.NotNil(t, parsed1) - parsed2, err := id.ParseMatrixURI("matrix:room/someroom:example.org/e/uOH4C9cK4HhMeFWkUXMbdF_dtndJ0j9je-kIK3XpV1s") - require.NoError(t, err) - require.NotNil(t, parsed2) - parsed3, err := id.ParseMatrixURI("matrix:roomid/7NdBVvkd4aLSbgKt9RXl:example.org/e/uOH4C9cK4HhMeFWkUXMbdF_dtndJ0j9je-kIK3XpV1s") - require.NoError(t, err) - require.NotNil(t, parsed3) + require.NotNil(t, parsed) - assert.Equal(t, roomAliasEventLink, *parsed1) - assert.Equal(t, roomAliasEventLink, *parsed2) - assert.Equal(t, roomIDEventLink, *parsed3) + assert.Equal(t, roomIDEventLink, *parsed) } func TestParseMatrixToURL_RoomAlias(t *testing.T) { @@ -158,21 +151,13 @@ func TestParseMatrixToURL_UserID(t *testing.T) { } func TestParseMatrixToURL_EventID(t *testing.T) { - parsed1, err := id.ParseMatrixToURL("https://matrix.to/#/#someroom:example.org/$uOH4C9cK4HhMeFWkUXMbdF_dtndJ0j9je-kIK3XpV1s") + parsed, err := id.ParseMatrixToURL("https://matrix.to/#/!7NdBVvkd4aLSbgKt9RXl:example.org/$uOH4C9cK4HhMeFWkUXMbdF_dtndJ0j9je-kIK3XpV1s") require.NoError(t, err) - require.NotNil(t, parsed1) - parsed2, err := id.ParseMatrixToURL("https://matrix.to/#/!7NdBVvkd4aLSbgKt9RXl:example.org/$uOH4C9cK4HhMeFWkUXMbdF_dtndJ0j9je-kIK3XpV1s") + require.NotNil(t, parsed) + parsedEncoded, err := id.ParseMatrixToURL("https://matrix.to/#/%217NdBVvkd4aLSbgKt9RXl:example.org/%24uOH4C9cK4HhMeFWkUXMbdF_dtndJ0j9je-kIK3XpV1s") require.NoError(t, err) - require.NotNil(t, parsed2) - parsed1Encoded, err := id.ParseMatrixToURL("https://matrix.to/#/%23someroom:example.org/%24uOH4C9cK4HhMeFWkUXMbdF_dtndJ0j9je-kIK3XpV1s") - require.NoError(t, err) - require.NotNil(t, parsed1) - parsed2Encoded, err := id.ParseMatrixToURL("https://matrix.to/#/%217NdBVvkd4aLSbgKt9RXl:example.org/%24uOH4C9cK4HhMeFWkUXMbdF_dtndJ0j9je-kIK3XpV1s") - require.NoError(t, err) - require.NotNil(t, parsed2) + require.NotNil(t, parsedEncoded) - assert.Equal(t, roomAliasEventLink, *parsed1) - assert.Equal(t, roomAliasEventLink, *parsed1Encoded) - assert.Equal(t, roomIDEventLink, *parsed2) - assert.Equal(t, roomIDEventLink, *parsed2Encoded) + assert.Equal(t, roomIDEventLink, *parsed) + assert.Equal(t, roomIDEventLink, *parsedEncoded) } diff --git a/id/opaque.go b/id/opaque.go index 16863b95..c1ad4988 100644 --- a/id/opaque.go +++ b/id/opaque.go @@ -32,11 +32,17 @@ type EventID string // https://github.com/matrix-org/matrix-doc/pull/2716 type BatchID string +// A DelayID is a string identifying a delayed event. +type DelayID string + func (roomID RoomID) String() string { return string(roomID) } func (roomID RoomID) URI(via ...string) *MatrixURI { + if roomID == "" { + return nil + } return &MatrixURI{ Sigil1: '!', MXID1: string(roomID)[1:], @@ -45,6 +51,11 @@ func (roomID RoomID) URI(via ...string) *MatrixURI { } func (roomID RoomID) EventURI(eventID EventID, via ...string) *MatrixURI { + if roomID == "" { + return nil + } else if eventID == "" { + return roomID.URI(via...) + } return &MatrixURI{ Sigil1: '!', MXID1: string(roomID)[1:], @@ -59,13 +70,20 @@ func (roomAlias RoomAlias) String() string { } func (roomAlias RoomAlias) URI() *MatrixURI { + if roomAlias == "" { + return nil + } return &MatrixURI{ Sigil1: '#', MXID1: string(roomAlias)[1:], } } +// Deprecated: room alias event links should not be used. Use room IDs instead. func (roomAlias RoomAlias) EventURI(eventID EventID) *MatrixURI { + if roomAlias == "" { + return nil + } return &MatrixURI{ Sigil1: '#', MXID1: string(roomAlias)[1:], diff --git a/id/roomversion.go b/id/roomversion.go new file mode 100644 index 00000000..578c10bd --- /dev/null +++ b/id/roomversion.go @@ -0,0 +1,265 @@ +// Copyright (c) 2025 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package id + +import ( + "errors" + "fmt" + "slices" +) + +type RoomVersion string + +const ( + RoomV0 RoomVersion = "" // No room version, used for rooms created before room versions were introduced, equivalent to v1 + RoomV1 RoomVersion = "1" + RoomV2 RoomVersion = "2" + RoomV3 RoomVersion = "3" + RoomV4 RoomVersion = "4" + RoomV5 RoomVersion = "5" + RoomV6 RoomVersion = "6" + RoomV7 RoomVersion = "7" + RoomV8 RoomVersion = "8" + RoomV9 RoomVersion = "9" + RoomV10 RoomVersion = "10" + RoomV11 RoomVersion = "11" + RoomV12 RoomVersion = "12" +) + +func (rv RoomVersion) Equals(versions ...RoomVersion) bool { + return slices.Contains(versions, rv) +} + +func (rv RoomVersion) NotEquals(versions ...RoomVersion) bool { + return !rv.Equals(versions...) +} + +var ErrUnknownRoomVersion = errors.New("unknown room version") + +func (rv RoomVersion) unknownVersionError() error { + return fmt.Errorf("%w %s", ErrUnknownRoomVersion, rv) +} + +func (rv RoomVersion) IsKnown() bool { + switch rv { + case RoomV0, RoomV1, RoomV2, RoomV3, RoomV4, RoomV5, RoomV6, RoomV7, RoomV8, RoomV9, RoomV10, RoomV11, RoomV12: + return true + default: + return false + } +} + +type StateResVersion int + +const ( + // StateResV1 is the original state resolution algorithm. + StateResV1 StateResVersion = 0 + // StateResV2 is state resolution v2 introduced by https://github.com/matrix-org/matrix-spec-proposals/pull/1759 + StateResV2 StateResVersion = 1 + // StateResV2_1 is state resolution v2.1 introduced by https://github.com/matrix-org/matrix-spec-proposals/pull/4297 + StateResV2_1 StateResVersion = 2 +) + +// StateResVersion returns the version of the state resolution algorithm used by this room version. +func (rv RoomVersion) StateResVersion() StateResVersion { + switch rv { + case RoomV0, RoomV1: + return StateResV1 + case RoomV2, RoomV3, RoomV4, RoomV5, RoomV6, RoomV7, RoomV8, RoomV9, RoomV10, RoomV11: + return StateResV2 + case RoomV12: + return StateResV2_1 + default: + panic(rv.unknownVersionError()) + } +} + +type EventIDFormat int + +const ( + // EventIDFormatCustom is the original format used by room v1 and v2. + // Event IDs in this format are an arbitrary string followed by a colon and the server name. + EventIDFormatCustom EventIDFormat = 0 + // EventIDFormatBase64 is the format used by room v3 introduced by https://github.com/matrix-org/matrix-spec-proposals/pull/1659. + // Event IDs in this format are the standard unpadded base64-encoded SHA256 reference hash of the event. + EventIDFormatBase64 EventIDFormat = 1 + // EventIDFormatURLSafeBase64 is the format used by room v4 and later introduced by https://github.com/matrix-org/matrix-spec-proposals/pull/2002. + // Event IDs in this format are the url-safe unpadded base64-encoded SHA256 reference hash of the event. + EventIDFormatURLSafeBase64 EventIDFormat = 2 +) + +// EventIDFormat returns the format of event IDs used by this room version. +func (rv RoomVersion) EventIDFormat() EventIDFormat { + switch rv { + case RoomV0, RoomV1, RoomV2: + return EventIDFormatCustom + case RoomV3: + return EventIDFormatBase64 + default: + return EventIDFormatURLSafeBase64 + } +} + +///////////////////// +// Room v5 changes // +///////////////////// +// https://github.com/matrix-org/matrix-spec-proposals/pull/2077 + +// EnforceSigningKeyValidity returns true if the `valid_until_ts` field of federation signing keys +// must be enforced on received events. +// +// See https://github.com/matrix-org/matrix-spec-proposals/pull/2076 +func (rv RoomVersion) EnforceSigningKeyValidity() bool { + return rv.NotEquals(RoomV0, RoomV1, RoomV2, RoomV3, RoomV4) +} + +///////////////////// +// Room v6 changes // +///////////////////// +// https://github.com/matrix-org/matrix-spec-proposals/pull/2240 + +// SpecialCasedAliasesAuth returns true if the `m.room.aliases` event authorization is special cased +// to only always allow servers to modify the state event with their own server name as state key. +// This also implies that the `aliases` field is protected from redactions. +// +// See https://github.com/matrix-org/matrix-spec-proposals/pull/2432 +func (rv RoomVersion) SpecialCasedAliasesAuth() bool { + return rv.Equals(RoomV0, RoomV1, RoomV2, RoomV3, RoomV4, RoomV5) +} + +// ForbidFloatsAndBigInts returns true if floats and integers greater than 2^53-1 or lower than -2^53+1 are forbidden everywhere. +// +// See https://github.com/matrix-org/matrix-spec-proposals/pull/2540 +func (rv RoomVersion) ForbidFloatsAndBigInts() bool { + return rv.NotEquals(RoomV0, RoomV1, RoomV2, RoomV3, RoomV4, RoomV5) +} + +// NotificationsPowerLevels returns true if the `notifications` field in `m.room.power_levels` is validated in event auth. +// However, the field is not protected from redactions. +// +// See https://github.com/matrix-org/matrix-spec-proposals/pull/2209 +func (rv RoomVersion) NotificationsPowerLevels() bool { + return rv.NotEquals(RoomV0, RoomV1, RoomV2, RoomV3, RoomV4, RoomV5) +} + +///////////////////// +// Room v7 changes // +///////////////////// +// https://github.com/matrix-org/matrix-spec-proposals/pull/2998 + +// Knocks returns true if the `knock` join rule is supported. +// +// See https://github.com/matrix-org/matrix-spec-proposals/pull/2403 +func (rv RoomVersion) Knocks() bool { + return rv.NotEquals(RoomV0, RoomV1, RoomV2, RoomV3, RoomV4, RoomV5, RoomV6) +} + +///////////////////// +// Room v8 changes // +///////////////////// +// https://github.com/matrix-org/matrix-spec-proposals/pull/3289 + +// RestrictedJoins returns true if the `restricted` join rule is supported. +// This also implies that the `allow` field in the `m.room.join_rules` event is supported and protected from redactions. +// +// See https://github.com/matrix-org/matrix-spec-proposals/pull/3083 +func (rv RoomVersion) RestrictedJoins() bool { + return rv.NotEquals(RoomV0, RoomV1, RoomV2, RoomV3, RoomV4, RoomV5, RoomV6, RoomV7) +} + +///////////////////// +// Room v9 changes // +///////////////////// +// https://github.com/matrix-org/matrix-spec-proposals/pull/3375 + +// RestrictedJoinsFix returns true if the `join_authorised_via_users_server` field in `m.room.member` events is protected from redactions. +// +// See https://github.com/matrix-org/matrix-spec-proposals/pull/3375 +func (rv RoomVersion) RestrictedJoinsFix() bool { + return rv.RestrictedJoins() && rv != RoomV8 +} + +////////////////////// +// Room v10 changes // +////////////////////// +// https://github.com/matrix-org/matrix-spec-proposals/pull/3604 + +// ValidatePowerLevelInts returns true if the known values in `m.room.power_levels` must be integers (and not strings). +// +// See https://github.com/matrix-org/matrix-spec-proposals/pull/3667 +func (rv RoomVersion) ValidatePowerLevelInts() bool { + return rv.NotEquals(RoomV0, RoomV1, RoomV2, RoomV3, RoomV4, RoomV5, RoomV6, RoomV7, RoomV8, RoomV9) +} + +// KnockRestricted returns true if the `knock_restricted` join rule is supported. +// +// See https://github.com/matrix-org/matrix-spec-proposals/pull/3787 +func (rv RoomVersion) KnockRestricted() bool { + return rv.NotEquals(RoomV0, RoomV1, RoomV2, RoomV3, RoomV4, RoomV5, RoomV6, RoomV7, RoomV8, RoomV9) +} + +////////////////////// +// Room v11 changes // +////////////////////// +// https://github.com/matrix-org/matrix-spec-proposals/pull/3820 + +// CreatorInContent returns true if the `m.room.create` event has a `creator` field in content. +// +// See https://github.com/matrix-org/matrix-spec-proposals/pull/2175 +func (rv RoomVersion) CreatorInContent() bool { + return rv.Equals(RoomV0, RoomV1, RoomV2, RoomV3, RoomV4, RoomV5, RoomV6, RoomV7, RoomV8, RoomV9, RoomV10) +} + +// RedactsInContent returns true if the `m.room.redaction` event has the `redacts` field in content instead of at the top level. +// The redaction protection is also moved from the top level to the content field. +// +// See https://github.com/matrix-org/matrix-spec-proposals/pull/2174 +// (and https://github.com/matrix-org/matrix-spec-proposals/pull/2176 for the redaction protection). +func (rv RoomVersion) RedactsInContent() bool { + return rv.NotEquals(RoomV0, RoomV1, RoomV2, RoomV3, RoomV4, RoomV5, RoomV6, RoomV7, RoomV8, RoomV9, RoomV10) +} + +// UpdatedRedactionRules returns true if various updates to the redaction algorithm are applied. +// +// Specifically: +// +// * the `membership`, `origin`, and `prev_state` fields at the top level of all events are no longer protected. +// * the entire content of `m.room.create` is protected. +// * the `redacts` field in `m.room.redaction` content is protected instead of the top-level field. +// * the `m.room.power_levels` event protects the `invite` field in content. +// * the `signed` field inside the `third_party_invite` field in content of `m.room.member` events is protected. +// +// See https://github.com/matrix-org/matrix-spec-proposals/pull/2176, +// https://github.com/matrix-org/matrix-spec-proposals/pull/3821, and +// https://github.com/matrix-org/matrix-spec-proposals/pull/3989 +func (rv RoomVersion) UpdatedRedactionRules() bool { + return rv.NotEquals(RoomV0, RoomV1, RoomV2, RoomV3, RoomV4, RoomV5, RoomV6, RoomV7, RoomV8, RoomV9, RoomV10) +} + +////////////////////// +// Room v12 changes // +////////////////////// +// https://github.com/matrix-org/matrix-spec-proposals/pull/4304 + +// Return value of StateResVersion was changed to StateResV2_1 + +// PrivilegedRoomCreators returns true if the creator(s) of a room always have infinite power level. +// This also implies that the `m.room.create` event has an `additional_creators` field, +// and that the creators can't be present in the `m.room.power_levels` event. +// +// See https://github.com/matrix-org/matrix-spec-proposals/pull/4289 +func (rv RoomVersion) PrivilegedRoomCreators() bool { + return rv.NotEquals(RoomV0, RoomV1, RoomV2, RoomV3, RoomV4, RoomV5, RoomV6, RoomV7, RoomV8, RoomV9, RoomV10, RoomV11) +} + +// RoomIDIsCreateEventID returns true if the ID of rooms is the same as the ID of the `m.room.create` event. +// This also implies that `m.room.create` events do not have a `room_id` field. +// +// See https://github.com/matrix-org/matrix-spec-proposals/pull/4291 +func (rv RoomVersion) RoomIDIsCreateEventID() bool { + return rv.NotEquals(RoomV0, RoomV1, RoomV2, RoomV3, RoomV4, RoomV5, RoomV6, RoomV7, RoomV8, RoomV9, RoomV10, RoomV11) +} diff --git a/id/servername.go b/id/servername.go new file mode 100644 index 00000000..923705b6 --- /dev/null +++ b/id/servername.go @@ -0,0 +1,58 @@ +// Copyright (c) 2025 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package id + +import ( + "regexp" + "strconv" +) + +type ParsedServerNameType int + +const ( + ServerNameDNS ParsedServerNameType = iota + ServerNameIPv4 + ServerNameIPv6 +) + +type ParsedServerName struct { + Type ParsedServerNameType + Host string + Port int +} + +var ServerNameRegex = regexp.MustCompile(`^(?:\[([0-9A-Fa-f:.]{2,45})]|(\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3})|([0-9A-Za-z.-]{1,255}))(?::(\d{1,5}))?$`) + +func ValidateServerName(serverName string) bool { + return len(serverName) <= 255 && len(serverName) > 0 && ServerNameRegex.MatchString(serverName) +} + +func ParseServerName(serverName string) *ParsedServerName { + if len(serverName) > 255 || len(serverName) < 1 { + return nil + } + match := ServerNameRegex.FindStringSubmatch(serverName) + if len(match) != 5 { + return nil + } + port, _ := strconv.Atoi(match[4]) + parsed := &ParsedServerName{ + Port: port, + } + switch { + case match[1] != "": + parsed.Type = ServerNameIPv6 + parsed.Host = match[1] + case match[2] != "": + parsed.Type = ServerNameIPv4 + parsed.Host = match[2] + case match[3] != "": + parsed.Type = ServerNameDNS + parsed.Host = match[3] + } + return parsed +} diff --git a/id/trust.go b/id/trust.go index 04f6e36b..6255093e 100644 --- a/id/trust.go +++ b/id/trust.go @@ -16,6 +16,7 @@ type TrustState int const ( TrustStateBlacklisted TrustState = -100 + TrustStateDeviceKeyMismatch TrustState = -5 TrustStateUnset TrustState = 0 TrustStateUnknownDevice TrustState = 10 TrustStateForwarded TrustState = 20 @@ -23,7 +24,7 @@ const ( TrustStateCrossSignedTOFU TrustState = 100 TrustStateCrossSignedVerified TrustState = 200 TrustStateVerified TrustState = 300 - TrustStateInvalid TrustState = (1 << 31) - 1 + TrustStateInvalid TrustState = -2147483647 ) func (ts *TrustState) UnmarshalText(data []byte) error { @@ -44,6 +45,8 @@ func ParseTrustState(val string) TrustState { switch strings.ToLower(val) { case "blacklisted": return TrustStateBlacklisted + case "device-key-mismatch": + return TrustStateDeviceKeyMismatch case "unverified": return TrustStateUnset case "cross-signed-untrusted": @@ -67,6 +70,8 @@ func (ts TrustState) String() string { switch ts { case TrustStateBlacklisted: return "blacklisted" + case TrustStateDeviceKeyMismatch: + return "device-key-mismatch" case TrustStateUnset: return "unverified" case TrustStateCrossSignedUntrusted: diff --git a/id/userid.go b/id/userid.go index 53b68b96..726a0d58 100644 --- a/id/userid.go +++ b/id/userid.go @@ -30,10 +30,11 @@ func NewEncodedUserID(localpart, homeserver string) UserID { } var ( - ErrInvalidUserID = errors.New("is not a valid user ID") - ErrNoncompliantLocalpart = errors.New("contains characters that are not allowed") - ErrUserIDTooLong = errors.New("the given user ID is longer than 255 characters") - ErrEmptyLocalpart = errors.New("empty localparts are not allowed") + ErrInvalidUserID = errors.New("is not a valid user ID") + ErrNoncompliantLocalpart = errors.New("contains characters that are not allowed") + ErrUserIDTooLong = errors.New("the given user ID is longer than 255 characters") + ErrEmptyLocalpart = errors.New("empty localparts are not allowed") + ErrNoncompliantServerPart = errors.New("is not a valid server name") ) // ParseCommonIdentifier parses a common identifier according to https://spec.matrix.org/v1.9/appendices/#common-identifier-format @@ -43,10 +44,10 @@ func ParseCommonIdentifier[Stringish ~string](identifier Stringish) (sigil byte, } sigil = identifier[0] strIdentifier := string(identifier) - if strings.ContainsRune(strIdentifier, ':') { - parts := strings.SplitN(strIdentifier, ":", 2) - localpart = parts[0][1:] - homeserver = parts[1] + colonIdx := strings.IndexByte(strIdentifier, ':') + if colonIdx > 0 { + localpart = strIdentifier[1:colonIdx] + homeserver = strIdentifier[colonIdx+1:] } else { localpart = strIdentifier[1:] } @@ -81,6 +82,9 @@ func (userID UserID) Homeserver() string { // // This does not parse or validate the user ID. Use the ParseAndValidate method if you want to ensure the user ID is valid first. func (userID UserID) URI() *MatrixURI { + if userID == "" { + return nil + } return &MatrixURI{ Sigil1: '@', MXID1: string(userID)[1:], @@ -100,21 +104,32 @@ func ValidateUserLocalpart(localpart string) error { return nil } -// ParseAndValidate parses the user ID into the localpart and server name like Parse, -// and also validates that the localpart is allowed according to the user identifiers spec. -func (userID UserID) ParseAndValidate() (localpart, homeserver string, err error) { - localpart, homeserver, err = userID.Parse() +// ParseAndValidateStrict is a stricter version of ParseAndValidateRelaxed that checks the localpart to only allow non-historical localparts. +// This should be used with care: there are real users still using historical localparts. +func (userID UserID) ParseAndValidateStrict() (localpart, homeserver string, err error) { + localpart, homeserver, err = userID.ParseAndValidateRelaxed() if err == nil { err = ValidateUserLocalpart(localpart) } - if err == nil && len(userID) > UserIDMaxLength { + return +} + +// ParseAndValidateRelaxed parses the user ID into the localpart and server name like Parse, +// and also validates that the user ID is not too long and that the server name is valid. +func (userID UserID) ParseAndValidateRelaxed() (localpart, homeserver string, err error) { + if len(userID) > UserIDMaxLength { err = ErrUserIDTooLong + return + } + localpart, homeserver, err = userID.Parse() + if err == nil && !ValidateServerName(homeserver) { + err = fmt.Errorf("%q %q", homeserver, ErrNoncompliantServerPart) } return } func (userID UserID) ParseAndDecode() (localpart, homeserver string, err error) { - localpart, homeserver, err = userID.ParseAndValidate() + localpart, homeserver, err = userID.ParseAndValidateStrict() if err == nil { localpart, err = DecodeUserLocalpart(localpart) } @@ -204,15 +219,15 @@ func DecodeUserLocalpart(str string) (string, error) { for i := 0; i < len(strBytes); i++ { b := strBytes[i] if !isValidByte(b) { - return "", fmt.Errorf("Byte pos %d: Invalid byte", i) + return "", fmt.Errorf("invalid encoded byte at position %d: %c", i, b) } if b == '_' { // next byte is a-z and should be upper-case or is another _ and should be a literal _ if i+1 >= len(strBytes) { - return "", fmt.Errorf("Byte pos %d: expected _[a-z_] encoding but ran out of string", i) + return "", fmt.Errorf("unexpected end of string after underscore at %d", i) } if !isValidEscapedChar(strBytes[i+1]) { // invalid escaping - return "", fmt.Errorf("Byte pos %d: expected _[a-z_] encoding", i) + return "", fmt.Errorf("unexpected byte %c after underscore at %d", strBytes[i+1], i) } if strBytes[i+1] == '_' { outputBuffer.WriteByte('_') @@ -222,7 +237,7 @@ func DecodeUserLocalpart(str string) (string, error) { i++ // skip next byte since we just handled it } else if b == '=' { // next 2 bytes are hex and should be buffered ready to be read as utf8 if i+2 >= len(strBytes) { - return "", fmt.Errorf("Byte pos: %d: expected quote-printable encoding but ran out of string", i) + return "", fmt.Errorf("unexpected end of string after equals sign at %d", i) } dst := make([]byte, 1) _, err := hex.Decode(dst, strBytes[i+1:i+3]) diff --git a/id/userid_test.go b/id/userid_test.go index 359bc687..57a88066 100644 --- a/id/userid_test.go +++ b/id/userid_test.go @@ -38,30 +38,30 @@ func TestUserID_Parse_Invalid(t *testing.T) { assert.True(t, errors.Is(err, id.ErrInvalidUserID)) } -func TestUserID_ParseAndValidate_Invalid(t *testing.T) { +func TestUserID_ParseAndValidateStrict_Invalid(t *testing.T) { const inputUserID = "@s p a c e:maunium.net" - _, _, err := id.UserID(inputUserID).ParseAndValidate() + _, _, err := id.UserID(inputUserID).ParseAndValidateStrict() assert.Error(t, err) assert.True(t, errors.Is(err, id.ErrNoncompliantLocalpart)) } -func TestUserID_ParseAndValidate_Empty(t *testing.T) { +func TestUserID_ParseAndValidateStrict_Empty(t *testing.T) { const inputUserID = "@:ponies.im" - _, _, err := id.UserID(inputUserID).ParseAndValidate() + _, _, err := id.UserID(inputUserID).ParseAndValidateStrict() assert.Error(t, err) assert.True(t, errors.Is(err, id.ErrEmptyLocalpart)) } -func TestUserID_ParseAndValidate_Long(t *testing.T) { +func TestUserID_ParseAndValidateStrict_Long(t *testing.T) { const inputUserID = "@aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa:example.com" - _, _, err := id.UserID(inputUserID).ParseAndValidate() + _, _, err := id.UserID(inputUserID).ParseAndValidateStrict() assert.Error(t, err) assert.True(t, errors.Is(err, id.ErrUserIDTooLong)) } -func TestUserID_ParseAndValidate_NotLong(t *testing.T) { +func TestUserID_ParseAndValidateStrict_NotLong(t *testing.T) { const inputUserID = "@aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa:example.com" - _, _, err := id.UserID(inputUserID).ParseAndValidate() + _, _, err := id.UserID(inputUserID).ParseAndValidateStrict() assert.NoError(t, err) } @@ -70,7 +70,7 @@ func TestUserIDEncoding(t *testing.T) { const encodedLocalpart = "_this=20local+part=20contains=20_il_le_ga_l=20ch=c3=a4racters=20=f0=9f=9a=a8" const inputServerName = "example.com" userID := id.NewEncodedUserID(inputLocalpart, inputServerName) - parsedLocalpart, parsedServerName, err := userID.ParseAndValidate() + parsedLocalpart, parsedServerName, err := userID.ParseAndValidateStrict() assert.NoError(t, err) assert.Equal(t, encodedLocalpart, parsedLocalpart) assert.Equal(t, inputServerName, parsedServerName) diff --git a/mediaproxy/mediaproxy.go b/mediaproxy/mediaproxy.go new file mode 100644 index 00000000..4d2bc7cf --- /dev/null +++ b/mediaproxy/mediaproxy.go @@ -0,0 +1,525 @@ +// Copyright (c) 2025 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package mediaproxy + +import ( + "context" + "errors" + "fmt" + "io" + "mime" + "mime/multipart" + "net/http" + "net/textproto" + "net/url" + "os" + "strconv" + "strings" + "time" + + "github.com/rs/zerolog" + "github.com/rs/zerolog/hlog" + "go.mau.fi/util/exerrors" + "go.mau.fi/util/exhttp" + "go.mau.fi/util/ptr" + "go.mau.fi/util/requestlog" + + "maunium.net/go/mautrix" + "maunium.net/go/mautrix/federation" + "maunium.net/go/mautrix/id" +) + +type GetMediaResponse interface { + isGetMediaResponse() +} + +func (*GetMediaResponseURL) isGetMediaResponse() {} +func (*GetMediaResponseData) isGetMediaResponse() {} +func (*GetMediaResponseCallback) isGetMediaResponse() {} +func (*GetMediaResponseFile) isGetMediaResponse() {} + +type GetMediaResponseURL struct { + URL string + ExpiresAt time.Time +} + +type GetMediaResponseWriter interface { + GetMediaResponse + io.WriterTo + GetContentType() string + GetContentLength() int64 +} + +var ( + _ GetMediaResponseWriter = (*GetMediaResponseCallback)(nil) + _ GetMediaResponseWriter = (*GetMediaResponseData)(nil) +) + +type GetMediaResponseData struct { + Reader io.ReadCloser + ContentType string + ContentLength int64 +} + +func (d *GetMediaResponseData) WriteTo(w io.Writer) (int64, error) { + return io.Copy(w, d.Reader) +} + +func (d *GetMediaResponseData) GetContentType() string { + return d.ContentType +} + +func (d *GetMediaResponseData) GetContentLength() int64 { + return d.ContentLength +} + +type GetMediaResponseCallback struct { + Callback func(w io.Writer) (int64, error) + ContentType string + ContentLength int64 +} + +func (d *GetMediaResponseCallback) WriteTo(w io.Writer) (int64, error) { + return d.Callback(w) +} + +func (d *GetMediaResponseCallback) GetContentLength() int64 { + return d.ContentLength +} + +func (d *GetMediaResponseCallback) GetContentType() string { + return d.ContentType +} + +type FileMeta struct { + ContentType string + ReplacementFile string +} + +type GetMediaResponseFile struct { + Callback func(w *os.File) (*FileMeta, error) +} + +type GetMediaFunc = func(ctx context.Context, mediaID string, params map[string]string) (response GetMediaResponse, err error) + +type MediaProxy struct { + KeyServer *federation.KeyServer + ServerAuth *federation.ServerAuth + + GetMedia GetMediaFunc + PrepareProxyRequest func(*http.Request) + + serverName string + serverKey *federation.SigningKey + + FederationRouter *http.ServeMux + ClientMediaRouter *http.ServeMux +} + +func New(serverName string, serverKey string, getMedia GetMediaFunc) (*MediaProxy, error) { + parsed, err := federation.ParseSynapseKey(serverKey) + if err != nil { + return nil, err + } + mp := &MediaProxy{ + serverName: serverName, + serverKey: parsed, + GetMedia: getMedia, + KeyServer: &federation.KeyServer{ + KeyProvider: &federation.StaticServerKey{ + ServerName: serverName, + Key: parsed, + }, + WellKnownTarget: fmt.Sprintf("%s:443", serverName), + Version: federation.ServerVersion{ + Name: "mautrix-go media proxy", + Version: strings.TrimPrefix(mautrix.VersionWithCommit, "v"), + }, + }, + } + mp.FederationRouter = http.NewServeMux() + mp.FederationRouter.HandleFunc("GET /v1/media/download/{mediaID}", mp.DownloadMediaFederation) + mp.FederationRouter.HandleFunc("GET /v1/media/thumbnail/{mediaID}", mp.DownloadMediaFederation) + mp.FederationRouter.HandleFunc("GET /v1/version", mp.KeyServer.GetServerVersion) + mp.ClientMediaRouter = http.NewServeMux() + mp.ClientMediaRouter.HandleFunc("GET /download/{serverName}/{mediaID}", mp.DownloadMedia) + mp.ClientMediaRouter.HandleFunc("GET /download/{serverName}/{mediaID}/{fileName}", mp.DownloadMedia) + mp.ClientMediaRouter.HandleFunc("GET /thumbnail/{serverName}/{mediaID}", mp.DownloadMedia) + mp.ClientMediaRouter.HandleFunc("PUT /upload/{serverName}/{mediaID}", mp.UploadNotSupported) + mp.ClientMediaRouter.HandleFunc("POST /upload", mp.UploadNotSupported) + mp.ClientMediaRouter.HandleFunc("POST /create", mp.UploadNotSupported) + mp.ClientMediaRouter.HandleFunc("GET /config", mp.UploadNotSupported) + mp.ClientMediaRouter.HandleFunc("GET /preview_url", mp.PreviewURLNotSupported) + return mp, nil +} + +type BasicConfig struct { + ServerName string `yaml:"server_name" json:"server_name"` + ServerKey string `yaml:"server_key" json:"server_key"` + FederationAuth bool `yaml:"federation_auth" json:"federation_auth"` + WellKnownResponse string `yaml:"well_known_response" json:"well_known_response"` +} + +func NewFromConfig(cfg BasicConfig, getMedia GetMediaFunc) (*MediaProxy, error) { + mp, err := New(cfg.ServerName, cfg.ServerKey, getMedia) + if err != nil { + return nil, err + } + if cfg.WellKnownResponse != "" { + mp.KeyServer.WellKnownTarget = cfg.WellKnownResponse + } + if cfg.FederationAuth { + mp.EnableServerAuth(nil, nil) + } + return mp, nil +} + +type ServerConfig struct { + Hostname string `yaml:"hostname" json:"hostname"` + Port uint16 `yaml:"port" json:"port"` +} + +func (mp *MediaProxy) Listen(cfg ServerConfig) error { + router := http.NewServeMux() + mp.RegisterRoutes(router, zerolog.Nop()) + return http.ListenAndServe(fmt.Sprintf("%s:%d", cfg.Hostname, cfg.Port), router) +} + +func (mp *MediaProxy) GetServerName() string { + return mp.serverName +} + +func (mp *MediaProxy) GetServerKey() *federation.SigningKey { + return mp.serverKey +} + +func (mp *MediaProxy) EnableServerAuth(client *federation.Client, keyCache federation.KeyCache) { + if keyCache == nil { + keyCache = federation.NewInMemoryCache() + } + if client == nil { + resCache, _ := keyCache.(federation.ResolutionCache) + client = federation.NewClient(mp.serverName, mp.serverKey, resCache) + } + mp.ServerAuth = federation.NewServerAuth(client, keyCache, func(auth federation.XMatrixAuth) string { + return mp.GetServerName() + }) +} + +func (mp *MediaProxy) RegisterRoutes(router *http.ServeMux, log zerolog.Logger) { + errorBodies := exhttp.ErrorBodies{ + NotFound: exerrors.Must(ptr.Ptr(mautrix.MUnrecognized.WithMessage("Unrecognized endpoint")).MarshalJSON()), + MethodNotAllowed: exerrors.Must(ptr.Ptr(mautrix.MUnrecognized.WithMessage("Invalid method for endpoint")).MarshalJSON()), + } + router.Handle("/_matrix/federation/", exhttp.ApplyMiddleware( + mp.FederationRouter, + exhttp.StripPrefix("/_matrix/federation"), + hlog.NewHandler(log), + hlog.RequestIDHandler("request_id", "Request-Id"), + requestlog.AccessLogger(requestlog.Options{TrustXForwardedFor: true}), + exhttp.HandleErrors(errorBodies), + )) + router.Handle("/_matrix/client/v1/media/", exhttp.ApplyMiddleware( + mp.ClientMediaRouter, + exhttp.StripPrefix("/_matrix/client/v1/media"), + hlog.NewHandler(log), + hlog.RequestIDHandler("request_id", "Request-Id"), + exhttp.CORSMiddleware, + requestlog.AccessLogger(requestlog.Options{TrustXForwardedFor: true}), + exhttp.HandleErrors(errorBodies), + )) + mp.KeyServer.Register(router, log) +} + +var ErrInvalidMediaIDSyntax = errors.New("invalid media ID syntax") + +func queryToMap(vals url.Values) map[string]string { + m := make(map[string]string, len(vals)) + for k, v := range vals { + m[k] = v[0] + } + return m +} + +func (mp *MediaProxy) getMedia(w http.ResponseWriter, r *http.Request) GetMediaResponse { + mediaID := r.PathValue("mediaID") + if !id.IsValidMediaID(mediaID) { + mautrix.MNotFound.WithMessage("Media ID %q is not valid", mediaID).Write(w) + return nil + } + resp, err := mp.GetMedia(r.Context(), mediaID, queryToMap(r.URL.Query())) + if err != nil { + var mautrixRespError mautrix.RespError + if errors.Is(err, ErrInvalidMediaIDSyntax) { + mautrix.MNotFound.WithMessage("This is a media proxy at %q, other media downloads are not available here", mp.serverName).Write(w) + } else if errors.As(err, &mautrixRespError) { + mautrixRespError.Write(w) + } else { + zerolog.Ctx(r.Context()).Err(err).Str("media_id", mediaID).Msg("Failed to get media URL") + mautrix.MNotFound.WithMessage("Media not found").Write(w) + } + return nil + } + return resp +} + +func startMultipart(ctx context.Context, w http.ResponseWriter) *multipart.Writer { + mpw := multipart.NewWriter(w) + w.Header().Set("Content-Type", strings.Replace(mpw.FormDataContentType(), "form-data", "mixed", 1)) + w.WriteHeader(http.StatusOK) + metaPart, err := mpw.CreatePart(textproto.MIMEHeader{ + "Content-Type": {"application/json"}, + }) + if err != nil { + zerolog.Ctx(ctx).Err(err).Msg("Failed to create multipart metadata field") + return nil + } + _, err = metaPart.Write([]byte(`{}`)) + if err != nil { + zerolog.Ctx(ctx).Err(err).Msg("Failed to write multipart metadata field") + return nil + } + return mpw +} + +func (mp *MediaProxy) DownloadMediaFederation(w http.ResponseWriter, r *http.Request) { + if mp.ServerAuth != nil { + var err *mautrix.RespError + r, err = mp.ServerAuth.Authenticate(r) + if err != nil { + err.Write(w) + return + } + } + ctx := r.Context() + log := zerolog.Ctx(ctx) + + resp := mp.getMedia(w, r) + if resp == nil { + return + } + + var mpw *multipart.Writer + if urlResp, ok := resp.(*GetMediaResponseURL); ok { + mpw = startMultipart(ctx, w) + if mpw == nil { + return + } + _, err := mpw.CreatePart(textproto.MIMEHeader{ + "Location": {urlResp.URL}, + }) + if err != nil { + log.Err(err).Msg("Failed to create multipart redirect field") + return + } + } else if fileResp, ok := resp.(*GetMediaResponseFile); ok { + responseStarted, err := doTempFileDownload(fileResp, func(wt io.WriterTo, size int64, mimeType string) error { + mpw = startMultipart(ctx, w) + if mpw == nil { + return fmt.Errorf("failed to start multipart writer") + } + dataPart, err := mpw.CreatePart(textproto.MIMEHeader{ + "Content-Type": {mimeType}, + }) + if err != nil { + return fmt.Errorf("failed to create multipart data field: %w", err) + } + _, err = wt.WriteTo(dataPart) + return err + }) + if err != nil { + log.Err(err).Msg("Failed to do media proxy with temp file") + if !responseStarted { + var mautrixRespError mautrix.RespError + if errors.As(err, &mautrixRespError) { + mautrixRespError.Write(w) + } else { + mautrix.MUnknown.WithMessage("Internal error proxying media").Write(w) + } + } + return + } + } else if dataResp, ok := resp.(GetMediaResponseWriter); ok { + mpw = startMultipart(ctx, w) + if mpw == nil { + return + } + dataPart, err := mpw.CreatePart(textproto.MIMEHeader{ + "Content-Type": {dataResp.GetContentType()}, + }) + if err != nil { + log.Err(err).Msg("Failed to create multipart data field") + return + } + _, err = dataResp.WriteTo(dataPart) + if err != nil { + log.Err(err).Msg("Failed to write multipart data field") + return + } + } else { + panic(fmt.Errorf("unknown GetMediaResponse type %T", resp)) + } + err := mpw.Close() + if err != nil { + log.Err(err).Msg("Failed to close multipart writer") + return + } +} + +func (mp *MediaProxy) addHeaders(w http.ResponseWriter, mimeType, fileName string) { + w.Header().Set("Cache-Control", "public, max-age=31536000, immutable") + contentDisposition := "attachment" + switch mimeType { + case "text/css", "text/plain", "text/csv", "application/json", "application/ld+json", "image/jpeg", "image/gif", + "image/png", "image/apng", "image/webp", "image/avif", "video/mp4", "video/webm", "video/ogg", "video/quicktime", + "audio/mp4", "audio/webm", "audio/aac", "audio/mpeg", "audio/ogg", "audio/wave", "audio/wav", "audio/x-wav", + "audio/x-pn-wav", "audio/flac", "audio/x-flac", "application/pdf": + contentDisposition = "inline" + } + if fileName != "" { + contentDisposition = mime.FormatMediaType(contentDisposition, map[string]string{ + "filename": fileName, + }) + } + w.Header().Set("Content-Disposition", contentDisposition) + w.Header().Set("Content-Type", mimeType) +} + +func (mp *MediaProxy) DownloadMedia(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + log := zerolog.Ctx(ctx) + if r.PathValue("serverName") != mp.serverName { + mautrix.MNotFound.WithMessage("This is a media proxy at %q, other media downloads are not available here", mp.serverName).Write(w) + return + } + resp := mp.getMedia(w, r) + if resp == nil { + return + } + + if urlResp, ok := resp.(*GetMediaResponseURL); ok { + w.Header().Set("Location", urlResp.URL) + expirySeconds := (time.Until(urlResp.ExpiresAt) - 5*time.Minute).Seconds() + if urlResp.ExpiresAt.IsZero() { + w.Header().Set("Cache-Control", "public, max-age=31536000, immutable") + } else if expirySeconds > 0 { + cacheControl := fmt.Sprintf("public, max-age=%d, immutable", int(expirySeconds)) + w.Header().Set("Cache-Control", cacheControl) + } else { + w.Header().Set("Cache-Control", "no-store") + } + w.WriteHeader(http.StatusTemporaryRedirect) + } else if fileResp, ok := resp.(*GetMediaResponseFile); ok { + responseStarted, err := doTempFileDownload(fileResp, func(wt io.WriterTo, size int64, mimeType string) error { + mp.addHeaders(w, mimeType, r.PathValue("fileName")) + w.Header().Set("Content-Length", strconv.FormatInt(size, 10)) + w.WriteHeader(http.StatusOK) + _, err := wt.WriteTo(w) + return err + }) + if err != nil { + log.Err(err).Msg("Failed to do media proxy with temp file") + if !responseStarted { + var mautrixRespError mautrix.RespError + if errors.As(err, &mautrixRespError) { + mautrixRespError.Write(w) + } else { + mautrix.MUnknown.WithMessage("Internal error proxying media").Write(w) + } + } + } + } else if writerResp, ok := resp.(GetMediaResponseWriter); ok { + if dataResp, ok := writerResp.(*GetMediaResponseData); ok { + defer dataResp.Reader.Close() + } + mp.addHeaders(w, writerResp.GetContentType(), r.PathValue("fileName")) + if writerResp.GetContentLength() != 0 { + w.Header().Set("Content-Length", strconv.FormatInt(writerResp.GetContentLength(), 10)) + } + w.WriteHeader(http.StatusOK) + _, err := writerResp.WriteTo(w) + if err != nil { + log.Err(err).Msg("Failed to write media data") + } + } else { + panic(fmt.Errorf("unknown GetMediaResponse type %T", resp)) + } +} + +func doTempFileDownload( + data *GetMediaResponseFile, + respond func(w io.WriterTo, size int64, mimeType string) error, +) (bool, error) { + tempFile, err := os.CreateTemp("", "mautrix-mediaproxy-*") + if err != nil { + return false, fmt.Errorf("failed to create temp file: %w", err) + } + origTempFile := tempFile + defer func() { + _ = origTempFile.Close() + _ = os.Remove(origTempFile.Name()) + }() + meta, err := data.Callback(tempFile) + if err != nil { + return false, err + } + if meta.ReplacementFile != "" { + tempFile, err = os.Open(meta.ReplacementFile) + if err != nil { + return false, fmt.Errorf("failed to open replacement file: %w", err) + } + defer func() { + _ = tempFile.Close() + _ = os.Remove(origTempFile.Name()) + }() + } else { + _, err = tempFile.Seek(0, io.SeekStart) + if err != nil { + return false, fmt.Errorf("failed to seek to start of temp file: %w", err) + } + } + fileInfo, err := tempFile.Stat() + if err != nil { + return false, fmt.Errorf("failed to stat temp file: %w", err) + } + mimeType := meta.ContentType + if mimeType == "" { + buf := make([]byte, 512) + n, err := tempFile.Read(buf) + if err != nil { + return false, fmt.Errorf("failed to read temp file to detect mime: %w", err) + } + buf = buf[:n] + _, err = tempFile.Seek(0, io.SeekStart) + if err != nil { + return false, fmt.Errorf("failed to seek to start of temp file: %w", err) + } + mimeType = http.DetectContentType(buf) + } + err = respond(tempFile, fileInfo.Size(), mimeType) + if err != nil { + return true, err + } + return true, nil +} + +var ( + ErrUploadNotSupported = mautrix.MUnrecognized. + WithMessage("This is a media proxy and does not support media uploads."). + WithStatus(http.StatusNotImplemented) + ErrPreviewURLNotSupported = mautrix.MUnrecognized. + WithMessage("This is a media proxy and does not support URL previews."). + WithStatus(http.StatusNotImplemented) +) + +func (mp *MediaProxy) UploadNotSupported(w http.ResponseWriter, r *http.Request) { + ErrUploadNotSupported.Write(w) +} + +func (mp *MediaProxy) PreviewURLNotSupported(w http.ResponseWriter, r *http.Request) { + ErrPreviewURLNotSupported.Write(w) +} diff --git a/mockserver/mockserver.go b/mockserver/mockserver.go new file mode 100644 index 00000000..507c24a5 --- /dev/null +++ b/mockserver/mockserver.go @@ -0,0 +1,307 @@ +// Copyright (c) 2025 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package mockserver + +import ( + "context" + "encoding/json" + "fmt" + "io" + "maps" + "net/http" + "net/http/httptest" + "strings" + "testing" + + globallog "github.com/rs/zerolog/log" // zerolog-allow-global-log + "github.com/stretchr/testify/require" + "go.mau.fi/util/dbutil" + "go.mau.fi/util/exerrors" + "go.mau.fi/util/exhttp" + "go.mau.fi/util/random" + + "maunium.net/go/mautrix" + "maunium.net/go/mautrix/crypto" + "maunium.net/go/mautrix/crypto/cryptohelper" + "maunium.net/go/mautrix/event" + "maunium.net/go/mautrix/id" +) + +func mustDecode(r *http.Request, data any) { + exerrors.PanicIfNotNil(json.NewDecoder(r.Body).Decode(data)) +} + +type userAndDeviceID struct { + UserID id.UserID + DeviceID id.DeviceID +} + +type MockServer struct { + Router *http.ServeMux + Server *httptest.Server + + AccessTokenToUserID map[string]userAndDeviceID + DeviceInbox map[id.UserID]map[id.DeviceID][]event.Event + AccountData map[id.UserID]map[event.Type]json.RawMessage + DeviceKeys map[id.UserID]map[id.DeviceID]mautrix.DeviceKeys + OneTimeKeys map[id.UserID]map[id.DeviceID]map[id.KeyID]mautrix.OneTimeKey + MasterKeys map[id.UserID]mautrix.CrossSigningKeys + SelfSigningKeys map[id.UserID]mautrix.CrossSigningKeys + UserSigningKeys map[id.UserID]mautrix.CrossSigningKeys + + PopOTKs bool + MemoryStore bool +} + +func Create(t testing.TB) *MockServer { + t.Helper() + + server := MockServer{ + AccessTokenToUserID: map[string]userAndDeviceID{}, + DeviceInbox: map[id.UserID]map[id.DeviceID][]event.Event{}, + AccountData: map[id.UserID]map[event.Type]json.RawMessage{}, + DeviceKeys: map[id.UserID]map[id.DeviceID]mautrix.DeviceKeys{}, + OneTimeKeys: map[id.UserID]map[id.DeviceID]map[id.KeyID]mautrix.OneTimeKey{}, + MasterKeys: map[id.UserID]mautrix.CrossSigningKeys{}, + SelfSigningKeys: map[id.UserID]mautrix.CrossSigningKeys{}, + UserSigningKeys: map[id.UserID]mautrix.CrossSigningKeys{}, + PopOTKs: true, + MemoryStore: true, + } + + router := http.NewServeMux() + router.HandleFunc("POST /_matrix/client/v3/login", server.postLogin) + router.HandleFunc("POST /_matrix/client/v3/keys/query", server.postKeysQuery) + router.HandleFunc("POST /_matrix/client/v3/keys/claim", server.postKeysClaim) + router.HandleFunc("PUT /_matrix/client/v3/sendToDevice/{type}/{txn}", server.putSendToDevice) + router.HandleFunc("PUT /_matrix/client/v3/user/{userID}/account_data/{type}", server.putAccountData) + router.HandleFunc("POST /_matrix/client/v3/keys/device_signing/upload", server.postDeviceSigningUpload) + router.HandleFunc("POST /_matrix/client/v3/keys/signatures/upload", server.emptyResp) + router.HandleFunc("POST /_matrix/client/v3/keys/upload", server.postKeysUpload) + server.Router = router + server.Server = httptest.NewServer(router) + t.Cleanup(server.Server.Close) + return &server +} + +func (ms *MockServer) getUserID(r *http.Request) userAndDeviceID { + authHeader := r.Header.Get("Authorization") + authHeader = strings.TrimPrefix(authHeader, "Bearer ") + userID, ok := ms.AccessTokenToUserID[authHeader] + if !ok { + panic("no user ID found for access token " + authHeader) + } + return userID +} + +func (ms *MockServer) emptyResp(w http.ResponseWriter, _ *http.Request) { + exhttp.WriteEmptyJSONResponse(w, http.StatusOK) +} + +func (ms *MockServer) postLogin(w http.ResponseWriter, r *http.Request) { + var loginReq mautrix.ReqLogin + mustDecode(r, &loginReq) + + deviceID := loginReq.DeviceID + if deviceID == "" { + deviceID = id.DeviceID(random.String(10)) + } + + accessToken := random.String(30) + userID := id.UserID(loginReq.Identifier.User) + ms.AccessTokenToUserID[accessToken] = userAndDeviceID{ + UserID: userID, + DeviceID: deviceID, + } + + exhttp.WriteJSONResponse(w, http.StatusOK, &mautrix.RespLogin{ + AccessToken: accessToken, + DeviceID: deviceID, + UserID: userID, + }) +} + +func (ms *MockServer) putSendToDevice(w http.ResponseWriter, r *http.Request) { + var req mautrix.ReqSendToDevice + mustDecode(r, &req) + evtType := event.Type{Type: r.PathValue("type"), Class: event.ToDeviceEventType} + + for user, devices := range req.Messages { + for device, content := range devices { + if _, ok := ms.DeviceInbox[user]; !ok { + ms.DeviceInbox[user] = map[id.DeviceID][]event.Event{} + } + content.ParseRaw(evtType) + ms.DeviceInbox[user][device] = append(ms.DeviceInbox[user][device], event.Event{ + Sender: ms.getUserID(r).UserID, + Type: evtType, + Content: *content, + }) + } + } + ms.emptyResp(w, r) +} + +func (ms *MockServer) putAccountData(w http.ResponseWriter, r *http.Request) { + userID := id.UserID(r.PathValue("userID")) + eventType := event.Type{Type: r.PathValue("type"), Class: event.AccountDataEventType} + + jsonData, _ := io.ReadAll(r.Body) + if _, ok := ms.AccountData[userID]; !ok { + ms.AccountData[userID] = map[event.Type]json.RawMessage{} + } + ms.AccountData[userID][eventType] = json.RawMessage(jsonData) + ms.emptyResp(w, r) +} + +func (ms *MockServer) postKeysQuery(w http.ResponseWriter, r *http.Request) { + var req mautrix.ReqQueryKeys + mustDecode(r, &req) + resp := mautrix.RespQueryKeys{ + MasterKeys: map[id.UserID]mautrix.CrossSigningKeys{}, + UserSigningKeys: map[id.UserID]mautrix.CrossSigningKeys{}, + SelfSigningKeys: map[id.UserID]mautrix.CrossSigningKeys{}, + DeviceKeys: map[id.UserID]map[id.DeviceID]mautrix.DeviceKeys{}, + } + for user := range req.DeviceKeys { + resp.MasterKeys[user] = ms.MasterKeys[user] + resp.UserSigningKeys[user] = ms.UserSigningKeys[user] + resp.SelfSigningKeys[user] = ms.SelfSigningKeys[user] + resp.DeviceKeys[user] = ms.DeviceKeys[user] + } + exhttp.WriteJSONResponse(w, http.StatusOK, &resp) +} + +func (ms *MockServer) postKeysClaim(w http.ResponseWriter, r *http.Request) { + var req mautrix.ReqClaimKeys + mustDecode(r, &req) + resp := mautrix.RespClaimKeys{ + OneTimeKeys: map[id.UserID]map[id.DeviceID]map[id.KeyID]mautrix.OneTimeKey{}, + } + for user, devices := range req.OneTimeKeys { + resp.OneTimeKeys[user] = map[id.DeviceID]map[id.KeyID]mautrix.OneTimeKey{} + for device := range devices { + keys := ms.OneTimeKeys[user][device] + for keyID, key := range keys { + if ms.PopOTKs { + delete(keys, keyID) + } + resp.OneTimeKeys[user][device] = map[id.KeyID]mautrix.OneTimeKey{ + keyID: key, + } + break + } + } + } + exhttp.WriteJSONResponse(w, http.StatusOK, &resp) +} + +func (ms *MockServer) postKeysUpload(w http.ResponseWriter, r *http.Request) { + var req mautrix.ReqUploadKeys + mustDecode(r, &req) + + uid := ms.getUserID(r) + userID := uid.UserID + if _, ok := ms.DeviceKeys[userID]; !ok { + ms.DeviceKeys[userID] = map[id.DeviceID]mautrix.DeviceKeys{} + } + if _, ok := ms.OneTimeKeys[userID]; !ok { + ms.OneTimeKeys[userID] = map[id.DeviceID]map[id.KeyID]mautrix.OneTimeKey{} + } + + if req.DeviceKeys != nil { + ms.DeviceKeys[userID][uid.DeviceID] = *req.DeviceKeys + } + otks, ok := ms.OneTimeKeys[userID][uid.DeviceID] + if !ok { + otks = map[id.KeyID]mautrix.OneTimeKey{} + ms.OneTimeKeys[userID][uid.DeviceID] = otks + } + if req.OneTimeKeys != nil { + maps.Copy(otks, req.OneTimeKeys) + } + + exhttp.WriteJSONResponse(w, http.StatusOK, &mautrix.RespUploadKeys{ + OneTimeKeyCounts: mautrix.OTKCount{SignedCurve25519: len(otks)}, + }) +} + +func (ms *MockServer) postDeviceSigningUpload(w http.ResponseWriter, r *http.Request) { + var req mautrix.UploadCrossSigningKeysReq[any] + mustDecode(r, &req) + + userID := ms.getUserID(r).UserID + ms.MasterKeys[userID] = req.Master + ms.SelfSigningKeys[userID] = req.SelfSigning + ms.UserSigningKeys[userID] = req.UserSigning + + ms.emptyResp(w, r) +} + +func (ms *MockServer) Login(t testing.TB, ctx context.Context, userID id.UserID, deviceID id.DeviceID) (*mautrix.Client, crypto.Store) { + t.Helper() + if ctx == nil { + ctx = context.TODO() + } + client, err := mautrix.NewClient(ms.Server.URL, "", "") + require.NoError(t, err) + client.Client = ms.Server.Client() + + _, err = client.Login(ctx, &mautrix.ReqLogin{ + Type: mautrix.AuthTypePassword, + Identifier: mautrix.UserIdentifier{ + Type: mautrix.IdentifierTypeUser, + User: userID.String(), + }, + DeviceID: deviceID, + Password: "password", + StoreCredentials: true, + }) + require.NoError(t, err) + + var store any + if ms.MemoryStore { + store = crypto.NewMemoryStore(nil) + client.StateStore = mautrix.NewMemoryStateStore() + } else { + store, err = dbutil.NewFromConfig("", dbutil.Config{ + PoolConfig: dbutil.PoolConfig{ + Type: "sqlite3-fk-wal", + URI: fmt.Sprintf("file:%s?mode=memory&cache=shared&_txlock=immediate", random.String(10)), + MaxOpenConns: 5, + MaxIdleConns: 1, + }, + }, nil) + require.NoError(t, err) + } + cryptoHelper, err := cryptohelper.NewCryptoHelper(client, []byte("test"), store) + require.NoError(t, err) + client.Crypto = cryptoHelper + + err = cryptoHelper.Init(ctx) + require.NoError(t, err) + + machineLog := globallog.Logger.With(). + Stringer("my_user_id", userID). + Stringer("my_device_id", deviceID). + Logger() + cryptoHelper.Machine().Log = &machineLog + + err = cryptoHelper.Machine().ShareKeys(ctx, 50) + require.NoError(t, err) + + return client, cryptoHelper.Machine().CryptoStore +} + +func (ms *MockServer) DispatchToDevice(t testing.TB, ctx context.Context, client *mautrix.Client) { + t.Helper() + + for _, evt := range ms.DeviceInbox[client.UserID][client.DeviceID] { + client.Syncer.(*mautrix.DefaultSyncer).Dispatch(ctx, &evt) + ms.DeviceInbox[client.UserID][client.DeviceID] = ms.DeviceInbox[client.UserID][client.DeviceID][1:] + } +} diff --git a/pushrules/action.go b/pushrules/action.go index 9838e88b..b5a884b2 100644 --- a/pushrules/action.go +++ b/pushrules/action.go @@ -105,7 +105,7 @@ func (action *PushAction) UnmarshalJSON(raw []byte) error { if ok { action.Action = ActionSetTweak action.Tweak = PushActionTweak(tweak) - action.Value, _ = val["value"] + action.Value = val["value"] } } return nil diff --git a/pushrules/action_test.go b/pushrules/action_test.go index a8f68415..3c0aa168 100644 --- a/pushrules/action_test.go +++ b/pushrules/action_test.go @@ -139,9 +139,9 @@ func TestPushAction_UnmarshalJSON_InvalidTypeDoesNothing(t *testing.T) { } err := pa.UnmarshalJSON([]byte(`{"foo": "bar"}`)) - assert.Nil(t, err) + assert.NoError(t, err) err = pa.UnmarshalJSON([]byte(`9001`)) - assert.Nil(t, err) + assert.NoError(t, err) assert.Equal(t, pushrules.PushActionType("unchanged"), pa.Action) assert.Equal(t, pushrules.PushActionTweak("unchanged"), pa.Tweak) @@ -156,7 +156,7 @@ func TestPushAction_UnmarshalJSON_StringChangesActionType(t *testing.T) { } err := pa.UnmarshalJSON([]byte(`"foo"`)) - assert.Nil(t, err) + assert.NoError(t, err) assert.Equal(t, pushrules.PushActionType("foo"), pa.Action) assert.Equal(t, pushrules.PushActionTweak("unchanged"), pa.Tweak) @@ -171,7 +171,7 @@ func TestPushAction_UnmarshalJSON_SetTweakChangesTweak(t *testing.T) { } err := pa.UnmarshalJSON([]byte(`{"set_tweak": "foo", "value": 123.0}`)) - assert.Nil(t, err) + assert.NoError(t, err) assert.Equal(t, pushrules.ActionSetTweak, pa.Action) assert.Equal(t, pushrules.PushActionTweak("foo"), pa.Tweak) @@ -185,7 +185,7 @@ func TestPushAction_MarshalJSON_TweakOutputWorks(t *testing.T) { Value: "bar", } data, err := pa.MarshalJSON() - assert.Nil(t, err) + assert.NoError(t, err) assert.Equal(t, []byte(`{"set_tweak":"foo","value":"bar"}`), data) } @@ -196,6 +196,6 @@ func TestPushAction_MarshalJSON_OtherOutputWorks(t *testing.T) { Value: "bar", } data, err := pa.MarshalJSON() - assert.Nil(t, err) + assert.NoError(t, err) assert.Equal(t, []byte(`"something else"`), data) } diff --git a/pushrules/condition.go b/pushrules/condition.go index 435178fb..caa717de 100644 --- a/pushrules/condition.go +++ b/pushrules/condition.go @@ -15,10 +15,10 @@ import ( "unicode" "github.com/tidwall/gjson" + "go.mau.fi/util/glob" "maunium.net/go/mautrix/event" "maunium.net/go/mautrix/id" - "maunium.net/go/mautrix/pushrules/glob" ) // Room is an interface with the functions that are needed for processing room-specific push conditions @@ -27,6 +27,11 @@ type Room interface { GetMemberCount() int } +type PowerLevelfulRoom interface { + Room + GetPowerLevels() *event.PowerLevelsEventContent +} + // EventfulRoom is an extension of Room to support MSC3664. type EventfulRoom interface { Room @@ -38,11 +43,12 @@ type PushCondKind string // The allowed push condition kinds as specified in https://spec.matrix.org/v1.2/client-server-api/#conditions-1 const ( - KindEventMatch PushCondKind = "event_match" - KindContainsDisplayName PushCondKind = "contains_display_name" - KindRoomMemberCount PushCondKind = "room_member_count" - KindEventPropertyIs PushCondKind = "event_property_is" - KindEventPropertyContains PushCondKind = "event_property_contains" + KindEventMatch PushCondKind = "event_match" + KindContainsDisplayName PushCondKind = "contains_display_name" + KindRoomMemberCount PushCondKind = "room_member_count" + KindEventPropertyIs PushCondKind = "event_property_is" + KindEventPropertyContains PushCondKind = "event_property_contains" + KindSenderNotificationPermission PushCondKind = "sender_notification_permission" // MSC3664: https://github.com/matrix-org/matrix-spec-proposals/pull/3664 @@ -82,6 +88,8 @@ func (cond *PushCondition) Match(room Room, evt *event.Event) bool { return cond.matchDisplayName(room, evt) case KindRoomMemberCount: return cond.matchMemberCount(room) + case KindSenderNotificationPermission: + return cond.matchSenderNotificationPermission(room, evt.Sender, cond.Key) default: return false } @@ -219,11 +227,11 @@ func (cond *PushCondition) matchValue(evt *event.Event) bool { switch cond.Kind { case KindEventMatch, KindRelatedEventMatch, KindUnstableRelatedEventMatch: - pattern, err := glob.Compile(cond.Pattern) - if err != nil { + pattern := glob.CompileWithImplicitContains(cond.Pattern) + if pattern == nil { return false } - return pattern.MatchString(stringifyForPushCondition(val)) + return pattern.Match(stringifyForPushCondition(val)) case KindEventPropertyIs: return valueEquals(val, cond.Value) case KindEventPropertyContains: @@ -334,3 +342,18 @@ func (cond *PushCondition) matchMemberCount(room Room) bool { return false } } + +func (cond *PushCondition) matchSenderNotificationPermission(room Room, sender id.UserID, key string) bool { + if key != "room" { + return false + } + plRoom, ok := room.(PowerLevelfulRoom) + if !ok { + return false + } + pls := plRoom.GetPowerLevels() + if pls == nil { + return false + } + return pls.GetUserLevel(sender) >= pls.Notifications.Room() +} diff --git a/pushrules/condition_test.go b/pushrules/condition_test.go index 0d3eaf7a..37af3e34 100644 --- a/pushrules/condition_test.go +++ b/pushrules/condition_test.go @@ -102,14 +102,6 @@ func newEventPropertyIsPushCondition(key string, value any) *pushrules.PushCondi } } -func newEventPropertyContainsPushCondition(key string, value any) *pushrules.PushCondition { - return &pushrules.PushCondition{ - Kind: pushrules.KindEventPropertyContains, - Key: key, - Value: value, - } -} - func TestPushCondition_Match_InvalidKind(t *testing.T) { condition := &pushrules.PushCondition{ Kind: pushrules.PushCondKind("invalid"), diff --git a/pushrules/glob/LICENSE b/pushrules/glob/LICENSE deleted file mode 100644 index cb00d952..00000000 --- a/pushrules/glob/LICENSE +++ /dev/null @@ -1,22 +0,0 @@ -Glob is licensed under the MIT "Expat" License: - -Copyright (c) 2016: Zachary Yedidia. - -Permission is hereby granted, free of charge, to any person obtaining -a copy of this software and associated documentation files (the -"Software"), to deal in the Software without restriction, including -without limitation the rights to use, copy, modify, merge, publish, -distribute, sublicense, and/or sell copies of the Software, and to -permit persons to whom the Software is furnished to do so, subject to -the following conditions: - -The above copyright notice and this permission notice shall be -included in all copies or substantial portions of the Software. - -THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, -EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF -MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. -IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY -CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, -TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE -SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. diff --git a/pushrules/glob/README.md b/pushrules/glob/README.md deleted file mode 100644 index e2e6c649..00000000 --- a/pushrules/glob/README.md +++ /dev/null @@ -1,28 +0,0 @@ -# String globbing in Go - -[![GoDoc](https://godoc.org/github.com/zyedidia/glob?status.svg)](http://godoc.org/github.com/zyedidia/glob) - -This package adds support for globs in Go. - -It simply converts glob expressions to regexps. I try to follow the standard defined [here](http://pubs.opengroup.org/onlinepubs/009695399/utilities/xcu_chap02.html#tag_02_13). - -# Example - -```go -package main - -import "github.com/zyedidia/glob" - -func main() { - glob, err := glob.Compile("{*.go,*.c}") - if err != nil { - // Error - } - - glob.Match([]byte("test.c")) // true - glob.Match([]byte("hello.go")) // true - glob.Match([]byte("test.d")) // false -} -``` - -You can call all the same functions on a glob that you can call on a regexp. diff --git a/pushrules/glob/glob.go b/pushrules/glob/glob.go deleted file mode 100644 index c270dbc5..00000000 --- a/pushrules/glob/glob.go +++ /dev/null @@ -1,108 +0,0 @@ -// Package glob provides objects for matching strings with globs -package glob - -import "regexp" - -// Glob is a wrapper of *regexp.Regexp. -// It should contain a glob expression compiled into a regular expression. -type Glob struct { - *regexp.Regexp -} - -// Compile a takes a glob expression as a string and transforms it -// into a *Glob object (which is really just a regular expression) -// Compile also returns a possible error. -func Compile(pattern string) (*Glob, error) { - r, err := globToRegex(pattern) - return &Glob{r}, err -} - -func globToRegex(glob string) (*regexp.Regexp, error) { - regex := "" - inGroup := 0 - inClass := 0 - firstIndexInClass := -1 - arr := []byte(glob) - - hasGlobCharacters := false - - for i := 0; i < len(arr); i++ { - ch := arr[i] - - switch ch { - case '\\': - i++ - if i >= len(arr) { - regex += "\\" - } else { - next := arr[i] - switch next { - case ',': - // Nothing - case 'Q', 'E': - regex += "\\\\" - default: - regex += "\\" - } - regex += string(next) - } - case '*': - if inClass == 0 { - regex += ".*" - } else { - regex += "*" - } - hasGlobCharacters = true - case '?': - if inClass == 0 { - regex += "." - } else { - regex += "?" - } - hasGlobCharacters = true - case '[': - inClass++ - firstIndexInClass = i + 1 - regex += "[" - hasGlobCharacters = true - case ']': - inClass-- - regex += "]" - case '.', '(', ')', '+', '|', '^', '$', '@', '%': - if inClass == 0 || (firstIndexInClass == i && ch == '^') { - regex += "\\" - } - regex += string(ch) - hasGlobCharacters = true - case '!': - if firstIndexInClass == i { - regex += "^" - } else { - regex += "!" - } - hasGlobCharacters = true - case '{': - inGroup++ - regex += "(" - hasGlobCharacters = true - case '}': - inGroup-- - regex += ")" - case ',': - if inGroup > 0 { - regex += "|" - hasGlobCharacters = true - } else { - regex += "," - } - default: - regex += string(ch) - } - } - - if hasGlobCharacters { - return regexp.Compile("^" + regex + "$") - } else { - return regexp.Compile(regex) - } -} diff --git a/pushrules/pushrules_test.go b/pushrules/pushrules_test.go index a531ca28..a5a0f5e7 100644 --- a/pushrules/pushrules_test.go +++ b/pushrules/pushrules_test.go @@ -25,7 +25,7 @@ func TestEventToPushRules(t *testing.T) { }, } pushRuleset, err := pushrules.EventToPushRules(evt) - assert.Nil(t, err) + assert.NoError(t, err) assert.NotNil(t, pushRuleset) assert.IsType(t, pushRuleset.Override, pushrules.PushRuleArray{}) diff --git a/pushrules/rule.go b/pushrules/rule.go index 0f7436f3..cf659695 100644 --- a/pushrules/rule.go +++ b/pushrules/rule.go @@ -8,10 +8,14 @@ package pushrules import ( "encoding/gob" + "regexp" + "strings" + + "go.mau.fi/util/exerrors" + "go.mau.fi/util/glob" "maunium.net/go/mautrix/event" "maunium.net/go/mautrix/id" - "maunium.net/go/mautrix/pushrules/glob" ) func init() { @@ -164,13 +168,20 @@ func (rule *PushRule) matchConditions(room Room, evt *event.Event) bool { } func (rule *PushRule) matchPattern(room Room, evt *event.Event) bool { - pattern, err := glob.Compile(rule.Pattern) - if err != nil { - return false - } msg, ok := evt.Content.Raw["body"].(string) if !ok { return false } + var buf strings.Builder + // As per https://spec.matrix.org/unstable/client-server-api/#push-rules, content rules are case-insensitive + // and must match whole words, so wrap the converted glob in (?i) and \b. + buf.WriteString(`(?i)\b`) + // strings.Builder will never return errors + exerrors.PanicIfNotNil(glob.ToRegexPattern(rule.Pattern, &buf)) + buf.WriteString(`\b`) + pattern, err := regexp.Compile(buf.String()) + if err != nil { + return false + } return pattern.MatchString(msg) } diff --git a/pushrules/rule_test.go b/pushrules/rule_test.go index 803c721e..7ff839a7 100644 --- a/pushrules/rule_test.go +++ b/pushrules/rule_test.go @@ -186,6 +186,34 @@ func TestPushRule_Match_Content(t *testing.T) { assert.True(t, rule.Match(blankTestRoom, evt)) } +func TestPushRule_Match_WordBoundary(t *testing.T) { + rule := &pushrules.PushRule{ + Type: pushrules.ContentRule, + Enabled: true, + Pattern: "test", + } + + evt := newFakeEvent(event.EventMessage, &event.MessageEventContent{ + MsgType: event.MsgEmote, + Body: "is testing pushrules", + }) + assert.False(t, rule.Match(blankTestRoom, evt)) +} + +func TestPushRule_Match_CaseInsensitive(t *testing.T) { + rule := &pushrules.PushRule{ + Type: pushrules.ContentRule, + Enabled: true, + Pattern: "test", + } + + evt := newFakeEvent(event.EventMessage, &event.MessageEventContent{ + MsgType: event.MsgEmote, + Body: "is TeSt-InG pushrules", + }) + assert.True(t, rule.Match(blankTestRoom, evt)) +} + func TestPushRule_Match_Content_Fail(t *testing.T) { rule := &pushrules.PushRule{ Type: pushrules.ContentRule, diff --git a/pushrules/ruleset.go b/pushrules/ruleset.go index 609997b4..c42d4799 100644 --- a/pushrules/ruleset.go +++ b/pushrules/ruleset.go @@ -68,6 +68,9 @@ func (rs *PushRuleset) MarshalJSON() ([]byte, error) { var DefaultPushActions = PushActionArray{&PushAction{Action: ActionDontNotify}} func (rs *PushRuleset) GetMatchingRule(room Room, evt *event.Event) (rule *PushRule) { + if rs == nil { + return nil + } // Add push rule collections to array in priority order arrays := []PushRuleCollection{rs.Override, rs.Content, rs.Room, rs.Sender, rs.Underride} // Loop until one of the push rule collections matches the room/event combo. diff --git a/requests.go b/requests.go index cdf020a0..cc8b7266 100644 --- a/requests.go +++ b/requests.go @@ -2,7 +2,9 @@ package mautrix import ( "encoding/json" + "fmt" "strconv" + "time" "maunium.net/go/mautrix/crypto/signatures" "maunium.net/go/mautrix/event" @@ -38,20 +40,40 @@ const ( type Direction rune +func (d Direction) MarshalJSON() ([]byte, error) { + return json.Marshal(string(d)) +} + +func (d *Direction) UnmarshalJSON(data []byte) error { + var str string + if err := json.Unmarshal(data, &str); err != nil { + return err + } + switch str { + case "f": + *d = DirectionForward + case "b": + *d = DirectionBackward + default: + return fmt.Errorf("invalid direction %q, must be 'f' or 'b'", str) + } + return nil +} + const ( DirectionForward Direction = 'f' DirectionBackward Direction = 'b' ) // ReqRegister is the JSON request for https://spec.matrix.org/v1.2/client-server-api/#post_matrixclientv3register -type ReqRegister struct { +type ReqRegister[UIAType any] struct { Username string `json:"username,omitempty"` Password string `json:"password,omitempty"` DeviceID id.DeviceID `json:"device_id,omitempty"` InitialDeviceDisplayName string `json:"initial_device_display_name,omitempty"` InhibitLogin bool `json:"inhibit_login,omitempty"` RefreshToken bool `json:"refresh_token,omitempty"` - Auth interface{} `json:"auth,omitempty"` + Auth UIAType `json:"auth,omitempty"` // Type for registration, only used for appservice user registrations // https://spec.matrix.org/v1.2/application-service-api/#server-admin-style-permissions @@ -83,6 +105,7 @@ type ReqLogin struct { Token string `json:"token,omitempty"` DeviceID id.DeviceID `json:"device_id,omitempty"` InitialDeviceDisplayName string `json:"initial_device_display_name,omitempty"` + RefreshToken bool `json:"refresh_token,omitempty"` // Whether or not the returned credentials should be stored in the Client StoreCredentials bool `json:"-"` @@ -90,6 +113,10 @@ type ReqLogin struct { StoreHomeserverURL bool `json:"-"` } +type ReqPutDevice struct { + DisplayName string `json:"display_name,omitempty"` +} + type ReqUIAuthFallback struct { Session string `json:"session"` User string `json:"user"` @@ -114,12 +141,17 @@ type ReqCreateRoom struct { InitialState []*event.Event `json:"initial_state,omitempty"` Preset string `json:"preset,omitempty"` IsDirect bool `json:"is_direct,omitempty"` - RoomVersion string `json:"room_version,omitempty"` + RoomVersion id.RoomVersion `json:"room_version,omitempty"` PowerLevelOverride *event.PowerLevelsEventContent `json:"power_level_content_override,omitempty"` - MeowRoomID id.RoomID `json:"fi.mau.room_id,omitempty"` - BeeperAutoJoinInvites bool `json:"com.beeper.auto_join_invites,omitempty"` + MeowRoomID id.RoomID `json:"fi.mau.room_id,omitempty"` + MeowCreateTS int64 `json:"fi.mau.origin_server_ts,omitempty"` + BeeperInitialMembers []id.UserID `json:"com.beeper.initial_members,omitempty"` + BeeperAutoJoinInvites bool `json:"com.beeper.auto_join_invites,omitempty"` + BeeperLocalRoomID id.RoomID `json:"com.beeper.local_room_id,omitempty"` + BeeperBridgeName string `json:"com.beeper.bridge_name,omitempty"` + BeeperBridgeAccountID string `json:"com.beeper.bridge_account_id,omitempty"` } // ReqRedact is the JSON request for https://spec.matrix.org/v1.2/client-server-api/#put_matrixclientv3roomsroomidredacteventidtxnid @@ -129,12 +161,37 @@ type ReqRedact struct { Extra map[string]interface{} } +type ReqRedactUser struct { + Reason string `json:"reason"` + Limit int `json:"-"` +} + type ReqMembers struct { At string `json:"at"` Membership event.Membership `json:"membership,omitempty"` NotMembership event.Membership `json:"not_membership,omitempty"` } +type ReqJoinRoom struct { + Via []string `json:"-"` + Reason string `json:"reason,omitempty"` + ThirdPartySigned any `json:"third_party_signed,omitempty"` +} + +type ReqKnockRoom struct { + Via []string `json:"-"` + Reason string `json:"reason,omitempty"` +} + +type ReqSearchUserDirectory struct { + SearchTerm string `json:"search_term"` + Limit int `json:"limit,omitempty"` +} + +type ReqMutualRooms struct { + From string `json:"-"` +} + // ReqInvite3PID is the JSON request for https://spec.matrix.org/v1.2/client-server-api/#post_matrixclientv3roomsroomidinvite-1 // It is also a JSON object used in https://spec.matrix.org/v1.2/client-server-api/#post_matrixclientv3createroom type ReqInvite3PID struct { @@ -163,6 +220,8 @@ type ReqKickUser struct { type ReqBanUser struct { Reason string `json:"reason,omitempty"` UserID id.UserID `json:"user_id"` + + MSC4293RedactEvents bool `json:"org.matrix.msc4293.redact_events,omitempty"` } // ReqUnbanUser is the JSON request for https://spec.matrix.org/v1.2/client-server-api/#post_matrixclientv3roomsroomidunban @@ -178,7 +237,8 @@ type ReqTyping struct { } type ReqPresence struct { - Presence event.Presence `json:"presence"` + Presence event.Presence `json:"presence"` + StatusMsg string `json:"status_msg,omitempty"` } type ReqAliasCreate struct { @@ -223,7 +283,7 @@ func (otk *OneTimeKey) MarshalJSON() ([]byte, error) { type ReqUploadKeys struct { DeviceKeys *DeviceKeys `json:"device_keys,omitempty"` - OneTimeKeys map[id.KeyID]OneTimeKey `json:"one_time_keys"` + OneTimeKeys map[id.KeyID]OneTimeKey `json:"one_time_keys,omitempty"` } type ReqKeysSignatures struct { @@ -260,11 +320,11 @@ func (csk *CrossSigningKeys) FirstKey() id.Ed25519 { return "" } -type UploadCrossSigningKeysReq struct { +type UploadCrossSigningKeysReq[UIAType any] struct { Master CrossSigningKeys `json:"master_key"` SelfSigning CrossSigningKeys `json:"self_signing_key"` UserSigning CrossSigningKeys `json:"user_signing_key"` - Auth interface{} `json:"auth,omitempty"` + Auth UIAType `json:"auth,omitempty"` } type KeyMap map[id.DeviceKeyID]string @@ -306,20 +366,40 @@ type ReqSendToDevice struct { Messages map[id.UserID]map[id.DeviceID]*event.Content `json:"messages"` } +type ReqSendEvent struct { + Timestamp int64 + TransactionID string + UnstableDelay time.Duration + UnstableStickyDuration time.Duration + DontEncrypt bool + MeowEventID id.EventID +} + +type ReqDelayedEvents struct { + DelayID id.DelayID `json:"-"` + Status event.DelayStatus `json:"-"` + NextBatch string `json:"-"` +} + +type ReqUpdateDelayedEvent struct { + DelayID id.DelayID `json:"-"` + Action event.DelayAction `json:"action"` +} + // ReqDeviceInfo is the JSON request for https://spec.matrix.org/v1.2/client-server-api/#put_matrixclientv3devicesdeviceid type ReqDeviceInfo struct { DisplayName string `json:"display_name,omitempty"` } // ReqDeleteDevice is the JSON request for https://spec.matrix.org/v1.2/client-server-api/#delete_matrixclientv3devicesdeviceid -type ReqDeleteDevice struct { - Auth interface{} `json:"auth,omitempty"` +type ReqDeleteDevice[UIAType any] struct { + Auth UIAType `json:"auth,omitempty"` } // ReqDeleteDevices is the JSON request for https://spec.matrix.org/v1.2/client-server-api/#post_matrixclientv3delete_devices -type ReqDeleteDevices struct { +type ReqDeleteDevices[UIAType any] struct { Devices []id.DeviceID `json:"devices"` - Auth interface{} `json:"auth,omitempty"` + Auth UIAType `json:"auth,omitempty"` } type ReqPutPushRule struct { @@ -331,18 +411,6 @@ type ReqPutPushRule struct { Pattern string `json:"pattern"` } -// Deprecated: MSC2716 was abandoned -type ReqBatchSend struct { - PrevEventID id.EventID `json:"-"` - BatchID id.BatchID `json:"-"` - - BeeperNewMessages bool `json:"-"` - BeeperMarkReadBy id.UserID `json:"-"` - - StateEventsAtStart []*event.Event `json:"state_events_at_start"` - Events []*event.Event `json:"events"` -} - type ReqBeeperBatchSend struct { // ForwardIfNoMessages should be set to true if the batch should be forward // backfilled if there are no messages currently in the room. @@ -363,10 +431,48 @@ type ReqSetReadMarkers struct { BeeperFullyReadExtra interface{} `json:"com.beeper.fully_read.extra,omitempty"` } +type BeeperInboxDone struct { + Delta int64 `json:"at_delta"` + AtOrder int64 `json:"at_order"` +} + +type ReqSetBeeperInboxState struct { + MarkedUnread *bool `json:"marked_unread,omitempty"` + Done *BeeperInboxDone `json:"done,omitempty"` + ReadMarkers *ReqSetReadMarkers `json:"read_markers,omitempty"` +} + type ReqSendReceipt struct { ThreadID string `json:"thread_id,omitempty"` } +type ReqPublicRooms struct { + IncludeAllNetworks bool + Limit int + Since string + ThirdPartyInstanceID string +} + +func (req *ReqPublicRooms) Query() map[string]string { + query := map[string]string{} + if req == nil { + return query + } + if req.IncludeAllNetworks { + query["include_all_networks"] = "true" + } + if req.Limit > 0 { + query["limit"] = strconv.Itoa(req.Limit) + } + if req.Since != "" { + query["since"] = req.Since + } + if req.ThirdPartyInstanceID != "" { + query["third_party_instance_id"] = req.ThirdPartyInstanceID + } + return query +} + // ReqHierarchy contains the parameters for https://spec.matrix.org/v1.4/client-server-api/#get_matrixclientv1roomsroomidhierarchy // // As it's a GET method, there is no JSON body, so this is only query parameters. @@ -454,3 +560,59 @@ type ReqKeyBackupData struct { IsVerified bool `json:"is_verified"` SessionData json.RawMessage `json:"session_data"` } + +type ReqReport struct { + Reason string `json:"reason,omitempty"` + Score int `json:"score,omitempty"` +} + +type ReqGetRelations struct { + RelationType event.RelationType + EventType event.Type + + Dir Direction + From string + To string + Limit int + Recurse bool +} + +func (rgr *ReqGetRelations) PathSuffix() ClientURLPath { + if rgr.RelationType != "" { + if rgr.EventType.Type != "" { + return ClientURLPath{rgr.RelationType, rgr.EventType.Type} + } + return ClientURLPath{rgr.RelationType} + } + return ClientURLPath{} +} + +func (rgr *ReqGetRelations) Query() map[string]string { + query := map[string]string{} + if rgr.Dir != 0 { + query["dir"] = string(rgr.Dir) + } + if rgr.From != "" { + query["from"] = rgr.From + } + if rgr.To != "" { + query["to"] = rgr.To + } + if rgr.Limit > 0 { + query["limit"] = strconv.Itoa(rgr.Limit) + } + if rgr.Recurse { + query["recurse"] = "true" + } + return query +} + +// ReqSuspend is the request body for https://github.com/matrix-org/matrix-spec-proposals/pull/4323 +type ReqSuspend struct { + Suspended bool `json:"suspended"` +} + +// ReqLocked is the request body for https://github.com/matrix-org/matrix-spec-proposals/pull/4323 +type ReqLocked struct { + Locked bool `json:"locked"` +} diff --git a/responses.go b/responses.go index 9e5fd0aa..4fbe1fbc 100644 --- a/responses.go +++ b/responses.go @@ -4,13 +4,16 @@ import ( "bytes" "encoding/json" "fmt" + "maps" "reflect" + "slices" "strconv" "strings" "github.com/tidwall/gjson" "github.com/tidwall/sjson" "go.mau.fi/util/jsontime" + "go.mau.fi/util/ptr" "maunium.net/go/mautrix/event" "maunium.net/go/mautrix/id" @@ -32,6 +35,11 @@ type RespJoinRoom struct { RoomID id.RoomID `json:"room_id"` } +// RespKnockRoom is the JSON response for https://spec.matrix.org/v1.13/client-server-api/#post_matrixclientv3knockroomidoralias +type RespKnockRoom struct { + RoomID id.RoomID `json:"room_id"` +} + // RespLeaveRoom is the JSON response for https://spec.matrix.org/v1.2/client-server-api/#post_matrixclientv3roomsroomidleave type RespLeaveRoom struct{} @@ -97,6 +105,29 @@ type RespContext struct { // RespSendEvent is the JSON response for https://spec.matrix.org/v1.2/client-server-api/#put_matrixclientv3roomsroomidsendeventtypetxnid type RespSendEvent struct { EventID id.EventID `json:"event_id"` + + UnstableDelayID id.DelayID `json:"delay_id,omitempty"` +} + +type RespUpdateDelayedEvent struct{} + +type RespDelayedEvents struct { + Scheduled []*event.ScheduledDelayedEvent `json:"scheduled,omitempty"` + Finalised []*event.FinalisedDelayedEvent `json:"finalised,omitempty"` + NextBatch string `json:"next_batch,omitempty"` + + // Deprecated: Synapse implementation still returns this + DelayedEvents []*event.ScheduledDelayedEvent `json:"delayed_events,omitempty"` + // Deprecated: Synapse implementation still returns this + FinalisedEvents []*event.FinalisedDelayedEvent `json:"finalised_events,omitempty"` +} + +type RespRedactUserEvents struct { + IsMoreEvents bool `json:"is_more_events"` + RedactedEvents struct { + Total int `json:"total"` + SoftFailed int `json:"soft_failed"` + } `json:"redacted_events"` } // RespMediaConfig is the JSON response for https://spec.matrix.org/v1.4/client-server-api/#get_matrixmediav3config @@ -111,10 +142,14 @@ type RespMediaUpload struct { // RespCreateMXC is the JSON response for https://spec.matrix.org/v1.7/client-server-api/#post_matrixmediav1create type RespCreateMXC struct { - ContentURI id.ContentURI `json:"content_uri"` - UnusedExpiresAt int `json:"unused_expires_at,omitempty"` + ContentURI id.ContentURI `json:"content_uri"` + UnusedExpiresAt jsontime.UnixMilli `json:"unused_expires_at,omitempty"` UnstableUploadURL string `json:"com.beeper.msc3870.upload_url,omitempty"` + + // Beeper extensions for uploading unique media only once + BeeperUniqueID string `json:"com.beeper.unique_id,omitempty"` + BeeperCompletedAt jsontime.UnixMilli `json:"com.beeper.completed_at,omitempty"` } // RespPreviewURL is the JSON response for https://spec.matrix.org/v1.2/client-server-api/#get_matrixmediav3preview_url @@ -151,8 +186,89 @@ type RespUserDisplayName struct { } type RespUserProfile struct { - DisplayName string `json:"displayname"` - AvatarURL id.ContentURI `json:"avatar_url"` + DisplayName string `json:"displayname,omitempty"` + AvatarURL id.ContentURI `json:"avatar_url,omitempty"` + Extra map[string]any `json:"-"` +} + +type marshalableUserProfile RespUserProfile + +func (r *RespUserProfile) UnmarshalJSON(data []byte) error { + err := json.Unmarshal(data, &r.Extra) + if err != nil { + return err + } + r.DisplayName, _ = r.Extra["displayname"].(string) + avatarURL, _ := r.Extra["avatar_url"].(string) + if avatarURL != "" { + r.AvatarURL, _ = id.ParseContentURI(avatarURL) + } + delete(r.Extra, "displayname") + delete(r.Extra, "avatar_url") + return nil +} + +func (r *RespUserProfile) MarshalJSON() ([]byte, error) { + if len(r.Extra) == 0 { + return json.Marshal((*marshalableUserProfile)(r)) + } + marshalMap := maps.Clone(r.Extra) + if r.DisplayName != "" { + marshalMap["displayname"] = r.DisplayName + } else { + delete(marshalMap, "displayname") + } + if !r.AvatarURL.IsEmpty() { + marshalMap["avatar_url"] = r.AvatarURL.String() + } else { + delete(marshalMap, "avatar_url") + } + return json.Marshal(marshalMap) +} + +type RespSearchUserDirectory struct { + Limited bool `json:"limited"` + Results []*UserDirectoryEntry `json:"results"` +} + +type UserDirectoryEntry struct { + RespUserProfile + UserID id.UserID `json:"user_id"` +} + +func (r *UserDirectoryEntry) UnmarshalJSON(data []byte) error { + err := r.RespUserProfile.UnmarshalJSON(data) + if err != nil { + return err + } + userIDStr, _ := r.Extra["user_id"].(string) + r.UserID = id.UserID(userIDStr) + delete(r.Extra, "user_id") + return nil +} + +func (r *UserDirectoryEntry) MarshalJSON() ([]byte, error) { + if r.Extra == nil { + r.Extra = make(map[string]any) + } + r.Extra["user_id"] = r.UserID.String() + return r.RespUserProfile.MarshalJSON() +} + +type RespMutualRooms struct { + Joined []id.RoomID `json:"joined"` + NextBatch string `json:"next_batch,omitempty"` + Count int `json:"count,omitempty"` +} + +type RespRoomSummary struct { + PublicRoomInfo + + Membership event.Membership `json:"membership,omitempty"` + + UnstableRoomVersion id.RoomVersion `json:"im.nheko.summary.room_version,omitempty"` + UnstableRoomVersionOld id.RoomVersion `json:"im.nheko.summary.version,omitempty"` + UnstableEncryption id.Algorithm `json:"im.nheko.summary.encryption,omitempty"` } // RespRegisterAvailable is the JSON response for https://spec.matrix.org/v1.4/client-server-api/#get_matrixclientv3registeravailable @@ -203,6 +319,9 @@ type RespLogin struct { DeviceID id.DeviceID `json:"device_id"` UserID id.UserID `json:"user_id"` WellKnown *ClientWellKnown `json:"well_known,omitempty"` + + RefreshToken string `json:"refresh_token,omitempty"` + ExpiresInMS int64 `json:"expires_in_ms,omitempty"` } // RespLogout is the JSON response for https://spec.matrix.org/v1.2/client-server-api/#post_matrixclientv3logout @@ -223,6 +342,24 @@ type LazyLoadSummary struct { InvitedMemberCount *int `json:"m.invited_member_count,omitempty"` } +func (lls *LazyLoadSummary) MemberCount() int { + if lls == nil { + return 0 + } + return ptr.Val(lls.JoinedMemberCount) + ptr.Val(lls.InvitedMemberCount) +} + +func (lls *LazyLoadSummary) Equal(other *LazyLoadSummary) bool { + if lls == other { + return true + } else if lls == nil || other == nil { + return false + } + return ptr.Val(lls.JoinedMemberCount) == ptr.Val(other.JoinedMemberCount) && + ptr.Val(lls.InvitedMemberCount) == ptr.Val(other.InvitedMemberCount) && + slices.Equal(lls.Heroes, other.Heroes) +} + type SyncEventsList struct { Events []*event.Event `json:"events,omitempty"` } @@ -318,6 +455,7 @@ type BeeperInboxPreviewEvent struct { type SyncJoinedRoom struct { Summary LazyLoadSummary `json:"summary"` State SyncEventsList `json:"state"` + StateAfter *SyncEventsList `json:"state_after,omitempty"` Timeline SyncTimeline `json:"timeline"` Ephemeral SyncEventsList `json:"ephemeral"` AccountData SyncEventsList `json:"account_data"` @@ -343,16 +481,7 @@ func (sjr SyncJoinedRoom) MarshalJSON() ([]byte, error) { } type SyncInvitedRoom struct { - Summary LazyLoadSummary `json:"summary"` - State SyncEventsList `json:"invite_state"` -} - -type marshalableSyncInvitedRoom SyncInvitedRoom - -var syncInvitedRoomPathsToDelete = []string{"summary"} - -func (sir SyncInvitedRoom) MarshalJSON() ([]byte, error) { - return marshalAndDeleteEmpty((marshalableSyncInvitedRoom)(sir), syncInvitedRoomPathsToDelete) + State SyncEventsList `json:"invite_state"` } type SyncKnockedRoom struct { @@ -417,29 +546,19 @@ type RespDeviceInfo struct { LastSeenTS int64 `json:"last_seen_ts"` } -// Deprecated: MSC2716 was abandoned -type RespBatchSend struct { - StateEventIDs []id.EventID `json:"state_event_ids"` - EventIDs []id.EventID `json:"event_ids"` - - InsertionEventID id.EventID `json:"insertion_event_id"` - BatchEventID id.EventID `json:"batch_event_id"` - BaseInsertionEventID id.EventID `json:"base_insertion_event_id"` - - NextBatchID id.BatchID `json:"next_batch_id"` -} - type RespBeeperBatchSend struct { EventIDs []id.EventID `json:"event_ids"` } // RespCapabilities is the JSON response for https://spec.matrix.org/v1.3/client-server-api/#get_matrixclientv3capabilities type RespCapabilities struct { - RoomVersions *CapRoomVersions `json:"m.room_versions,omitempty"` - ChangePassword *CapBooleanTrue `json:"m.change_password,omitempty"` - SetDisplayname *CapBooleanTrue `json:"m.set_displayname,omitempty"` - SetAvatarURL *CapBooleanTrue `json:"m.set_avatar_url,omitempty"` - ThreePIDChanges *CapBooleanTrue `json:"m.3pid_changes,omitempty"` + RoomVersions *CapRoomVersions `json:"m.room_versions,omitempty"` + ChangePassword *CapBooleanTrue `json:"m.change_password,omitempty"` + SetDisplayname *CapBooleanTrue `json:"m.set_displayname,omitempty"` + SetAvatarURL *CapBooleanTrue `json:"m.set_avatar_url,omitempty"` + ThreePIDChanges *CapBooleanTrue `json:"m.3pid_changes,omitempty"` + GetLoginToken *CapBooleanTrue `json:"m.get_login_token,omitempty"` + UnstableAccountModeration *CapUnstableAccountModeration `json:"uk.timedout.msc4323,omitempty"` Custom map[string]interface{} `json:"-"` } @@ -548,29 +667,44 @@ func (vers *CapRoomVersions) IsAvailable(version string) bool { return available } +type CapUnstableAccountModeration struct { + Suspend bool `json:"suspend"` + Lock bool `json:"lock"` +} + +type RespPublicRooms struct { + Chunk []*PublicRoomInfo `json:"chunk"` + NextBatch string `json:"next_batch,omitempty"` + PrevBatch string `json:"prev_batch,omitempty"` + TotalRoomCountEstimate int `json:"total_room_count_estimate"` +} + +type PublicRoomInfo struct { + RoomID id.RoomID `json:"room_id"` + AvatarURL id.ContentURIString `json:"avatar_url,omitempty"` + CanonicalAlias id.RoomAlias `json:"canonical_alias,omitempty"` + GuestCanJoin bool `json:"guest_can_join"` + JoinRule event.JoinRule `json:"join_rule,omitempty"` + Name string `json:"name,omitempty"` + NumJoinedMembers int `json:"num_joined_members"` + RoomType event.RoomType `json:"room_type"` + Topic string `json:"topic,omitempty"` + WorldReadable bool `json:"world_readable"` + + RoomVersion id.RoomVersion `json:"room_version,omitempty"` + Encryption id.Algorithm `json:"encryption,omitempty"` + AllowedRoomIDs []id.RoomID `json:"allowed_room_ids,omitempty"` +} + // RespHierarchy is the JSON response for https://spec.matrix.org/v1.4/client-server-api/#get_matrixclientv1roomsroomidhierarchy type RespHierarchy struct { - NextBatch string `json:"next_batch,omitempty"` - Rooms []ChildRoomsChunk `json:"rooms"` + NextBatch string `json:"next_batch,omitempty"` + Rooms []*ChildRoomsChunk `json:"rooms"` } type ChildRoomsChunk struct { - AvatarURL id.ContentURI `json:"avatar_url,omitempty"` - CanonicalAlias id.RoomAlias `json:"canonical_alias,omitempty"` - ChildrenState []StrippedStateWithTime `json:"children_state"` - GuestCanJoin bool `json:"guest_can_join"` - JoinRule event.JoinRule `json:"join_rule,omitempty"` - Name string `json:"name,omitempty"` - NumJoinedMembers int `json:"num_joined_members"` - RoomID id.RoomID `json:"room_id"` - RoomType event.RoomType `json:"room_type"` - Topic string `json:"topic,omitempty"` - WorldReadble bool `json:"world_readable"` -} - -type StrippedStateWithTime struct { - event.StrippedState - Timestamp jsontime.UnixMilli `json:"origin_server_ts"` + PublicRoomInfo + ChildrenState []*event.Event `json:"children_state"` } type RespAppservicePing struct { @@ -619,3 +753,47 @@ type RespRoomKeysUpdate struct { Count int `json:"count"` ETag string `json:"etag"` } + +type RespOpenIDToken struct { + AccessToken string `json:"access_token"` + ExpiresIn int `json:"expires_in"` + MatrixServerName string `json:"matrix_server_name"` + TokenType string `json:"token_type"` // Always "Bearer" +} + +type RespGetRelations struct { + Chunk []*event.Event `json:"chunk"` + NextBatch string `json:"next_batch,omitempty"` + PrevBatch string `json:"prev_batch,omitempty"` + RecursionDepth int `json:"recursion_depth,omitempty"` +} + +// RespSuspended is the response body for https://github.com/matrix-org/matrix-spec-proposals/pull/4323 +type RespSuspended struct { + Suspended bool `json:"suspended"` +} + +// RespLocked is the response body for https://github.com/matrix-org/matrix-spec-proposals/pull/4323 +type RespLocked struct { + Locked bool `json:"locked"` +} + +type ConnectionInfo struct { + IP string `json:"ip,omitempty"` + LastSeen jsontime.UnixMilli `json:"last_seen,omitempty"` + UserAgent string `json:"user_agent,omitempty"` +} + +type SessionInfo struct { + Connections []ConnectionInfo `json:"connections,omitempty"` +} + +type DeviceInfo struct { + Sessions []SessionInfo `json:"sessions,omitempty"` +} + +// RespWhoIs is the response body for https://spec.matrix.org/v1.15/client-server-api/#get_matrixclientv3adminwhoisuserid +type RespWhoIs struct { + UserID id.UserID `json:"user_id,omitempty"` + Devices map[id.DeviceID]DeviceInfo `json:"devices,omitempty"` +} diff --git a/responses_test.go b/responses_test.go index b23d85ad..73d82635 100644 --- a/responses_test.go +++ b/responses_test.go @@ -8,7 +8,6 @@ package mautrix_test import ( "encoding/json" - "fmt" "testing" "github.com/stretchr/testify/assert" @@ -86,7 +85,6 @@ func TestRespCapabilities_UnmarshalJSON(t *testing.T) { var caps mautrix.RespCapabilities err := json.Unmarshal([]byte(sampleData), &caps) require.NoError(t, err) - fmt.Println(caps) require.NotNil(t, caps.RoomVersions) assert.Equal(t, "9", caps.RoomVersions.Default) diff --git a/room.go b/room.go index c3ddb7e6..4292bff5 100644 --- a/room.go +++ b/room.go @@ -5,8 +5,6 @@ import ( "maunium.net/go/mautrix/id" ) -type RoomStateMap = map[event.Type]map[string]*event.Event - // Room represents a single Matrix room. type Room struct { ID id.RoomID @@ -25,8 +23,8 @@ func (room Room) UpdateState(evt *event.Event) { // GetStateEvent returns the state event for the given type/state_key combo, or nil. func (room Room) GetStateEvent(eventType event.Type, stateKey string) *event.Event { - stateEventMap, _ := room.State[eventType] - evt, _ := stateEventMap[stateKey] + stateEventMap := room.State[eventType] + evt := stateEventMap[stateKey] return evt } diff --git a/sqlstatestore/statestore.go b/sqlstatestore/statestore.go index cd94215d..11957dfa 100644 --- a/sqlstatestore/statestore.go +++ b/sqlstatestore/statestore.go @@ -17,7 +17,9 @@ import ( "strings" "github.com/rs/zerolog" + "go.mau.fi/util/confusable" "go.mau.fi/util/dbutil" + "go.mau.fi/util/exslices" "maunium.net/go/mautrix/event" "maunium.net/go/mautrix/id" @@ -37,6 +39,8 @@ const VersionTableName = "mx_version" type SQLStateStore struct { *dbutil.Database IsBridge bool + + DisableNameDisambiguation bool } func NewSQLStateStore(db *dbutil.Database, log dbutil.DatabaseLogger, isBridge bool) *SQLStateStore { @@ -58,6 +62,9 @@ func (store *SQLStateStore) IsRegistered(ctx context.Context, userID id.UserID) } func (store *SQLStateStore) MarkRegistered(ctx context.Context, userID id.UserID) error { + if userID == "" { + return fmt.Errorf("user ID is empty") + } _, err := store.Exec(ctx, "INSERT INTO mx_registrations (user_id) VALUES ($1) ON CONFLICT (user_id) DO NOTHING", userID) return err } @@ -65,6 +72,7 @@ func (store *SQLStateStore) MarkRegistered(ctx context.Context, userID id.UserID type Member struct { id.UserID event.MemberEventContent + NameSkeleton [32]byte } func (store *SQLStateStore) GetRoomMembers(ctx context.Context, roomID id.RoomID, memberships ...event.Membership) (map[id.UserID]*event.MemberEventContent, error) { @@ -80,14 +88,11 @@ func (store *SQLStateStore) GetRoomMembers(ctx context.Context, roomID id.RoomID query = fmt.Sprintf("%s AND membership IN (%s)", query, strings.Join(placeholders, ",")) } rows, err := store.Query(ctx, query, args...) - if err != nil { - return nil, err - } members := make(map[id.UserID]*event.MemberEventContent) - return members, dbutil.NewRowIter(rows, func(row dbutil.Scannable) (ret Member, err error) { + return members, dbutil.NewRowIterWithError(rows, func(row dbutil.Scannable) (ret Member, err error) { err = row.Scan(&ret.UserID, &ret.Membership, &ret.Displayname, &ret.AvatarURL) return - }).Iter(func(m Member) (bool, error) { + }, err).Iter(func(m Member) (bool, error) { members[m.UserID] = &m.MemberEventContent return true, nil }) @@ -154,10 +159,7 @@ func (store *SQLStateStore) FindSharedRooms(ctx context.Context, userID id.UserI ` } rows, err := store.Query(ctx, query, userID) - if err != nil { - return nil, err - } - return dbutil.NewRowIter(rows, dbutil.ScanSingleColumn[id.RoomID]).AsList() + return dbutil.NewRowIterWithError(rows, dbutil.ScanSingleColumn[id.RoomID], err).AsList() } func (store *SQLStateStore) IsInRoom(ctx context.Context, roomID id.RoomID, userID id.UserID) bool { @@ -183,6 +185,11 @@ func (store *SQLStateStore) IsMembership(ctx context.Context, roomID id.RoomID, } func (store *SQLStateStore) SetMembership(ctx context.Context, roomID id.RoomID, userID id.UserID, membership event.Membership) error { + if roomID == "" { + return fmt.Errorf("room ID is empty") + } else if userID == "" { + return fmt.Errorf("user ID is empty") + } _, err := store.Exec(ctx, ` INSERT INTO mx_user_profile (room_id, user_id, membership, displayname, avatar_url) VALUES ($1, $2, $3, '', '') ON CONFLICT (room_id, user_id) DO UPDATE SET membership=excluded.membership @@ -190,14 +197,101 @@ func (store *SQLStateStore) SetMembership(ctx context.Context, roomID id.RoomID, return err } +const insertUserProfileQuery = ` + INSERT INTO mx_user_profile (room_id, user_id, membership, displayname, avatar_url, name_skeleton) + VALUES ($1, $2, $3, $4, $5, $6) + ON CONFLICT (room_id, user_id) DO UPDATE + SET membership=excluded.membership, + displayname=excluded.displayname, + avatar_url=excluded.avatar_url, + name_skeleton=excluded.name_skeleton +` + +type userProfileRow struct { + UserID id.UserID + Membership event.Membership + Displayname string + AvatarURL id.ContentURIString + NameSkeleton []byte +} + +func (u *userProfileRow) GetMassInsertValues() [5]any { + return [5]any{u.UserID, u.Membership, u.Displayname, u.AvatarURL, u.NameSkeleton} +} + +var userProfileMassInserter = dbutil.NewMassInsertBuilder[*userProfileRow, [1]any](insertUserProfileQuery, "($1, $%d, $%d, $%d, $%d, $%d)") + func (store *SQLStateStore) SetMember(ctx context.Context, roomID id.RoomID, userID id.UserID, member *event.MemberEventContent) error { - _, err := store.Exec(ctx, ` - INSERT INTO mx_user_profile (room_id, user_id, membership, displayname, avatar_url) VALUES ($1, $2, $3, $4, $5) - ON CONFLICT (room_id, user_id) DO UPDATE SET membership=excluded.membership, displayname=excluded.displayname, avatar_url=excluded.avatar_url - `, roomID, userID, member.Membership, member.Displayname, member.AvatarURL) + if roomID == "" { + return fmt.Errorf("room ID is empty") + } else if userID == "" { + return fmt.Errorf("user ID is empty") + } + var nameSkeleton []byte + if !store.DisableNameDisambiguation && len(member.Displayname) > 0 { + nameSkeletonArr := confusable.SkeletonHash(member.Displayname) + nameSkeleton = nameSkeletonArr[:] + } + _, err := store.Exec(ctx, insertUserProfileQuery, roomID, userID, member.Membership, member.Displayname, member.AvatarURL, nameSkeleton) return err } +func (store *SQLStateStore) IsConfusableName(ctx context.Context, roomID id.RoomID, currentUser id.UserID, name string) ([]id.UserID, error) { + if store.DisableNameDisambiguation { + return nil, nil + } + skeleton := confusable.SkeletonHash(name) + rows, err := store.Query(ctx, "SELECT user_id FROM mx_user_profile WHERE room_id=$1 AND name_skeleton=$2 AND user_id<>$3", roomID, skeleton[:], currentUser) + return dbutil.NewRowIterWithError(rows, dbutil.ScanSingleColumn[id.UserID], err).AsList() +} + +const userProfileMassInsertBatchSize = 500 + +func (store *SQLStateStore) ReplaceCachedMembers(ctx context.Context, roomID id.RoomID, evts []*event.Event, onlyMemberships ...event.Membership) error { + if roomID == "" { + return fmt.Errorf("room ID is empty") + } + return store.DoTxn(ctx, nil, func(ctx context.Context) error { + err := store.ClearCachedMembers(ctx, roomID, onlyMemberships...) + if err != nil { + return fmt.Errorf("failed to clear cached members: %w", err) + } + rows := make([]*userProfileRow, min(len(evts), userProfileMassInsertBatchSize)) + for _, evtsChunk := range exslices.Chunk(evts, userProfileMassInsertBatchSize) { + rows = rows[:0] + for _, evt := range evtsChunk { + content, ok := evt.Content.Parsed.(*event.MemberEventContent) + if !ok { + continue + } + row := &userProfileRow{ + UserID: id.UserID(*evt.StateKey), + Membership: content.Membership, + Displayname: content.Displayname, + AvatarURL: content.AvatarURL, + } + if !store.DisableNameDisambiguation && len(content.Displayname) > 0 { + nameSkeletonArr := confusable.SkeletonHash(content.Displayname) + row.NameSkeleton = nameSkeletonArr[:] + } + rows = append(rows, row) + } + query, args := userProfileMassInserter.Build([1]any{roomID}, rows) + _, err = store.Exec(ctx, query, args...) + if err != nil { + return fmt.Errorf("failed to insert members: %w", err) + } + } + if len(onlyMemberships) == 0 { + err = store.MarkMembersFetched(ctx, roomID) + if err != nil { + return fmt.Errorf("failed to mark members as fetched: %w", err) + } + } + return nil + }) +} + func (store *SQLStateStore) ClearCachedMembers(ctx context.Context, roomID id.RoomID, memberships ...event.Membership) error { query := "DELETE FROM mx_user_profile WHERE room_id=$1" params := make([]any, len(memberships)+1) @@ -211,10 +305,57 @@ func (store *SQLStateStore) ClearCachedMembers(ctx context.Context, roomID id.Ro query += fmt.Sprintf(" AND membership IN (%s)", strings.Join(placeholders, ",")) } _, err := store.Exec(ctx, query, params...) + if err != nil { + return err + } + _, err = store.Exec(ctx, "UPDATE mx_room_state SET members_fetched=false WHERE room_id=$1", roomID) return err } +func (store *SQLStateStore) HasFetchedMembers(ctx context.Context, roomID id.RoomID) (fetched bool, err error) { + err = store.QueryRow(ctx, "SELECT COALESCE(members_fetched, false) FROM mx_room_state WHERE room_id=$1", roomID).Scan(&fetched) + if errors.Is(err, sql.ErrNoRows) { + err = nil + } + return +} + +func (store *SQLStateStore) MarkMembersFetched(ctx context.Context, roomID id.RoomID) error { + if roomID == "" { + return fmt.Errorf("room ID is empty") + } + _, err := store.Exec(ctx, ` + INSERT INTO mx_room_state (room_id, members_fetched) VALUES ($1, true) + ON CONFLICT (room_id) DO UPDATE SET members_fetched=true + `, roomID) + return err +} + +type userAndMembership struct { + UserID id.UserID + event.MemberEventContent +} + +func (store *SQLStateStore) GetAllMembers(ctx context.Context, roomID id.RoomID) (map[id.UserID]*event.MemberEventContent, error) { + rows, err := store.Query(ctx, "SELECT user_id, membership, displayname, avatar_url FROM mx_user_profile WHERE room_id=$1", roomID) + if err != nil { + return nil, err + } + output := make(map[id.UserID]*event.MemberEventContent) + err = dbutil.NewRowIterWithError(rows, func(row dbutil.Scannable) (res userAndMembership, err error) { + err = row.Scan(&res.UserID, &res.Membership, &res.Displayname, &res.AvatarURL) + return + }, err).Iter(func(member userAndMembership) (bool, error) { + output[member.UserID] = &member.MemberEventContent + return true, nil + }) + return output, err +} + func (store *SQLStateStore) SetEncryptionEvent(ctx context.Context, roomID id.RoomID, content *event.EncryptionEventContent) error { + if roomID == "" { + return fmt.Errorf("room ID is empty") + } contentBytes, err := json.Marshal(content) if err != nil { return fmt.Errorf("failed to marshal content JSON: %w", err) @@ -229,7 +370,7 @@ func (store *SQLStateStore) SetEncryptionEvent(ctx context.Context, roomID id.Ro func (store *SQLStateStore) GetEncryptionEvent(ctx context.Context, roomID id.RoomID) (*event.EncryptionEventContent, error) { var data []byte err := store. - QueryRow(ctx, "SELECT encryption FROM mx_room_state WHERE room_id=$1", roomID). + QueryRow(ctx, "SELECT encryption FROM mx_room_state WHERE room_id=$1 AND encryption IS NOT NULL", roomID). Scan(&data) if errors.Is(err, sql.ErrNoRows) { return nil, nil @@ -252,6 +393,9 @@ func (store *SQLStateStore) IsEncrypted(ctx context.Context, roomID id.RoomID) ( } func (store *SQLStateStore) SetPowerLevels(ctx context.Context, roomID id.RoomID, levels *event.PowerLevelsEventContent) error { + if roomID == "" { + return fmt.Errorf("room ID is empty") + } _, err := store.Exec(ctx, ` INSERT INTO mx_room_state (room_id, power_levels) VALUES ($1, $2) ON CONFLICT (room_id) DO UPDATE SET power_levels=excluded.power_levels @@ -260,89 +404,92 @@ func (store *SQLStateStore) SetPowerLevels(ctx context.Context, roomID id.RoomID } func (store *SQLStateStore) GetPowerLevels(ctx context.Context, roomID id.RoomID) (levels *event.PowerLevelsEventContent, err error) { + levels = &event.PowerLevelsEventContent{} err = store. - QueryRow(ctx, "SELECT power_levels FROM mx_room_state WHERE room_id=$1", roomID). - Scan(&dbutil.JSON{Data: &levels}) + QueryRow(ctx, "SELECT power_levels, create_event FROM mx_room_state WHERE room_id=$1 AND power_levels IS NOT NULL", roomID). + Scan(&dbutil.JSON{Data: &levels}, &dbutil.JSON{Data: &levels.CreateEvent}) if errors.Is(err, sql.ErrNoRows) { - err = nil + return nil, nil + } else if err != nil { + return nil, err + } + if levels.CreateEvent != nil { + err = levels.CreateEvent.Content.ParseRaw(event.StateCreate) } return } func (store *SQLStateStore) GetPowerLevel(ctx context.Context, roomID id.RoomID, userID id.UserID) (int, error) { - if store.Dialect == dbutil.Postgres { - var powerLevel int - err := store. - QueryRow(ctx, ` - SELECT COALESCE((power_levels->'users'->$2)::int, (power_levels->'users_default')::int, 0) - FROM mx_room_state WHERE room_id=$1 - `, roomID, userID). - Scan(&powerLevel) - return powerLevel, err - } else { - levels, err := store.GetPowerLevels(ctx, roomID) - if err != nil { - return 0, err - } - return levels.GetUserLevel(userID), nil + levels, err := store.GetPowerLevels(ctx, roomID) + if err != nil { + return 0, err } + return levels.GetUserLevel(userID), nil } func (store *SQLStateStore) GetPowerLevelRequirement(ctx context.Context, roomID id.RoomID, eventType event.Type) (int, error) { - if store.Dialect == dbutil.Postgres { - defaultType := "events_default" - defaultValue := 0 - if eventType.IsState() { - defaultType = "state_default" - defaultValue = 50 - } - var powerLevel int - err := store. - QueryRow(ctx, ` - SELECT COALESCE((power_levels->'events'->$2)::int, (power_levels->'$3')::int, $4) - FROM mx_room_state WHERE room_id=$1 - `, roomID, eventType.Type, defaultType, defaultValue). - Scan(&powerLevel) - if errors.Is(err, sql.ErrNoRows) { - err = nil - powerLevel = defaultValue - } - return powerLevel, err - } else { - levels, err := store.GetPowerLevels(ctx, roomID) - if err != nil { - return 0, err - } - return levels.GetEventLevel(eventType), nil + levels, err := store.GetPowerLevels(ctx, roomID) + if err != nil { + return 0, err } + return levels.GetEventLevel(eventType), nil } func (store *SQLStateStore) HasPowerLevel(ctx context.Context, roomID id.RoomID, userID id.UserID, eventType event.Type) (bool, error) { - if store.Dialect == dbutil.Postgres { - defaultType := "events_default" - defaultValue := 0 - if eventType.IsState() { - defaultType = "state_default" - defaultValue = 50 - } - var hasPower bool - err := store. - QueryRow(ctx, `SELECT - COALESCE((power_levels->'users'->$2)::int, (power_levels->'users_default')::int, 0) - >= - COALESCE((power_levels->'events'->$3)::int, (power_levels->'$4')::int, $5) - FROM mx_room_state WHERE room_id=$1`, roomID, userID, eventType.Type, defaultType, defaultValue). - Scan(&hasPower) - if errors.Is(err, sql.ErrNoRows) { - err = nil - hasPower = defaultValue == 0 - } - return hasPower, err - } else { - levels, err := store.GetPowerLevels(ctx, roomID) - if err != nil { - return false, err - } - return levels.GetUserLevel(userID) >= levels.GetEventLevel(eventType), nil + levels, err := store.GetPowerLevels(ctx, roomID) + if err != nil { + return false, err } + return levels.GetUserLevel(userID) >= levels.GetEventLevel(eventType), nil +} + +func (store *SQLStateStore) SetCreate(ctx context.Context, evt *event.Event) error { + if evt.Type != event.StateCreate { + return fmt.Errorf("invalid event type for create event: %s", evt.Type) + } else if evt.RoomID == "" { + return fmt.Errorf("room ID is empty") + } + _, err := store.Exec(ctx, ` + INSERT INTO mx_room_state (room_id, create_event) VALUES ($1, $2) + ON CONFLICT (room_id) DO UPDATE SET create_event=excluded.create_event + `, evt.RoomID, dbutil.JSON{Data: evt}) + return err +} + +func (store *SQLStateStore) GetCreate(ctx context.Context, roomID id.RoomID) (evt *event.Event, err error) { + err = store. + QueryRow(ctx, "SELECT create_event FROM mx_room_state WHERE room_id=$1 AND create_event IS NOT NULL", roomID). + Scan(&dbutil.JSON{Data: &evt}) + if errors.Is(err, sql.ErrNoRows) { + return nil, nil + } else if err != nil { + return nil, err + } + if evt != nil { + err = evt.Content.ParseRaw(event.StateCreate) + } + return +} + +func (store *SQLStateStore) SetJoinRules(ctx context.Context, roomID id.RoomID, rules *event.JoinRulesEventContent) error { + if roomID == "" { + return fmt.Errorf("room ID is empty") + } + _, err := store.Exec(ctx, ` + INSERT INTO mx_room_state (room_id, join_rules) VALUES ($1, $2) + ON CONFLICT (room_id) DO UPDATE SET join_rules=excluded.join_rules + `, roomID, dbutil.JSON{Data: rules}) + return err +} + +func (store *SQLStateStore) GetJoinRules(ctx context.Context, roomID id.RoomID) (levels *event.JoinRulesEventContent, err error) { + levels = &event.JoinRulesEventContent{} + err = store. + QueryRow(ctx, "SELECT join_rules FROM mx_room_state WHERE room_id=$1 AND join_rules IS NOT NULL", roomID). + Scan(&dbutil.JSON{Data: &levels}) + if errors.Is(err, sql.ErrNoRows) { + levels = nil + err = nil + } + return } diff --git a/sqlstatestore/v00-latest-revision.sql b/sqlstatestore/v00-latest-revision.sql index 41c2b9a1..4679f1c6 100644 --- a/sqlstatestore/v00-latest-revision.sql +++ b/sqlstatestore/v00-latest-revision.sql @@ -1,4 +1,4 @@ --- v0 -> v5: Latest revision +-- v0 -> v10 (compatible with v3+): Latest revision CREATE TABLE mx_registrations ( user_id TEXT PRIMARY KEY @@ -8,16 +8,25 @@ CREATE TABLE mx_registrations ( CREATE TYPE membership AS ENUM ('join', 'leave', 'invite', 'ban', 'knock'); CREATE TABLE mx_user_profile ( - room_id TEXT, - user_id TEXT, - membership membership NOT NULL, - displayname TEXT NOT NULL DEFAULT '', - avatar_url TEXT NOT NULL DEFAULT '', + room_id TEXT, + user_id TEXT, + membership membership NOT NULL, + displayname TEXT NOT NULL DEFAULT '', + avatar_url TEXT NOT NULL DEFAULT '', + + name_skeleton bytea, + PRIMARY KEY (room_id, user_id) ); +CREATE INDEX mx_user_profile_membership_idx ON mx_user_profile (room_id, membership); +CREATE INDEX mx_user_profile_name_skeleton_idx ON mx_user_profile (room_id, name_skeleton); + CREATE TABLE mx_room_state ( - room_id TEXT PRIMARY KEY, - power_levels jsonb, - encryption jsonb + room_id TEXT PRIMARY KEY, + power_levels jsonb, + encryption jsonb, + create_event jsonb, + join_rules jsonb, + members_fetched BOOLEAN NOT NULL DEFAULT false ); diff --git a/sqlstatestore/v05-mark-encryption-state-resync.go b/sqlstatestore/v05-mark-encryption-state-resync.go index bf44d308..b7f2b1c2 100644 --- a/sqlstatestore/v05-mark-encryption-state-resync.go +++ b/sqlstatestore/v05-mark-encryption-state-resync.go @@ -8,7 +8,7 @@ import ( ) func init() { - UpgradeTable.Register(-1, 5, 0, "Mark rooms that need crypto state event resynced", true, func(ctx context.Context, db *dbutil.Database) error { + UpgradeTable.Register(-1, 5, 0, "Mark rooms that need crypto state event resynced", dbutil.TxnModeOn, func(ctx context.Context, db *dbutil.Database) error { portalExists, err := db.TableExists(ctx, "portal") if err != nil { return fmt.Errorf("failed to check if portal table exists") diff --git a/sqlstatestore/v06-displayname-disambiguation.go b/sqlstatestore/v06-displayname-disambiguation.go new file mode 100644 index 00000000..d0d1d502 --- /dev/null +++ b/sqlstatestore/v06-displayname-disambiguation.go @@ -0,0 +1,55 @@ +// Copyright (c) 2024 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package sqlstatestore + +import ( + "context" + + "go.mau.fi/util/confusable" + "go.mau.fi/util/dbutil" + + "maunium.net/go/mautrix/id" +) + +type roomUserName struct { + RoomID id.RoomID + UserID id.UserID + Name string +} + +func init() { + UpgradeTable.Register(-1, 6, 3, "Add disambiguation column for user profiles", dbutil.TxnModeOn, func(ctx context.Context, db *dbutil.Database) error { + _, err := db.Exec(ctx, ` + ALTER TABLE mx_user_profile ADD COLUMN name_skeleton bytea; + CREATE INDEX mx_user_profile_membership_idx ON mx_user_profile (room_id, membership); + CREATE INDEX mx_user_profile_name_skeleton_idx ON mx_user_profile (room_id, name_skeleton); + `) + if err != nil { + return err + } + const ChunkSize = 1000 + const GetEntriesChunkQuery = "SELECT room_id, user_id, displayname FROM mx_user_profile WHERE displayname<>'' LIMIT $1 OFFSET $2" + const SetSkeletonHashQuery = `UPDATE mx_user_profile SET name_skeleton = $3 WHERE room_id = $1 AND user_id = $2` + for offset := 0; ; offset += ChunkSize { + entries, err := dbutil.NewSimpleReflectRowIter[roomUserName](db.Query(ctx, GetEntriesChunkQuery, ChunkSize, offset)).AsList() + if err != nil { + return err + } + for _, entry := range entries { + skel := confusable.SkeletonHash(entry.Name) + _, err = db.Exec(ctx, SetSkeletonHashQuery, entry.RoomID, entry.UserID, skel[:]) + if err != nil { + return err + } + } + if len(entries) < ChunkSize { + break + } + } + return nil + }) +} diff --git a/sqlstatestore/v07-full-member-flag.sql b/sqlstatestore/v07-full-member-flag.sql new file mode 100644 index 00000000..32f2ef6c --- /dev/null +++ b/sqlstatestore/v07-full-member-flag.sql @@ -0,0 +1,2 @@ +-- v7 (compatible with v3+): Add flag for whether the full member list has been fetched +ALTER TABLE mx_room_state ADD COLUMN members_fetched BOOLEAN NOT NULL DEFAULT false; diff --git a/sqlstatestore/v08-create-event.sql b/sqlstatestore/v08-create-event.sql new file mode 100644 index 00000000..9f1b55c9 --- /dev/null +++ b/sqlstatestore/v08-create-event.sql @@ -0,0 +1,2 @@ +-- v8 (compatible with v3+): Add create event to room state table +ALTER TABLE mx_room_state ADD COLUMN create_event jsonb; diff --git a/sqlstatestore/v09-clear-empty-room-ids.sql b/sqlstatestore/v09-clear-empty-room-ids.sql new file mode 100644 index 00000000..ca951068 --- /dev/null +++ b/sqlstatestore/v09-clear-empty-room-ids.sql @@ -0,0 +1,3 @@ +-- v9 (compatible with v3+): Clear invalid rows +DELETE FROM mx_room_state WHERE room_id=''; +DELETE FROM mx_user_profile WHERE room_id='' OR user_id=''; diff --git a/sqlstatestore/v10-join-rules.sql b/sqlstatestore/v10-join-rules.sql new file mode 100644 index 00000000..3074c46a --- /dev/null +++ b/sqlstatestore/v10-join-rules.sql @@ -0,0 +1,2 @@ +-- v10 (compatible with v3+): Add join rules to room state table +ALTER TABLE mx_room_state ADD COLUMN join_rules jsonb; diff --git a/statestore.go b/statestore.go index 8fe5f8b3..2bd498dd 100644 --- a/statestore.go +++ b/statestore.go @@ -8,6 +8,7 @@ package mautrix import ( "context" + "maps" "sync" "github.com/rs/zerolog" @@ -26,21 +27,41 @@ type StateStore interface { TryGetMember(ctx context.Context, roomID id.RoomID, userID id.UserID) (*event.MemberEventContent, error) SetMembership(ctx context.Context, roomID id.RoomID, userID id.UserID, membership event.Membership) error SetMember(ctx context.Context, roomID id.RoomID, userID id.UserID, member *event.MemberEventContent) error + IsConfusableName(ctx context.Context, roomID id.RoomID, currentUser id.UserID, name string) ([]id.UserID, error) ClearCachedMembers(ctx context.Context, roomID id.RoomID, memberships ...event.Membership) error + ReplaceCachedMembers(ctx context.Context, roomID id.RoomID, evts []*event.Event, onlyMemberships ...event.Membership) error SetPowerLevels(ctx context.Context, roomID id.RoomID, levels *event.PowerLevelsEventContent) error GetPowerLevels(ctx context.Context, roomID id.RoomID) (*event.PowerLevelsEventContent, error) + SetCreate(ctx context.Context, evt *event.Event) error + GetCreate(ctx context.Context, roomID id.RoomID) (*event.Event, error) + + GetJoinRules(ctx context.Context, roomID id.RoomID) (*event.JoinRulesEventContent, error) + SetJoinRules(ctx context.Context, roomID id.RoomID, content *event.JoinRulesEventContent) error + + HasFetchedMembers(ctx context.Context, roomID id.RoomID) (bool, error) + MarkMembersFetched(ctx context.Context, roomID id.RoomID) error + GetAllMembers(ctx context.Context, roomID id.RoomID) (map[id.UserID]*event.MemberEventContent, error) + SetEncryptionEvent(ctx context.Context, roomID id.RoomID, content *event.EncryptionEventContent) error IsEncrypted(ctx context.Context, roomID id.RoomID) (bool, error) GetRoomJoinedOrInvitedMembers(ctx context.Context, roomID id.RoomID) ([]id.UserID, error) } +type StateStoreUpdater interface { + UpdateState(ctx context.Context, evt *event.Event) +} + func UpdateStateStore(ctx context.Context, store StateStore, evt *event.Event) { if store == nil || evt == nil || evt.StateKey == nil { return } + if directUpdater, ok := store.(StateStoreUpdater); ok { + directUpdater.UpdateState(ctx, evt) + return + } // We only care about events without a state key (power levels, encryption) or member events with state key if evt.Type != event.StateMember && evt.GetStateKey() != "" { return @@ -53,6 +74,19 @@ func UpdateStateStore(ctx context.Context, store StateStore, evt *event.Event) { err = store.SetPowerLevels(ctx, evt.RoomID, content) case *event.EncryptionEventContent: err = store.SetEncryptionEvent(ctx, evt.RoomID, content) + case *event.CreateEventContent: + err = store.SetCreate(ctx, evt) + case *event.JoinRulesEventContent: + err = store.SetJoinRules(ctx, evt.RoomID, content) + default: + switch evt.Type { + case event.StateMember, event.StatePowerLevels, event.StateEncryption, event.StateCreate: + zerolog.Ctx(ctx).Warn(). + Stringer("event_id", evt.ID). + Str("event_type", evt.Type.Type). + Type("content_type", evt.Content.Parsed). + Msg("Got known event type with unknown content type in UpdateStateStore") + } } if err != nil { zerolog.Ctx(ctx).Warn().Err(err). @@ -72,23 +106,30 @@ func (cli *Client) StateStoreSyncHandler(ctx context.Context, evt *event.Event) } type MemoryStateStore struct { - Registrations map[id.UserID]bool `json:"registrations"` - Members map[id.RoomID]map[id.UserID]*event.MemberEventContent `json:"memberships"` - PowerLevels map[id.RoomID]*event.PowerLevelsEventContent `json:"power_levels"` - Encryption map[id.RoomID]*event.EncryptionEventContent `json:"encryption"` + Registrations map[id.UserID]bool `json:"registrations"` + Members map[id.RoomID]map[id.UserID]*event.MemberEventContent `json:"memberships"` + MembersFetched map[id.RoomID]bool `json:"members_fetched"` + PowerLevels map[id.RoomID]*event.PowerLevelsEventContent `json:"power_levels"` + Encryption map[id.RoomID]*event.EncryptionEventContent `json:"encryption"` + Create map[id.RoomID]*event.Event `json:"create"` + JoinRules map[id.RoomID]*event.JoinRulesEventContent `json:"join_rules"` registrationsLock sync.RWMutex membersLock sync.RWMutex powerLevelsLock sync.RWMutex encryptionLock sync.RWMutex + joinRulesLock sync.RWMutex } func NewMemoryStateStore() StateStore { return &MemoryStateStore{ - Registrations: make(map[id.UserID]bool), - Members: make(map[id.RoomID]map[id.UserID]*event.MemberEventContent), - PowerLevels: make(map[id.RoomID]*event.PowerLevelsEventContent), - Encryption: make(map[id.RoomID]*event.EncryptionEventContent), + Registrations: make(map[id.UserID]bool), + Members: make(map[id.RoomID]map[id.UserID]*event.MemberEventContent), + MembersFetched: make(map[id.RoomID]bool), + PowerLevels: make(map[id.RoomID]*event.PowerLevelsEventContent), + Encryption: make(map[id.RoomID]*event.EncryptionEventContent), + Create: make(map[id.RoomID]*event.Event), + JoinRules: make(map[id.RoomID]*event.JoinRulesEventContent), } } @@ -143,6 +184,11 @@ func (store *MemoryStateStore) GetMember(ctx context.Context, roomID id.RoomID, return member, err } +func (store *MemoryStateStore) IsConfusableName(ctx context.Context, roomID id.RoomID, currentUser id.UserID, name string) ([]id.UserID, error) { + // TODO implement? + return nil, nil +} + func (store *MemoryStateStore) TryGetMember(_ context.Context, roomID id.RoomID, userID id.UserID) (member *event.MemberEventContent, err error) { store.membersLock.RLock() defer store.membersLock.RUnlock() @@ -223,9 +269,40 @@ func (store *MemoryStateStore) ClearCachedMembers(_ context.Context, roomID id.R } } } + store.MembersFetched[roomID] = false return nil } +func (store *MemoryStateStore) HasFetchedMembers(ctx context.Context, roomID id.RoomID) (bool, error) { + store.membersLock.RLock() + defer store.membersLock.RUnlock() + return store.MembersFetched[roomID], nil +} + +func (store *MemoryStateStore) MarkMembersFetched(ctx context.Context, roomID id.RoomID) error { + store.membersLock.Lock() + defer store.membersLock.Unlock() + store.MembersFetched[roomID] = true + return nil +} + +func (store *MemoryStateStore) ReplaceCachedMembers(ctx context.Context, roomID id.RoomID, evts []*event.Event, onlyMemberships ...event.Membership) error { + _ = store.ClearCachedMembers(ctx, roomID, onlyMemberships...) + for _, evt := range evts { + UpdateStateStore(ctx, store, evt) + } + if len(onlyMemberships) == 0 { + _ = store.MarkMembersFetched(ctx, roomID) + } + return nil +} + +func (store *MemoryStateStore) GetAllMembers(ctx context.Context, roomID id.RoomID) (map[id.UserID]*event.MemberEventContent, error) { + store.membersLock.RLock() + defer store.membersLock.RUnlock() + return maps.Clone(store.Members[roomID]), nil +} + func (store *MemoryStateStore) SetPowerLevels(_ context.Context, roomID id.RoomID, levels *event.PowerLevelsEventContent) error { store.powerLevelsLock.Lock() store.PowerLevels[roomID] = levels @@ -236,6 +313,9 @@ func (store *MemoryStateStore) SetPowerLevels(_ context.Context, roomID id.RoomI func (store *MemoryStateStore) GetPowerLevels(_ context.Context, roomID id.RoomID) (levels *event.PowerLevelsEventContent, err error) { store.powerLevelsLock.RLock() levels = store.PowerLevels[roomID] + if levels != nil && levels.CreateEvent == nil { + levels.CreateEvent = store.Create[roomID] + } store.powerLevelsLock.RUnlock() return } @@ -252,6 +332,23 @@ func (store *MemoryStateStore) HasPowerLevel(ctx context.Context, roomID id.Room return exerrors.Must(store.GetPowerLevel(ctx, roomID, userID)) >= exerrors.Must(store.GetPowerLevelRequirement(ctx, roomID, eventType)), nil } +func (store *MemoryStateStore) SetCreate(ctx context.Context, evt *event.Event) error { + store.powerLevelsLock.Lock() + store.Create[evt.RoomID] = evt + if pls, ok := store.PowerLevels[evt.RoomID]; ok && pls.CreateEvent == nil { + pls.CreateEvent = evt + } + store.powerLevelsLock.Unlock() + return nil +} + +func (store *MemoryStateStore) GetCreate(ctx context.Context, roomID id.RoomID) (*event.Event, error) { + store.powerLevelsLock.RLock() + evt := store.Create[roomID] + store.powerLevelsLock.RUnlock() + return evt, nil +} + func (store *MemoryStateStore) SetEncryptionEvent(_ context.Context, roomID id.RoomID, content *event.EncryptionEventContent) error { store.encryptionLock.Lock() store.Encryption[roomID] = content @@ -265,7 +362,31 @@ func (store *MemoryStateStore) GetEncryptionEvent(_ context.Context, roomID id.R return store.Encryption[roomID], nil } +func (store *MemoryStateStore) SetJoinRules(ctx context.Context, roomID id.RoomID, content *event.JoinRulesEventContent) error { + store.joinRulesLock.Lock() + store.JoinRules[roomID] = content + store.joinRulesLock.Unlock() + return nil +} + +func (store *MemoryStateStore) GetJoinRules(ctx context.Context, roomID id.RoomID) (*event.JoinRulesEventContent, error) { + store.joinRulesLock.RLock() + defer store.joinRulesLock.RUnlock() + return store.JoinRules[roomID], nil +} + func (store *MemoryStateStore) IsEncrypted(ctx context.Context, roomID id.RoomID) (bool, error) { cfg, err := store.GetEncryptionEvent(ctx, roomID) return cfg != nil && cfg.Algorithm == id.AlgorithmMegolmV1, err } + +func (store *MemoryStateStore) FindSharedRooms(ctx context.Context, userID id.UserID) (rooms []id.RoomID, err error) { + store.membersLock.RLock() + defer store.membersLock.RUnlock() + for roomID, members := range store.Members { + if _, ok := members[userID]; ok { + rooms = append(rooms, roomID) + } + } + return rooms, nil +} diff --git a/synapseadmin/client.go b/synapseadmin/client.go index 775b4b13..6925ca7d 100644 --- a/synapseadmin/client.go +++ b/synapseadmin/client.go @@ -14,9 +14,9 @@ import ( // // https://matrix-org.github.io/synapse/latest/usage/administration/admin_api/index.html type Client struct { - *mautrix.Client + Client *mautrix.Client } func (cli *Client) BuildAdminURL(path ...any) string { - return cli.BuildURL(mautrix.SynapseAdminURLPath(path)) + return cli.Client.BuildURL(mautrix.SynapseAdminURLPath(path)) } diff --git a/synapseadmin/register.go b/synapseadmin/register.go index d7a94f6f..05e0729a 100644 --- a/synapseadmin/register.go +++ b/synapseadmin/register.go @@ -73,11 +73,7 @@ func (req *ReqSharedSecretRegister) Sign(secret string) string { // This does not need to be called manually as SharedSecretRegister will automatically call this if no nonce is provided. func (cli *Client) GetRegisterNonce(ctx context.Context) (string, error) { var resp respGetRegisterNonce - _, err := cli.MakeFullRequest(ctx, mautrix.FullRequest{ - Method: http.MethodGet, - URL: cli.BuildURL(mautrix.SynapseAdminURLPath{"v1", "register"}), - ResponseJSON: &resp, - }) + _, err := cli.Client.MakeRequest(ctx, http.MethodGet, cli.BuildAdminURL("v1", "register"), nil, &resp) if err != nil { return "", err } @@ -97,12 +93,7 @@ func (cli *Client) SharedSecretRegister(ctx context.Context, sharedSecret string } req.SHA1Checksum = req.Sign(sharedSecret) var resp mautrix.RespRegister - _, err = cli.MakeFullRequest(ctx, mautrix.FullRequest{ - Method: http.MethodPost, - URL: cli.BuildURL(mautrix.SynapseAdminURLPath{"v1", "register"}), - RequestJSON: req, - ResponseJSON: &resp, - }) + _, err = cli.Client.MakeRequest(ctx, http.MethodPost, cli.BuildAdminURL("v1", "register"), &req, &resp) if err != nil { return nil, err } diff --git a/synapseadmin/roomapi.go b/synapseadmin/roomapi.go index 0953377e..0925b748 100644 --- a/synapseadmin/roomapi.go +++ b/synapseadmin/roomapi.go @@ -75,16 +75,17 @@ type RespListRooms struct { // https://matrix-org.github.io/synapse/latest/admin_api/rooms.html#list-room-api func (cli *Client) ListRooms(ctx context.Context, req ReqListRoom) (RespListRooms, error) { var resp RespListRooms - var reqURL string - reqURL = cli.BuildURLWithQuery(mautrix.SynapseAdminURLPath{"v1", "rooms"}, req.BuildQuery()) - _, err := cli.MakeFullRequest(ctx, mautrix.FullRequest{ - Method: http.MethodGet, - URL: reqURL, - ResponseJSON: &resp, - }) + reqURL := cli.Client.BuildURLWithQuery(mautrix.SynapseAdminURLPath{"v1", "rooms"}, req.BuildQuery()) + _, err := cli.Client.MakeRequest(ctx, http.MethodGet, reqURL, nil, &resp) return resp, err } +func (cli *Client) RoomInfo(ctx context.Context, roomID id.RoomID) (resp *RoomInfo, err error) { + reqURL := cli.BuildAdminURL("v1", "rooms", roomID) + _, err = cli.Client.MakeRequest(ctx, http.MethodGet, reqURL, nil, &resp) + return +} + type RespRoomMessages = mautrix.RespMessages // RoomMessages returns a list of messages in a room. @@ -108,17 +109,14 @@ func (cli *Client) RoomMessages(ctx context.Context, roomID id.RoomID, from, to if limit != 0 { query["limit"] = strconv.Itoa(limit) } - urlPath := cli.BuildURLWithQuery(mautrix.SynapseAdminURLPath{"v1", "rooms", roomID, "messages"}, query) - _, err = cli.MakeFullRequest(ctx, mautrix.FullRequest{ - Method: http.MethodGet, - URL: urlPath, - ResponseJSON: &resp, - }) + urlPath := cli.Client.BuildURLWithQuery(mautrix.SynapseAdminURLPath{"v1", "rooms", roomID, "messages"}, query) + _, err = cli.Client.MakeRequest(ctx, http.MethodGet, urlPath, nil, &resp) return resp, err } type ReqDeleteRoom struct { Purge bool `json:"purge,omitempty"` + ForcePurge bool `json:"force_purge,omitempty"` Block bool `json:"block,omitempty"` Message string `json:"message,omitempty"` RoomName string `json:"room_name,omitempty"` @@ -129,6 +127,19 @@ type RespDeleteRoom struct { DeleteID string `json:"delete_id"` } +type RespDeleteRoomResult struct { + KickedUsers []id.UserID `json:"kicked_users,omitempty"` + FailedToKickUsers []id.UserID `json:"failed_to_kick_users,omitempty"` + LocalAliases []id.RoomAlias `json:"local_aliases,omitempty"` + NewRoomID id.RoomID `json:"new_room_id,omitempty"` +} + +type RespDeleteRoomStatus struct { + Status string `json:"status,omitempty"` + Error string `json:"error,omitempty"` + ShutdownRoom RespDeleteRoomResult `json:"shutdown_room,omitempty"` +} + // DeleteRoom deletes a room from the server, optionally blocking it and/or purging all data from the database. // // This calls the async version of the endpoint, which will return immediately and delete the room in the background. @@ -137,13 +148,35 @@ type RespDeleteRoom struct { func (cli *Client) DeleteRoom(ctx context.Context, roomID id.RoomID, req ReqDeleteRoom) (RespDeleteRoom, error) { reqURL := cli.BuildAdminURL("v2", "rooms", roomID) var resp RespDeleteRoom - _, err := cli.MakeFullRequest(ctx, mautrix.FullRequest{ + _, err := cli.Client.MakeRequest(ctx, http.MethodDelete, reqURL, &req, &resp) + return resp, err +} + +func (cli *Client) DeleteRoomStatus(ctx context.Context, deleteID string) (resp RespDeleteRoomStatus, err error) { + reqURL := cli.BuildAdminURL("v2", "rooms", "delete_status", deleteID) + _, err = cli.Client.MakeRequest(ctx, http.MethodGet, reqURL, nil, &resp) + return +} + +// DeleteRoomSync deletes a room from the server, optionally blocking it and/or purging all data from the database. +// +// This calls the synchronous version of the endpoint, which will block until the room is deleted. +// +// https://element-hq.github.io/synapse/latest/admin_api/rooms.html#version-1-old-version +func (cli *Client) DeleteRoomSync(ctx context.Context, roomID id.RoomID, req ReqDeleteRoom) (resp RespDeleteRoomResult, err error) { + reqURL := cli.BuildAdminURL("v1", "rooms", roomID) + httpClient := &http.Client{} + _, err = cli.Client.MakeFullRequest(ctx, mautrix.FullRequest{ Method: http.MethodDelete, URL: reqURL, - ResponseJSON: &resp, RequestJSON: &req, + ResponseJSON: &resp, + MaxAttempts: 1, + // Use a fresh HTTP client without timeouts + Client: httpClient, }) - return resp, err + httpClient.CloseIdleConnections() + return } type RespRoomsMembers struct { @@ -157,11 +190,7 @@ type RespRoomsMembers struct { func (cli *Client) RoomMembers(ctx context.Context, roomID id.RoomID) (RespRoomsMembers, error) { reqURL := cli.BuildAdminURL("v1", "rooms", roomID, "members") var resp RespRoomsMembers - _, err := cli.MakeFullRequest(ctx, mautrix.FullRequest{ - Method: http.MethodGet, - URL: reqURL, - ResponseJSON: &resp, - }) + _, err := cli.Client.MakeRequest(ctx, http.MethodGet, reqURL, nil, &resp) return resp, err } @@ -174,11 +203,7 @@ type ReqMakeRoomAdmin struct { // https://matrix-org.github.io/synapse/latest/admin_api/rooms.html#make-room-admin-api func (cli *Client) MakeRoomAdmin(ctx context.Context, roomIDOrAlias string, req ReqMakeRoomAdmin) error { reqURL := cli.BuildAdminURL("v1", "rooms", roomIDOrAlias, "make_room_admin") - _, err := cli.MakeFullRequest(ctx, mautrix.FullRequest{ - Method: http.MethodPost, - URL: reqURL, - RequestJSON: &req, - }) + _, err := cli.Client.MakeRequest(ctx, http.MethodPost, reqURL, &req, nil) return err } @@ -191,11 +216,7 @@ type ReqJoinUserToRoom struct { // https://matrix-org.github.io/synapse/latest/admin_api/room_membership.html func (cli *Client) JoinUserToRoom(ctx context.Context, roomID id.RoomID, req ReqJoinUserToRoom) error { reqURL := cli.BuildAdminURL("v1", "join", roomID) - _, err := cli.MakeFullRequest(ctx, mautrix.FullRequest{ - Method: http.MethodPost, - URL: reqURL, - RequestJSON: &req, - }) + _, err := cli.Client.MakeRequest(ctx, http.MethodPost, reqURL, &req, nil) return err } @@ -208,11 +229,7 @@ type ReqBlockRoom struct { // https://matrix-org.github.io/synapse/latest/admin_api/rooms.html#block-room-api func (cli *Client) BlockRoom(ctx context.Context, roomID id.RoomID, req ReqBlockRoom) error { reqURL := cli.BuildAdminURL("v1", "rooms", roomID, "block") - _, err := cli.MakeFullRequest(ctx, mautrix.FullRequest{ - Method: http.MethodPut, - URL: reqURL, - RequestJSON: &req, - }) + _, err := cli.Client.MakeRequest(ctx, http.MethodPut, reqURL, &req, nil) return err } @@ -228,10 +245,6 @@ type RoomsBlockResponse struct { func (cli *Client) GetRoomBlockStatus(ctx context.Context, roomID id.RoomID) (RoomsBlockResponse, error) { var resp RoomsBlockResponse reqURL := cli.BuildAdminURL("v1", "rooms", roomID, "block") - _, err := cli.MakeFullRequest(ctx, mautrix.FullRequest{ - Method: http.MethodGet, - URL: reqURL, - ResponseJSON: &resp, - }) + _, err := cli.Client.MakeRequest(ctx, http.MethodGet, reqURL, nil, &resp) return resp, err } diff --git a/synapseadmin/userapi.go b/synapseadmin/userapi.go index 31d0a6dc..b1de55b6 100644 --- a/synapseadmin/userapi.go +++ b/synapseadmin/userapi.go @@ -32,11 +32,7 @@ type ReqResetPassword struct { // https://matrix-org.github.io/synapse/latest/admin_api/user_admin_api.html#reset-password func (cli *Client) ResetPassword(ctx context.Context, req ReqResetPassword) error { reqURL := cli.BuildAdminURL("v1", "reset_password", req.UserID) - _, err := cli.MakeFullRequest(ctx, mautrix.FullRequest{ - Method: http.MethodPost, - URL: reqURL, - RequestJSON: &req, - }) + _, err := cli.Client.MakeRequest(ctx, http.MethodPost, reqURL, &req, nil) return err } @@ -47,12 +43,8 @@ func (cli *Client) ResetPassword(ctx context.Context, req ReqResetPassword) erro // // https://matrix-org.github.io/synapse/latest/admin_api/user_admin_api.html#check-username-availability func (cli *Client) UsernameAvailable(ctx context.Context, username string) (resp *mautrix.RespRegisterAvailable, err error) { - u := cli.BuildURLWithQuery(mautrix.SynapseAdminURLPath{"v1", "username_available"}, map[string]string{"username": username}) - _, err = cli.MakeFullRequest(ctx, mautrix.FullRequest{ - Method: http.MethodGet, - URL: u, - ResponseJSON: &resp, - }) + u := cli.Client.BuildURLWithQuery(mautrix.SynapseAdminURLPath{"v1", "username_available"}, map[string]string{"username": username}) + _, err = cli.Client.MakeRequest(ctx, http.MethodGet, u, nil, &resp) if err == nil && !resp.Available { err = fmt.Errorf(`request returned OK status without "available": true`) } @@ -73,11 +65,7 @@ type RespListDevices struct { // // https://matrix-org.github.io/synapse/latest/admin_api/user_admin_api.html#list-all-devices func (cli *Client) ListDevices(ctx context.Context, userID id.UserID) (resp *RespListDevices, err error) { - _, err = cli.MakeFullRequest(ctx, mautrix.FullRequest{ - Method: http.MethodGet, - URL: cli.BuildAdminURL("v2", "users", userID, "devices"), - ResponseJSON: &resp, - }) + _, err = cli.Client.MakeRequest(ctx, http.MethodGet, cli.BuildAdminURL("v2", "users", userID, "devices"), nil, &resp) return } @@ -101,11 +89,7 @@ type RespUserInfo struct { // // https://matrix-org.github.io/synapse/latest/admin_api/user_admin_api.html#query-user-account func (cli *Client) GetUserInfo(ctx context.Context, userID id.UserID) (resp *RespUserInfo, err error) { - _, err = cli.MakeFullRequest(ctx, mautrix.FullRequest{ - Method: http.MethodGet, - URL: cli.BuildAdminURL("v2", "users", userID), - ResponseJSON: &resp, - }) + _, err = cli.Client.MakeRequest(ctx, http.MethodGet, cli.BuildAdminURL("v2", "users", userID), nil, &resp) return } @@ -118,11 +102,20 @@ type ReqDeleteUser struct { // https://matrix-org.github.io/synapse/latest/admin_api/user_admin_api.html#deactivate-account func (cli *Client) DeactivateAccount(ctx context.Context, userID id.UserID, req ReqDeleteUser) error { reqURL := cli.BuildAdminURL("v1", "deactivate", userID) - _, err := cli.MakeFullRequest(ctx, mautrix.FullRequest{ - Method: http.MethodPost, - URL: reqURL, - RequestJSON: &req, - }) + _, err := cli.Client.MakeRequest(ctx, http.MethodPost, reqURL, &req, nil) + return err +} + +type ReqSuspendUser struct { + Suspend bool `json:"suspend"` +} + +// SuspendAccount suspends or unsuspends a specific local user account. +// +// https://element-hq.github.io/synapse/latest/admin_api/user_admin_api.html#suspendunsuspend-account +func (cli *Client) SuspendAccount(ctx context.Context, userID id.UserID, req ReqSuspendUser) error { + reqURL := cli.BuildAdminURL("v1", "suspend", userID) + _, err := cli.Client.MakeRequest(ctx, http.MethodPut, reqURL, &req, nil) return err } @@ -144,11 +137,7 @@ type ReqCreateOrModifyAccount struct { // https://matrix-org.github.io/synapse/latest/admin_api/user_admin_api.html#create-or-modify-account func (cli *Client) CreateOrModifyAccount(ctx context.Context, userID id.UserID, req ReqCreateOrModifyAccount) error { reqURL := cli.BuildAdminURL("v2", "users", userID) - _, err := cli.MakeFullRequest(ctx, mautrix.FullRequest{ - Method: http.MethodPut, - URL: reqURL, - RequestJSON: &req, - }) + _, err := cli.Client.MakeRequest(ctx, http.MethodPut, reqURL, &req, nil) return err } @@ -164,11 +153,7 @@ type ReqSetRatelimit = RatelimitOverride // https://matrix-org.github.io/synapse/latest/admin_api/user_admin_api.html#set-ratelimit func (cli *Client) SetUserRatelimit(ctx context.Context, userID id.UserID, req ReqSetRatelimit) error { reqURL := cli.BuildAdminURL("v1", "users", userID, "override_ratelimit") - _, err := cli.MakeFullRequest(ctx, mautrix.FullRequest{ - Method: http.MethodPost, - URL: reqURL, - RequestJSON: &req, - }) + _, err := cli.Client.MakeRequest(ctx, http.MethodPost, reqURL, &req, nil) return err } @@ -178,11 +163,7 @@ type RespUserRatelimit = RatelimitOverride // // https://matrix-org.github.io/synapse/latest/admin_api/user_admin_api.html#get-status-of-ratelimit func (cli *Client) GetUserRatelimit(ctx context.Context, userID id.UserID) (resp RespUserRatelimit, err error) { - _, err = cli.MakeFullRequest(ctx, mautrix.FullRequest{ - Method: http.MethodGet, - URL: cli.BuildAdminURL("v1", "users", userID, "override_ratelimit"), - ResponseJSON: &resp, - }) + _, err = cli.Client.MakeRequest(ctx, http.MethodGet, cli.BuildAdminURL("v1", "users", userID, "override_ratelimit"), nil, &resp) return } @@ -190,9 +171,6 @@ func (cli *Client) GetUserRatelimit(ctx context.Context, userID id.UserID) (resp // // https://matrix-org.github.io/synapse/latest/admin_api/user_admin_api.html#delete-ratelimit func (cli *Client) DeleteUserRatelimit(ctx context.Context, userID id.UserID) (err error) { - _, err = cli.MakeFullRequest(ctx, mautrix.FullRequest{ - Method: http.MethodDelete, - URL: cli.BuildAdminURL("v1", "users", userID, "override_ratelimit"), - }) + _, err = cli.Client.MakeRequest(ctx, http.MethodDelete, cli.BuildAdminURL("v1", "users", userID, "override_ratelimit"), nil, nil) return } diff --git a/sync.go b/sync.go index d4208404..598df8e0 100644 --- a/sync.go +++ b/sync.go @@ -90,6 +90,7 @@ func (s *DefaultSyncer) ProcessResponse(ctx context.Context, res *RespSync, sinc err = fmt.Errorf("ProcessResponse panicked! since=%s panic=%s\n%s", since, r, debug.Stack()) } }() + ctx = context.WithValue(ctx, SyncTokenContextKey, since) for _, listener := range s.syncListeners { if !listener(ctx, res, since) { @@ -97,33 +98,38 @@ func (s *DefaultSyncer) ProcessResponse(ctx context.Context, res *RespSync, sinc } } - s.processSyncEvents(ctx, "", res.ToDevice.Events, event.SourceToDevice) - s.processSyncEvents(ctx, "", res.Presence.Events, event.SourcePresence) - s.processSyncEvents(ctx, "", res.AccountData.Events, event.SourceAccountData) + s.processSyncEvents(ctx, "", res.ToDevice.Events, event.SourceToDevice, false) + s.processSyncEvents(ctx, "", res.Presence.Events, event.SourcePresence, false) + s.processSyncEvents(ctx, "", res.AccountData.Events, event.SourceAccountData, false) for roomID, roomData := range res.Rooms.Join { - s.processSyncEvents(ctx, roomID, roomData.State.Events, event.SourceJoin|event.SourceState) - s.processSyncEvents(ctx, roomID, roomData.Timeline.Events, event.SourceJoin|event.SourceTimeline) - s.processSyncEvents(ctx, roomID, roomData.Ephemeral.Events, event.SourceJoin|event.SourceEphemeral) - s.processSyncEvents(ctx, roomID, roomData.AccountData.Events, event.SourceJoin|event.SourceAccountData) + if roomData.StateAfter == nil { + s.processSyncEvents(ctx, roomID, roomData.State.Events, event.SourceJoin|event.SourceState, false) + s.processSyncEvents(ctx, roomID, roomData.Timeline.Events, event.SourceJoin|event.SourceTimeline, false) + } else { + s.processSyncEvents(ctx, roomID, roomData.Timeline.Events, event.SourceJoin|event.SourceTimeline, true) + s.processSyncEvents(ctx, roomID, roomData.StateAfter.Events, event.SourceJoin|event.SourceState, false) + } + s.processSyncEvents(ctx, roomID, roomData.Ephemeral.Events, event.SourceJoin|event.SourceEphemeral, false) + s.processSyncEvents(ctx, roomID, roomData.AccountData.Events, event.SourceJoin|event.SourceAccountData, false) } for roomID, roomData := range res.Rooms.Invite { - s.processSyncEvents(ctx, roomID, roomData.State.Events, event.SourceInvite|event.SourceState) + s.processSyncEvents(ctx, roomID, roomData.State.Events, event.SourceInvite|event.SourceState, false) } for roomID, roomData := range res.Rooms.Leave { - s.processSyncEvents(ctx, roomID, roomData.State.Events, event.SourceLeave|event.SourceState) - s.processSyncEvents(ctx, roomID, roomData.Timeline.Events, event.SourceLeave|event.SourceTimeline) + s.processSyncEvents(ctx, roomID, roomData.State.Events, event.SourceLeave|event.SourceState, false) + s.processSyncEvents(ctx, roomID, roomData.Timeline.Events, event.SourceLeave|event.SourceTimeline, false) } return } -func (s *DefaultSyncer) processSyncEvents(ctx context.Context, roomID id.RoomID, events []*event.Event, source event.Source) { +func (s *DefaultSyncer) processSyncEvents(ctx context.Context, roomID id.RoomID, events []*event.Event, source event.Source, ignoreState bool) { for _, evt := range events { - s.processSyncEvent(ctx, roomID, evt, source) + s.processSyncEvent(ctx, roomID, evt, source, ignoreState) } } -func (s *DefaultSyncer) processSyncEvent(ctx context.Context, roomID id.RoomID, evt *event.Event, source event.Source) { +func (s *DefaultSyncer) processSyncEvent(ctx context.Context, roomID id.RoomID, evt *event.Event, source event.Source, ignoreState bool) { evt.RoomID = roomID // Ensure the type class is correct. It's safe to mutate the class since the event type is not a pointer. @@ -149,6 +155,7 @@ func (s *DefaultSyncer) processSyncEvent(ctx context.Context, roomID id.RoomID, } evt.Mautrix.EventSource = source + evt.Mautrix.IgnoreState = ignoreState s.Dispatch(ctx, evt) } @@ -191,8 +198,8 @@ func (s *DefaultSyncer) OnFailedSync(res *RespSync, err error) (time.Duration, e } var defaultFilter = Filter{ - Room: RoomFilter{ - Timeline: FilterPart{ + Room: &RoomFilter{ + Timeline: &FilterPart{ Limit: 50, }, }, @@ -257,7 +264,7 @@ func dontProcessOldEvents(userID id.UserID, resp *RespSync, since string) bool { // cli.Syncer.(mautrix.ExtensibleSyncer).OnSync(cli.MoveInviteState) func (cli *Client) MoveInviteState(ctx context.Context, resp *RespSync, _ string) bool { for _, meta := range resp.Rooms.Invite { - var inviteState []event.StrippedState + var inviteState []*event.Event var inviteEvt *event.Event for _, evt := range meta.State.Events { if evt.Type == event.StateMember && evt.GetStateKey() == cli.UserID.String() { @@ -265,12 +272,7 @@ func (cli *Client) MoveInviteState(ctx context.Context, resp *RespSync, _ string } else { evt.Type.Class = event.StateEventType _ = evt.Content.ParseRaw(evt.Type) - inviteState = append(inviteState, event.StrippedState{ - Content: evt.Content, - Type: evt.Type, - StateKey: evt.GetStateKey(), - Sender: evt.Sender, - }) + inviteState = append(inviteState, evt) } } if inviteEvt != nil { diff --git a/url.go b/url.go index 4646b442..91b3d49d 100644 --- a/url.go +++ b/url.go @@ -57,13 +57,13 @@ func BuildURL(baseURL *url.URL, path ...any) *url.URL { // BuildURL builds a URL with the Client's homeserver and appservice user ID set already. func (cli *Client) BuildURL(urlPath PrefixableURLPath) string { - return cli.BuildURLWithQuery(urlPath, nil) + return cli.BuildURLWithFullQuery(urlPath, nil) } // BuildClientURL builds a URL with the Client's homeserver and appservice user ID set already. // This method also automatically prepends the client API prefix (/_matrix/client). func (cli *Client) BuildClientURL(urlPath ...any) string { - return cli.BuildURLWithQuery(ClientURLPath(urlPath), nil) + return cli.BuildURLWithFullQuery(ClientURLPath(urlPath), nil) } type PrefixableURLPath interface { @@ -97,15 +97,30 @@ func (saup SynapseAdminURLPath) FullPath() []any { // BuildURLWithQuery builds a URL with query parameters in addition to the Client's homeserver // and appservice user ID set already. func (cli *Client) BuildURLWithQuery(urlPath PrefixableURLPath, urlQuery map[string]string) string { + return cli.BuildURLWithFullQuery(urlPath, func(q url.Values) { + for k, v := range urlQuery { + q.Set(k, v) + } + }) +} + +// BuildURLWithQuery builds a URL with query parameters in addition to the Client's homeserver +// and appservice user ID set already. +func (cli *Client) BuildURLWithFullQuery(urlPath PrefixableURLPath, fn func(q url.Values)) string { + if cli == nil { + return "client is nil" + } hsURL := *BuildURL(cli.HomeserverURL, urlPath.FullPath()...) query := hsURL.Query() if cli.SetAppServiceUserID { query.Set("user_id", string(cli.UserID)) } - if urlQuery != nil { - for k, v := range urlQuery { - query.Set(k, v) - } + if cli.SetAppServiceDeviceID && cli.DeviceID != "" { + query.Set("device_id", string(cli.DeviceID)) + query.Set("org.matrix.msc3202.device_id", string(cli.DeviceID)) + } + if fn != nil { + fn(query) } hsURL.RawQuery = query.Encode() return hsURL.String() diff --git a/version.go b/version.go index 82817bca..f00bbf39 100644 --- a/version.go +++ b/version.go @@ -4,10 +4,11 @@ import ( "fmt" "regexp" "runtime" + "runtime/debug" "strings" ) -const Version = "v0.18.0" +const Version = "v0.26.3" var GoModVersion = "" var Commit = "" @@ -15,11 +16,20 @@ var VersionWithCommit = Version var DefaultUserAgent = "mautrix-go/" + Version + " go/" + strings.TrimPrefix(runtime.Version(), "go") -var goModVersionRegex = regexp.MustCompile(`v.+\d{14}-([0-9a-f]{12})`) - func init() { + if GoModVersion == "" { + info, _ := debug.ReadBuildInfo() + if info != nil { + for _, mod := range info.Deps { + if mod.Path == "maunium.net/go/mautrix" { + GoModVersion = mod.Version + break + } + } + } + } if GoModVersion != "" { - match := goModVersionRegex.FindStringSubmatch(GoModVersion) + match := regexp.MustCompile(`v.+\d{14}-([0-9a-f]{12})`).FindStringSubmatch(GoModVersion) if match != nil { Commit = match[1] } diff --git a/versions.go b/versions.go index d3dd3c67..61b2e4ea 100644 --- a/versions.go +++ b/versions.go @@ -19,6 +19,9 @@ type RespVersions struct { } func (versions *RespVersions) ContainsFunc(match func(found SpecVersion) bool) bool { + if versions == nil { + return false + } for _, found := range versions.Versions { if match(found) { return true @@ -40,6 +43,9 @@ func (versions *RespVersions) ContainsGreaterOrEqual(version SpecVersion) bool { } func (versions *RespVersions) GetLatest() (latest SpecVersion) { + if versions == nil { + return + } for _, ver := range versions.Versions { if ver.GreaterThan(latest) { latest = ver @@ -54,16 +60,34 @@ type UnstableFeature struct { } var ( - FeatureAppservicePing = UnstableFeature{UnstableFlag: "fi.mau.msc2659.stable", SpecVersion: SpecV17} + FeatureAsyncUploads = UnstableFeature{UnstableFlag: "fi.mau.msc2246.stable", SpecVersion: SpecV17} + FeatureAppservicePing = UnstableFeature{UnstableFlag: "fi.mau.msc2659.stable", SpecVersion: SpecV17} + FeatureAuthenticatedMedia = UnstableFeature{UnstableFlag: "org.matrix.msc3916.stable", SpecVersion: SpecV111} + FeatureUnstableMutualRooms = UnstableFeature{UnstableFlag: "uk.half-shot.msc2666.query_mutual_rooms"} + FeatureStableMutualRooms = UnstableFeature{UnstableFlag: "uk.half-shot.msc2666.query_mutual_rooms.stable" /*, SpecVersion: SpecV118*/} + FeatureUserRedaction = UnstableFeature{UnstableFlag: "org.matrix.msc4194"} + FeatureViewRedactedContent = UnstableFeature{UnstableFlag: "fi.mau.msc2815"} + FeatureUnstableAccountModeration = UnstableFeature{UnstableFlag: "uk.timedout.msc4323"} + FeatureStableAccountModeration = UnstableFeature{UnstableFlag: "uk.timedout.msc4323.stable" /*, SpecVersion: SpecV118*/} + FeatureUnstableProfileFields = UnstableFeature{UnstableFlag: "uk.tcpip.msc4133"} + FeatureArbitraryProfileFields = UnstableFeature{UnstableFlag: "uk.tcpip.msc4133.stable", SpecVersion: SpecV116} + FeatureRedactSendAsEvent = UnstableFeature{UnstableFlag: "com.beeper.msc4169"} - BeeperFeatureHungry = UnstableFeature{UnstableFlag: "com.beeper.hungry"} - BeeperFeatureBatchSending = UnstableFeature{UnstableFlag: "com.beeper.batch_sending"} - BeeperFeatureRoomYeeting = UnstableFeature{UnstableFlag: "com.beeper.room_yeeting"} - BeeperFeatureAutojoinInvites = UnstableFeature{UnstableFlag: "com.beeper.room_create_autojoin_invites"} - BeeperFeatureArbitraryProfileMeta = UnstableFeature{UnstableFlag: "com.beeper.arbitrary_profile_meta"} + BeeperFeatureHungry = UnstableFeature{UnstableFlag: "com.beeper.hungry"} + BeeperFeatureBatchSending = UnstableFeature{UnstableFlag: "com.beeper.batch_sending"} + BeeperFeatureRoomYeeting = UnstableFeature{UnstableFlag: "com.beeper.room_yeeting"} + BeeperFeatureAutojoinInvites = UnstableFeature{UnstableFlag: "com.beeper.room_create_autojoin_invites"} + BeeperFeatureArbitraryProfileMeta = UnstableFeature{UnstableFlag: "com.beeper.arbitrary_profile_meta"} + BeeperFeatureAccountDataMute = UnstableFeature{UnstableFlag: "com.beeper.account_data_mute"} + BeeperFeatureInboxState = UnstableFeature{UnstableFlag: "com.beeper.inbox_state"} + BeeperFeatureArbitraryMemberChange = UnstableFeature{UnstableFlag: "com.beeper.arbitrary_member_change"} + BeeperFeatureEphemeralEvents = UnstableFeature{UnstableFlag: "com.beeper.ephemeral"} ) func (versions *RespVersions) Supports(feature UnstableFeature) bool { + if versions == nil { + return false + } return versions.UnstableFeatures[feature.UnstableFlag] || (!feature.SpecVersion.IsEmpty() && versions.ContainsGreaterOrEqual(feature.SpecVersion)) } @@ -95,6 +119,14 @@ var ( SpecV17 = MustParseSpecVersion("v1.7") SpecV18 = MustParseSpecVersion("v1.8") SpecV19 = MustParseSpecVersion("v1.9") + SpecV110 = MustParseSpecVersion("v1.10") + SpecV111 = MustParseSpecVersion("v1.11") + SpecV112 = MustParseSpecVersion("v1.12") + SpecV113 = MustParseSpecVersion("v1.13") + SpecV114 = MustParseSpecVersion("v1.14") + SpecV115 = MustParseSpecVersion("v1.15") + SpecV116 = MustParseSpecVersion("v1.16") + SpecV117 = MustParseSpecVersion("v1.17") ) func (svf SpecVersionFormat) String() string {